Conditioning neural denoising diffusion models
Generative modes that match the observations, not the training data
2022-11-30 — 2025-07-23
Wherein conditioning of neural denoising diffusion models is surveyed, and a twisted SMC method is described that evaluates the observation likelihood at the denoiser’s Tweedie estimate of x0 to guide particles.
With neural diffusion models, we can generate samples from the unconditional distribution \(p\bigl(x_{\tau_0}\bigr)\approx p(x)\). To solve inverse problems, however, we need to sample from the posterior \(p(x_{\tau_0}\mid y)\).
There are lots of ways we might condition—sometimes differing only in emphasis.
Writing this was a great learning experience, but if you want to learn how to condition diffusions, you can speed-run by simply checking out the review article in Du et al. (2023).
1 Notation
First, let’s fix the notation. I’ll use a slight variant of the notation from the denoising diffusion SDE notebook. Because I need \(t\) for other things, we’ll use \(\tau\) for the pseudo-time grid, and write \(\tau_0=0<\tau_1<\cdots<\tau_T=1\) to index the discrete pseudo-time grid. We write \(x_{\tau_i}\) for the state at time \(\tau_i\).
For simplicity, we’ll assume a variance-preserving (VP) diffusion SDE. We corrupt data \(x_{\tau_0}=x_0\sim p_{\rm data}(x)\) by the variance-preserving SDE
\[ \mathrm{d}x_\tau = -\tfrac12\,\beta(\tau)\,x_\tau\,\mathrm{d}\tau + \sqrt{\beta(\tau)}\,\mathrm{d}W_\tau, \]
Or, in discrete form for each step \(\tau_i\to\tau_{i+1}\):
\[ x_{\tau_{i+1}} = \sqrt{1-\beta_{i+1}}\;x_{\tau_i} + \sqrt{\beta_{i+1}}\;\varepsilon, \quad \varepsilon\sim\mathcal N(0,I). \] Write \(\Delta\tau := \tau_i-\tau_{i-1}\) (uniform unless stated otherwise).
We also define the convenience terms \[ \bar\alpha(\tau)\;=\;\exp\!\Bigl(-\!\!\int_0^\tau\!\beta(s)\,ds\Bigr), \quad\text{so}\quad \sigma(\tau)^2 \;=\;1-\bar\alpha(\tau). \] and \[ \beta_i := 1-\frac{\bar\alpha(\tau_i)}{\bar\alpha(\tau_{i-1})}. \]
1.1 Score Network & Training
We train \(s_\theta(x,\tau)\) to approximate the time-indexed score \(\nabla_{x_\tau}\log p_\tau(x_\tau)\) by minimizing the denoising loss.
\[ \mathcal L(\theta)=\mathbb E\bigl\|\nabla_{x_\tau}\log p_\tau(x_\tau)-s_\theta(x_\tau,\tau)\bigr\|^2. \]
where we have \(x_\tau = \sqrt{\bar\alpha(\tau)}\,x_0 + \sqrt{1-\bar\alpha(\tau)}\,\varepsilon\).
Equivalently, we can parametrize \(s_\theta\) directly to predict the noise. In a VP diffusion the conditional distribution of the noise \(\varepsilon\) given the noisy point \(x_\tau\) is Gaussian with mean \[ \begin{aligned} \mathbb E[\varepsilon\mid x_\tau] &=-\frac{\sigma(\tau)}{\sqrt{\bar\alpha(\tau)}}\;\nabla_{x_\tau}\log p_\tau(x_\tau)\\ \nabla_{x_\tau}\log p_\tau(x_\tau) &=-\frac{\sqrt{\bar\alpha(\tau)}}{\sigma(\tau)}\;\varepsilon. \end{aligned} \] Hence predicting the noise (the “\(\varepsilon_\theta\)” parametrisation) or predicting the score (the “\(s_\theta\)” parametrisation) carries the same information up to the known factor \(\sqrt{\bar\alpha(\tau)}/\sigma(\tau)=\sigma(\tau)\).
1.2 Reverse-Time Sampling
The reverse SDE is as follows. Integrate the reverse SDE from \(\tau_T=1\) down to \(\tau_0=0\)
\[ \mathrm{d}x_\tau = \Bigl[-\tfrac12\,\beta(\tau)\,x_\tau \;-\;\beta(\tau)\,\nabla_{x_\tau}\log p_\tau(x_\tau)\Bigr]\mathrm{d}\tau + \sqrt{\beta(\tau)}\,\mathrm{d}\bar W_\tau, \] We use \(s_\theta(x,\tau)\approx\nabla\log p_\tau(x)\). (\(\bar W_\tau\) is the time reversal of an independent Wiener process \(W_\tau\).)
Alternatively, we can use the deterministic / probability-flow ODE:
\[ \mathrm{d}x_\tau = \Bigl[-\tfrac12\,\beta(\tau)\,x_\tau \;-\;\,\beta(\tau)\,\nabla_{x_\tau}\log p_\tau(x_\tau)\Bigr] \,\mathrm{d}\tau, \]
With an initial draw \(x_{\tau_T}\sim\mathcal N(0,I)\), we obtain the same marginals \(p_\tau\) without injecting extra noise at each step, which is insane and wasn’t obvious to me.
On our \(\tau\) grid, the discrete-time DDPM reverse update becomes, for \(i=T,\dots,1\):
\[ x_{\tau_{i-1}} = \frac{1}{\sqrt{1-\beta_{i}}}\Bigl( x_{\tau_i} - \beta_{i}\,s_\theta\bigl(x_{\tau_i},\tau_i\bigr) \Bigr) + \sqrt{\tilde\beta_{i}}\,\zeta, \quad \zeta\sim\mathcal N(0,I). \] We introduced here \(\tilde\beta_i:=\beta_i\!\left(\frac{1}{\bar\alpha(\tau_{i-1})}-\frac{1}{\bar\alpha(\tau_i)}\right)\) (the “posterior” variance).
The DDIM variant removes \(\zeta\) for a deterministic two-step inversion.
2 Generic conditioning
Here’s a quick rewrite of Rozet and Louppe (2023b). Note I’ve updated the notation to match the rest of this notebook.
We could train a conditional score network \(s_\phi\bigl(x_{\tau_i},\tau_i\mid y\bigr)\) to approximate the posterior score \(\nabla_{x_{\tau_i}}\log p\bigl(x_{\tau_i}\mid y\bigr)\) and plug it into the reverse SDE. But this requires \((x,y)\) pairs during training and re‐training whenever the observation model \(p(y\mid x)\) changes.
Instead, many have observed (Song, Sohl-Dickstein, et al. 2022; Adam et al. 2022; Chung et al. 2023; Kawar, Vaksman, and Elad 2021; Song, Shen, et al. 2022) that by Bayes’ rule the posterior score decomposes as \[ \nabla_{x_{\tau_i}}\log p\bigl(x_{\tau_i}\mid y\bigr) = \nabla_{x_{\tau_i}}\log p\bigl(x_{\tau_i}\bigr) + \nabla_{x_{\tau_i}}\log p\bigl(y\mid x_{\tau_i}\bigr)\,. \] Since the prior score \(\nabla_{x_{\tau_i}}\log p(x_{\tau_i})\) is well‐approximated by a single score network \(s_\phi(x_{\tau_i},\tau_i)\), the remaining task is to estimate the likelihood score \(\nabla_{x_{\tau_i}}\log p(y\mid x_{\tau_i})\).
Assuming a differentiable measurement operator \(\mathcal A\) and Gaussian observations \(p(y\mid x)=\mathcal N\bigl(y\mid\mathcal A(x),\Sigma_y\bigr)\), Chung et al. (2023) propose approximating \[ p\bigl(y\mid x_{\tau_i}\bigr) = \int p(y\mid x)\,p(x\mid x_{\tau_i})\,\mathrm{d}x \;\approx\; \mathcal N\bigl(y\mid\mathcal A(\hat x(x_{\tau_i})),\,\Sigma_y\bigr), \] where the denoised mean \(\hat x(x_{\tau_i})=\mathbb{E}\bigl[\,x\mid x_{\tau_i}\bigr]\) is given by Tweedie’s formula (Efron 2011; Kim and Ye 2021): \[ \begin{aligned} \mathbb{E}[x\mid x_{\tau_i}] &=\mathbb E[x_0\!\mid\!x_\tau]=\frac{x_\tau-\!\sqrt{1-\bar\alpha(\tau)}\,s_\theta(x_\tau,\tau)}{\sqrt{\bar\alpha(\tau)}}\\ \end{aligned} \] Because the log‐likelihood of a multivariate Gaussian is analytic and \(s_\phi(x_{\tau_i},\tau_i)\) is differentiable, we can compute \(\nabla_{x_{\tau_i}}\log p\bigl(y\mid x_{\tau_i}\bigr)\) in a zero‐shot fashion—without training any network beyond the unconditional score model \(s_\phi\).
Note that this last assumption is strong; it’s probably too strong for the models I would bother using diffusions for. Don’t worry, we can get fancier and more effective.
3 Ensemble Score Conditioning
A simple trick that sometimes works (F. Bao, Zhang, and Zhang 2024a), but it’s biased. TBC.
4 Sequential Monte Carlo
This seems to be SOTA?
LLM-aided summary of Wu et al. (2024):
We recall standard SMC / Particle Filtering:
Goal: sample from a sequence of distributions \(\{\nu_i\}_{i=0}^T\), ending at some target \(\nu_0\).
Particles: maintain \(K\) samples (particles) \(\{x^k_{i}\}_{k=1}^K\) with weights \(\{w^k_i\}\).
Iterate for \(i=T\to0\):
Resample particles according to \(w^k_{i+1}\) to focus on high-probability regions.
Propose \(x^k_i \sim r_i(x_i \mid x^k_{i+1})\).
Weight each by
\[ w^k_i \;=\; \underbrace{\frac{\text{target density at }x^k_i}{\text{proposal density at }x^k_i}}_{\text{importance weight}} \;. \]
Convergence: as \(K\to\infty\), the weighted ensemble approximates the true target exactly.
In a diffusion model, we can view the reverse noising chain
\[ p_\theta(x_{0:T}) =\;p(x_T)\,\prod_{i=1}^T p_\theta(x_{\tau_{i-1}}\mid x_{\tau_i}) \]
We view this as exactly such a sequential model over \(x_{\tau_T}\to x_0\), where \(\nu_{i}\) is the joint \(p_\theta(x_0,\dots,x_1)\) marginalized forward to \(\tau_i\).
To sample from the conditional \(p_\theta(x_0\mid y)\), we treat the conditioning as part of the final target and apply SMC. However, if we naïvely run SMC with the unconditional transition kernels [TODO clarify]
\[ r_i\bigl(x_{\tau_{i-1}}\mid x_{\tau_i}\bigr) = p_\theta\bigl(x_{\tau_{i-1}}\mid x_{\tau_i}\bigr) \]
If we only tack on a final weight \(w_0\propto p(y\mid x_0)\), we need an astronomical number of particles, since most particles get near-zero weight whenever the prior \(p_\theta(x_0)\) is substantially unlikely compared to the conditional \(p_\theta(x_0\!\mid\!y)\).
Twisting is a classic SMC technique that addresses this by introducing a sequence of auxiliary functions \(\{\tilde p_\theta(y\mid x_{\tau_i})\}_{i=0}^T\) to re-weight proposals at every time step, not just at the end. The ideal choice at step \(i\) would be
\[ r_i^\ast(x_{\tau_{\,i-1}}\!\mid\!x_{\tau_i}) \;\propto\; p_\theta\bigl(x_{\tau_{\,i-1}}\mid x_{\tau_i}\bigr)\; p_\theta\bigl(y\mid x_{\tau_{\,i-1}}\bigr), \]
which—if we could sample it—would make SMC exact with a single particle. [TODO clarify] However, \(p_\theta(y\mid x_{\tau_{\,i-1}})\) is itself intractable.
TDS replaces the optimal twisting \(p_\theta(y\mid x_i)\) with a tractable surrogate based on the denoising network \(\hat x_0(x_i)\); the network estimates the denoised \(x_0\) from the noisy \(x_i\).
\[ \tilde p_\theta(y\mid x_{\tau_i}) \;=\; p\bigl(y \mid \hat x_0(x_{\tau_i})\bigr), \]
i.e. we evaluate the observation-likelihood at the diffusion denoiser’s one-step posterior mean estimate \(\hat x_0\). Since \(\hat x_0(x_{\tau})\approx \mathbb E[x_0\!\mid\!x_{\tau}]\), this becomes increasingly accurate as \(\tau\to0\). Define \(\sigma_i^2:=\sigma(\tau_i)^2=1-\bar\alpha(\tau_i).\)
Twisted proposal from \(\tau_i\to\tau_{i-1}\):
\[ \tilde r_i\bigl(x_{\tau_{\,i-1}}\!\mid\!x_{\tau_i},y\bigr) = \mathcal N\Bigl(x_{\tau_{\,i-1}};\, \underbrace{x_{\tau_i} \;+\;\underbrace{\sigma_i^2\,s_i(x_{\tau_i},y)}_{\text{“guided” drift}}}_{\text{mean}},\; \sigma_i^2\,I\Bigr), \]
where
\[ s_i(x_{\tau_i},y) = s_\theta(x_{\tau_i},\tau_i) + \nabla_{x_{\tau_i}}\log\tilde p_\theta\bigl(y\mid x_{\tau_i}\bigr). \]
Twisted weight for each particle:
\[ w_{\tau_{\,i-1}} \;=\; \frac{ p_\theta\bigl(x_{\tau_{\,i-1}}\!\mid\!x_{\tau_i}\bigr)\; \tilde p_\theta\bigl(y\mid x_{\tau_{\,i-1}}\bigr) } { \tilde p_\theta\bigl(y\mid x_{\tau_i}\bigr)\; \tilde r_i\bigl(x_{\tau_{\,i-1}}\!\mid\!x_{\tau_i},y\bigr) }. \]
Twisted Sister:
This corrects for the surrogate twisting and ensures asymptotic exactness as \(K\to\infty\).
In early steps (\(i\approx T\)), the surrogate \(\tilde p_\theta(y\mid x_{\tau_i})\) may be very broad — twisting is mild. Then, in late steps (\(i\to0\)), \(\hat x_0(x_{\tau_i})\) is accurate, so \(\tilde p_\theta(y\mid x_{\tau_i})\approx p_\theta(y\mid x_{\tau_i})\) and the proposals are nearly optimal. Resampling in between keeps the particle cloud focused on regions consistent with both the diffusion prior and the conditioning \(y\).
In practice, we often need a surprisingly small number of particles; even 2–8 particles often suffice to outperform heuristic conditional samplers (like plain classifier guidance or “replacement” inpainting).
5 (Conditional) Schrödinger Bridge
Shi et al. (2022) introduced the Conditional Schrödinger Bridge (CSB), which is a natural extension of the Schrödinger Bridge.
We seek a path-measure \(\pi^*(x_{0:T}\mid y)\) minimizing
\[ \mathrm{KL}\bigl(\pi(\,\cdot\mid y)\,\|\;p(x_{0:T})\bigr) \]
subject to:
- Start at \(\tau_T\): \(\pi_{\tau_T}(x_T\mid y)=\mathcal N(x_T;0,I)\).
- End at \(\tau_0\): \(\pi_{\tau_0}(x_0\mid y)=p(x_0\mid y)\).
Here, \(p(x_{0:T}) = p(x_T)\prod_i p(x_{\tau_{i-1}}\mid x_{\tau_i})\) is the unconditional forward noising chain.
5.1 Amortized IPF Algorithm
We parameterize two families of drift networks that take \(y\) as input:
\[ B^n_i(x, y) \quad\text{and}\quad F^n_i(x, y) \quad\text{for }i=1,\dots,T,\;n=0,\dots,L. \]
We alternate two KL-projection steps:
Backward half-step (\(\tau_T\to\tau_0\), enforce prior):
\[ \pi^{2n+1}(\cdot\mid y) = \arg\min_{\pi}\; \mathrm{KL}\bigl(\pi(\cdot\mid y)\,\|\;\pi^{2n}(\cdot\mid y)\bigr) \quad\text{s.t.}\quad \pi_{\tau_T}(x_T\mid y)=\mathcal N(0,I). \]
Fit \(B^{n+1}_i(x,y)\) by matching the backward SDE induced by \(\pi^{2n+1}\).
Forward half-step (\(\tau_0\to\tau_T\), enforce posterior):
\[ \pi^{2n+2}(\cdot\mid y) = \arg\min_{\pi}\; \mathrm{KL}\bigl(\pi(\cdot\mid y)\,\|\;\pi^{2n+1}(\cdot\mid y)\bigr) \quad\text{s.t.}\quad \pi_{\tau_0}(x_0\mid y)=p(x_0\mid y). \]
Fit \(F^{n+1}_i(x,y)\) by matching the forward SDE induced by \(\pi^{2n+2}\).
After \(L\) IPF iterations, we obtain networks \(B^L_i(x,y),\,F^L_i(x,y)\) whose composed bridge \(\pi^*(\cdot\mid y)\) exactly matches both endpoints for any \(y\).
We can think of the backward step as pinning the Gaussian prior, and the forward step as pinning the conditional \(p(x_0\mid y)\). We now define the composed conditional bridge as the concatenation of the two halves mentioned above:
Initial draw at \(\tau_T\):
\[ x_{\tau_T}\sim p(x_{\tau_T})=\mathcal N(0,I). \]
Backward transitions for \(i=T,T-1,\dots,1\):
\[ x_{\tau_{i-1}} \;\sim\; \underbrace{\mathcal N\bigl(x_{\tau_i} + \Delta\tau\,B^L_i(x_{\tau_i},y),\;\Delta\tau\,I\bigr)}_{\pi^*_{\text{back},i}\!(x_{\tau_{i-1}}\mid x_{\tau_i},y)}. \]
Forward transitions for \(i=1,2,\dots,T\):
\[ x_{\tau_i} \;\sim\; \underbrace{\mathcal N\bigl(x_{\tau_{i-1}} + \Delta\tau\,F^L_i(x_{\tau_{i-1}},y),\;\Delta\tau\,I\bigr)}_{\pi^*_{\text{for},i}\!(x_{\tau_i}\mid x_{\tau_{i-1}},y)}. \]
Hence the full path-measure is
\[ \boxed{ \pi^*(x_{0:T}\mid y) = p(x_{\tau_T})\;\prod_{i=T}^1 \pi^*_{\text{back},i}(x_{\tau_{i-1}}\mid x_{\tau_i},y) \;\times\!\!\! \prod_{i=1}^T \pi^*_{\text{for},i}(x_{\tau_i}\mid x_{\tau_{i-1}},y)\,, } \]
which, by construction, satisfies both endpoint constraints for any \(y\). CSB finds the “smoothest” (in the sense of being entropy-regularized) stochastic flow between noise and the data posterior. This seems intuitively fine, but my reasoning here is vibes-based. I need to read the paper more carefully to get a proper grip on it.
5.2 Sampling with the Learned Conditional Bridge
To draw \(x_0\sim p(x_0\mid y)\):
Initialize \(x_{\tau_T}\sim\mathcal N(0,I)\).
Integrate the learned SDE (or its probability-flow ODE) backward from \(\tau=T\) to \(0\):
\[ \mathrm{d}x_\tau = B^L_i\bigl(x_\tau,y\bigr)\,\mathrm{d}\tau + \sqrt{\Delta\tau}\,\mathrm{d}W_\tau \quad \text{for }\tau\in[\tau_{i-1},\tau_i]. \]
(Optionally) Forward integrate with \(F^L_i(x,y)\) to refine or compute likelihoods.
Because \(B^L_i\) and \(F^L_i\) both depend on \(y\), the same trained model applies to arbitrary observations.
Amortized CSB encodes the observation \(y\) directly into every drift net \(B_i(x,y)\) and \(F_i(x,y)\). There’s no per-instance retraining or importance weighting — once the joint IPF training over \((x_0,y)\) is done, we can plug in any new \(y\) and run the sampler.
6 Computational Trade-offs of Those Last Two
Aspect | Twisted SMC (TDS) | Amortized CSB |
---|---|---|
Training | Train only the denoiser \(\hat x_0\) | Train \(2\times L\) drift nets on \((x,y)\) |
Inference cost | \(K\) particles × \(T\) steps | Single trajectory over \(T\) steps |
Exactness | As \(K\to\infty\), exact | As IPF is perfectly trained, exact |
7 Consistency models
Song et al. (2023)
8 Inpainting
If we want coherence with part of an existing image, we call that inpainting, and there are specialized methods for it (Ajay et al. 2023; Grechka, Couairon, and Cord 2024; Liu, Niepert, and Broeck 2023; Lugmayr et al. 2022; Sharrock et al. 2022; Wu et al. 2024; Zhang et al. 2023).
9 Reconstruction/inversion
Perturbed and partial observations; misc methods for them (Choi et al. 2021; Kawar et al. 2022; Nair, Mei, and Patel 2023; Peng et al. 2024; Xie and Li 2022; Zhao et al. 2023; Song, Shen, et al. 2022; Zamir et al. 2021; Chung et al. 2023; Sui et al. 2024).