Quasi-gradients of discrete parameters
2022-12-20 — 2025-07-25
Wherein the justification of straight-through estimators is presented via mirror descent, and quantized neural-network updates are shown to be performed in a dual space induced by the projection
Notes on taking gradients through functions that seem to have no gradients in the standard sense because their arguments are discrete. We cover lots of loosely related concepts here that might not form a single coherent theme. TBC.
1 Let’s all worry about the Straight-through Estimator
When we train neural networks with quantized weights or activations, we introduce a non-differentiable projection.
\[ P:\mathbb R^r\to X \]
(e.g. \(P(x)=\mathrm{sign}(x)\) for binary networks). A common trick is to ignore the true Jacobian \(\partial P/\partial x\) and back-propagate gradients as if \(P\) were the identity:
\[ \tilde x_{k+1} = \tilde x_k - \eta\nabla f\bigl(P(\tilde x_k)\bigr), \quad x_{k+1}=P(\tilde x_{k+1}). \]
This is the Straight-Through Estimator (STE) (Bengio, Léonard, and Courville 2013) — a crude but surprisingly effective heuristic to avoid vanishing gradients through hard quantization. People use this a lot in practice, but then we fret about it. There’s a miniature industry trying to fix it.
1.1 Mirror Descent to STE
Ajanthan et al. (2021) tries to justify STE as a mirror descent trick. Mirror Descent, to recap, generalizes gradient descent to settings where the feasible set \(X\subset\mathbb R^r\) is not naturally Euclidean. It relies on:
A mirror map \(\Phi:C\to\mathbb R\), a strictly convex, differentiable function whose gradient \(\nabla\Phi\) maps the primal space \(X\) into the dual space \(\mathbb R^r\).
The associated Bregman divergence
\[ D_\Phi(p,q)=\Phi(p)-\Phi(q)-\langle\nabla\Phi(q),\,p-q\rangle. \]
Starting from \(x_0\in X\), each iteration does
\[ \nabla\Phi(y_{k+1}) \;=\;\nabla\Phi(x_k)-\eta\,g_k,\quad x_{k+1} = \arg\min_{x\in X}\;D_\Phi\bigl(x,y_{k+1}\bigr), \]
where \(g_k\in\partial f(x_k)\). Equivalently,
\[ x_{k+1} =\arg\min_{x\in X}\;\langle\eta\,g_k,\,x\rangle + D_\Phi(x,x_k). \]
Ajanthan et al. show that any strictly monotone projection \(P\) yields a valid mirror map.
\[ \Phi(x) \;=\;\int_{x_0}^x P^{-1}(y)\,dy, \]
Because \(\nabla\Phi(x)=P^{-1}(x)\) and \(\Phi\) is strictly convex. With this choice:
MD in the dual space amounts to storing auxiliary variables \(\tilde x=P^{-1}(x)\).
The MD update \(\nabla\Phi(y_{k+1})=\nabla\Phi(x_k)-\eta g_k\) becomes simply
\[ \tilde x_{k+1} = \tilde x_k - \eta\,g_k, \]
and mapping back to the primal:
\[ x_{k+1}=P(\tilde x_{k+1}). \]
This is exactly the STE procedure.
In this view, STE is simply a numerically stable implementation of MD under the mirror map induced by the projection \(P\). Mirror Descent provides the missing theoretical foundation for why STE works so well in practice: it is simply gradient descent in a dual space tailored to the geometry of quantization.
This is kind of wild. It’s not at all clear to me that mirror descent should have generalized to discrete spaces; maybe I’m thinking about it wrong.
We’ll be violating many of the mirror descent assumptions in the NN setting (there’s surely some kind of convexity violation). I wonder whether we could recover mirror descent theory in the vicinity of an optimum or something? This dual space perspective looks handy.
1.2 STE as a Bayes procedure
Meng, Bachmann, and Khan (2020) is another attempt to justify the STE; it devises a Bayesian update corresponding to the STE, as seen in NN quantization.
2 Stochastic gradients via REINFORCE
The classic generic REINFORCE/Score function method for estimating gradients of expectations can be used to estimate gradients of functions of discrete random variables as a special case. There are extra tricks used in practice for discrete random variables to keep it performant; see, e.g., (Grathwohl et al. 2018; Liu et al. 2019; Mnih and Gregor 2014; Tucker et al. 2017).
3 Gumbel-(soft)max
A.k.a. the concrete distribution. See Gumbel-max.
4 Gradients of other weird things
Differentiable sorting? See, e.g., Grover et al. (2018) and Prillo and Eisenschlos (2020).
5 Avoiding the need for gradients
Famously, Expectation Maximization can handle some of the same optimization problems as gradient-based methods, but without gradients. There are presumably more variants.
6 Other methods
What even are (Grathwohl et al. 2021; Zhang, Liu, and Liu 2022)? I think they work for quantized continuous variables, or possibly ordinal variables?