Tuesday, February 13, 2018

Automatic Learning Rate Scheduling That Really Works

Training deep learning models can be a pain. In particular, there is this perception that one of the reasons it's a pain is because you have to fiddle with learning rates. For example, arguably the most popular strategy for setting learning rates looks like this:
  1. Run vanilla stochastic gradient descent with momentum and a fixed learning rate
  2. Wait for the loss to stop improving
  3. Reduce the learning rate
  4. Go back to step 1 or stop if the learning rate is really small
Many papers reporting state-of-the-art results do this. There have been a lot of other methods proposed, like ADAM, but I've always found the above procedure to work best. This is a common finding. The only fiddly part of this procedure is the "wait for the loss to stop improving" step. A lot of people just eyeball a plot of the loss and manually intervene when it looks like its flattened out. Or worse, they pick a certain number of iterations ahead of time and blindly stop when that limit is reached. Both of these ways of deciding when to reduce the learning rate suck. 

Fortunately, there is a simple method from classical statistics we can use to decide if the loss is still improving, and thus, when to reduce it. With this method it's trivial to fully automate the above procedure. In fact, it's what I've used to train all the public DNN models in dlib over the last few years: e.g. face detection,  face recognition, vehicle detection, and imagenet classification. It's the default solving strategy used by dlib's DNN solver. The rest of this blog post explains how it works.

Fundamentally, what we need is a method that takes a noisy time series of $n$ loss values, $Y=\{y_0,y_1,y_2,\dots,y_{n-1}\}$, and tells us if the time series is trending down or not. To do this, we model the time series as a noisy line corrupted by Gaussian noise:
\newcommand{\N} {\mathcal{N} } y_i = m\times i + b + \epsilon
\] Here, $m$ and $b$ are the unknown true slope and intercept parameters of the line, and $\epsilon$ is a Gaussian noise term with mean 0 and variance $\sigma^2$. Let's also define the function $\text{slope}(Y)$ that takes in a time series, performs OLS, and outputs the OLS estimate of $m$, the slope of the line. You can then ask the following question: what is the probability that a time series sampled from our noisy line model will have a negative slope according to OLS? That is, what is the value of?
P(\text{slope}(Y) < 0)
\]If we could compute an estimate of $P(\text{slope}(Y)<0)$ we could use it to test if the loss is still decreasing. Fortunately, computing the above quantity turns out to be easy. In fact, $\text{slope}(Y)$ is a Gaussian random variable with this distribution:
\text{slope}(Y) \sim \N\left(m, \frac{12 \sigma^2}{n^3-n}\right)
\]We don't know the true values of $m$ and $\sigma^2$, but they are easily estimated from data. We can obviously use $\text{slope}(Y)$ to estimate $m$. As for $\sigma^2$, it's customary to estimate it like this:
\[ \sigma^2 = \frac{1}{n-2} \sum_{i=0}^{n-1} (y_i - \hat y_i)^2 \] which gives an unbiased estimate of the true $\sigma^2$. Here $y_i - \hat y_i$ is the difference between the observed time series value at time $i$ and the value predicted by the OLS fitted line at time $i$. I should point out that none of this is new stuff, in fact, these properties of OLS are discussed in detail on the Wikipedia page about OLS.

So let's recap. We need a method to decide if the loss is trending downward or not. I'm suggesting that you use $P(\text{slope}(Y) < 0)$, the probability that a line fit to your loss curve will have negative slope. Moreover, as discussed above, this probability is easy to compute since it's just a question about a simple Gaussian variable and the two parameters of the Gaussian variable are given by a straightforward application of OLS.

You should also note that the variance of $\text{slope}(Y)$ decays at the very quick rate of $O(1/n^3)$, where $n$ is the number of loss samples. So it becomes very accurate as the length of the time series grows. To illustrate just how accurate this is, let's look at some examples. The figure below shows four different time series plots, each consisting of $n=4000$ points. Each plot is a draw from our noisy line model with parameters: $b=0$, $\sigma^2=1$, and $m \in \{0.001, 0.00004, -0.00004, -0.001\}$. For each of these noisy plots I've computed $P(\text{slope}(Y) < 0)$ and included it in the title.

