Functional programming and Spark: do they mix?

December 22, 2017

So, Spark. A very popular option for large-scale distributed data processing on the JVM, doubly so when working with Scala. Spark has many advantages (rich ML library, great infrastructure for SQL-like data transformations), but also many problems: its entire API is side effecting; throwing exceptions is liberally used as an error reporting mechanism and type safety is practically non-existent in the DataFrame API.

Can we do anything to alleviate those pains? It’s certainly worth trying as there are currently no obviously better alternatives for doing in-memory data processing or machine learning on the JVM. We should discuss what we’re aiming for, though, as there are several “levels” of safety we can gain:

  1. We can make do without any sort of safety entirely. Use Spark imperatively as the documentation suggests, throw and catch exceptions and live with it. I’ll show one concrete issue with that in a moment.
  2. We can use a functional programming library, such as cats or scalaz, and an IO monad, such as Monix Task or the upcoming IO in scalaz 8, to adapt the Spark API as needed.
  3. We can use frameless, a library that wraps the Spark API wholesale and provides an entirely type-safe and functional API.

Option 3 is of course the most desirable choice, but frameless unfortunately does not cover all of the Spark API yet. I encourage you to contribute, though; the project has an ambitious but entirely achievable goal, and if you use Spark, you can easily dogfood with it and gain the benefits of typesafety.

We’ll discuss option 2 in this post. I think there is much benefit to be had by adapting the Spark API with an IO monad, as that allows you to reason about your job and its behavior and use all of the wonderful abstractions that functional programming brings.

This post is typechecked with tut, so let’s get some of the required imports and infrastructure out of the way:

import cats._, cats.data._, cats.implicits._
import monix.eval.{ Coeval, Task }
import monix.execution.Scheduler
import monix.execution.Scheduler.Implicits.global
import org.apache.spark.sql.{ DataFrame, Dataset, SparkSession }
import org.apache.spark.sql.{ functions => f }
import scala.concurrent.{ Await, Future }
import scala.concurrent.duration._

implicit val session = (SparkSession
  .builder()
  .appName("Retaining Sanity with Spark")
  .master("local")
  .getOrCreate())

import session.implicits._

A motivating example

Spark is pretty straightforward to use, if you just want to churn out a job that runs a couple of data transformations. Here’s a sample that computes the average of a DataFrame of numbers:

import scala.util.Random
// import scala.util.Random

val df = session.sparkContext.parallelize(List.fill(100)(Random.nextLong)).toDF
// df: org.apache.spark.sql.DataFrame = [value: bigint]

df.agg(f.avg("value")).head()
// res2: org.apache.spark.sql.Row = [-2.49656330274346528E17]

What happens when we’d like to run two of these operations in parallel? Spark will block when we call head(), so we’ll need to run this in Future:

def computeAvg(df: DataFrame) = Future(df.agg(f.avg("value")).head())
// computeAvg: (df: org.apache.spark.sql.DataFrame)scala.concurrent.Future[org.apache.spark.sql.Row]

val df2 = session.sparkContext.parallelize(List.fill(100)(Random.nextLong)).toDF
// df2: org.apache.spark.sql.DataFrame = [value: bigint]

val result = Await.result(computeAvg(df) zip computeAvg(df2), 30.seconds)
// result: (org.apache.spark.sql.Row, org.apache.spark.sql.Row) = ([-2.49656330274346528E17],[-9.6780802359920627E17])

The problem with this approach, while not apparent in the example, is that Spark will actually run the two actions sequentially, as by default jobs will occupy all the cores available (assuming there are enough partitions in the underlying RDD). If we set spark.scheduler.mode to FAIR, we can use thread-locals to run the actions on different scheduler pools:

session.stop()

implicit val session = (SparkSession
  .builder()
  .appName("Retaining Sanity with Spark")
  .master("local")
  .config("spark.scheduler.mode", "FAIR")
  .getOrCreate())

