Friday 29 December 2017

Various gradient descent methods

Edit: fixed a bug in the code (was using total gradient instead of average). As a result, the first two graphs are slightly different, and the scale of the gradient comparison graphs is different.

After the previous post I wanted to solve the same problem with gradient descent. Turns out this particular problem illuminates rather well the somewhat finicky nature of gradient descent, while also being simple enough to understand.

The three points $(10,37)$, $(243,328)$ and $(341,451)$ are of course collinear, or very nearly so, and thus fitting a straight line through them is not the hardest problem I've ever solved. The correct answer, as determined in the previous post, is $y = 24.416 + 1.250 x$. But finding this out with ordinary gradient descent is harder than I thought.

By "ordinary gradient descent", I mean one where the weights are initialised randomly to "small" values and then the gradient of the loss function is followed, with a constant learning rate for each weight component, until convergence occurs. The basic problem with this is that the optimal values 24.416 and 1.250 are far enough apart in terms of magnitude that, since we're using the same learning rate for both weight components, either the gradient for the latter ($w_1$) will oscillate wildly and possibly diverge (learning rate too big), or the gradient for the former ($w_0$) will converge to the correct value very slowly (learning rate too small). To illustrate, here's a picture of completely standard gradient descent with a small learning rate:
Recall that the optimal value for $w_0$ is 24.416; yet after ten thousand iterations, standard gradient descent has only managed to increase it to about 3. Increasing the learning rate from $10^{-5}$ to, say, $10^{-4}$ doesn't help, since that makes the algorithm diverge - the other weight is so small that it cannot abide a larger learning rate.

There are various ways to try to remedy this. One thing I'm not sure would help much is better initialisation (I went for (unscaled) standard normal for what it's worth - being careful to keep the same seed across all my experiments). The optimal values for the weights are only about an order of magnitude apart, and it's not clear whether there's any general method that could guess more suitable initial values "in one go": if you get it wrong by about an order of magnitude, that doesn't help at all, and if you could somehow do better than that, then gradient descent wouldn't even be needed: just use your magical guessing method instead.

A better idea is to normalise the data, so that the gradients are of similar magnitude in each direction. I'll come back to this point later.

One thing that is often done is adaptive learning rates of various kinds. These seem to be the go-to method in practice, based both on Andrew Ng's Coursera courses and what's commonly available in APIs such as Tensorflow; of course these more advanced methods become more valuable the more difficult the problem is. So I implemented (from scratch) some of them.

Momentum


The momentum method is a straightforward optimisation that maintains a moving average of previous gradients, and uses this in place of the current gradient at each step. The idea is that averaging the gradients will cancel out any oscillations and thus result in a smaller effective learning rate for any gradient that would otherwise oscillate, while non-oscillating gradients are not affected by the averaging and will remain as large as they were. This enables the use of a larger learning rate, since the likelihood of divergence will be lower. (In practice the averaging is done with a simple-to-implement exponentially decaying average; it would be possible to do an "exact" average as well that remembers the previous $k$ gradients, but in practice exponential decay works just fine.) And indeed, with momentum (all other things being equal), a higher learning rate is achieved and the learning goes much faster, though still not amazingly fast:


RMSProp


RMSProp is another way of averaging gradients; it's based only on the magnitudes, and not the signs. This can work better in some cases I guess, but in my case it behaved rather strangely. It allowed me to use a larger learning rate than momentum, and converged more quickly, but after that it started oscillating a bit. This could be due to how RMSProp uses (square roots of) squared gradients, which cannot really prevent back-and-forth oscillations, since the information about the sign of the gradient is lost.


Adam


Finally, Adam is a combination of momentum and RMSProp. Basically, both of these are done at the same time, and the combined method should therefore prevent oscillations and also converge fairly quickly, even with a fairly high base learning rate. And this indeed seems to be the case:


Gradient comparisons


As a further comparison of the various methods, I also had a look at the magnitudes of the gradients. Now as you recall, we want each component of the gradient to reach zero (and stay there), at which point the algorithm has converged. Plotting the gradient components directly basically corroborates the story above: plain gradient descent is stable but extremely slow; momentum is better; RMSProp is faster but less stable than momentum; and Adam is the best overall, achieving both speed and stability.


Code used


Here's the quick and dirty code I used for Adam; It's written in Scala using the ND4J library for the matrix computations. (The code for the other methods is similar but less interesting.)
case object SquaredError extends Error {

  override def averageError(z: INDArray, y: INDArray): INDArray = {
    val diff = z.sub(y)
    diff.mmul(diff.transpose()).div(2.0*diff.shape()(1))
  }

  override def errorGradient(z: INDArray, y: INDArray, X: INDArray): INDArray = {
    val diff = z.sub(y)
    diff.mmul(X.transpose()).mul(1.0/diff.shape()(1))
  }
}

case class Adam(
                 betaMomentum: Double = 0.9,
                 betaSquares: Double = 0.999,
                 epsilon: Double = 1E-8
               ) extends Optimiser {
  override def runForNRounds(n: Int, wInitial: INDArray, data: Data, learningRate: Double, error: Error,
                             printTooMuchCrap: Boolean): (INDArray, Seq[Double]) = {
    var w = wInitial
    val errors = ArrayBuffer[Double]()

    var momentum = Nd4j.zeros(w.shape()(1))
    var squares = Nd4j.zeros(w.shape()(1))

    for (i <- 0 until n) {
      val z = w.mmul(data.X)
      val avgErr = error.averageError(z, data.y)
      val dw = error.errorGradient(z, data.y, data.X)

      momentum = momentum.mul(betaMomentum).add(dw.mul(1.0 - betaMomentum))
      val dw2 = Transforms.pow(dw, 2.0)
      squares = squares.mul(betaSquares).add(dw2.mul(1.0 - betaSquares))

      w = w.sub(momentum.div(Transforms.sqrt(squares).add(epsilon)).mul(learningRate))

      val avgErrDbl = avgErr.getDouble(0)
      GradientDescent.maybePrintTooMuchCrap(printTooMuchCrap, w, dw, avgErrDbl)
      errors.append(avgErrDbl)
    }
    (w, errors)
  }
}


Normalising the data


Finally, a few words about normalising the data. I'm not quite sure what to think of this. Normalisation seems to be both quite elegant and straightforward, solving the uneven gradient problem effectively, and also somewhat unwieldy to apply to any problem where you might not have all of the data to hand to begin with, such as any kind of online learning or reinforcement learning, or a real-world project where after six months you might get a batch of brand new data. In actual applications it seems one can usually get away with only centering the data (i.e. subtracting the mean) or even forgo any normalisation at all. I suspect normalising the data is one of those techniques where I will eventually come across a problem that genuinely requires it, but this simplest of examples here doesn't.

In any case, here's what happens to any collinear or nearly collinear data set when you normalise each input and output dimension by subtracting the mean and dividing by the standard deviation:
Pictured: the three points, normalised, and the line $y = x$. Needless to say, this line is not difficult to find even with the most primitive gradient descent method!

No comments:

Post a Comment