From looking at these plots it should be obvious that $P(\text{slope}(Y) < 0)$ is quite good at detecting the slope. In particular, I doubt you can tell the difference between the two middle plots (the ones with slopes -0.00004 and 0.00004). But as you can see, the test statistic I'm suggesting, $P(\text{slope}(Y) < 0)$, has no trouble at all correctly identifying one as sloping up and the other as sloping down.

I find that a nice way to parameterize this in actual code is to count the number of mini-batches that executed while $P(\text{slope}(Y) < 0) < 0.51$. That is, find out how many loss values you have to look at before there is evidence the loss has been decreasing. To be very clear, this bit of pseudo-code implements the idea:
    def count_steps_without_decrease(Y):
        steps_without_decrease = 0
        n = len(Y)
        for i in reversed(range(n)):
            if P(slope(Y[i:n]) < 0) < 0.51:
                steps_without_decrease = n-i
        return steps_without_decrease
You can then use a rule like: "if the steps without decrease is 1000 I will lower the learning rate by 10x". However, there is one more issue that needs to be addressed. This is the fact that loss curves sometimes have really large transient spikes, where, for one reason or another (e.g. maybe a bad mini-batch) the loss will suddenly become huge for a moment. Not all models or datasets have this problem during training, but some do. In these cases, count_steps_without_decrease() might erroneously return a very large value. You can deal with this problem by discarding the top 10% of loss values inside count_steps_without_decrease(). This makes the entire procedure robust to these noisy outliers. Note, however, that the final test you would want to use is:
count_steps_without_decrease(Y) > threshold and count_steps_without_decrease_robust(Y) > threshold
That is, perform the check with and without outlier discarding. You need both checks because the 10% largest loss values might have occurred at the very beginning of Y. For example, maybe you are waiting for 1000 (i.e. threshold=1000) mini-batches to execute without showing evidence of the loss going down. And maybe the first 100 all showed a dropping loss while the last 900 were flat. The check that discarded the top 10% would erroneously indicate that the loss was NOT dropping. So you want to perform both checks and if both agree that the loss isn't dropping then you can be confident it's not dropping.

It should be emphasized that this method isn't substantively different from what a whole lot of people already do when training deep neural networks. The only difference here is that the "look at the loss and see if it's decreasing" step is being done by a computer. The point of this blog post is to point out that this check is trivially automatable with boring old simple statistics. There is no reason to do it by hand. Let the computer do it and find something more productive to do with your time than babysitting SGD. The test is simple to implement yourself, but if you want to just call a function you can call dlib's count_steps_without_decrease() and count_steps_without_decrease_robust() routines from C++ or Python.