def computeAvg(df: DataFrame, pool: String)(implicit session: SparkSession) = 
  Future {
    session.sparkContext.setLocalProperty("spark.scheduler.pool", pool)
    df.agg(f.avg("value")).head()
    session.sparkContext.setLocalProperty("spark.scheduler.pool", null)
  }

There’s more info on this here. It’s interesting to note that there are now two levels of execution being controlled here: the ExecutionContext on which the Future is being executed and the scheduler pool on which the job is being executed.

In any case, we can’t use this approach to sensibly build combinators for composing larger programs; writing an onPool(f: Future[A], p: String): Future[A] combinator that runs an existing thunk on another scheduler pool is impossible because this is done using thread-locals and the f thunk might be executed on an entirely different thread.

Another issue is that Future is eager: we have to be very strict (pun not intended) about how we use the computeAvg method if we want to control when and where the Spark actions are executed.

And lastly, we glossed over the fact that all (yes, all) of Spark’s API calls may throw exceptions at any given point, crashing our program. It is literally the antithesis of a safe API.

Regaining our functional sanity

So, having motivated why we’re spending time on this, let’s discuss what we’d like to achieve. I’ll show in this post how we can:

Let’s get started!

Using Monix to control execution

We first need to get into the habit of working exclusively in Task. All Spark calls need to be wrapped in the Task.eval / Task.apply constructors. Our program now looks like this:

def buildSession: Task[SparkSession] = Task.eval {
  SparkSession
    .builder()
    .appName("Retaining Sanity with Spark")
    .master("local")
    .config("spark.scheduler.mode", "FAIR")
    .getOrCreate() 
}

def createDF(data: List[Int])(implicit session: SparkSession): Task[DataFrame] = Task.eval {
  import session.implicits._
  val rdd = session.sparkContext.parallelize(data)

  rdd.toDF
}

def computeAvg(df: DataFrame, pool: String)(implicit session: SparkSession): Task[Double] = 
  Task.eval {
    session.sparkContext.setLocalProperty("spark.scheduler.pool", pool)
    val result = df.agg(f.avg("value")).head().getDouble(0)
    session.sparkContext.setLocalProperty("spark.scheduler.pool", null)

    result
  }

Doesn’t look too different so far. When we compose everything together, we get back a Task representing the result of our program:

def program: Task[Double] = for {
  sparkSession <- buildSession
  result <- {
    implicit val session = sparkSession
    import scala.util.Random

    val data = List.fill(100)(Random.nextInt)

    for {
      df <- createDF(data)
      avg <- computeAvg(df, "pool")
    } yield avg
  }
} yield result

Let’s turn to creating the onPool combinator. Here’s a naive version that composes the actual task with two tasks that set/unset the thread-local:

def onPool[A](task: Task[A], pool: String)(implicit session: SparkSession): Task[A] = 
  for {
    _ <- Task.eval(session.sparkContext.setLocalProperty("spark.scheduler.pool", pool))
    result <- task
    _ <- Task.eval(session.sparkContext.setLocalProperty("spark.scheduler.pool", null))
  } yield result

The problem with this naive version is that if task has an asynchronous boundary (as the one produced by Task.apply for example), the setLocalProperty will set the thread-local on an irrelevant thread:

// Works properly for Task.eval:
val test = Task.eval(println(session.sparkContext.getLocalProperty("spark.scheduler.pool")))
// test: monix.eval.Task[Unit] = Task.Eval(<function0>)

onPool(test, "pool").runAsync
// pool
// res9: monix.execution.CancelableFuture[Unit] = monix.execution.CancelableFuture$Pure@583739e2

// ... but not for Task.fork:
val forked = Task.fork(test)
// forked: monix.eval.Task[Unit] = Task.FlatMap(Task@180218614, <function1>)

onPool(forked, "pool").runAsync
// null
// res11: monix.execution.CancelableFuture[Unit] = Async(Success(()),monix.execution.cancelables.StackedCancelable@745061b9)

