Gradient flows, sometimes stochastic
Infinitesimal optimization, SDEs and generalisation
January 29, 2020 — October 10, 2024
ODE models (i.e. continuous limits) of optimisation trajectories, especially stochastic gradient descent. There are various flavours here, but I do not attempt to put them all into cunning order; I am stamp-collecting the approaches I find interesting. Very closely related: designing SDEs that sample from some desired density, which is more-or-less what every Markov chain Monte Carlo algorithm aims to do. Sometimes stochastic gradient flows sample from a tricky distribution, or nearly so, and then they are just another MCMC sampler. Sometimes this sampling distribution is even one we want, basically by accident.
1 Deterministic
We can think of gradient flows as a continuous limit of gradient descent. There is a (deterministic) ODE corresponding to training dynamics at an infinitesimal training rate as the number of steps goes to infinity.
I generally do not use these for anything useful, so I will leave this one aside for now
TBD.
2 Stochastic gradient flows
Useful. I should probably write an intro? But in fact, I think these whatsits are way simpler to understand if we simply dive in.
2.1 Classic stochastic gradient descent
The classic SGD optimizer updates the parameters by taking steps proportional to the negative of the gradient of the loss function with respect to the parameters. Here’s the discrete-time update rule:
where
is the parameter vector at iteration , is the learning rate, controlling the step size, and is the stochastic gradient at iteration , typically estimated using a mini-batch of data.
The stochastic gradient
where
is the true gradient of the loss function, and is the gradient noise, assumed to follow a Gaussian distribution .
To transition from the discrete-time SGD updates to a continuous-time framework, we introduce a small time step
Taking the limit as
We model the stochastic term
where
defines the noise intensity, is the Wiener increment, satisfying .
This is a weird step. Should we assume that the Wiener noise is independent of model state, and constant in time? This is a common assumption in SDEs, but obviously is a brutal approximation; why would the noisiness of the gradient not vary in time? We kinda deal with this by estimating an empirical, diagonal gradient noise variance online in the second moment estimator, but that doesn’t change the
And why would it be Gaussian at all? It would not be, of course, but as with many systems we can hope that this approximation is not too bad because of some nice central limit theorem.
Example: Consider a quadratic loss function defined via where
Substituting into the SDE,
This SDE describes an Ornstein-Uhlenbeck process, whose stationary distribution is Gaussian with mean
From the SDE, we can re-discretize the update using the Euler-Maruyama method, which is a quick-and-dirty option but OK for exposition. This should get us back to something like our original SGD rule (Equation 1). For an SDE of the form
the Euler-Maruyama discretization over a time step
where
Pattern matching to our case,
- Drift Function:
- Diffusion Coefficient:
(assumed constant)
so the equation becomes
where
is a vector of standard normal random variables, denotes a matrix such that .
That is indeed the classic SGD update rule, but now we have a natural way of re-scaling it in time, and taking continuous limits etc.
For illustrative purposes, consider the quadratic loss function from earlier. Substituting into the state-space model:
2.2 Adam
Adam is a popular optimizer that combines the benefits of adaptive learning rates and a flavour of momentum. The Adam optimizer updates the parameters using adaptive estimates of the first and second moments of the gradients. An SDE for Adam is surely published somewhere, but I couldn’t find it for some reason, so I derive it here.
cf Adam as an optimisation method, which is what it is designed for.
The resulting SDEs resemble underdamped Langevin dynamics (as seen in Langevin MCMC), because there is momentum and a friction term. They are a little weird though, because the adaptive mean estimates of the gradient drift over time.
We start from the original Adam equations, build the continuous-time version of the stochastic differential equations (SDEs), introduce the Wiener noise to model gradient noise, and then write the SDEs. Finally we can discretize using Euler-Maruyama.
The Adam optimizer updates the parameters using adaptive estimates of the first and second moments of the gradients. Here are the discrete-time update rules:
First moment update:
where
is the first moment estimate (momentum) at step , is the noisy gradient estimate at step , is the decay rate for the momentum (typically close to 1).
Second moment update:
where
is the second moment estimate (variance), is the decay rate for the second moment.
where
is the parameter vector at step , is the learning rate, is a small constant to prevent division by zero.
Next, we approximate the Adam update equations as continuous-time SDEs by introducing a small time step
The first moment (momentum) evolves as a stochastic process influenced by the true gradient
First moment update SDE:
where
is the decay rate for momentum, is the true gradient, represents the Wiener noise capturing gradient stochasticity.
Second moment update SDE:
The second moment (variance) evolves based on the square of the gradient and also involves the stochastic noise term. The squared gradient noise term simplifies to a deterministic contribution due to Itô calculus.
where
,- The noise term
arises from the squared Wiener process and simplifies to a deterministic term due to Itô calculus.
Position update SDE:
The position update depends on the momentum and variance estimates, with the learning rate
The division here is element-wise, and
In this setup, the Wiener noise
We transform the system into state-space form, stacking the position, momentum, and variance into a state vector. We express the system as a matrix operation involving deterministic and stochastic parts.
We stack the position
Based on the SDEs, we now write the discrete-time process updates using the Euler-Maruyama discretisation.
Position Update:
Momentum Update (with Wiener noise):
Variance Update (with simplified Wiener noise contribution):
We can now express the process update equations in state-space matrix update form. Define the system matrix
Define the gradient-dependent vector
Define the noise matrix
The final state-space form for the Adam-type optimizer, including the Wiener noise, is
where
is the state vector, is the system matrix governing the deterministic part of the update, is a gradient-dependent vector, maps the stochastic noise into the system.
This is a little misleading to my eye. It looks like a linear update, and in fact like two uncoupled systems. But we inject
2.3 Nesterov momentum
See Q. Li, Tai, and Weinan (2019).
3 Stochastic DE for early stage training
We are not near any optimum, let alone a quadratic one. We do not know much about
4 Near the optimum
Typically we imagine that the
The limiting diffusion is another term I have seen, which seems describes diffusion around an optimum (Gu et al. 2022; Z. Li, Wang, and Arora 2021; Lyu, Li, and Arora 2023; Wang et al. 2023). I do not know if it is the same as the nearly-quadratic case. Seems to be regarded as useful for understanding generalisation. TBD.