Saturday 13 August 2016

Normalised Discounted Cumulative Gain (NDCG) for Spark DataFrames, using a UserDefinedAggregateFunction

(Edit: The code is available on GitHub.)

Recently I was messing around with a small free-time project to improve search results with machine learning. As part of this I needed a way of evaluating the quality of a given set of search results. An important goal for any search engine is to display the most relevant search results first; and as always, one needs a metric to measure how well one is doing in attaining the goal.

Suitable evaluation metrics for search results would be, for example, Normalised Discounted Cumulative Gain and Mean Average Precision. For this project I decided to go with the former (NDCG), since it intuitively felt more suitable due to the sparsity of accurate relevance scores in the data set I was working with. I used Spark with DataFrames. Now of course Spark already has an NDCG implementation built in, but it doesn't work directly with DataFrames, and I felt like educating myself in DataFrame usage and extensibility.

What is NDCG then? Well, suppose you're given a bunch of search results data, something like this:

val schema = new StructType(Array(
  StructField("searchId", LongType),
  StructField("timestamp", LongType),
  StructField("resultUrl", StringType),
  StructField("position", IntegerType),
  StructField("clicked", IntegerType),
  StructField("converted", IntegerType),
  StructField("relevanceScore", DoubleType)))
val data = sc.parallelize(Seq(
  Row(123L, 1471097840569L, "https://some.site/",        1, 1, 0, 1.28),
  Row(123L, 1471097840569L, "https://another.site/",     2, 0, 0, 2.3001),
  Row(123L, 1471097840569L, "https://yet.another.site/", 3, 0, 0, 0.792),
  Row(123L, 1471097840569L, "https://a.relevant.site/",  4, 1, 1, 1.51),
  Row(456L, 1471102902205L, "https://another.search/",   1, 0, 0, 0.07),
  Row(456L, 1471102902205L, "https://another.result/",   2, 0, 0, 0.04),
  Row(456L, 1471102902205L, "https://another.site/",     3, 1, 0, 0.02)
))
val df = sqlContext.createDataFrame(data, schema)

Now the non-normalised Discounted Cumulative Gain is easy to calculate directly:

df.groupBy($"searchId").agg(sum($"relevanceScore"/log(2.0, $"position"+1)).as("DCG")).show

// +--------+------------------+
// |searchId|               DCG|
// +--------+------------------+
// |     456|0.1052371901428583|
// |     123|3.7775231288805324|
// +--------+------------------+

But the problem of course is because the DCG is not normalised, it's difficult to use it as a comparison between search results. To solve this, we can normalise the DCG score by calculating the ideal (i.e. best possible) score for each set of search results, then dividing DCG by that. This gives the NDCG (normalised DCG).

Unlike DCG, it is difficult to calculate the NDCG directly with SQL or the SQL-like language supported by Spark DataFrames. Luckily defining your own aggregate functions for DataFrames is easy:

object NDCG extends UserDefinedAggregateFunction {
  def inputSchema = new StructType()
    .add("position", DoubleType)
    .add("relevance", DoubleType)
  def bufferSchema = new StructType()
    .add("positions", ArrayType(DoubleType, false))
    .add("relevances", ArrayType(DoubleType, false))
  def dataType = DoubleType
  def deterministic = true
  def initialize(buffer: MutableAggregationBuffer) = {
    buffer(0) = IndexedSeq[Double]()
    buffer(1) = IndexedSeq[Double]()
  }
  def update(buffer: MutableAggregationBuffer, input: Row) = {
    if(!input.isNullAt(0) && !input.isNullAt(1)) {
      val (position, relevance) = (input.getDouble(0), input.getDouble(1))
      buffer(0) = buffer.getAs[IndexedSeq[Double]](0) :+ position
      buffer(1) = buffer.getAs[IndexedSeq[Double]](1) :+ relevance
    }
  }
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    if(!buffer2.isNullAt(0) && !buffer2.isNullAt(1)) {
      buffer1(0) = buffer1.getAs[IndexedSeq[Double]](0) ++
                   buffer2.getAs[IndexedSeq[Double]](0)
      buffer1(1) = buffer1.getAs[IndexedSeq[Double]](1) ++ 
                   buffer2.getAs[IndexedSeq[Double]](1)
    }
  }
  private def totalGain(scores: Seq[(Double, Double)]): Double = {
    val (_, gain) = scores.foldLeft((1, 0.0))(
      (fa, tuple) => tuple match { case (_, score) =>
        if(score <= 0.0) (fa._1+1, fa._2)
        else if(fa._1 == 1) (fa._1+1, fa._2+score)
        else (fa._1+1, fa._2+score/(Math.log(fa._1+1)/Math.log(2.0)))
      })
    gain
  }
  def evaluate(buffer: Row) = {
    val (positions, relevances) = (buffer.getAs[IndexedSeq[Double]](0), buffer.getAs[IndexedSeq[Double]](1))
    val scores = (positions, relevances).zipped.toList.sorted
    val ideal = scores.map(_._2).filter(_>0).sortWith(_>_).zipWithIndex.map { case (s,i0) => (i0+1.0,s) }
    val (thisScore, idealScore) = (totalGain(scores), totalGain(ideal))
//    println(s"scores $scores -> $thisScore\nideal $ideal -> $idealScore")
    if(idealScore == 0.0) 0.0 else thisScore / idealScore
  }
}

And using it is easy:

df.groupBy($"searchId").agg(NDCG($"position", $"relevanceScore").as("NDCG")).show

// +--------+------------------+
// |searchId|              NDCG|
// +--------+------------------+
// |     456|               1.0|
// |     123|0.8922089188046599|
// +--------+------------------+

The code may not be production-quality, but it works as expected. The idea here is simple: given a set of rows, each containing at least a "position" and a "relevance", the custom aggregate function simply saves these in arrays, and then after the last row is read, calculates both the ideal score and the actual score from the arrays and returns their quotient. A typical search engine will return tens or maybe hundreds of search results for each query, so the temporary arrays do not grow to an unmanageable size and performance is good.

And of course it is easy to use any other arbitrary formula for the relevance scores as well, provided you have the data:

df.groupBy($"searchId").agg(NDCG($"position", $"clicked"+$"converted".cast(DoubleType)*3.0).as("NDCG")).show

The code is also available on GitHub.