A possible solution here is to use monix.eval.Coeval: a data type that represents synchronous evaluation. Coeval has a Monad instance, so we can create a Coeval that wraps our Spark API call, decorate it with setting/clearing the scheduler pool thread-local, and convert the resulting Coeval to a Task:

def onPool[A](task: Coeval[A], pool: String)(implicit session: SparkSession): Task[A] = 
  (for {
    _ <- Coeval(session.sparkContext.setLocalProperty("spark.scheduler.pool", pool))
    result <- task
    _ <- Coeval(session.sparkContext.setLocalProperty("spark.scheduler.pool", null))
  } yield result).task

And this should now work properly:

val test = Coeval(println(session.sparkContext.getLocalProperty("spark.scheduler.pool")))
// test: monix.eval.Coeval[Unit] = <function0>

val forked = Task.fork(onPool(test, "pool"))
// forked: monix.eval.Task[Unit] = Task.FlatMap(Task@180218614, <function1>)

forked.runAsync
// res12: monix.execution.CancelableFuture[Unit] = Async(List(),monix.execution.cancelables.StackedCancelable@486765f1)
pool

The last thing we need to take care of is error handling. We have to clear the scheduler pool setting if the inner task fails, or we’ll leak that setting to other tasks. This means we need to slightly modify onPool. Along the way, we’ll move it to an implicit class so we can get infix syntax:

implicit class CoevalOps[A](thunk: Coeval[A]) {
  def onPool(pool: String)(implicit session: SparkSession): Task[A] = 
    Coeval(session.sparkContext.setLocalProperty("spark.scheduler.pool", pool))
      .flatMap(_ => thunk)
      .doOnFinish(_ => Coeval(session.sparkContext.setLocalProperty("spark.scheduler.pool", null)))
      .task
}

And we can use it as such:

scala> val test = Coeval { 
     |   println(session.sparkContext.getLocalProperty("spark.scheduler.pool")) 
     | }.onPool("pool")
test: monix.eval.Task[Unit] = Task.Eval(<function0>)

scala> test.runAsync
pool
res13: monix.execution.CancelableFuture[Unit] = monix.execution.CancelableFuture$Pure@3c0e7718

Principled and composable timing

Timing individual operations in a big Spark job is important for tracing performance issues. Commonly, the instrumentation is so intrusive that it obscures the actual code. This is something we’d like to avoid.

Other than unintrusive instrumentation, we should be able to inspect the durations of individual operations and aggregate them. We can do this nicely with the WriterT monad transformer.

WriterT[F[_], L, A] is a data type isomorphic to F[(L, A)] - a log of type L and a result of type A together in an effect F[_]. In our case, L would be something describing the timings; A would be the result of an operation and F would be Task.

To describe the timings, we can use a simple map:

case class Timings(data: Map[String, FiniteDuration])

We’ll name this Writer monad TimedTask:

type TimedTask[A] = WriterT[Task, Timings, A]

When we compose the tasks together using for comprehensions, the Writer monad will concatenate the maps to keep the timings. For that to work, we need a Monoid instance for Timings:

implicit val timingsMonoid: Monoid[Timings] = new Monoid[Timings] {
  def empty: Timings = Timings(Map.empty)
  def combine(x: Timings, y: Timings): Timings = Timings(x.data ++ y.data)
}

Let’s now add a combinator to lift a Task[A] to TimedTask[A]. The combinator will measure the time before and after the task and add that entry to the Timings map. We’ll also add a combinator that marks a task as untimed:

implicit class TaskOps[A](task: Task[A]) {
  def timed(key: String): TimedTask[A] = 
    WriterT {
      for {
        startTime <- Task.eval(System.currentTimeMillis().millis)
        result <- task
        endTime <- Task.eval(System.currentTimeMillis().millis)
      } yield (Timings(Map(key -> (endTime - startTime))), result)
    }

  def untimed: TimedTask[A] = 
    WriterT(task.map((Monoid[Timings].empty, _)))
}