Finally, one more useful thing you can do is the following: you can periodically check if $P(\text{slope}(Y) > 0) \gt 0.99$, that is, check if we are really certain that the loss is going up, rather than down. This can happen and I've had training runs that were going fine and then suddenly the loss shot up and stayed high for a really long time, basically ruining the training run. This doesn't seem to be too much of an issue with simple losses like the log-loss. However, structured loss functions that perform some kind of hard negative mining inside a mini-batch will sometimes go haywire if they hit a very bad mini-batch. You can fix this problem by simply reloading from an earlier network state before the loss increased. But to do this you need a reliable way to measure "the loss is going up" and $P(\text{slope}(Y) > 0) \gt 0.99$ is excellent for this task. This idea is called backtracking and has a long history in numerical optimization. Backtracking significantly increases solver robustness in many cases and is well worth using.


  1. Why it’s better than predefined Learning rate policy, like steps or polynomial decay?

  2. Some models take days to train while others take minutes. You don't want to just always run for days, that's a waste of time. Conversely, if you don't let the solver run long enough you will underfit. How are you going to know, ahead of time, how long to run the solver? You don't. You need to measure progress while the solver is running and do something reasonable. Just running blind to the problem you are optimizing is never going to be a good idea.

  3. Pardon me if this question seems elementary, But I can't understand why we have to calculate P(Slope(Y) < 0)...

    Can't we just do OLS and get an approximate value for `m`, then whether it's negative or positive we take action ?

    What I wonder is two things:

    1. What benefit does calculating the probability have when we can have an approximate `m`?

    2. How exactly is the probability computed? I can understand the OLS but I couldn't get how you went from `Slope(Y)` to P(Slope(Y) < 0)... how did you calculate the probability?

    Thanks for this great Blog post...

  4. Yes, you could do that and it wouldn't be awful. However, there are cases where it would put you into an infinite loop. Consider the case where the slope of the loss curve is asymptotically approaching zero but never goes positive. Simply thresholding m at 0 will never terminate but the test I suggested would. Most real world problems probably don't exhibit that kind of behavior very often, but these kinds of corner cases are things you need to be concerned with if you want to make a robust numerical solver.

    Slope(Y) is a Gaussian variable. So you just need to know its mean and variance and then you call a function that computes the CDF of a Gaussian. Every platform has such functions already available. You don't need to implement it yourself.

  5. Very interesting and seamless automation indeed.
    I notice that the python module dlib does contain the count_steps_without_decrease function. However, regarding your last comment about resetting the training at an earlier stage, I don't see the P(slope>X) function being exposed.
    It is in the c++ source and I wonder if it could be made exposed to the python side.

  6. There is dlib.probability_that_sequence_is_increasing()

  7. Hi, thanks for the post. I'm in the process of implementing this. My current method is if the average loss of the last n batches, has not decreased for the last m times, half the learning rate.

    Just want to confirm one detail - it looks like you calculate the n slope values of Y using the values Y[n-i:n], e.g. there will be 2,3,4,5 -> n values used for the calculation each time. Why do it like this, instead of for example keeping a running score of P always calculated using the last n values of Y?

  8. There aren't N slope values, there is one slope value. You find it using OLS. You can do a recursive OLS if you want. It doesn't matter how you find it so long as you find it.

  9. Ok I think I got it: You calculate P(slope < 0) for the last n loss values, for n in 1:N (N being the container size) and chose the maximum value of n (n_max) where P(slope < 0) < 0.51. That is, for all n in n_max+1:N, P(slope < 0) >= 0.51. (Or should it be minimum n where P(slope < 0) >= 0.51?)

  10. By the way, I find that using an adaptive learning rate schedule with ADAM makes it work even better.

  11. Yes, you have the right idea. You run backwards over the data until P(slope<0) is >= 0.51. If you have to go back really far to see that the slope is decreasing then you have probably converged.

    Huh, well, it depends on the problem. Most of the time I find that ADAM makes things worse.

  12. On my datasets (greyscale images) I periodically train SGD to see if I can get a better result, but validation is always a little worse than ADAM with adaptive training rate. Vanilla ADAM is worse though.

    Have you looked into any ways of choosing the threshold? I wonder if it can also be automated.

    With regards to the code, count_steps_without_decrease loops until it processes the entire container, should it instead return as soon as it has found P(slope < 0) >= 0.51?

    Thanks again, I am using your method now and it works great.

  13. I don't think the particular setting of the threshold matters. That degree of freedom is taken up by the "how many steps threshold".

    The loop needs to run the entire length. You want to use the most confident estimate you will see. So if you get to the end and it's not clear that things are decreasing then you should report that. Anyway, try it and see what happens if what I'm saying isn't clear. You will see that you want to run over the whole array.

  14. Very nice work Mr. King!

    I played a bit with your learning rate scheduler. I think that setting the 'num_steps_without_decrease' threshold should be somehow bound to the dataset size and/or the mini-batch size. Do you have any recommendations?

    How many loss values do I have to save to get an accurate estimate? Is 2 * num_steps_without_decrease sufficient?

  15. I don't think it has much to do with dataset size. If it's related to anything in the dataset it's the underlying difficulty of the problem, which is more a function of the signal to noise ratio than anything related to the size. I would simply set it to a large value, the largest value that is tolerable and then not worry about it.

    You only need num_steps_without_decrease (the threshold) loss values since you don't look at any values beyond those.

  16. At present, the count_steps_without_decrease function returns the *maximum* count (n) for which their is *no* evidence of decreasing, regardless of whether or not smaller values of n had evidence. I.e. if for n = 1:100 you have evidence, but for n = 101 you do not, it would return 101. Sorry to bring this up again but im not sure if it is by design or not (or am I reading the code wrongly!), because that is slightly different to what you write in the blog post. It would be different if instead count_steps_without_decrease returns as soon as it finds n for P >= 0.51, i.e. the minimum n for which there is evidence

    Anyway it means that if you save the last N loss values and *also* set the threshold to N, there is no point calculating P for the last n < N loss values... Just calculate P once for n = N and if it less than 0.51 (no evidence) the threshold has been met, because it doesn't matter what P is for n < N. My current code keeps the last 2 * threshold loss values.

  17. P.S. if it did exit the loop as soon as P >= 0.51 this could also have problems, e.g. imagine if the last 3 loss values just happened to line up nicely and give P>=0.51...

  18. That's not what happens. It doesn't just always output N. You should run the function and see what it does. I think that will help understand it.

  19. Yes, it doesn't always output N - I've been using the function for the last week and plotting the count to help chose a good threshold. What I'm saying is that if you use N = threshold, all you need to is calculate P one time for n=N, and if P < 0.51 drop the learning rate.

  20. Yes, if that's all you do with it then you just need to compute it for the longest length. But I find that logging the point at which the loss stopping being flat is a very useful diagnostic.

  21. Are you confident that you have the correct sampling distribution for the slope? 12 *sigma**2 / (n ** 3 - n) is not an expression for the variance of the sampling distribution of the slope that I am familiar with.

    For example, these slides outline the derivation of the sampling distribution of the OLS slope when we assume a normal distribution for the errors. The derivation for the variance of the sampling distribution of the slope is given as sigma**2 / sum(x_i - bar(x))^2.

    Is there something else going on in this procedure which yields the variance estimator 12 *sigma**2 / (n ** 3 - n) ?

  22. Yes, I'm quite confident it is correct. If you take the textbook formula and plug in the specific case here (in particular, x_i is defined on the integers 0,1,2,3,...) you end up with the formula in this post. Just plug in some values and you will see you get the same results.

  23. Ah, yes, I didn't realize the substitution had already been made. I see now that this result is easily derived from that assumption.

  24. Really simple question, but I can't find the answer in 15 minutes of Googling: why is the denominator in the variance calculation n-2 and not n-1?

  25. Because there are 2 degrees of freedom when fitting a line, the slope and the intercept. But the best way to see this is to look at the equation for standard error (see https://en.wikipedia.org/wiki/Ordinary_least_squares), apply it to our specific case, and find that you end up with this equation. When in doubt, do the algebra :)

  26. I have a question about the selection of Y in your procedure. I wish to know how many iterations of Y do you keep. For instance, if we set steps_without_decrease to be 1000, how long should we keep Y? or we just use all history of Y, that is, from training started, or we just use Y[-2000:], or Y[-10000:], since I found that calculate the 'steps_without_decrease' with a very large Y is a bit time consuming.

  27. If your stopping threshold is "stop when it's been flat for at least 1000 steps" then you only need to keep the last 1000 loss values in Y. Any more history is irrelevant, by definition, since you only care about the last 1000.

  28. And what about the weight decay, as the weight decay will affect the optimization process. It seems that you did not include weight decay in the calculation of the steps_without_decrease. So I think if the weight decay is a bit large, the calculation may give a wrong result?

  29. I don't see how the weight decay has anything to do with the calculations discussed in this blog post. Here, we are talking about detecting when a noisy time series has become asymptotically flat. You could do this with any time series, it doesn't even have to come from a deep learning training procedure. The time series could come from a noisy sensor measuring the water level in a tank and the math would be exactly the same.

  30. Isn't the variance of the slope directly readable in the covariance matrix (R in your code) instead of computing 12 *sigma**2 / (n ** 3 - n)?

  31. Eh, the variance of the slope is also sigma**2 * R(0,0). So yeah, could have done that too. The deeper question is why in the world did I use recursive least squares for this rather than just inv(XX)*XY. This code is more complex than it needs to be in that way. There is some story that makes sense of how the code evolved to this, I forget what it is though.

  32. True, I was exactly wondering the same:) since you anyway all the time process it in batch sequentially. Anyway, thanks! Good job!

  33. There is however some parts which I don't understand and I cannot figure it out from the literature:
    - why is the variance of the slope not simply R(0, 0)?
    - where does the formula for computing sigma^2 comes from? In your code, you do something like this:residual_squared = residual_squared + std::pow((y - trans(x)*w),2.0)*temp;
    Any pointers on relevant literature would be helpful, thanks! What I looked at is recursive least squares and Kalman filter.

  34. See this section https://en.wikipedia.org/wiki/Ordinary_least_squares#Finite_sample_properties for notes on the variance computation. I'm not sure where a reference is for the sigma computation. I think I worked it out with pencil and paper.