Gradient flows, sometimes stochastic
Infinitesimal optimization, SDEs and generalisation
January 30, 2020 — October 10, 2024
ODE models (i.e. continuous limits) of optimisation trajectories, especially stochastic gradient descent. There are various flavours here, but I do not attempt to put them all into cunning order; I am stamp-collecting the approaches I find interesting. Very closely related: designing SDEs that sample from some desired density, which is more-or-less what every Markov chain Monte Carlo algorithm aims to do. Sometimes stochastic gradient flows sample from a tricky distribution, or nearly so, and then they are just another MCMC sampler. Sometimes this sampling distribution is even one we want, basically by accident.
1 Deterministic
We can think of gradient flows as a continuous limit of gradient descent. There is a (deterministic) ODE corresponding to training dynamics at an infinitesimal training rate as the number of steps goes to infinity.
I generally do not use these for anything useful, so I will leave this one aside for now
TBD.
2 Stochastic gradient flows
Useful. I should probably write an intro? But in fact, I think these whatsits are way simpler to understand if we simply dive in.
2.1 Classic stochastic gradient descent
The classic SGD optimizer updates the parameters by taking steps proportional to the negative of the gradient of the loss function with respect to the parameters. Here’s the discrete-time update rule:
\[ \mathbf{x}_{k+1} = \mathbf{x}_k - \eta \nabla \hat{U}(\mathbf{x}_k) \tag{1}\]
where
- \(\mathbf{x}_k\) is the parameter vector at iteration \(k\),
- \(\eta\) is the learning rate, controlling the step size, and
- \(\nabla \hat{U}(\mathbf{x}_k)\) is the stochastic gradient at iteration \(k\), typically estimated using a mini-batch of data.
\(\mathbf{x}\) is usually some set of interesting parameters, such as the weights of a neural network.
The stochastic gradient \(\nabla \hat{U}(\mathbf{x}_k)\) can be modelled, for want of a better option, as a noisy version of the true gradient \(\nabla U(\mathbf{x}_k)\),
\[ \nabla \hat{U}(\mathbf{x}_k) = \nabla U(\mathbf{x}_k) + \boldsymbol{\xi}_k \]
where
- \(\nabla U(\mathbf{x}_k)\) is the true gradient of the loss function, and
- \(\boldsymbol{\xi}_k\) is the gradient noise, assumed to follow a Gaussian distribution \(\boldsymbol{\xi}_k \sim \mathcal{N}(0, \Sigma)\).
To transition from the discrete-time SGD updates to a continuous-time framework, we introduce a small time step \(\Delta t\) such that \(k \Delta t = t\). The goal is to express the parameter updates as differential equations by letting the time step size shrink to zero. Starting from the discrete update,
\[ \begin{aligned} \mathbf{x}_{k+1} &= \mathbf{x}_k - \eta \nabla \hat{U}(\mathbf{x}_k)\\ \frac{\mathbf{x}_{k+1} - \mathbf{x}_k}{\Delta t} &= -\frac{\eta}{\Delta t} \nabla \hat{U}(\mathbf{x}_k)\\ \frac{d\mathbf{x}(t)}{\Delta t} &= -\frac{\eta}{\Delta t} \left( \nabla U(\mathbf{x}(t)) + \boldsymbol{\xi}(t) \right) \end{aligned} \]
Taking the limit as \(\Delta t \to 0\),
\[ d\mathbf{x}(t)= -\eta \nabla \hat{U}(\mathbf{x}(t)) +\sqrt{\eta \Sigma} d\mathbf{W}(t) \]
We model the stochastic term \(\boldsymbol{\xi}(t)\) using a standard Wiener process \(\mathbf{W}(t)\),
\[ \boldsymbol{\xi}(t) \, dt = \sigma \, d\mathbf{W}(t) \]
where
- \(\sigma\sqrt{\Sigma}\) defines the noise intensity,
- \(d\mathbf{W}(t)\) is the Wiener increment, satisfying \(d\mathbf{W}(t) \sim \mathcal{N}(0, dt)\).
This is a weird step. Should we assume that the Wiener noise is independent of model state, and constant in time? This is a common assumption in SDEs, but obviously is a brutal approximation; why would the noisiness of the gradient not vary in time? We kinda deal with this by estimating an empirical, diagonal gradient noise variance online in the second moment estimator, but that doesn’t change the \(d\mathbf{W}(t)\) directly.
And why would it be Gaussian at all? It would not be, of course, but as with many systems we can hope that this approximation is not too bad because of some nice central limit theorem.
Example: Consider a quadratic loss function defined via where \(A\) is a positive definite matrix and \(\mathbf{b}\) a vector,
\[ \begin{aligned} U(\mathbf{x}) &= \frac{1}{2} \mathbf{x}^\top A \mathbf{x} +\mathbf{b}^\top\mathbf{x}\\ \nabla U(\mathbf{x}) &= A \mathbf{x} + \mathbf{b} \end{aligned} \]
Substituting into the SDE,
\[ d\mathbf{x}(t) = -\eta A \mathbf{x}(t) dt -\eta \mathbf{b} dt +\sqrt{\eta \Sigma} \, d\mathbf{W}(t) \]
This SDE describes an Ornstein-Uhlenbeck process, whose stationary distribution is Gaussian with mean \(\mathbf{b}\) and whose variance is… annoying to calculate but occasionally useful (search for Lyapunov equation). There is a whole mini-industry based on that realisation — see Mandt, Hoffman, and Blei (2017).
From the SDE, we can re-discretize the update using the Euler-Maruyama method, which is a quick-and-dirty option but OK for exposition. This should get us back to something like our original SGD rule (Equation 1). For an SDE of the form
\[ d\mathbf{x}(t) = f(\mathbf{x}(t), t) dt + G(\mathbf{x}(t), t) d\mathbf{W}(t) \]
the Euler-Maruyama discretization over a time step \(\Delta t\) is
\[ \mathbf{x}_{k+1} = \mathbf{x}_k + f(\mathbf{x}_k, t_k) \Delta t + G(\mathbf{x}_k, t_k) \Delta \mathbf{W}_k \]
where
- \(\mathbf{x}_k = \mathbf{x}(t_k)\)
- \(t_k = k \Delta t\)
- \(\Delta \mathbf{W}_k = \mathbf{W}(t_{k+1}) - \mathbf{W}(t_k)\)
\(\Delta \mathbf{W}_k\) is a Gaussian random variable with zero mean and covariance \(\Delta t\) (since increments of Wiener processes are normally distributed with variance proportional to \(\Delta t\)).
Pattern matching to our case,
- Drift Function: \(f(\mathbf{x}(t), t) = -\eta \nabla U(\mathbf{x}(t))\)
- Diffusion Coefficient: \(G(\mathbf{x}(t), t) = \sqrt{\eta \Sigma}\) (assumed constant)
so the equation becomes
\[ \begin{aligned} \mathbf{x}_{k+1} &= \mathbf{x}_k - \eta \nabla U(\mathbf{x}_k) \Delta t + \sqrt{\eta \Sigma} \Delta \mathbf{W}_k\\ % &= \mathbf{x}_k - \eta \nabla U(\mathbf{x}_k) \Delta t + \sqrt{\eta \Sigma} \sqrt{\Delta t} \, \boldsymbol{\epsilon}_k\\ &= \mathbf{x}_k - \eta \nabla U(\mathbf{x}_k) \Delta t + \sqrt{\eta \Delta t} \, \Sigma^{1/2} \boldsymbol{\epsilon}_k \end{aligned} \]
where
- \(\Delta \mathbf{W}_k \sim \mathcal{N}(\mathbf{0}, \Delta t \, \mathbf{I})\)
- \(\boldsymbol{\epsilon}_k\) is a vector of standard normal random variables, \(\boldsymbol{\epsilon}_k \sim \mathcal{N}(\mathbf{0}, \mathbf{I})\)
- \(\Sigma^{1/2}\) denotes a matrix such that \(\Sigma^{1/2} (\Sigma^{1/2})^\top = \Sigma\).
That is indeed the classic SGD update rule, but now we have a natural way of re-scaling it in time, and taking continuous limits etc.
For illustrative purposes, consider the quadratic loss function from earlier. Substituting into the state-space model:
\[ \begin{aligned} \mathbf{x}_{k+1} &= \mathbf{x}_k - \eta \Delta t \, A \mathbf{x}_k + \sqrt{\eta \Sigma} \Delta \mathbf{W}_k\\ &= \left( \mathbf{I} - \eta \Delta t \, A \right) \mathbf{x}_k - \eta \Delta t \mathbf{b} + \sqrt{\eta \Sigma} \Delta \mathbf{W}_k. \end{aligned} \]
2.2 Adam
Adam is a popular optimizer that combines the benefits of adaptive learning rates and a flavour of momentum. The Adam optimizer updates the parameters using adaptive estimates of the first and second moments of the gradients. An SDE for Adam is surely published somewhere, but I couldn’t find it for some reason, so I derive it here.
cf Adam as an optimisation method, which is what it is designed for.
The resulting SDEs resemble underdamped Langevin dynamics (as seen in Langevin MCMC), because there is momentum and a friction term. They are a little weird though, because the adaptive mean estimates of the gradient drift over time.
We start from the original Adam equations, build the continuous-time version of the stochastic differential equations (SDEs), introduce the Wiener noise to model gradient noise, and then write the SDEs. Finally we can discretize using Euler-Maruyama.
The Adam optimizer updates the parameters using adaptive estimates of the first and second moments of the gradients. Here are the discrete-time update rules:
First moment update: \[ \mathbf{m}_k = \beta_1 \mathbf{m}_{k-1} + (1 - \beta_1) \nabla \hat{U}(\mathbf{x}_k) \tag{2}\]
where
- \(\mathbf{m}_k\) is the first moment estimate (momentum) at step \(k\),
- \(\nabla \hat{U}(\mathbf{x}_k)\) is the noisy gradient estimate at step \(k\),
- \(\beta_1\) is the decay rate for the momentum (typically close to 1).
Second moment update: \[ \mathbf{s}_k = \beta_2 \mathbf{s}_{k-1} + (1 - \beta_2) (\nabla \hat{U}(\mathbf{x}_k))^2 \tag{3}\]
where
- \(\mathbf{s}_k\) is the second moment estimate (variance),
- \(\beta_2\) is the decay rate for the second moment.
\[ \mathbf{x}_{k+1} = \mathbf{x}_k - \eta \frac{\mathbf{m}_k}{\sqrt{\mathbf{s}_k} + \epsilon} \tag{4}\]
where
- \(\mathbf{x}_k\) is the parameter vector at step \(k\),
- \(\eta\) is the learning rate,
- \(\epsilon\) is a small constant to prevent division by zero.
Next, we approximate the Adam update equations as continuous-time SDEs by introducing a small time step \(\Delta t\) and converting the discrete update rules into differential equations. We introduce Wiener noise to model the stochasticity in the gradient estimates. And we are very loose! We do not bother to show the SDE is well-posed.
The first moment (momentum) evolves as a stochastic process influenced by the true gradient \(\nabla U(\mathbf{x}(t))\) and stochastic noise \(\mathbf{W}(t)\).
First moment update SDE:
\[ d\mathbf{m}(t) = -\gamma_m \mathbf{m}(t) \, dt + (1 - \beta_1) \nabla U(\mathbf{x}(t)) \, dt + \sigma \, d\mathbf{W}(t) \tag{5}\]
where
- \(\gamma_m \approx -\frac{\log(\beta_1)}{\Delta t}\) is the decay rate for momentum,
- \(\nabla U(\mathbf{x}(t))\) is the true gradient,
- \(\sigma \, d\mathbf{W}(t)\) represents the Wiener noise capturing gradient stochasticity.
Second moment update SDE:
The second moment (variance) evolves based on the square of the gradient and also involves the stochastic noise term. The squared gradient noise term simplifies to a deterministic contribution due to Itô calculus.
\[ d\mathbf{s}(t) = -\gamma_s \mathbf{s}(t) \, dt + (1 - \beta_2) (\nabla U(\mathbf{x}(t)))^2 \, dt + \sigma^2 \, dt \tag{6}\]
where
- \(\gamma_s \approx -\frac{\log(\beta_2)}{\Delta t}\),
- The noise term \(\sigma^2 \, dt\) arises from the squared Wiener process \(d\mathbf{W}(t)^2\) and simplifies to a deterministic term due to Itô calculus.
Position update SDE:
The position update depends on the momentum and variance estimates, with the learning rate \(\eta\) controlling the step size.
\[ d\mathbf{x}(t) = - \eta \frac{\mathbf{m}(t)}{\sqrt{\mathbf{s}(t)} + \epsilon} \, dt \tag{7}\]
The division here is element-wise, and \(\epsilon\) is a small constant (vector) to prevent division by zero.
In this setup, the Wiener noise \(d\mathbf{W}(t)\) enters the momentum update directly, while its squared form appears in the variance update. The shared noise term reflects the fact that both the first and second moment estimates are based on the same underlying noisy gradient information.
We transform the system into state-space form, stacking the position, momentum, and variance into a state vector. We express the system as a matrix operation involving deterministic and stochastic parts.
We stack the position \(\mathbf{x}_k\), momentum \(\mathbf{m}_k\), and variance \(\mathbf{s}_k\) into a single state vector \(\mathbf{z}_k\),
\[ \mathbf{z}_k = \begin{bmatrix} \mathbf{x}_k \\ \mathbf{m}_k \\ \mathbf{s}_k \end{bmatrix} \]
Based on the SDEs, we now write the discrete-time process updates using the Euler-Maruyama discretisation.
Position Update:
\[ \mathbf{x}_{k+1} = \mathbf{x}_k - \eta \frac{\mathbf{m}_k}{\sqrt{\mathbf{s}_k} + \epsilon} \Delta t \]
Momentum Update (with Wiener noise):
\[ \mathbf{m}_{k+1} = \mathbf{m}_k - \gamma_m \mathbf{m}_k \Delta t + (1 - \beta_1) \nabla U(\mathbf{x}_k) \Delta t + \sigma \sqrt{\Delta t} \, \boldsymbol{\xi} \] Where \(\boldsymbol{\xi} \sim \mathcal{N}(0, \mathbf{I})\) is the Gaussian noise driving the stochastic gradient estimates.
Variance Update (with simplified Wiener noise contribution):
\[ \mathbf{s}_{k+1} = \mathbf{s}_k - \gamma_s \mathbf{s}_k \Delta t + (1 - \beta_2) (\nabla U(\mathbf{x}_k))^2 \Delta t + \sigma^2 \Delta t \]
We can now express the process update equations in state-space matrix update form. Define the system matrix \(\mathbf{A}\):
\[ \mathbf{A} = \begin{bmatrix} \mathbf{I} & -\eta \Delta t \frac{1}{\sqrt{\mathbf{s}_k} + \epsilon} \mathbf{I} & \mathbf{0} \\ \mathbf{0} & (1 - \gamma_m \Delta t) \mathbf{I} & \mathbf{0} \\ \mathbf{0} & \mathbf{0} & (1 - \gamma_s \Delta t) \mathbf{I} \end{bmatrix} \]
Define the gradient-dependent vector \(\mathbf{b}_k\),
\[ \mathbf{b}_k = \begin{bmatrix} \mathbf{0} \\ (1 - \beta_1) \nabla U(\mathbf{x}_k) \Delta t \\ (1 - \beta_2) (\nabla U(\mathbf{x}_k))^2 \Delta t + \sigma^2 \Delta t \end{bmatrix} \]
Define the noise matrix \(\mathbf{G}\),
\[ \mathbf{G} = \begin{bmatrix} \mathbf{0} \\ \sigma \sqrt{\Delta t} \mathbf{I} \\ \mathbf{0} \end{bmatrix} \]
The final state-space form for the Adam-type optimizer, including the Wiener noise, is
\[ \mathbf{z}_{k+1} = \mathbf{A} \mathbf{z}_k + \mathbf{b}_k + \mathbf{G} \boldsymbol{\xi}_k \]
where
- \(\mathbf{z}_k = \begin{bmatrix} \mathbf{x}_k \\ \mathbf{m}_k \\ \mathbf{s}_k \end{bmatrix}\) is the state vector,
- \(\mathbf{A}\) is the system matrix governing the deterministic part of the update,
- \(\mathbf{b}_k\) is a gradient-dependent vector,
- \(\mathbf{G}\) maps the stochastic noise \(\boldsymbol{\xi}_k \sim \mathcal{N}(0, \mathbf{I})\) into the system.
This is a little misleading to my eye. It looks like a linear update, and in fact like two uncoupled systems. But we inject \(\mathbf{s}_k\) back into the position update, so the systems are both nonlinear and coupled. This does make things more complicated, but how depends on what we were doing this for.
2.3 Nesterov momentum
See Q. Li, Tai, and Weinan (2019).
3 Stochastic DE for early stage training
We are not near any optimum, let alone a quadratic one. We do not know much about \(U\) could be anything at all. I have an inkling that this regime is used for choosing scaling rules for model training, (Q. Li, Tai, and Weinan 2019; Z. Li, Malladi, and Arora 2021; Malladi et al. 2022).
4 Near the optimum
Typically we imagine that the \(U\) loss near an optimum is quadratic-ish (which it is in many useful cases), in which case the diffusion might look nice (see examples above). Many SDE dynamics then sample from a Gaussian; cite Bernstein–von Mises theorem here. This is particularly in providing us interpretations as of sampling from a Bayes posterior. See Bayes by Backprop for some applications.
The limiting diffusion is another term I have seen, which seems describes diffusion around an optimum (Gu et al. 2022; Z. Li, Wang, and Arora 2021; Lyu, Li, and Arora 2023; Wang et al. 2023). I do not know if it is the same as the nearly-quadratic case. Seems to be regarded as useful for understanding generalisation. TBD.