Saturday 1 July 2017

Spark union & column order issue

Edit: the demonstration code is also on Github.

I quite like Spark, though it has some peculiar gotchas, perhaps more than most other big data tools I've used. For example a long time ago I came across some code which had a list of ordinary-looking transformations on datasets, but each of them ended with a .map(identity). What on earth was the point of that?

Well, it turns out that the union() method of Spark Datasets is based on the ordering, not the names, of the columns. This is because Datasets are based on DataFrames, which of course do not contain case classes, but rather columns in a specific order. When you examine a Dataset, Spark will automatically turn each Row into the appropriate case class using column names, regardless of the column order in the underlying DataFrame. However union() is based on the column ordering, not the names.

An example to illustrate. Say we have a case class with some counter value:

case class Thing(id: String, count: Long, name: String)
val things1: Dataset[Thing] = sc.parallelize(Seq(
  Thing("thing1", 123, "some_thing"),
  Thing("thing2", 101, "another_thing"),
  Thing("thing2", 100, "another_thing")
)).toDS
val things2: Dataset[Thing] = sc.parallelize(Seq(
  Thing("foo", 5, "different_thing"),
  Thing("foo", 15, "different_thing"),
  Thing("bar", 6, "whatever_thing")
)).toDS

things1.union(things2).show // works as expected

So far, so good. But say we want to add up the counter values. Depending on how we do the aggregation, the columns might end up in a different order - even though the Dataset has the same type:

val agg1: Dataset[Thing] = things1.groupBy($"id", $"name").agg(sum("count").as("count")).as[Thing]

scala> agg1.show
+------+-------------+-----+
|    id|         name|count|
+------+-------------+-----+
|thing2|another_thing|  201|
|thing1|   some_thing|  123|
+------+-------------+-----+

Now trying to union the aggregated things with the original things will fail, even though both are of type Dataset[Thing]. The reason is the different column order in the DataFrames (the error message is not the clearest):

scala> agg1.union(things2).show
org.apache.spark.sql.AnalysisException: Cannot up cast `count` from string to bigint as it may truncate
The type path of the target object is:
- field (class: "scala.Long", name: "count")
- root class: "Thing"

The easiest workaround is to add a .map(identity) to the end of each such aggregation. After this everything works as expected:

scala> agg1.map(identity).union(things2).show
+------+-----+---------------+
|    id|count|           name|
+------+-----+---------------+
|thing2|  201|  another_thing|
|thing1|  123|     some_thing|
|   foo|    5|different_thing|
|   foo|   15|different_thing|
|   bar|    6| whatever_thing|
+------+-----+---------------+

Note that this is a known issue, SPARK-21109. A method to do a union by name will be added in the future as detailed in SPARK-21043.