Saturday 13 August 2016

Normalised Discounted Cumulative Gain (NDCG) for Spark DataFrames, using a UserDefinedAggregateFunction

(Edit: The code is available on GitHub.)

Recently I was messing around with a small free-time project to improve search results with machine learning. As part of this I needed a way of evaluating the quality of a given set of search results. An important goal for any search engine is to display the most relevant search results first; and as always, one needs a metric to measure how well one is doing in attaining the goal.

Suitable evaluation metrics for search results would be, for example, Normalised Discounted Cumulative Gain and Mean Average Precision. For this project I decided to go with the former (NDCG), since it intuitively felt more suitable due to the sparsity of accurate relevance scores in the data set I was working with. I used Spark with DataFrames. Now of course Spark already has an NDCG implementation built in, but it doesn't work directly with DataFrames, and I felt like educating myself in DataFrame usage and extensibility.

What is NDCG then? Well, suppose you're given a bunch of search results data, something like this:

val schema = new StructType(Array(
  StructField("searchId", LongType),
  StructField("timestamp", LongType),
  StructField("resultUrl", StringType),
  StructField("position", IntegerType),
  StructField("clicked", IntegerType),
  StructField("converted", IntegerType),
  StructField("relevanceScore", DoubleType)))
val data = sc.parallelize(Seq(
  Row(123L, 1471097840569L, "https://some.site/",        1, 1, 0, 1.28),
  Row(123L, 1471097840569L, "https://another.site/",     2, 0, 0, 2.3001),
  Row(123L, 1471097840569L, "https://yet.another.site/", 3, 0, 0, 0.792),
  Row(123L, 1471097840569L, "https://a.relevant.site/",  4, 1, 1, 1.51),
  Row(456L, 1471102902205L, "https://another.search/",   1, 0, 0, 0.07),
  Row(456L, 1471102902205L, "https://another.result/",   2, 0, 0, 0.04),
  Row(456L, 1471102902205L, "https://another.site/",     3, 1, 0, 0.02)
))
val df = sqlContext.createDataFrame(data, schema)

Now the non-normalised Discounted Cumulative Gain is easy to calculate directly:

df.groupBy($"searchId").agg(sum($"relevanceScore"/log(2.0, $"position"+1)).as("DCG")).show

// +--------+------------------+
// |searchId|               DCG|
// +--------+------------------+
// |     456|0.1052371901428583|
// |     123|3.7775231288805324|
// +--------+------------------+

But the problem of course is because the DCG is not normalised, it's difficult to use it as a comparison between search results. To solve this, we can normalise the DCG score by calculating the ideal (i.e. best possible) score for each set of search results, then dividing DCG by that. This gives the NDCG (normalised DCG).

Unlike DCG, it is difficult to calculate the NDCG directly with SQL or the SQL-like language supported by Spark DataFrames. Luckily defining your own aggregate functions for DataFrames is easy:

object NDCG extends UserDefinedAggregateFunction {
  def inputSchema = new StructType()
    .add("position", DoubleType)
    .add("relevance", DoubleType)
  def bufferSchema = new StructType()
    .add("positions", ArrayType(DoubleType, false))
    .add("relevances", ArrayType(DoubleType, false))
  def dataType = DoubleType
  def deterministic = true
  def initialize(buffer: MutableAggregationBuffer) = {
    buffer(0) = IndexedSeq[Double]()
    buffer(1) = IndexedSeq[Double]()
  }
  def update(buffer: MutableAggregationBuffer, input: Row) = {
    if(!input.isNullAt(0) && !input.isNullAt(1)) {
      val (position, relevance) = (input.getDouble(0), input.getDouble(1))
      buffer(0) = buffer.getAs[IndexedSeq[Double]](0) :+ position
      buffer(1) = buffer.getAs[IndexedSeq[Double]](1) :+ relevance
    }
  }
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    if(!buffer2.isNullAt(0) && !buffer2.isNullAt(1)) {
      buffer1(0) = buffer1.getAs[IndexedSeq[Double]](0) ++
                   buffer2.getAs[IndexedSeq[Double]](0)
      buffer1(1) = buffer1.getAs[IndexedSeq[Double]](1) ++ 
                   buffer2.getAs[IndexedSeq[Double]](1)
    }
  }
  private def totalGain(scores: Seq[(Double, Double)]): Double = {
    val (_, gain) = scores.foldLeft((1, 0.0))(
      (fa, tuple) => tuple match { case (_, score) =>
        if(score <= 0.0) (fa._1+1, fa._2)
        else if(fa._1 == 1) (fa._1+1, fa._2+score)
        else (fa._1+1, fa._2+score/(Math.log(fa._1+1)/Math.log(2.0)))
      })
    gain
  }
  def evaluate(buffer: Row) = {
    val (positions, relevances) = (buffer.getAs[IndexedSeq[Double]](0), buffer.getAs[IndexedSeq[Double]](1))
    val scores = (positions, relevances).zipped.toList.sorted
    val ideal = scores.map(_._2).filter(_>0).sortWith(_>_).zipWithIndex.map { case (s,i0) => (i0+1.0,s) }
    val (thisScore, idealScore) = (totalGain(scores), totalGain(ideal))
//    println(s"scores $scores -> $thisScore\nideal $ideal -> $idealScore")
    if(idealScore == 0.0) 0.0 else thisScore / idealScore
  }
}