And let’s see how both onPool and timed can be used to combine operations together:

def createDF(data: List[Int])(implicit session: SparkSession): Task[DataFrame] = Task.eval {
  import session.implicits._
  val rdd = session.sparkContext.parallelize(data)

  rdd.toDF
}
// createDF: (data: List[Int])(implicit session: org.apache.spark.sql.SparkSession)monix.eval.Task[org.apache.spark.sql.DataFrame]

def computeAvg(df: DataFrame)(implicit session: SparkSession): Coeval[Double] = 
  Coeval(df.agg(f.avg("value")).head().getDouble(0))
// computeAvg: (df: org.apache.spark.sql.DataFrame)(implicit session: org.apache.spark.sql.SparkSession)monix.eval.Coeval[Double]

def program(data: List[Int])(implicit session: SparkSession): TimedTask[Double] = 
  for {
    df <- createDF(data).timed("DataFrame creation")
    avg <- computeAvg(df).onPool("pool").timed("Average computation")
  } yield avg
// program: (data: List[Int])(implicit session: org.apache.spark.sql.SparkSession)TimedTask[Double]

Await.result(
  program(List.fill(100)(Random.nextInt)).run.runAsync,
  30.seconds
)
// res14: (Timings, Double) = (Timings(Map(DataFrame creation -> 11 milliseconds, Average computation -> 100 milliseconds)),-1.219367541E7)

We receive the result of the program along with the accumulated timings when we run the program. However, oftentimes throughout long running jobs, we’d like to unpack the intermediate timing data and log it. We can’t do this in the middle of a for-comprehension that constructs a TimedTask, as the Writer monad cannot observe the rest of the log it is being combined with. We need to do once the timed program is constructed.

We can design this as a combinator of the form TimeTask[A] => Task[A] that decorates the resulting Task with a logging effect. For simplicity, we’ll use println, but a logger can be captured implicitly or passed to logTimings:

implicit class TimedTaskOps[A](task: TimedTask[A]) {
  def logTimings(heading: String): Task[A] = 
    for {
      resultAndLog <- task.run
      (log, result) = resultAndLog
      _ <- Task.eval {
        println {
          List(
            s"${heading}:",
            log.data
              .map {
                case (entry, duration) => 
                  s"\t${entry}: ${duration.toMillis.toString}ms"
              }
              .toList
              .mkString("\n"),
            s"\tTotal: ${log.data.values.map(_.toMillis).sum}ms" 
          ).mkString("\n")
        }
      }
    } yield result
}
// defined class TimedTaskOps

Now, we can compose two programs that log their timings when they complete:

val composed = 
  for {
    fst <- program(List.fill(100)(Random.nextInt)).logTimings("First Program")
    snd <- program(List.fill(100)(Random.nextInt)).logTimings("Second Program")
  } yield ()
// composed: monix.eval.Task[Unit] = Task.FlatMap(Task@1186157294, <function1>)

Await.result(composed.runAsync, 30.seconds)
// First Program:
// 	DataFrame creation: 7ms
// 	Average computation: 68ms
// 	Total: 75ms
// Second Program:
// 	DataFrame creation: 7ms
// 	Average computation: 56ms
// 	Total: 63ms

Passing around state ergonomically

Jobs (and most programs in general) often involve some accumulation and transformation of state. In our case, it can be the intermediate DataFrame being transformed and auxillary data and results obtained throughout the job. It can get pretty tedious to pass around the data we need:

def loadKeys: Task[List[String]] = Task(List.fill(10)(Random.nextString(5)))

def pruneSomeKeys(keys: List[String]): Task[List[String]] = Task(keys take 3)

def pruneMoreKeys(keys: List[String]): Task[List[String]] = Task(keys drop 1)

def createDF(keys: List[String])(implicit spark: SparkSession): Task[DataFrame] = 
  Task(keys.toDF)

