- Run vanilla stochastic gradient descent with momentum and a fixed learning rate
- Wait for the loss to stop improving
- Reduce the learning rate
- Go back to step 1 or stop if the learning rate is really small
Many papers reporting state-of-the-art results do this. There have been a lot of other methods proposed, like ADAM, but I've always found the above procedure to work best. This is a common finding. The only fiddly part of this procedure is the "wait for the loss to stop improving" step. A lot of people just eyeball a plot of the loss and manually intervene when it looks like its flattened out. Or worse, they pick a certain number of iterations ahead of time and blindly stop when that limit is reached. Both of these ways of deciding when to reduce the learning rate suck.
Fortunately, there is a simple method from classical statistics we can use to decide if the loss is still improving, and thus, when to reduce it. With this method it's trivial to fully automate the above procedure. In fact, it's what I've used to train all the public DNN models in dlib over the last few years: e.g. face detection, face recognition, vehicle detection, and imagenet classification. It's the default solving strategy used by dlib's DNN solver. The rest of this blog post explains how it works.
Fundamentally, what we need is a method that takes a noisy time series of $n$ loss values, $Y=\{y_0,y_1,y_2,\dots,y_{n-1}\}$, and tells us if the time series is trending down or not. To do this, we model the time series as a noisy line corrupted by Gaussian noise:
\[
\newcommand{\N} {\mathcal{N} } y_i = m\times i + b + \epsilon
\] Here, $m$ and $b$ are the unknown true slope and intercept parameters of the line, and $\epsilon$ is a Gaussian noise term with mean 0 and variance $\sigma^2$. Let's also define the function $\text{slope}(Y)$ that takes in a time series, performs OLS, and outputs the OLS estimate of $m$, the slope of the line. You can then ask the following question: what is the probability that a time series sampled from our noisy line model will have a negative slope according to OLS? That is, what is the value of?
\[
P(\text{slope}(Y) < 0)
\]If we could compute an estimate of $P(\text{slope}(Y)<0)$ we could use it to test if the loss is still decreasing. Fortunately, computing the above quantity turns out to be easy. In fact, $\text{slope}(Y)$ is a Gaussian random variable with this distribution:
\[
\text{slope}(Y) \sim \N\left(m, \frac{12 \sigma^2}{n^3-n}\right)
\]We don't know the true values of $m$ and $\sigma^2$, but they are easily estimated from data. We can obviously use $\text{slope}(Y)$ to estimate $m$. As for $\sigma^2$, it's customary to estimate it like this:
\[ \sigma^2 = \frac{1}{n-2} \sum_{i=0}^{n-1} (y_i - \hat y_i)^2 \] which gives an unbiased estimate of the true $\sigma^2$. Here $y_i - \hat y_i$ is the difference between the observed time series value at time $i$ and the value predicted by the OLS fitted line at time $i$. I should point out that none of this is new stuff, in fact, these properties of OLS are discussed in detail on the Wikipedia page about OLS.
So let's recap. We need a method to decide if the loss is trending downward or not. I'm suggesting that you use $P(\text{slope}(Y) < 0)$, the probability that a line fit to your loss curve will have negative slope. Moreover, as discussed above, this probability is easy to compute since it's just a question about a simple Gaussian variable and the two parameters of the Gaussian variable are given by a straightforward application of OLS.
You should also note that the variance of $\text{slope}(Y)$ decays at the very quick rate of $O(1/n^3)$, where $n$ is the number of loss samples. So it becomes very accurate as the length of the time series grows. To illustrate just how accurate this is, let's look at some examples. The figure below shows four different time series plots, each consisting of $n=4000$ points. Each plot is a draw from our noisy line model with parameters: $b=0$, $\sigma^2=1$, and $m \in \{0.001, 0.00004, -0.00004, -0.001\}$. For each of these noisy plots I've computed $P(\text{slope}(Y) < 0)$ and included it in the title.
I find that a nice way to parameterize this in actual code is to count the number of mini-batches that executed while $P(\text{slope}(Y) < 0) < 0.51$. That is, find out how many loss values you have to look at before there is evidence the loss has been decreasing. To be very clear, this bit of pseudo-code implements the idea:
def count_steps_without_decrease(Y):
steps_without_decrease = 0
n = len(Y)
for i in reversed(range(n)):
if P(slope(Y[i:n]) < 0) < 0.51:
steps_without_decrease = n-i
return steps_without_decrease
count_steps_without_decrease(Y) > threshold and count_steps_without_decrease_robust(Y) > threshold
That is, perform the check with and without outlier discarding. You need both checks because the 10% largest loss values might have occurred at the very beginning of Y. For example, maybe you are waiting for 1000 (i.e. threshold=1000) mini-batches to execute without showing evidence of the loss going down. And maybe the first 100 all showed a dropping loss while the last 900 were flat. The check that discarded the top 10% would erroneously indicate that the loss was NOT dropping. So you want to perform both checks and if both agree that the loss isn't dropping then you can be confident it's not dropping.It should be emphasized that this method isn't substantively different from what a whole lot of people already do when training deep neural networks. The only difference here is that the "look at the loss and see if it's decreasing" step is being done by a computer. The point of this blog post is to point out that this check is trivially automatable with boring old simple statistics. There is no reason to do it by hand. Let the computer do it and find something more productive to do with your time than babysitting SGD. The test is simple to implement yourself, but if you want to just call a function you can call dlib's count_steps_without_decrease() and count_steps_without_decrease_robust() routines from C++ or Python.
Finally, one more useful thing you can do is the following: you can periodically check if $P(\text{slope}(Y) > 0) \gt 0.99$, that is, check if we are really certain that the loss is going up, rather than down. This can happen and I've had training runs that were going fine and then suddenly the loss shot up and stayed high for a really long time, basically ruining the training run. This doesn't seem to be too much of an issue with simple losses like the log-loss. However, structured loss functions that perform some kind of hard negative mining inside a mini-batch will sometimes go haywire if they hit a very bad mini-batch. You can fix this problem by simply reloading from an earlier network state before the loss increased. But to do this you need a reliable way to measure "the loss is going up" and $P(\text{slope}(Y) > 0) \gt 0.99$ is excellent for this task. This idea is called backtracking and has a long history in numerical optimization. Backtracking significantly increases solver robustness in many cases and is well worth using.