And using it is easy:

df.groupBy($"searchId").agg(NDCG($"position", $"relevanceScore").as("NDCG")).show

// +--------+------------------+
// |searchId|              NDCG|
// +--------+------------------+
// |     456|               1.0|
// |     123|0.8922089188046599|
// +--------+------------------+

The code may not be production-quality, but it works as expected. The idea here is simple: given a set of rows, each containing at least a "position" and a "relevance", the custom aggregate function simply saves these in arrays, and then after the last row is read, calculates both the ideal score and the actual score from the arrays and returns their quotient. A typical search engine will return tens or maybe hundreds of search results for each query, so the temporary arrays do not grow to an unmanageable size and performance is good.

And of course it is easy to use any other arbitrary formula for the relevance scores as well, provided you have the data:

df.groupBy($"searchId").agg(NDCG($"position", $"clicked"+$"converted".cast(DoubleType)*3.0).as("NDCG")).show

The code is also available on GitHub.

Wednesday 16 March 2016

Spark & streaming: first impressions

I recently participated in a 24-hour company hackathon with two colleagues. We used Spark Streaming to do near real-time processing of production data, with plain old MySQL as a real-time "session" data store. We even managed to bolt on a machine learning algorithm using Spark MLlib, courtesy of yours truly.

Spark turned out to be amazingly easy to use and it performed really well for our use case. We did ten second microbatches, which gave us near real-time data processing, and we could also get near real-time metrics and statistics as well with statsd, graphite and d3.js.

For the processing, our division of labour was fairly standard, with three separate Spark workflows:

1) The main processor continuously reads production data from a suitable data source, in real time, and processes it in ten-second chunks. Data "sessions" are created and updated, using a single MySQL table for storage. As soon as we have a minimum amount of data available, we also calculate a prediction for the future "outcome" of each data bunch, based on a previously learned model (see below). Real-time stats and metrics are sent to statsd as data comes in.

2) The history processor runs periodically every X seconds. It connects to MySQL, selects all recently completed data bunches from the main MySQL table and processes them. A completed bunch is simply one whose data has not been updated for some specific amount of time, say 15 minutes.

For each completed bunch the final stats (sums, counts etc) are calculated and sent on to statsd and the bunch is then moved into the history table. This way we can keep the live table at a manageable size. In addition, for each finished bunch we also check how well our previous prediction held up, i.e. whether our prediction was the same as the actual final outcome.

3) The model builder runs periodically every Z minutes. It connects to MySQL, selects a random sample of recent historical data bunches from the history table, and trains a machine learning model on them. We used a random forest for our predictions. Basically, I wrote a bunch of code to turn our historical samples into vectors of doubles (taking categorical variables into account as well), configured the MLlib random forest learner with proper parameters, and that was that. After the model is trained it is simply saved to disk and then used at step 1) by the main processor. With very little work I was able to get a prediction accuracy that was many times better than random guessing and clearly worthwhile.

All in all, initial impressions of Spark were very positive. It made building a relatively non-trivial pipeline like the above super easy. For our processing we don't need to do any joins or other complicated things, which does make things a bit easier; nevertheless with Spark you get scaling, redundancy and failover out of the box, which will help a lot with future-proofing. The ease of MLlib and overall the amount of attention being paid to proper working and scaling of all the tools and libraries is just really nice. Spark Streaming works very nicely as well, and according to several smart people seems to be a very good solution for streaming in general.

MySQL in contrast works really well on just one beefy machine (if you have enough RAM and a proper SSD), in our case handling up to 1000 read-write requests per second with a live table of around 300-600k items at any given time. But high-availability and failover is a bit harder to do, as is scaling. Since this was just a random hackathon we "solved" those problems by simply ignoring them. The value of MySQL was that we could have multiple indices on a table and thus do lookups based both on the key and the recentness of the latest data, which is more annoying to do with something like Cassandra or indeed with Spark itself. So MySQL worked well for us in this limited use case, but future-proofing it would be noticeably harder than with Spark.

Need to do more things with Spark in the future.