Conditioning neural denoising diffusion models
Generative modes that match the observations, not the training data
2022-11-30 — 2025-05-27
Suspiciously similar content
With neural diffusion models, we can generate samples from the unconditional distribution \(p\bigl(x_{\tau_0}\bigr)\approx p(x)\). To solve inverse problems, however, we need to sample from the posterior \(p(x_{\tau_0}\mid y)\).
There are lots of ways we might try to condition, differing sometimes only in emphasis.
1 Notation
First, let us fix notation. I’ll use a slight variant of the notation from the denoising diffusion SDE notebook. Because I need \(t\) for other things, we’ll talk about \(\tau\) for the pseudo-time grid, and use \(\tau_0=0<\tau_1<\cdots<\tau_T=1\) to index the discrete pseudo-time grid. We write \(x_{\tau_i}\) for the state at time \(\tau_i\).
For simplicity, we’ll assume a variance-preserving (VP) diffusion SDE. We corrupt data \(x_{\tau_0}=x_0\sim p_{\rm data}(x)\) by the variance-preserving SDE
\[ \mathrm{d}x_\tau = -\tfrac12\,\beta(\tau)\,x_\tau\,\mathrm{d}\tau + \sqrt{\beta(\tau)}\,\mathrm{d}W_\tau, \]
or in discrete form for each step \(\tau_i\to\tau_{i+1}\):
\[ x_{\tau_{i+1}} = \sqrt{1-\beta_{i+1}}\;x_{\tau_i} + \sqrt{\beta_{i+1}}\;\varepsilon, \quad \varepsilon\sim\mathcal N(0,I). \] Write \(\Delta\tau := \tau_i-\tau_{i-1}\) (uniform unless stated otherwise).
We also define the convenience terms \[ \bar\alpha(\tau)\;=\;\exp\!\Bigl(-\!\!\int_0^\tau\!\beta(s)\,ds\Bigr), \quad\text{so}\quad \sigma(\tau)^2 \;=\;1-\bar\alpha(\tau). \] and \[ \beta_i := 1-\frac{\bar\alpha(\tau_i)}{\bar\alpha(\tau_{i-1})}. \]
1.1 Score Network & Training
We train \(s_\theta(x,\tau)\) to approximate the time-indexed score \(\nabla_{x_\tau}\log p_\tau(x_\tau)\) by minimizing the denoising loss
\[ \mathcal L(\theta)=\mathbb E\bigl\|\nabla_{x_\tau}\log p_\tau(x_\tau)-s_\theta(x_\tau,\tau)\bigr\|^2. \]
where \(x_\tau = \sqrt{\bar\alpha(\tau)}\,x_0 + \sqrt{1-\bar\alpha(\tau)}\,\varepsilon\).
Equivalently, we can parametrize \(s_\theta\) directly to predict the noise. In a VP diffusion the conditional distribution of the noise \(\varepsilon\) given the noisy point \(x_\tau\) is Gaussian with mean \[ \begin{aligned} \mathbb E[\varepsilon\mid x_\tau] &=-\frac{\sigma(\tau)}{\sqrt{\bar\alpha(\tau)}}\;\nabla_{x_\tau}\log p_\tau(x_\tau)\\ \nabla_{x_\tau}\log p_\tau(x_\tau) &=-\frac{\sqrt{\bar\alpha(\tau)}}{\sigma(\tau)}\;\varepsilon. \end{aligned} \] Hence predicting the noise (the “\(\varepsilon_\theta\)” parametrisation) or predicting the score (the “\(s_\theta\)” parametrisation) carries the same information up to the known factor \(\sqrt{\bar\alpha(\tau)}/\sigma(\tau)=\sigma(\tau)\).
1.2 Reverse-Time Sampling
The reverse SDE is as follows. Integrate the reverse SDE from \(\tau_T=1\) down to \(\tau_0=0\)
\[ \mathrm{d}x_\tau = \Bigl[-\tfrac12\,\beta(\tau)\,x_\tau \;-\;\beta(\tau)\,\nabla_{x_\tau}\log p_\tau(x_\tau)\Bigr]\mathrm{d}\tau + \sqrt{\beta(\tau)}\,\mathrm{d}\bar W_\tau, \] using \(s_\theta(x,\tau)\approx\nabla\log p_\tau(x)\). (\(\bar W_\tau\) is an independent time reversal of the Wiener process \(W_\tau\).)
Alternatively, we can use the deterministic / probability-flow ODE:
\[ \mathrm{d}x_\tau = \Bigl[-\tfrac12\,\beta(\tau)\,x_\tau \;-\;\,\beta(\tau)\,\nabla_{x_\tau}\log p_\tau(x_\tau)\Bigr] \,\mathrm{d}\tau, \]
with initial draw $x_{_T}N(0,I). This yields the same marginals \(p_\tau\) without injecting extra noise at each step (which is insane and not at all obvious to me).
On our \(\tau\) grid, the discrete time DDPM reverse update becomes, for \(i=T,\dots,1\):
\[ x_{\tau_{i-1}} = \frac{1}{\sqrt{1-\beta_{i}}}\Bigl( x_{\tau_i} - \beta_{i}\,s_\theta\bigl(x_{\tau_i},\tau_i\bigr) \Bigr) + \sqrt{\tilde\beta_{i}}\,\zeta, \quad \zeta\sim\mathcal N(0,I). \] We introduced here \(\tilde\beta_i:=\beta_i\!\left(\frac{1}{\bar\alpha(\tau_{i-1})}-\frac{1}{\bar\alpha(\tau_i)}\right)\) (the “posterior” variance).
The DDIM variant removes \(\zeta\) for a deterministic two-step inversion.
2 Generic conditioning
Here is a quick rewrite of Rozet and Louppe (2023b). Note I have updated the notation to match the rest of this notebook.
We could train a conditional score network \(s_\phi\bigl(x_{\tau_i},\tau_i\mid y\bigr)\) to approximate the posterior score \(\nabla_{x_{\tau_i}}\log p\bigl(x_{\tau_i}\mid y\bigr)\) and plug it into the reverse SDE. But this requires \((x,y)\) pairs during training and re‐training whenever the observation model \(p(y\mid x)\) changes.
Instead, many have observed (Song, Sohl-Dickstein, et al. 2022; Adam et al. 2022; Chung et al. 2023; Kawar, Vaksman, and Elad 2021; Song, Shen, et al. 2022) that by Bayes’ rule the posterior score decomposes as \[ \nabla_{x_{\tau_i}}\log p\bigl(x_{\tau_i}\mid y\bigr) = \nabla_{x_{\tau_i}}\log p\bigl(x_{\tau_i}\bigr) + \nabla_{x_{\tau_i}}\log p\bigl(y\mid x_{\tau_i}\bigr)\,. \] Since the prior score \(\nabla_{x_{\tau_i}}\log p(x_{\tau_i})\) is well‐approximated by a single score network \(s_\phi(x_{\tau_i},\tau_i)\), the remaining task is to estimate the likelihood score \(\nabla_{x_{\tau_i}}\log p(y\mid x_{\tau_i})\).
Assuming a differentiable measurement operator \(\mathcal A\) and Gaussian observations \(p(y\mid x)=\mathcal N\bigl(y\mid\mathcal A(x),\Sigma_y\bigr)\), Chung et al. (2023) propose approximating \[ p\bigl(y\mid x_{\tau_i}\bigr) = \int p(y\mid x)\,p(x\mid x_{\tau_i})\,\mathrm{d}x \;\approx\; \mathcal N\bigl(y\mid\mathcal A(\hat x(x_{\tau_i})),\,\Sigma_y\bigr), \] where the denoised mean \(\hat x(x_{\tau_i})=\mathbb{E}\bigl[\,x\mid x_{\tau_i}\bigr]\) is given by Tweedie’s formula (Efron 2011; Kim and Ye 2021): \[ \begin{aligned} \mathbb{E}[x\mid x_{\tau_i}] &=\mathbb E[x_0\!\mid\!x_\tau]=\frac{x_\tau-\!\sqrt{1-\bar\alpha(\tau)}\,s_\theta(x_\tau,\tau)}{\sqrt{\bar\alpha(\tau)}}\\ \end{aligned} \] Because the log‐likelihood of a multivariate Gaussian is analytic and \(s_\phi(x_{\tau_i},\tau_i)\) is differentiable, we can compute \(\nabla_{x_{\tau_i}}\log p\bigl(y\mid x_{\tau_i}\bigr)\) in a zero‐shot fashion—without training any network beyond the unconditional score model \(s_\phi\).
Note that this last assumption is strong; probably too strong for the models I would bother using diffusions on. Don’t worry we can get fancier and more effective.
3 Ensemble Score Conditioning
A simple trick that sometimes works (F. Bao, Zhang, and Zhang 2024a) but is biased. TBC.
4 Sequential Monte Carlo
This seems to be SOTA?
LLM-aided summary of Wu et al. (2024):
We recall standard SMC / Particle Filtering:
Goal: sample from a sequence of distributions \(\{\nu_i\}_{i=0}^T\), ending in some target \(\nu_0\).
Particles: maintain \(K\) samples (particles) \(\{x^k_{i}\}_{k=1}^K\) with weights \(\{w^k_i\}\).
Iterate for \(i=T\to0\):
Resample particles according to \(w^k_{i+1}\) to focus on high-probability regions.
Propose \(x^k_i \sim r_i(x_i \mid x^k_{i+1})\).
Weight each by
\[ w^k_i \;=\; \underbrace{\frac{\text{target density at }x^k_i}{\text{proposal density at }x^k_i}}_{\text{importance weight}} \;. \]
Convergence: as \(K\to\infty\), the weighted ensemble approximates the true target exactly.
In a diffusion model, we can view the reverse noising chain
\[ p_\theta(x_{0:T}) =\;p(x_T)\,\prod_{i=1}^T p_\theta(x_{\tau_{i-1}}\mid x_{\tau_i}) \]
as exactly such a sequential model over \(x_{\tau_T}\to x_0\), where \(\nu_{i}\) is the joint \(p_\theta(x_0,\dots,x_1)\) marginalized forward to \(\tau_i\).
To sample conditionally \(p_\theta(x_0\mid y)\), we treat the conditioning as part of the final target and apply SMC. However, if we naïvely run SMC with the unconditional transition kernels
\[ r_i\bigl(x_{\tau_{i-1}}\mid x_{\tau_i}\bigr) = p_\theta\bigl(x_{\tau_{i-1}}\mid x_{\tau_i}\bigr) \]
and only tack on a final weight \(w_0\propto p(y\mid x_0)\), we need an astronomical number of particles, since most will get near-zero weight whenever the prior \(p_\theta(x_0)\) is substantially unlikely compared to the conditional \(p_\theta(x_0\!\mid\!y)\).
Twisting is a classic SMC technique which solves this problem, introducing a sequence of auxiliary functions \(\{\tilde p_\theta(y\mid x_{\tau_i})\}_{i=0}^T\) to re-weight proposals at every time-step, not just at the end. The ideal optimal choice at step \(i\) would be
\[ r_i^\ast(x_{\tau_{\,i-1}}\!\mid\!x_{\tau_i}) \;\propto\; p_\theta\bigl(x_{\tau_{\,i-1}}\mid x_{\tau_i}\bigr)\; p_\theta\bigl(y\mid x_{\tau_{\,i-1}}\bigr), \]
which—if we could sample it—would make SMC exact with a single particle. However, \(p_\theta(y\mid x_{\tau_{\,i-1}})\) is itself intractable.
TDS replaces the optimal twisting \(p_\theta(y\mid x_i)\) with a tractable surrogate based on the denoising network \(\hat x_0(x_i)\), which estimates the denoised \(x_0\) from the noisy \(x_i\).
\[ \tilde p_\theta(y\mid x_{\tau_i}) \;=\; p\bigl(y \mid \hat x_0(x_{\tau_i})\bigr), \]
i.e. we evaluate the observation-likelihood at the diffusion denoiser’s one-step posterior mean estimate \(\hat x_0\). Since \(\hat x_0(x_{\tau})\approx \mathbb E[x_0\!\mid\!x_{\tau}]\), this becomes increasingly accurate as \(\tau\to0\). Define \(\sigma_i^2:=\sigma(\tau_i)^2=1-\bar\alpha(\tau_i).\)
Twisted proposal from \(\tau_i\to\tau_{i-1}\):
\[ \tilde r_i\bigl(x_{\tau_{\,i-1}}\!\mid\!x_{\tau_i},y\bigr) = \mathcal N\Bigl(x_{\tau_{\,i-1}};\, \underbrace{x_{\tau_i} \;+\;\underbrace{\sigma_i^2\,s_i(x_{\tau_i},y)}_{\text{“guided” drift}}}_{\text{mean}},\; \sigma_i^2\,I\Bigr), \]
where
\[ s_i(x_{\tau_i},y) = s_\theta(x_{\tau_i},\tau_i) + \nabla_{x_{\tau_i}}\log\tilde p_\theta\bigl(y\mid x_{\tau_i}\bigr). \]
Twisted weight for each particle:
\[ w_{\tau_{\,i-1}} \;=\; \frac{ p_\theta\bigl(x_{\tau_{\,i-1}}\!\mid\!x_{\tau_i}\bigr)\; \tilde p_\theta\bigl(y\mid x_{\tau_{\,i-1}}\bigr) } { \tilde p_\theta\bigl(y\mid x_{\tau_i}\bigr)\; \tilde r_i\bigl(x_{\tau_{\,i-1}}\!\mid\!x_{\tau_i},y\bigr) }. \]
Twisted Sister:
This corrects for using the surrogate twisting and ensures asymptotic exactness as \(K\to\infty\).
In early steps (\(i\approx T\)) the surrogate \(\tilde p_\theta(y\mid x_{\tau_i})\) may be very broad — twisting is mild. Then, in late steps (\(i\to0\)) \(\hat x_0(x_{\tau_i})\) is accurate, so \(\tilde p_\theta(y\mid x_{\tau_i})\approx p_\theta(y\mid x_{\tau_i})\) and the proposals are nearly optimal. Resampling in between keeps the particle cloud focused on regions consistent with both the diffusion prior and the conditioning \(y\).
In practice, we often need a surprisingly small number of particles; even 2–8 particles often suffice to outperform heuristic conditional samplers (like plain classifier guidance or “replacement” inpainting).
5 (Conditional) Schrödinger Bridge
Shi et al. (2022) introduced the Conditional Schrödinger Bridge (CSB), which is a natural extension of the Schrödinger Bridge.
We seek a path-measure \(\pi^*(x_{0:T}\mid y)\) minimising
\[ \mathrm{KL}\bigl(\pi(\,\cdot\mid y)\,\|\;p(x_{0:T})\bigr) \]
subject to
- Start at \(\tau_T\): \(\pi_{\tau_T}(x_T\mid y)=\mathcal N(x_T;0,I)\),
- End at \(\tau_0\): \(\pi_{\tau_0}(x_0\mid y)=p(x_0\mid y)\).
Here \(p(x_{0:T}) = p(x_T)\prod_i p(x_{\tau_{i-1}}\mid x_{\tau_i})\) is the unconditional forward noising chain.
5.1 Amortized IPF Algorithm
We parameterize two families of drift networks that take \(y\) as input:
\[ B^n_i(x, y) \quad\text{and}\quad F^n_i(x, y) \quad\text{for }i=1,\dots,T,\;n=0,\dots,L. \]
We alternate two KL-projection steps:
Backward half-step (\(\tau_T\to\tau_0\), enforce prior):
\[ \pi^{2n+1}(\cdot\mid y) = \arg\min_{\pi}\; \mathrm{KL}\bigl(\pi(\cdot\mid y)\,\|\;\pi^{2n}(\cdot\mid y)\bigr) \quad\text{s.t.}\quad \pi_{\tau_T}(x_T\mid y)=\mathcal N(0,I). \]
Fit \(B^{n+1}_i(x,y)\) by matching the backward SDE induced by \(\pi^{2n+1}\).
Forward half-step (\(\tau_0\to\tau_T\), enforce posterior):
\[ \pi^{2n+2}(\cdot\mid y) = \arg\min_{\pi}\; \mathrm{KL}\bigl(\pi(\cdot\mid y)\,\|\;\pi^{2n+1}(\cdot\mid y)\bigr) \quad\text{s.t.}\quad \pi_{\tau_0}(x_0\mid y)=p(x_0\mid y). \]
Fit \(F^{n+1}_i(x,y)\) by matching the forward SDE induced by \(\pi^{2n+2}\).
After \(L\) IPF iterations we obtain networks \(B^L_i(x,y),\,F^L_i(x,y)\) whose composed bridge \(\pi^*(\cdot\mid y)\) exactly matches both endpoints for any \(y\).
We can imagine that the backward step “pins” the Gaussian prior; the forward step “pins” the conditional \(p(x_0\mid y)\). We now define the composed conditional bridge as the concatenation of both of the aforeπiclrmentioned:
Initial draw at \(\tau_T\):
\[ x_{\tau_T}\sim p(x_{\tau_T})=\mathcal N(0,I). \]
Backward transitions for \(i=T,T-1,\dots,1\):
\[ x_{\tau_{i-1}} \;\sim\; \underbrace{\mathcal N\bigl(x_{\tau_i} + \Delta\tau\,B^L_i(x_{\tau_i},y),\;\Delta\tau\,I\bigr)}_{\pi^*_{\text{back},i}\!(x_{\tau_{i-1}}\mid x_{\tau_i},y)}. \]
Forward transitions for \(i=1,2,\dots,T\):
\[ x_{\tau_i} \;\sim\; \underbrace{\mathcal N\bigl(x_{\tau_{i-1}} + \Delta\tau\,F^L_i(x_{\tau_{i-1}},y),\;\Delta\tau\,I\bigr)}_{\pi^*_{\text{for},i}\!(x_{\tau_i}\mid x_{\tau_{i-1}},y)}. \]
Hence the full path-measure is
\[ \boxed{ \pi^*(x_{0:T}\mid y) = p(x_{\tau_T})\;\prod_{i=T}^1 \pi^*_{\text{back},i}(x_{\tau_{i-1}}\mid x_{\tau_i},y) \;\times\!\!\! \prod_{i=1}^T \pi^*_{\text{for},i}(x_{\tau_i}\mid x_{\tau_{i-1}},y)\,, } \]
which by construction satisfies both endpoint constraints for any \(y\). CSB finds the “smoothest” (in the sense of being entropy-regularized) stochastic flow between noise and data-posterior. This seems intuitively fine, but my reasoning here is vibes-based. I need to read the paper better to get a proper grip on it.
5.2 Sampling with the Learned Conditional Bridge
To draw \(x_0\sim p(x_0\mid y)\):
Initialize \(x_{\tau_T}\sim\mathcal N(0,I)\).
Backward integrate the learned SDE (or its probability-flow ODE) from \(\tau=T\) down to \(0\):
\[ \mathrm{d}x_\tau = B^L_i\bigl(x_\tau,y\bigr)\,\mathrm{d}\tau + \sqrt{\Delta\tau}\,\mathrm{d}W_\tau \quad \text{for }\tau\in[\tau_{i-1},\tau_i]. \]
(Optionally) Forward integrate with \(F^L_i(x,y)\) to refine or to compute likelihoods.
Because \(B^L_i\) and \(F^L_i\) both depend on \(y\), the same trained model applies to arbitrary observations.
Amortized CSB encodes the observation \(y\) directly into every drift net \(B_i(x,y)\) and \(F_i(x,y)\). There is no per-instance retraining or importance weights — once the joint IPF training over \((x_0,y)\) is done, we can plug in any new \(y\) and run the sampler.
6 Computational Trade-offs Of Those Last Two
Aspect | Twisted SMC (TDS) | Amortized CSB |
---|---|---|
Training | Only train denoiser \(\hat x_0\) | Train \(2\times L\) drift nets on \((x,y)\) |
Inference cost | \(K\) particles × \(T\) steps | Single trajectory over \(T\) steps |
Exactness | As \(K\to\infty\), exact | as IPF perfectly trained, exact |
7 Consistency models
Song et al. (2023)
8 Inpainting
If we want coherence with part of an existing image, we call that inpainting and there are specialized methods for it (Ajay et al. 2023; Grechka, Couairon, and Cord 2024; Liu, Niepert, and Broeck 2023; Lugmayr et al. 2022; Sharrock et al. 2022; Wu et al. 2024; Zhang et al. 2023).
9 Reconstruction/inversion
Perturbed and partial observations; misc methods therefor (Choi et al. 2021; Kawar et al. 2022; Nair, Mei, and Patel 2023; Peng et al. 2024; Xie and Li 2022; Zhao et al. 2023; Song, Shen, et al. 2022; Zamir et al. 2021; Chung et al. 2023; Sui et al. 2024).