Quasi-gradients of discrete parameters
2022-12-20 — 2025-07-25
Suspiciously similar content
Notes on taking gradients through things that look like they have no gradients in the standard sense because their arguments are discrete. There are a lot of loosely related concepts in here that may not reflect an actual common theme. TBC.
1 Let’s all worry about the Straight-through Estimator
When training neural networks with quantized weights or activations, we introduce a non-differentiable projection
\[ P:\mathbb R^r\to X \]
(e.g. \(P(x)=\mathrm{sign}(x)\) for binary networks). A common “trick” is to ignore the true Jacobian \(\partial P/\partial x\) and simply back-propagate gradients as if \(P\) were the identity:
\[ \tilde x_{k+1} = \tilde x_k - \eta\nabla f\bigl(P(\tilde x_k)\bigr), \quad x_{k+1}=P(\tilde x_{k+1}). \]
This is the Straight-Through Estimator (STE) (Bengio, Léonard, and Courville 2013) — a crude but surprisingly effective heuristic to avoid vanishing gradients through hard quantization. PEople use this a lot in practice, but then fret about it. There is a miniature industry fixing it.
1.1 Mirror Descent to STE
Ajanthan et al. (2021) tries to justify STE as a mirror descent trick. Mirror Descent, to recap, generalizes gradient descent to settings where the feasible set \(X\subset\mathbb R^r\) is not naturally Euclidean. It relies on:
A mirror map \(\Phi:C\to\mathbb R\), a strictly convex, differentiable function whose gradient \(\nabla\Phi\) maps the primal space \(X\) into the dual space \(\mathbb R^r\).
The associated Bregman divergence
\[ D_\Phi(p,q)=\Phi(p)-\Phi(q)-\langle\nabla\Phi(q),\,p-q\rangle. \]
Starting from \(x_0\in X\), each iteration does
\[ \nabla\Phi(y_{k+1}) \;=\;\nabla\Phi(x_k)-\eta\,g_k,\quad x_{k+1} = \arg\min_{x\in X}\;D_\Phi\bigl(x,y_{k+1}\bigr), \]
where \(g_k\in\partial f(x_k)\) . Equivalently,
\[ x_{k+1} =\arg\min_{x\in X}\;\langle\eta\,g_k,\,x\rangle + D_\Phi(x,x_k). \]
Ajanthan et al. show that any strictly monotone projection \(P\) yields a valid mirror map
\[ \Phi(x) \;=\;\int_{x_0}^x P^{-1}(y)\,dy, \]
because \(\nabla\Phi(x)=P^{-1}(x)\) and \(\Phi\) is strictly convex. With this choice:
MD in the dual space amounts to storing auxiliary variables \(\tilde x=P^{-1}(x)\).
The MD update \(\nabla\Phi(y_{k+1})=\nabla\Phi(x_k)-\eta g_k\) becomes simply
\[ \tilde x_{k+1} = \tilde x_k - \eta\,g_k, \]
and mapping back to the primal:
\[ x_{k+1}=P(\tilde x_{k+1}). \]
This is exactly the STE procedure.
In this view, STE is nothing other than a numerically-stable implementation of MD under the mirror map induced by the projection \(P\). Mirror Descent provides the missing theoretical foundation for why STE works so well in practice: it is simply gradient descent in a dual space tailored to the geometry of quantization.
This is kinda wild. It is not clear at all to me that mirror descent should have generalised to discrete spaces, which shows I am thinking about it wrong.
We will be violating a lot of the mirror descent assumptions in the NN setting (surely some kind of convexity violation). I wonder if you could recover mirror descent theory in the vicinity of an optimum or something? This dual space perspective looks handy.
1.2 STE as a Bayes procedure
Meng, Bachmann, and Khan (2020) is another attempt to justify that devises a Bayesian update corresponding to the STE, as seen in NN quantization.
2 Stochastic gradients via REINFORCE
The classic generic REINFORCE/Score function method for estimating gradients of expectations of functions of random variables can be used to estimate gradients of functions of discrete random variables as a special case. There are particular extra tricks used in practice for discrete random variables to keep it performant; see e.g. (Grathwohl et al. 2018; Liu et al. 2019; Mnih and Gregor 2014; Tucker et al. 2017).
3 Gumbel-(soft)max
A.k.a. the concrete distribution. See Gumbel-max.
4 Gradients of other weird things
Differentiable sorting? See, e.g. Grover et al. (2018) and Prillo and Eisenschlos (2020).
5 Avoiding the need for gradients
Famously, Expectation Maximization can handle some of the same optimisation problems as gradient-based methods, but without needing gradients. There are presumably more variants.
6 Other methods
What even are (Grathwohl et al. 2021; Zhang, Liu, and Liu 2022)? I think they work for quantised continuous vars, or possibly ordinal vars?