Adaptive stochastic gradient descent
I need to mention Adam and RMSProp etc somewhere
January 30, 2020 — October 8, 2024
Modern SGD algorithms are often of the “adaptive” flavour, which means that the learning rate is adaptively tuned for each parameter during the learning process, putting them somewhere between first order and second order.
Justifications for the adaptation are some mixture of theoretical and empirical, and most of all economical: Many very big companies are based on the most famous method Adam, so it’s pretty good. Sebastian Ruder over 2016-2020 maintained an overview of gradient descent optimization algorithms, which is my favourite introduction to the topic. Read that before hoping to get anything new from me.
There are explicitly 2nd-order flavours (Shampoo springs to mind).
1 Adam
The Adam optimizer (Kingma and Ba 2017) updates parameters using estimates of first and second moments of the gradients. As the most popular optimizer in deep learning, it is worth understanding, but by the same token, analyzed to oblivion. There are many competent introductions to Adam. Here are two:
- Adam - Cornell University Computational Optimization Open Textbook
- Rahul Agarwal, Complete Guide to the Adam Optimization Algorithm
The update rule for the parameter \(\phi_t\) is
\[ \begin{aligned} \phi_{t+1} = \phi_t - \lambda_t \frac{m_t}{\sqrt{v_t} + \epsilon}, \end{aligned} \]
where
- \(\lambda_t\) is the learning rate at time \(t\).
- \(m_t\) is the exponentially decaying average of past gradients (first moment estimate).
- \(v_t\) is the exponentially decaying average of past squared gradients (second moment estimate).
- \(\epsilon\) is a small constant to prevent division by zero.
This is the workhorse of modern ML. Much has been written about it. Do I have anything new to add? Honestly, I do not know, because there is so much written about it that search through it is overwhelming, so I write out the bits I need here.
We can interpret the Adam update as solving a regularized least-squares problem at each time step. Specifically, we consider the following optimization problem:
\[ \begin{aligned} \phi_{t+1} = \arg\min_{\phi} \left[ \frac{1}{2} (\phi - \phi_t)^\top \mathbf{H}_t (\phi - \phi_t) + \lambda_t (\nabla L(\phi_t))^\top (\phi - \phi_t) \right], \end{aligned} \]
where
- \(\mathbf{H}_t\) is a positive definite matrix (regularization term).
- \(\nabla L(\phi_t)\) is the gradient of the loss function at \(\phi_t\).
- The first term penalizes deviations from the current parameter estimate \(\phi_t\).
- The second term encourages movement in the direction of the negative gradient.
To find \(\phi_{t+1}\), we take the gradient of the objective function with respect to \(\phi\) and set it to zero,
\[ \begin{aligned} \nabla_\phi \left[ \frac{1}{2} (\phi - \phi_t)^\top \mathbf{H}_t (\phi - \phi_t) + \lambda_t (\nabla L(\phi_t))^\top (\phi - \phi_t) \right] &= 0, \\ \mathbf{H}_t (\phi - \phi_t) + \lambda_t \nabla L(\phi_t) &= 0. \end{aligned} \]
Solving for \(\phi_{t+1}\),
\[ \begin{aligned} \phi_{t+1} = \phi_t - \lambda_t \mathbf{H}_t^{-1} \nabla L(\phi_t). \end{aligned} \]
Adam uses some extra tricks I have not mentioned here, like bias correction terms for \(m_t\) and \(v_t\), and there are many details in how we actually estimate \(m_t\) and \(v_t\). The complications do not add anything to this introduction, so I omit them.
In the Adam optimizer, the parameters are updated using the first moment estimate \(m_t\) instead of the exact gradient \(\nabla L(\phi_t)\). Additionally, the update is scaled by the inverse square root of the second moment estimate \(v_t\). We can make the following substitutions:
- Replace \(\nabla L(\phi_t)\) with \(m_t\) (assuming \(m_t \approx \nabla L(\phi_t)\)).
- Let \(\mathbf{H}_t = \operatorname{diag}(\sqrt{v_t} + \epsilon)\).
With these substitutions, the update becomes
\[ \begin{aligned} \bar{\phi}_{t+1} = \phi_t - \lambda_t \operatorname{diag}\left( \frac{1}{\sqrt{v_t} + \epsilon} \right) m_t. \end{aligned} \]
which matches the Adam update rule.
1.1 As a Bayes procedure 1
Least squares is only ever one Bayesian saying well actually… away from being a Gaussian posterior updates.
Suppose the parameter vector at time \(t\) to be \(\phi_t\) is a random variable and we want to see what kind of Bayesian interpretation we can give to the Adam update rule. Assume a Gaussian prior over the parameters \(\phi\),
\[ \begin{aligned} p(\phi) = \mathcal{N}(\phi; \bar{\phi}_t, \mathbf{H}_t^{-1}), \end{aligned} \]
where
- \(\bar{\phi}_t\) is the mean of the prior distribution.
- \(\mathbf{H}_t^{-1}\) is the covariance matrix of the prior, reflecting our confidence in the current estimate \(\bar{\phi}_t\).
The likelihood function is based on observing a gradient \(m_t\), which provides information about the parameters,
\[ \begin{aligned} p(m_t | \phi) = \mathcal{N}(m_t; 0, \lambda_t^{-1} \mathbf{H}_t), \end{aligned} \]
where the likelihood encourages \(\phi\) to move in the direction that reduces the loss function. Applying Bayes’ theorem,
\[ \begin{aligned} p(\phi | m_t) &\propto p(m_t | \phi) p(\phi)\\ &= \mathcal{N}(\phi; \bar{\phi}_{t+1}, \mathbf{H}_t^{-1}), \end{aligned} \]
where \(\bar{\phi}_{t+1}\) is the mean of the posterior distribution, given by
\[ \begin{aligned} \bar{\phi}_{t+1} = \bar{\phi}_t - \lambda_t \mathbf{H}_{t}^{-1} m_t. \end{aligned} \]
This matches the update rule derived from the regularized least-squares problem and the Adam optimizer. Let’s consider when the likelihood function
\[ \begin{aligned} p(m_t | \phi) = \mathcal{N}(m_t; 0, \lambda_t^{-1} \mathbf{H}_t) \end{aligned} \]
might be a meaningful choice. In those terms, what likelihood function have we just assumed? We need to find a likelihood function \(p(\text{data} | \phi)\) such that its negative log-likelihood corresponds to the linear term \(\lambda_t (\nabla L(\phi_t))^\top (\phi - \phi_t)\) in our optimization problem.
However, a standard Gaussian likelihood leads to a quadratic term in \(\phi\), not a linear one. What are we even doing, then?
Let’s define a synthetic observation \(y_t\) such that \[ \begin{aligned} y_t = A (\phi - \bar{\phi}_t) + \epsilon_t \end{aligned} \]
where
- \(A\) is a matrix we will define.
- \(\epsilon_t\) is Gaussian noise: \(\epsilon_t \sim \mathcal{N}(0, \mathbf{V}_t)\). Our goal is to choose \(A\) and \(\mathbf{V}_t\) such that the negative log-likelihood corresponds to the linear term in our optimization problem.
The likelihood function is
\[ \begin{aligned} p(y_t | \phi) = \mathcal{N}(y_t; A (\phi - \bar{\phi}_t), \mathbf{V}_t) \end{aligned} \]
The negative log-likelihood (up to a constant) is:
\[ \begin{aligned} - \log p(y_t | \phi) = \frac{1}{2} (y_t - A (\phi - \bar{\phi}_t))^\top \mathbf{V}_t^{-1} (y_t - A (\phi - \bar{\phi}_t)) \end{aligned} \]
If we fix \(A = I\) (identity matrix) the negative log-likelihood becomes
\[ \begin{aligned} - \log p(y_t | \phi) &= \frac{1}{2} (y_t - (\phi - \bar{\phi}_t))^\top \mathbf{V}_t^{-1} (y_t - (\phi - \bar{\phi}_t))\\ &= \frac{1}{2} (\phi - \bar{\phi}_t - y_t)^\top \mathbf{V}_t^{-1} (\phi - \bar{\phi}_t - y_t) \end{aligned} \]
If we assume that \(\phi - \bar{\phi}_t\) is small but \(y_t\) is not, we can approximate the negative log-likelihood as
\[ \begin{aligned} - \log p(y_t | \phi) \approx - (\phi - \bar{\phi}_t)^\top \mathbf{V}_t^{-1} y_t \end{aligned} \]
This is an assumption that gives us a linear term in \(\phi\). (It’s not very satisfying though). The negative log of the posterior is (up to a constant)
\[ \begin{aligned} - \log p(\phi | y_t) &= - \log p(y_t | \phi) - \log p(\phi) \\ & \approx \frac{1}{2} (\phi - \bar{\phi}_t)^\top \mathbf{H}_t (\phi - \bar{\phi}_t) - (\phi - \bar{\phi}_t)^\top \mathbf{V}_t^{-1} y_t \end{aligned} \]
To find the posterior mean which is also the posterior maximum \(\bar{\phi}_{t+1}=\operatorname{argmax}_\phi -\log p(\phi|y_t)\), we set the gradient of the negative log-posterior to zero,
\[ \begin{aligned} 0&=\nabla_\phi \left[ \frac{1}{2} (\phi - \bar{\phi}_t)^\top \mathbf{H}_t (\phi - \bar{\phi}_t) - (\phi - \bar{\phi}_t)^\top \mathbf{V}_t^{-1} y_t \right]\\ &=\mathbf{H}_t (\phi - \bar{\phi}_t) - \mathbf{V}_t^{-1} y_t\\ \Rightarrow \bar{\phi}_{t+1} &= \bar{\phi}_t + \mathbf{H}_t^{-1} \mathbf{V}_t^{-1} y_t \end{aligned} \]
Recall that \(y_t = \nabla L(\bar{\phi}_t)+\epsilon\), so
\[ \begin{aligned} \bar{\phi}_{t+1} = \bar{\phi}_t - \lambda_t \mathbf{H}_t^{-1} \mathbf{V}_t^{-1} \nabla L(\bar{\phi}_t) \end{aligned} \] To recover the Adam update rule, we need to choose \(\mathbf{H}_t^{-1} \mathbf{V}_t^{-1} = \operatorname{diag}\left( \frac{1}{\sqrt{v_t} + \epsilon} \right)\).
While that was coherent, I am not sure I learned anything. This still seems arbitrary; for one thing, the Adam update uses a square root of a variance matrix, which does not naturally arise here.
1.2 As a Bayes procedure 2
Let us try something different and see if we can get a more plausible rationale. We propose that the uncertainty in the parameters is proportional to the square root of the gradient variance,
\[ \begin{aligned} \operatorname{Var}[\phi_t] \propto \sqrt{v_t}. \end{aligned} \]
This means that the standard deviation (uncertainty) of each parameter \(\phi_t^i\) is proportional to \(\sqrt{v_t^i}\), where \(v_t^i\) is the estimated variance of the \(i\)-th parameter’s gradient. Assume a prior distribution over the parameters \(\phi\)
\[ \begin{aligned} p(\phi) = \mathcal{N}(\phi; \bar{\phi}_t, \Sigma_t), \end{aligned} \]
where
- \(\bar{\phi}_t\) is the mean of the prior distribution (current estimate).
- \(\Sigma_t\) is the prior covariance matrix, reflecting our uncertainty about \(\phi_t\). We set the prior covariance matrix to
\[ \begin{aligned} \Sigma_t = \sigma^2 \operatorname{diag}(\sqrt{v_t}), \end{aligned} \]
where \(\sigma^2\) is a scaling factor. The observed gradient \(m_t\) provides information about the parameters,
\[ \begin{aligned} m_t = \nabla L(\bar{\phi}_t) + \epsilon_t, \end{aligned} \]
where \(\nabla L(\bar{\phi}_t)\) is the true gradient of the loss function at \(\bar{\phi}_t\), and \(\epsilon_t\) represents the noise in the gradient estimation due to stochastic sampling, with \(\epsilon_t \sim \mathcal{N}(0, V_t),\) and \(V_t = \operatorname{diag}(v_t)\). Now, we want the posterior mean update for \(\phi\),
\[ \begin{aligned} \bar{\phi}_{t+1} = \bar{\phi}_t - K_t m_t, \end{aligned} \]
where \(K_t\) is a gain matrix that adjusts the update based on the uncertainties in the parameters and the observations (gradients). The gain matrix \(K_t\) is the usual, \[ \begin{aligned} K_t = \Sigma_t (V_t + \Sigma_t)^{-1}. \end{aligned} \]
Given that \(\Sigma_t\) and \(V_t\) are diagonal matrices in Adam, the inversion and multiplication are straightforward.
Assuming \(\Sigma_t \ll V_t\) (i.e., parameter uncertainty is small compared to gradient noise), we can approximate
\[ \begin{aligned} V_t + \Sigma_t &\approx V_t\\ K_t &= \Sigma_t V_t^{-1}. \end{aligned} \]
Substituting \(\Sigma_t = \sigma^2 \operatorname{diag}(\sqrt{v_t})\) and \(V_t = \operatorname{diag}(v_t)\), then \(\lambda_t = \sigma^2\) we have
\[ \begin{aligned} K_t = \sigma^2 \operatorname{diag}(\sqrt{v_t}) \operatorname{diag}(v_t)^{-1} \\ &= \sigma^2 \operatorname{diag}\left( \frac{1}{\sqrt{v_t}} \right).\\ \bar{\phi}_{t+1} &= \bar{\phi}_t - K_t m_t = \bar{\phi}_t - \sigma^2 \operatorname{diag}\left( \frac{1}{\sqrt{v_t}} \right) m_t\\ &= \bar{\phi}_t - \lambda_t \frac{m_t}{\sqrt{v_t}}. \end{aligned} \]
This matches the Adam update rule (neglecting the small constant \(\epsilon\)).
By setting the parameter uncertainty proportional to the square root of the gradient variance, we adjust the parameter updates based on how much we “trust” the gradient information:
- High Gradient Variance (\(v_t^i\) large): Indicates high uncertainty in the gradient estimate for parameter \(\phi_t^i\), leading to a small update due to larger \(\sqrt{v_t^i}\) in the denominator.
- Low Gradient Variance (\(v_t^i\) small): Indicates more reliable gradient information, leading to a larger update for that parameter.
The coincidence of the prior noise happening to be proportional to the square root of the gradient variance still feels a little bit arbitrary. We are using those 2nd moment estimates twice, once to invent the prior and then again for likelihood, and in a strange way.
1.3 As a Bayes procedure 3
OK, that was weirder than I expected, especially because of that danged square root. It can be eliminated, but with a more involved method (Lin et al. 2024). Maybe we should follow M. E. Khan and Rue (2024) and interpret it as a natural gradient method. TBC
1.4 As a gradient flow
See gradient flows
1.5 AdamW
TBD.
2 Nadam
TBD
3 Adagrad
TBD
4 RMSprop
TBD
5 As a natural gradient method
6 Sparse variants
TBD