def transformDF(df: DataFrame)(implicit spark: SparkSession): Task[DataFrame] = 
  Task(df limit 3)
val program = for {
  keys       <- loadKeys
  prunedKeys <- pruneSomeKeys(keys)
  pruneMoreKeys <- pruneMoreKeys(prunedKeys)
  df            <- createDF(prunedKeys)
  transformed   <- transformDF(df)
} yield transformed
// program: monix.eval.Task[org.apache.spark.sql.DataFrame] = Task.FlatMap(Task@709429437, <function1>)

What happens if we need to introduce an intermediate step? For example:

def pruneEvenMoreKeys(keys: List[String]): Task[List[String]] = Task(keys)
// pruneEvenMoreKeys: (keys: List[String])monix.eval.Task[List[String]]

val amendedProgram = for {
  keys          <- loadKeys
  prunedKeys    <- pruneSomeKeys(keys)
  prunedAgain   <- pruneEvenMoreKeys(prunedKeys)
  pruneMoreKeys <- pruneMoreKeys(prunedKeys)
  df            <- createDF(prunedKeys)
  transformed   <- transformDF(df)
} yield transformed
// amendedProgram: monix.eval.Task[org.apache.spark.sql.DataFrame] = Task.FlatMap(Task@915798586, <function1>)

Unfortunately, we introduced a bug, as we forgot to update the data that pruneMoreKeys and createDF are operating on. This example might seem contrived, but I assure you that this has happened to me in real, production code ;-)

One possible solution is to stop explicitly passing around the intermediate state entirely. The StateT monad transformer can help us with that. StateT[F[_], S, A] is isomorphic to S => F[(S, A)] - a function that receives an initial state and outputs a resulting state with a resulting value in an effect F.

Similarly to how we worked with the WriterT monad, we first define the state that we use in our program, and partially apply the State monad along with it:

case class JobState(keys: List[String], df: DataFrame)

type StateAction[A] = StateT[Task, JobState, A]

We now redefine the methods above in terms of this monad:

def loadKeys: StateAction[Unit] = StateT.modifyF { s => 
  Task(s.copy(keys = List.fill(10)(Random.nextString(5))))
}

def pruneSomeKeys: StateAction[Unit] = StateT.modifyF { s =>
  Task(s.copy(keys = s.keys take 3))
}

def pruneMoreKeys: StateAction[Unit] = StateT.modifyF { s =>
  Task(s.copy(keys = s.keys drop 1))
}

def createDF(implicit spark: SparkSession): StateAction[Unit] = 
  StateT.modifyF { s =>
    Task(s.copy(df = s.keys.toDF))
  }

def transformDF(implicit spark: SparkSession): StateAction[Unit] = 
  StateT.modifyF { s =>
    Task(s.copy(df = s.df limit 3))
  }

The program composition is now much cleaner:

val stateProgram: StateAction[Unit] = for {
  _ <- loadKeys
  _ <- pruneSomeKeys
  _ <- pruneMoreKeys
  _ <- createDF
  _ <- transformDF
} yield ()
// stateProgram: StateAction[Unit] = cats.data.IndexedStateT@24654e04

Await.result(
  stateProgram
    .run(JobState(Nil, session.sqlContext.emptyDataFrame))
    .runAsync,
  30.seconds
)
// res25: (JobState, Unit) = (JobState(List(둁捎ዙ쓝텊, 嗵栓蛤綆⮺),[value: string]),())

The only issue that we might take with this design is that we shoved all the data into the state, while the keys aren’t needed when running transformDF. Additionally, we had to introduce an artificial empty state; this goes against a good practice of making illegal states unrepresentable.

We can use IndexedStateT to model this more accurately; this is a data type similar to StateT that differs by having different types for input and output states. Formally, it is a function of the form SA => F[(SB, A)], where SA and SB represent the input and output states.

To use it, we’ll define separate states for our program:

case object Empty
case class ProcessingKeys(keys: List[String])
case class ProcessingDF(df: DataFrame)
case class Done(df: DataFrame)

And we will redefine our functions again to model how the state transitions:

def loadKeys: IndexedStateT[Task, Empty.type, ProcessingKeys, Unit] = 
  IndexedStateT.setF {
    Task(ProcessingKeys(List.fill(10)(Random.nextString(5))))
  }

def pruneSomeKeys: StateT[Task, ProcessingKeys, Unit] = 
  StateT.modifyF { s =>
    Task(s.copy(keys = s.keys take 3))
  }

def pruneMoreKeys: StateT[Task, ProcessingKeys, Unit] = 
  StateT.modifyF { s =>
    Task(s.copy(keys = s.keys drop 1))
  }

def createDF(implicit spark: SparkSession): IndexedStateT[Task, ProcessingKeys, ProcessingDF, Unit] = 
  IndexedStateT.modifyF { s =>
    Task(ProcessingDF(s.keys.toDF))
  }

def transformDF(implicit spark: SparkSession): IndexedStateT[Task, ProcessingDF, Done, Unit] = 
  IndexedStateT.modifyF { s =>
    Task(Done(s.df limit 3))
  }

Note how functions that stay within the same state type are still using the plain StateT. This is because StateT is actually an alias for IndexedStateT[F, S, S, A] - a state transition that does not change the state type.

We can now launch our program with an accurate empty, uninitialized state and get back the Done state:

val indexedStateProgram: IndexedStateT[Task, Empty.type, Done, Unit] = for {
  _ <- loadKeys
  _ <- pruneSomeKeys
  _ <- pruneMoreKeys
  _ <- createDF
  _ <- transformDF
} yield ()
// indexedStateProgram: cats.data.IndexedStateT[monix.eval.Task,Empty.type,Done,Unit] = cats.data.IndexedStateT@2fcf5261

Await.result(
  indexedStateProgram
    .run(Empty)
    .runAsync,
  30.seconds
)
// res30: (Done, Unit) = (Done([value: string]),())

Note that we would get a compilation failure if we mix up the order of the actions:

scala> val fail: IndexedStateT[Task, Empty.type, Done, Unit] = for {
     |   _ <- loadKeys
     |   _ <- pruneSomeKeys
     |   _ <- createDF
     |   _ <- pruneMoreKeys
     |   _ <- transformDF
     | } yield ()
<console>:64: error: type mismatch;
 found   : cats.data.IndexedStateT[monix.eval.Task,ProcessingDF,Done,Unit]
 required: cats.data.IndexedStateT[monix.eval.Task,ProcessingKeys,?,?]
         _ <- transformDF
           ^
<console>:63: error: polymorphic expression cannot be instantiated to expected type;
 found   : [B, SC]cats.data.IndexedStateT[monix.eval.Task,ProcessingKeys,SC,B]
 required: cats.data.IndexedStateT[monix.eval.Task,ProcessingDF,Done,Unit]
         _ <- pruneMoreKeys
           ^

If we want to combine the timing functionality from the previous section, that’s also entirely possible; we’d need to define a monad stack of IndexedStateT[TimedTask, SA, SB, A] and define the timed combinators for this stack. To be honest, though, working with concrete transformers in Scala is pretty boilerplate-heavy. A much more ergonomic approach is using tagless final to abstract over the transformers; that’s a subject to an entirely different post, though.

Summary

Through these few combinators we’ve defined, we saw how we can easily and lightly wrap some of Spark’s API to regain properties we like about programs written with functional programming. It is admittedly not a complete approach: typesafety is still missing from many of the DataFrame methods and nulls may creep up here and there. But it’s definitely an improvement!

I hope you found these examples useful. Feel free to hit me up on twitter if you’ve got any questions.

This post was typechecked by tut on Scala 2.11.11 with Spark 2.2.0, cats 1.0.0-RC1 and Monix 3.0.0-M2.