The denoising diffusion SDE
Stochastic diffusions that are reversible in a computationally useful sense
2023-12-09 — 2024-12-10
Sample from tricky distributions via clever use of stochastic differential equations. Diffusion models use time-evolving stochastic differential equations to transform between distributions we have and the distributions that we want to have. Made famous by neural diffusions and score diffusion more generally.
I am publishing this because I don’t have time to finish it up, so it’s now or never but there are serious errors in this presentation. Maybe your best use for this document is to be all like “Hey
Conceptually, this requires two pieces of infrastructure:
- Forward Process (“Diffusion”): Gradually adds noise to the data, transforming it into a known reference distribution.
- Reverse Process (“Denoising”): Removes the noise, transforming the reference distribution back into the target distribution (in this case, the posterior).
The fact we can set up the forward process is not surprising. The weird thing is that the reverse process works too, which is a neat result from the theory of SDEs. The unintuitive nature of that, which apparently Stratonovich himself did not notice (Anderson 1982), is what this post exists to explore.
You probably do not wish to read this; it was written to help me understand the denoising diffusion SDE, and I run down some blind alleys.
1 Setup
The forward SDE characterises that forward process, describing how to add noise to the data until it matches the reference distribution. We start in our question to understand this guy by starting with a scalar-valued stochastic differential equation (SDE) of the form
\[ dZ_{\tau} = b(\tau, Z_{\tau}) \, d\tau + \sigma(\tau) \, dW_{\tau} \tag{1}\]
- \(Z_{\tau}\): State at time \(\tau \in [0, 1]\).
- \(b(\tau)\): Drift coefficient.
- \(\sigma(\tau)\): Diffusion coefficient.
- \(dW_{\tau}\): Increment of a standard Wiener process (Brownian motion).
To reverse the diffusion process and construct the backwards/(denoising) process, we use the backward SDE. To make this guy go we have to introduce some extra notation. Suppose the density of \(Z_{\tau}\) is \(q_{Z_{\tau}}(z)\); in machine learning notation \(Z_{\tau}\) is \(q_{Z_{\tau}}(z)\). Then we can write
\[ \begin{aligned} dZ_{\tau} &= -\left[ b(\tau, Z_{\tau}) - \sigma^2(\tau) \nabla_z \log q_{Z_{\tau}}(Z_{\tau}) \right] \, d\tau + \sigma(\tau) \, d\overleftarrow{W}_{\tau} \\ &= -\tilde{b}(\tau, Z_{\tau}) \, d\tau + \sigma(\tau) \, d\overleftarrow{W}_{\tau} \end{aligned} \tag{2}\]
- \(\nabla_z \log q_{Z_{\tau}}(Z_{\tau})\): Score function of the distribution at time \(\tau\).
- \(\tilde{b}(\tau, z) = b(\tau, z) - \sigma^2(\tau) \nabla_z \log q_{Z_{\tau}}(z)\): Adjusted drift term.
- \(d\overleftarrow{W}_{\tau}\): Increment of the “backward Wiener process”. Let’s not think too hard about that and just say “filtration” and “\(\sigma\)-algebra” until we feel like we are being mathematically sophisticated.
What in blazes happened there? Where did this \(q\) come from? What even is this backward Wiener process? Why are these terms like they are? etc. There is a claim about Fokker-Planck equations that is traditional, but it is not very intuitive, and introduces some heavy machinery which doesn’t really help, at least for me. So I moved that to the end.
We run into ALL KINDS OF CRAZY when it comes to working out which direction time is running here. I’ve put a negative sign in front of the drift term in the backward SDE to make it look like the forward SDE, but this is a slightly weird way of doing things, because it implies that time is running forward, rather than backward, and we are looking at a backward SDE. I think I am being consistent about this, but feel free to call me out if I am not. Alternatively, you might want to read this upside-down.
Let us understand this in a more intuitive way, i.e., by being dumb about it and seeing what does not work.
2 Let us discretize
OK, first simplification, we consider the Euler-Maruyama discretization of the forward SDEs Equation 1, and take that as ground-truth for now. In this one, we ignore continuous time but think about small time steps \(\Delta \tau\) over which behaviour is approximately linear.
With this discretization, let us think about the distribution of \(Z_{\tau}\) given \(Z_{\tau{+}\Delta\tau}\). For a small step \(\Delta\tau\),
\[ Z_{\tau+\Delta \tau} = Z_{\tau} + b(\tau, Z_{\tau}) \, \Delta\tau + \sigma(\tau) \Delta W_{\tau} \tag{3}\]
The forward transition kernel \(p_{\text{forward}}(z_{\tau{+}\Delta\tau} | z_{\tau})\) is the probability density of \(Z_{\tau{+}\Delta\tau}\) given \(Z_{\tau}=z_{\tau}\). We have
\[ Z_{\tau{+}\Delta\tau} | Z_{\tau}=z_{\tau} \sim \mathcal{N}\bigl(z_{\tau} + b(\tau,z_{\tau})\Delta\tau,\; \sigma^2(\tau)\Delta\tau I_d\bigr). \]
Hence, the forward transition kernel is \[ p_{\text{forward}}(z_{\tau{+}\Delta\tau}|z_{\tau}) = \frac{\exp\left(-\frac{\| z_{\tau{+}\Delta\tau}-z_{\tau}-b(\tau,z_{\tau})\Delta\tau \|^2}{2\sigma^2(\tau)\Delta\tau}\right)}{(2\pi \sigma^2(\tau)\Delta\tau)^{d/2}}. \tag{4}\]
Discretization of the backward SDE Equation 2 proceeds similarly,
\[ Z_{\tau} = Z_{\tau{+}\Delta\tau} - \tilde{b}(\tau{+}\Delta\tau, Z_{\tau{+}\Delta\tau})(-\Delta\tau) + \sigma(\tau{+}\Delta\tau)\Delta\overleftarrow{W}_{\tau}. \tag{5}\]
Since \(\Delta\overleftarrow{W}_{\tau} \sim \mathcal{N}(0,\Delta\tau I_d)\), the conditional distribution \(Z_{\tau}|Z_{\tau{+}\Delta\tau}=z_{\tau{+}\Delta\tau}\) is also Gaussian.
Hence the backward transition kernel, is \[ \begin{aligned} p_{\text{backward}}(z_{\tau}|z_{\tau{+}\Delta\tau}) &= \frac{\exp\left(-\frac{\| z_{\tau}-z_{\tau{+}\Delta\tau}+\tilde{b}(\tau{+}\Delta\tau,z_{\tau{+}\Delta\tau})\Delta\tau \|^2}{2\sigma^2(\tau{+}\Delta\tau)\Delta\tau}\right)}{(2\pi \sigma^2(\tau{+}\Delta\tau)\Delta\tau)^{d/2}}\\ &= \frac{\exp\left(-\frac{\| z_{\tau}-z_{\tau{+}\Delta\tau}+\left[ b(\tau, Z_{\tau{+}\Delta\tau}) - \sigma^2(\tau) \nabla_z \log q_{Z_{\tau}}({+}\Delta\tau) \right]\Delta\tau \|^2}{2\sigma^2(\tau{+}\Delta\tau)\Delta\tau}\right)}{(2\pi \sigma^2(\tau{+}\Delta\tau)\Delta\tau)^{d/2}}. \end{aligned} \tag{6}\]
Writing it this way wis a chaotic move, but it gives us the tools to play with these distributions and see what happens. Also it puts lots of Greek letters on the page, which makes us look clever.
3 Pathwise Euler-Maruyama update
Now, let us consider a single path of the forward and backward SDEs. Let us see what happens when we apply the forward and backward steps to a single path.
We consider a variate \(Z_{\tau}'\), which arises from applying the forward then backward steps to a variate \(Z_{\tau}\), i.e. \[ \begin{aligned} Z_{\tau}' &= Z_{\tau{+}\Delta\tau} - \tilde{b}(\tau{+}\Delta\tau, Z_{\tau{+}\Delta\tau})\Delta\tau + \sigma(\tau{+}\Delta\tau)\Delta\overleftarrow{W}_{\tau}\\ &= Z_{\tau} + b(\tau, Z_{\tau}) \, \Delta\tau + \sigma(\tau) \Delta W_{\tau} - \tilde{b}(\tau{+}\Delta\tau, Z_{\tau{+}\Delta\tau})\Delta\tau + \sigma(\tau{+}\Delta\tau)\Delta\overleftarrow{W}_{\tau}\\ &= Z_{\tau} + b(\tau, Z_{\tau}) \, \Delta\tau + \sigma(\tau) \Delta W_{\tau} - \left[ b(\tau{+}\Delta\tau, Z_{\tau{+}\Delta\tau}) - \sigma^2(\tau{+}\Delta\tau) \nabla_z \log q_{Z_{\tau}}(Z_{\tau{+}\Delta\tau}) \right]\Delta\tau + \sigma(\tau{+}\Delta\tau)\Delta\overleftarrow{W}_{\tau}. \end{aligned} \]
Now, let us interpret the reversibility of the Euler discretization is precisely \(\Delta\overleftarrow{W}_{\tau} = -\Delta W_{\tau}\). In this interpretation
\[ \begin{aligned} Z_{\tau}' &= Z_{\tau} + b(\tau, Z_{\tau}) \, \Delta\tau + \sigma(\tau) \Delta W_{\tau} - \left[ b(\tau{+}\Delta\tau, Z_{\tau{+}\Delta\tau}) - \sigma^2(\tau{+}\Delta\tau) \nabla_z \log q_{Z_{\tau}}(Z_{\tau}) \right]\Delta\tau + \sigma(\tau{+}\Delta\tau)\Delta\overleftarrow{W}_{\tau}\\ &= Z_{\tau} + (b(\tau, Z_{\tau}) - b(\tau{+}\Delta\tau, Z_{\tau{+}\Delta\tau}))\Delta\tau + \sigma^2(\tau{+}\Delta\tau) \nabla_z \log q_{Z_{\tau}}(Z_{\tau{+}\Delta\tau}) \Delta\tau + (\sigma(\tau) - \sigma(\tau{+}\Delta\tau))\Delta W_{\tau}. \end{aligned} \]
So that doesn’t look like it will return a given point to its original state, in general.
Even if we think the equations are stationary, it is unclear why the bias term would cancel, and we might not be persuaded that the diffusion term would go to zero. \[ \begin{aligned} Z_{\tau}' &= Z_{\tau} + b(Z_{\tau}) \, \Delta\tau + \sigma \Delta W_{\tau} - \left[ b(Z_{\tau{+}\Delta\tau}) - \sigma^2 \nabla_z \log q_{Z_{\tau}}(Z_{\tau}) \right]\Delta\tau + \sigma\Delta\overleftarrow{W}_{\tau}\\ &= Z_{\tau} + (b(Z_{\tau}) - b(Z_{\tau{+}\Delta\tau}))\Delta\tau + \sigma^2 \nabla_z \log q_{Z_{\tau}}(Z_{\tau{+}\Delta\tau}) \Delta\tau + (\sigma - \sigma)\Delta W_{\tau}\\ &= Z_{\tau} + (b(Z_{\tau}) - b(Z_{\tau{+}\Delta\tau}))\Delta\tau + \sigma^2 \nabla_z \log q_{Z_{\tau}}(Z_{\tau{+}\Delta\tau}) \Delta\tau. \end{aligned} \]
It seems we are looking at this the wrong way.
4 Markov kernels
OTOH, we could wonder if the measure \(q_{Z_\tau}(z)\) is preserved by the application of the forward step and the backward step in turn. Let’s see how to approach this.
The resulting probability density arises from the double integral,
\[ p(Z_{\tau}' | Z_{\tau} = z_{\tau}) = \int p_{\text{backward}}(Z_{\tau}' | Z_{\tau{+}\Delta\tau} = z_{\tau{+}\Delta\tau}) \, p_{\text{forward}}(Z_{\tau{+}\Delta\tau} = z_{\tau{+}\Delta\tau} | Z_{\tau} = z_{\tau}) \, dz_{\tau{+}\Delta\tau}. \]
This integral represents the composition of the forward transition from \(z_{\tau}\) to \(z_{\tau{+}\Delta\tau}\) followed by the backward transition from \(z_{\tau{+}\Delta\tau}\) to \(Z_{\tau}'\). Our goal is to evaluate this integral to understand how applying the forward and backward steps affects the original state \(z_{\tau}\).
We know from the theory of time-reversed Markov processes that the backward kernel can be expressed in terms of the forward kernel and the ratio of marginals. In particular, for sufficiently small \(\Delta\tau\), we have:
\[ p_{\text{backward}}(z_{\tau}|z_{\tau{+}\Delta\tau}) = \frac{q_{Z_{\tau}|Z_{\tau{+}\Delta\tau}}(z_{\tau}|z_{\tau{+}\Delta\tau})}{q_{Z_{\tau{+}\Delta\tau}}(z_{\tau{+}\Delta\tau})/q_{Z_{\tau}}(z_{\tau})} \propto p_{\text{forward}}(z_{\tau{+}\Delta\tau}|z_{\tau}) \frac{q_{Z_{\tau}}(z_{\tau})}{q_{Z_{\tau{+}\Delta\tau}}(z_{\tau{+}\Delta\tau})}. \]
This relationship arises from Bayes’ rule applied to transition densities:
\[ p_{\text{backward}}(z_{\tau}|z_{\tau{+}\Delta\tau}) = \frac{p_{\text{forward}}(z_{\tau{+}\Delta\tau}|z_{\tau})q_{Z_{\tau}}(z_{\tau})}{q_{Z_{\tau{+}\Delta\tau}}(z_{\tau{+}\Delta\tau})}. \]
To ensure the backward step is the inverse of the forward step (in an infinitesimal sense), the drift \(\tilde{b}\) must be chosen so that this relationship holds. In fact, identifying:
\[ \tilde{b}(\tau,z) = b(\tau,z) - \sigma^2(\tau)\nabla_z \log q_{Z_{\tau}}(z) \]
is exactly the condition that makes this inverse relationship work. This ensures:
\[ p_{\text{backward}}(z_{\tau}|z_{\tau{+}\Delta\tau})\ p_{\text{forward}}(z_{\tau{+}\Delta\tau}|z_{\tau}) \approx \frac{q_{Z_{\tau}}(z_{\tau})}{q_{Z_{\tau}}(z_{\tau})} = 1 \quad (\text{up to first order in }\Delta\tau), \]
meaning that a forward step followed by a backward step returns us close to where we started, consistent with the idea of time reversal.
We consider a small time interval \(\Delta \tau\) and show that the forward and backward steps are inverses of each other. I.e. take the Euler-Maruyama discretization as given, and ignore questions about infinite limits and so on. The forward and backward steps are each some linear operation on some Gaussian noise.
Forward Step:
Starting from \(Z_{\tau}\), the forward SDE updates \(Z_{\tau + \Delta \tau}\):
\[ Z_{\tau + \Delta \tau} = Z_{\tau} + b(\tau, Z_{\tau}) \Delta \tau + \sigma(\tau) \Delta W_{\tau} \]
Backward Step:
Starting from \(Z_{\tau + \Delta \tau}\), the backward SDE updates \(Z_{\tau}\):
\[ Z_{\tau} = Z_{\tau{+}\Delta\tau} + \tilde{b}(\tau{+}\Delta\tau, Z_{\tau{+}\Delta\tau}) \Delta \tau + \sigma(\tau{+}\Delta\tau) \Delta W_{\tau{+}\Delta\tau} \]
If you are like me, you did not find that especially helpful for intuition-building, which is probably why they do not bother with it in typical blog posts about score diffusions. Let us try to make it concrete by actually working through the steps of the reverse SDE.
Backward Step plus forward step:
We can substitute the forward step into the backward step to get a concrete expression for the concatenation of these steps. Since this is a Gaussian we can think about the evolution in terms of the mean and variance of the distribution.
Let us fix a path of \(Z_{\tau}=z_{\tau}\) and consider the evolution of the distribution, i.e.
\[ \begin{aligned} (Z_{\tau + \Delta \tau}\mid Z_{\tau}=z_{\tau} ) &= z_{\tau} + b(\tau, z_{\tau}) \Delta \tau + \sigma(\tau) \Delta W_{\tau} \end{aligned} \]
- If \(\Delta W_{\tau}\) and \(\Delta \overleftarrow{W}_{\tau}\) are such that \(\Delta \overleftarrow{W}_{\tau} = -\Delta W_{\tau}\), then the stochastic terms cancel out.
- The score function \(\nabla_z \log q_{Z_{\tau}}(Z_{\tau})\) ensures that the drift terms are adjusted appropriately.
\[ Z_{\tau} = Z_{\tau + \Delta \tau} + \left[ -b(\tau_\Delta\tau, Z_{\tau + \Delta \tau}) + \sigma^2(\tau+ \Delta \tau) \nabla_z \log q_{Z_{\tau + \Delta \tau}}(Z_{\tau + \Delta \tau}) \right] \Delta \tau + \sigma(\tau{+}\Delta\tau) \Delta \overleftarrow{W}_{\tau+ \Delta \tau} \]
5 Implementation
Below is an example of how to numerically simulate the forward and backward SDEs using JAX.
We’ll implement a numerical experiment using JAX to simulate both the forward and backward SDEs with a non-trivial observation model. Specifically, we’ll define an observation model \(y = g(x) + \varepsilon\) where \(g(x) = \sin(x)\), making the score function more challenging.
5.1 Import Libraries
6 SDE Coefficients
We set \(\alpha_{\tau} = 1 - \tau\) and \(\beta_{\tau}^2 = \tau\), leading to
7 Observation Model and Score Function
Let \(g(x) = \sin(x)\) and assume Gaussian noise \(\varepsilon \sim \mathcal{N}(0, \sigma^2)\). The likelihood is:
\[ p(Y=y|X=x) = \mathcal{N}(y | \sin(x), \sigma^2) \]
Thus, the gradient of the log-likelihood is:
\[ \nabla_x \log p(Y=y|X=x) = \frac{y - \sin(x)}{\sigma^2} \cos(x) \]
Define the score function incorporating the posterior:
Code
def posterior_score(z, tau, y, sigma2):
# Score of the prior (assuming q(Z_tau) is standard normal for simplicity)
score_prior = -z / tau # From S_prior = -z / tau for q(Z_tau) = N(0, tau)
# Gradient of log-likelihood
grad_log_likelihood = (y - jnp.sin(z)) / sigma2 * jnp.cos(z)
# Damping function h(tau) = 1 - tau
h_tau = 1.0 - tau
# Posterior score
return score_prior + h_tau * grad_log_likelihood7.1 Forward SDE Simulation
Simulate the forward diffusion process, adding noise to the data:
Code
def simulate_forward_sde(Z0, key, N=1000):
dtau = 1.0 / N
tau_values = jnp.linspace(0.0, 1.0, N + 1)
Z_tau = Z0
Z_tau_list = [Z_tau]
for i in range(N):
tau = tau_values[i]
b = b_tau(tau)
sigma = sigma_tau(tau)
key, subkey = random.split(key)
dW = random.normal(subkey, shape=Z0.shape) * jnp.sqrt(dtau)
Z_tau = Z_tau + b * Z_tau * dtau + sigma * dW
Z_tau_list.append(Z_tau)
return jnp.array(Z_tau_list)7.2 Backward SDE Simulation
Simulate the backward denoising process, incorporating the posterior score:
Code
def simulate_backward_sde(Z1, y, sigma2, key, N=1000):
dtau = -1.0 / N # Negative time step for backward integration
tau_values = jnp.linspace(1.0, 0.0, N + 1)
Z_tau = Z1
Z_tau_list = [Z_tau]
for i in range(N):
tau = tau_values[i]
b = b_tau(tau)
sigma = sigma_tau(tau)
S_posterior = posterior_score(Z_tau, tau, y, sigma2)
drift = b * Z_tau - sigma**2 * S_posterior
key, subkey = random.split(key)
dW = random.normal(subkey, shape=Z1.shape) * jnp.sqrt(-dtau)
Z_tau = Z_tau + drift * dtau + sigma * dW
Z_tau_list.append(Z_tau)
return jnp.array(Z_tau_list)7.3 Example Usage
Simulate both processes and visualise the results.
Code
# Initialize random keys
key = random.PRNGKey(0)
key, subkey = random.split(key)
# Number of samples and time steps
num_samples = 1000
N_timesteps = 1000
# Observation parameters
y = 0.5 # Observed value
sigma2 = 0.1 # Variance of observation noise
# Initial data distribution: X ~ N(5, 1)
Z0 = random.normal(subkey, shape=(num_samples,)) * 1.0 + 5.0
# Simulate forward SDE
key, subkey = random.split(key)
Z_forward = simulate_forward_sde(Z0, subkey, N=N_timesteps)
# Simulate backward SDE starting from Z1
Z1 = Z_forward[-1]
key, subkey = random.split(key)
Z_backward = simulate_backward_sde(Z1, y, sigma2, subkey, N=N_timesteps)
# Convert to numpy for plotting
Z0_np = np.array(Z0)
Z1_np = np.array(Z1)
Z_backward_np = np.array(Z_backward[-1])
# Plotting the distributions
plt.figure(figsize=(18, 5))
# Initial distribution
plt.subplot(1, 3, 1)
plt.hist(Z0_np, bins=100, density=True, alpha=0.6, color='blue', label='Initial Z0')
plt.title('Initial Distribution at τ=0')
plt.xlabel('Z0')
plt.ylabel('Density')
plt.legend()
# After Forward SDE
plt.subplot(1, 3, 2)
plt.hist(Z1_np, bins=100, density=True, alpha=0.6, color='green', label='After Forward SDE')
plt.title('Distribution at τ=1')
plt.xlabel('Z1')
plt.ylabel('Density')
plt.legend()
# # After Backward SDE
# plt.subplot(1, 3, 3)
# plt.hist(Z_backward_np, bins=100, density=True, alpha=0.6, color='red', label='Recovered Z0')
# plt.title('Recovered Distribution at τ=0')
# plt.xlabel('Z0 Recovered')
# plt.ylabel('Density')
# plt.legend()
plt.tight_layout()
plt.show()8 Fokker-Planck derivation
Now it might not be obvious what \(a\) and \(b\) would produce the desired result in the forward case, let alone how the reverse bit works. It was not to me. There is a traditional explanation that involves saying Fokker-Planck a lot. Let us give that.
We consider the time-reversed version of Equation 1. We aim to find an SDE that, when integrated backward in time, describes the evolution of \(Z_{\tau}\) such that it recovers the original distribution at \(\tau = 0\).
To derive the backward SDE, we use the theory of time reversal for diffusion processes via the Fokker-Planck equation.
\[ \frac{\partial q_{Z_{\tau}}(z)}{\partial \tau} = -\nabla_z \cdot [b(\tau, z) q_{Z_{\tau}}(z)] + \frac{1}{2} \sigma^2(\tau) \Delta_z q_{Z_{\tau}}(z) \tag{7}\]
- \(\nabla_z \cdot\) denotes the divergence with respect to \(z\).
- \(\Delta_z\) is the Laplacian with respect to \(z\).
Let us examine all those symbols to make sure we are doing what we expect. For a vector field \(\mathbf {F}(z)\) defined in \(\mathbb {R}^d\), the divergence is a scalar function that measures the “outflow” or “spreading” of the vector field at a point:
\[ \nabla_z \cdot \mathbf{F}(z) = \frac{\partial F_1(z)}{\partial z_1} + \frac{\partial F_2(z)}{\partial z_2} + \dots + \frac{\partial F_d(z)}{\partial z_d} \]
Where: - \(\mathbf{F}(z) = [F_1(z), F_2(z), \dots, F_d(z)]^\top\) is the vector field. - \(z = [z_1, z_2, \dots, z_d]^\top\) are the coordinates in \(\mathbb{R}^d\).
In the Fokker-Planck equation, the divergence \(\nabla_z \cdot (\mathbf{J}(z, \tau))\) measures the net flux \(\mathbf{J}\) of probability density leaving a region. \(\mathbf{J}(z, \tau)\), the probability current, is typically expressed as
\[ \mathbf{J}(z, \tau) = \mu(z, \tau) q(z, \tau) \tag{8}\]
Where \(\mu(z, \tau)\) is the drift term from the SDE.
The intuitions here are:
- If \(\nabla_z \cdot \mathbf {J}(z, \tau) > 0\): Probability density is “spreading out” at \(z\).
- If \(\nabla_z \cdot \mathbf {J}(z, \tau) < 0\): Probability density is “concentrating” at \(z\).
The Laplacian \(\Delta_z\) of a scalar function \(f(z)\) in \(\mathbb{R}^d\) is the divergence of the gradient of \(f(z)\):
\[ \Delta_z f(z) = \nabla_z \cdot (\nabla_z f(z)) \]
In Cartesian coordinates, it is expressed as
\[ \Delta_z f(z) = \frac{\partial^2 f(z)}{\partial z_1^2} + \frac{\partial^2 f(z)}{\partial z_2^2} + \dots + \frac{\partial^2 f(z)}{\partial z_d^2} \]
The term \(\frac{1}{2} \sigma^2 \Delta_z q(z, \tau)\) in the Fokker-Planck equation represents the diffusion or “spreading” of probability density due to the stochastic component of the SDE. \(\Delta_z q(z, \tau)\) measures the curvature of \(q(z, \tau)\).
The intuition is
- If \(\Delta_z q(z, \tau) > 0\): The probability density at \(z\) is concave, and diffusion increases the density.
- If \(\Delta_z q(z, \tau) < 0\): The probability density at \(z\) is convex, and diffusion decreases the density.
If you believe that the various theorems about Fokker-Planck hold this might be useful? But if you would prefer some actual intuition-building you might want something more.
For the time-reversed process, we define \(\tilde{Z}_{\tau} = Z_{T - \tau}\), where \(T\) is the total time (in our case, \(T = 1\)).
The Fokker-Planck machinery is what gives us Equation 2, in terms of probability current \(J(\tau, z)\), defined as
\[ J(\tau, z) = b(\tau, z) q_{Z_{\tau}}(z) - \frac{1}{2} \sigma^2(\tau) \nabla_z q_{Z_{\tau}}(z) \]
The drift of the time-reversed process is related to the negative of the forward current normalised by the density
\[ \tilde{b}(\tau, z) = b(\tau, z) - \sigma^2(\tau) \nabla_z \log q_{Z_{\tau}}(z) \]
Intuitively, the term \(-\sigma^2(\tau) \nabla_z \log q_{Z_{\tau}}(Z_{\tau})\) adjusts the drift to account for the probability flow induced by diffusion.
Putting these together we get the expected backward SDE from Equation 2.
\[ dZ_{\tau} = \tilde{b}(\tau, Z_{\tau}) d\tau + \sigma(\tau) d\overleftarrow{W}_{\tau} \]
Below is a more explicit and careful derivation that shows where the logarithm of the density comes from. The key idea is that, to identify the drift of the reversed process, we compare the partial differential equations (PDEs) governing the densities. When we try to match these PDEs, we naturally end up dividing by the density \(p\), which introduces the \(\log p\) term because \(\nabla_x \log p = \frac{\nabla_x p}{p}\).
9 Fokker-Planck derivation 2
9.1 Forward SDE and Density
Suppose the forward SDE is: \[ dZ_{\tau} = b(\tau, Z_{\tau})\,d\tau + \sigma(\tau)\,dW_{\tau}, \quad \tau \in [0,T]. \] Let \(p(t,x)\) be the probability density of \(Z_t\). Then \(p(t,x)\) satisfies the forward Kolmogorov (Fokker-Planck) equation: \[ \partial_t p(t,x) = -\nabla_x \cdot [b(t,x)p(t,x)] + \tfrac{1}{2}\nabla_x \cdot [\sigma(t)\sigma(t)^T \nabla_x p(t,x)]. \]
This equation describes how the density of the process evolves forward in time.
9.2 The Reversed-Time Process and Its Density
Now define the reversed-time process \(\overleftarrow{Z}_\tau := Z_{T-\tau}\) for \(\tau \in [0,T].\) We would like to find an SDE of the form \[ d\overleftarrow{Z}_{\tau} = \overleftarrow{b}(\tau,\overleftarrow{Z}_{\tau})\,d\tau + \sigma(T-\tau)\, d\overleftarrow{W}_{\tau} \] for some drift \(\overleftarrow{b}(\tau,x)\).
The law of \(\overleftarrow{Z}_\tau\) at time \(\tau\) is the same as the law of \(Z_{T-\tau}\). Hence its density at time \(\tau\) should be \(p(T-\tau,x)\).
To summarize: - Forward process \(Z_t\) has density \(p(t,x)\). - Reversed process \(\overleftarrow{Z}_\tau := Z_{T-\tau}\) has density \(\overleftarrow{p}(\tau,x) = p(T-\tau,x)\).
9.3 Forward Equation for the Reversed Density
If \(\overleftarrow{Z}_\tau\) satisfies \[ d\overleftarrow{Z}_{\tau} = \overleftarrow{b}(\tau,x)\,d\tau + \sigma(T-\tau)\,d\overleftarrow{W}_{\tau}, \] then \(\overleftarrow{p}(\tau,x)\) must satisfy its own forward Kolmogorov equation in \(\tau\): \[ \partial_{\tau}\overleftarrow{p}(\tau,x) = -\nabla_x \cdot[\overleftarrow{b}(\tau,x)\overleftarrow{p}(\tau,x)] + \tfrac{1}{2}\nabla_x \cdot[\sigma(T-\tau)\sigma(T-\tau)^T \nabla_x \overleftarrow{p}(\tau,x)]. \]
Since \(\overleftarrow{p}(\tau,x) = p(T-\tau,x)\), we can rewrite this in terms of \(p\): \[ \partial_{\tau} p(T-\tau,x) = -\nabla_x \cdot[\overleftarrow{b}(\tau,x)p(T-\tau,x)] + \tfrac{1}{2}\nabla_x \cdot[\sigma(T-\tau)\sigma(T-\tau)^T \nabla_x p(T-\tau,x)]. \]
9.4 Back to the Original Forward Equation
From the original PDE for \(p(t,x)\), we have: \[ \partial_t p(t,x) = -\nabla_x \cdot[b(t,x)p(t,x)] + \tfrac{1}{2}\nabla_x \cdot[\sigma(t)\sigma(t)^T \nabla_x p(t,x)]. \]
Replace \(t\) by \(T-\tau\): \[ \partial_{T-\tau} p(T-\tau,x) = -\nabla_x \cdot[b(T-\tau,x)p(T-\tau,x)] + \tfrac{1}{2}\nabla_x \cdot[\sigma(T-\tau)\sigma(T-\tau)^T \nabla_x p(T-\tau,x)]. \]
Taking into account that \(\partial_{T-\tau} = -\partial_{\tau}\), we get: \[ -\partial_{\tau} p(T-\tau,x) = -\nabla_x \cdot[b(T-\tau,x)p(T-\tau,x)] + \tfrac{1}{2}\nabla_x \cdot[\sigma(T-\tau)\sigma(T-\tau)^T \nabla_x p(T-\tau,x)]. \]
Multiply through by \(-1\): \[ \partial_{\tau} p(T-\tau,x) = \nabla_x \cdot[b(T-\tau,x)p(T-\tau,x)] - \tfrac{1}{2}\nabla_x \cdot[\sigma(T-\tau)\sigma(T-\tau)^T \nabla_x p(T-\tau,x)]. \]
9.5 Solving the Two PDEs
We have two expressions for \(\partial_{\tau} p(T-\tau,x)\):
From the reversed SDE perspective: \[ \partial_{\tau} p(T-\tau,x) = -\nabla_x \cdot[\overleftarrow{b}(\tau,x)p(T-\tau,x)] + \tfrac{1}{2}\nabla_x \cdot[\sigma(T-\tau)\sigma(T-\tau)^T \nabla_x p(T-\tau,x)]. \]
From the original PDE transformed in time: \[ \partial_{\tau} p(T-\tau,x) = \nabla_x \cdot[b(T-\tau,x)p(T-\tau,x)] - \tfrac{1}{2}\nabla_x \cdot[\sigma(T-\tau)\sigma(T-\tau)^T \nabla_x p(T-\tau,x)]. \]
Set these equal to each other: \[ -\nabla_x \cdot[\overleftarrow{b}(\tau,x)p(T-\tau,x)] + \tfrac{1}{2}\nabla_x \cdot[\sigma(T-\tau)\sigma(T-\tau)^T \nabla_x p(T-\tau,x)] = \nabla_x \cdot[b(T-\tau,x)p(T-\tau,x)] - \tfrac{1}{2}\nabla_x \cdot[\sigma(T-\tau)\sigma(T-\tau)^T \nabla_x p(T-\tau,x)]. \]
Move all terms involving \(\nabla_x p(T-\tau,x)\) together: \[ -\nabla_x \cdot[\overleftarrow{b}(\tau,x)p(T-\tau,x)] = \nabla_x \cdot[b(T-\tau,x)p(T-\tau,x)] - \nabla_x \cdot[\sigma(T-\tau)\sigma(T-\tau)^T \nabla_x p(T-\tau,x)]. \]
9.6 Factor and Divide by the Density
The key step: to isolate \(\overleftarrow{b}\), we divide through by \(p(T-\tau,x)\). Before doing that, we rewrite the equation explicitly. Notice that: \[ \nabla_x \cdot[\overleftarrow{b}(\tau,x)p(T-\tau,x)] = p(T-\tau,x)\nabla_x \cdot \overleftarrow{b}(\tau,x) + \overleftarrow{b}(\tau,x)\cdot \nabla_x p(T-\tau,x). \]
Similarly, \[ \nabla_x \cdot[b(T-\tau,x)p(T-\tau,x)] = p(T-\tau,x)\nabla_x \cdot b(T-\tau,x) + b(T-\tau,x)\cdot \nabla_x p(T-\tau,x). \]
And \[ \nabla_x \cdot[\sigma(T-\tau)\sigma(T-\tau)^T \nabla_x p(T-\tau,x)] = \sigma(T-\tau)\sigma(T-\tau)^T : \nabla_x^2 p(T-\tau,x), \] where “:” denotes the Frobenius inner product. But it’s more intuitive to keep it in divergence form.
Now divide every term by \(p(T-\tau,x)\). This gives us terms like \(\frac{\nabla_x p(T-\tau,x)}{p(T-\tau,x)}\), which is precisely \(\nabla_x \log p(T-\tau,x)\).
After dividing by \(p(T-\tau,x)\), terms organize as follows. We get an equation of the form: \[ - \overleftarrow{b}(\tau,x) - \overleftarrow{b}(\tau,x)\cdot \nabla_x \log p(T-\tau,x) = b(T-\tau,x) + b(T-\tau,x)\cdot \nabla_x \log p(T-\tau,x) - \sigma(T-\tau)\sigma(T-\tau)^T \nabla_x \log p(T-\tau,x). \]
By carefully rearranging (grouping terms involving \(\overleftarrow{b}\), and then isolating \(\overleftarrow{b}\)), one finds the well-known formula: \[ \overleftarrow{b}(\tau,x) = -b(T-\tau,x) + \sigma(T-\tau)\sigma(T-\tau)^T \nabla_x \log p(T-\tau,x). \]
The appearance of the score \(\log p(T-\tau,x)\) is due precisely to this division by \(p(T-\tau,x)\). When we go from equations involving \(p\) to equations involving velocities or drifts that must be expressed in terms of more fundamental gradients, we get the ratio \(\frac{\nabla_x p}{p}\), which is \(\nabla_x \log p\).
10 Propagator version
Let’s try again from the top using the forward propagato
Consider discrete time steps \(t_0, t_1, \ldots, t_N\) with \(t_{k+1} = t_k + \Delta t\). Let \(p_k(x) = p(t_k, x)\) denote the probability density of \(Z_{t_k}\).
The forward propagation operator \(\mathcal{F}_{k \to k+1}\) maps \(p_k\) to \(p_{k+1}\):
\[ p_{k+1}(x') = \mathcal{F}_{k \to k+1} p_k(x') = \int p(t_{k+1} = t_k + \Delta t, x' \mid t_k, x) \, p_k(x) \, dx. \]
Here, \(p(t_{k+1}, x' \mid t_k, x)\) is the transition probability from \(x\) at time \(t_k\) to \(x'\) at \(t_{k+1}\).
For small \(\Delta t\), assuming the SDE has coefficients \(b\) and \(\sigma\) that are smooth enough, the transition probability can be approximated using the Fokker-Planck (a.k.a. Kolmogorov forward) equation.
Similarly, the reverse propagation operator \(\mathcal{R}_{k+1 \to k}\) maps \(p_{k+1}\) back to \(p_k\):
\[ p_k(x) = \mathcal{R}_{k+1 \to k} p_{k+1}(x) = \int p(t_k, x \mid t_{k+1}, x') \, p_{k+1}(x') \, dx'. \]
Here, \(p(t_k, x \mid t_{k+1}, x')\) is the reverse transition probability density from \(x'\) at \(t_{k+1}\) to \(x\) at \(t_k\).
The detailed balance condition imposes a relationship between \(\mathcal{F}\) and \(\mathcal{R}\) which will give us time-reversibility of the process. Specifically, for the process to be reversible, they must satisfy:
\[ p(t_k, x) \, p(t_{k+1}, x' \mid t_k, x) = p(t_{k+1}, x') \, p(t_k, x \mid t_{k+1}, x'). \]
This equality ensures that the flow of probability from \(x\) to \(x'\) forward in time is balanced by the flow from \(x'\) to \(x\) backward in time.
To construct the backward SDE, we need to express the reverse dynamics in terms of an SDE similar to the forward one but with adjusted drift.
Using Bayes’ theorem, the reverse transition probability can be expressed in terms of the forward transition probability and the marginal densities:
\[ p(t_k, x \mid t_{k+1}, x') = \frac{p(t_{k+1}, x' \mid t_k, x) \, p(t_k, x)}{p(t_{k+1}, x')}. \]
For small \(\Delta t\), the forward transition probability \(p(t_{k+1}, x' \mid t_k, x)\) can be approximated as:
\[ p(t_{k+1}, x' \mid t_k, x) \approx \mathcal{N}\left(x' \, \bigg| \, x + b(t_k, x) \Delta t, \, \sigma(t_k) \sigma(t_k)^T \Delta t \right), \]
where \(\mathcal{N}(\cdot \mid \mu, \Sigma)\) denotes the Gaussian (normal) distribution with mean \(\mu\) and covariance \(\Sigma\).
To find the drift \(\overleftarrow{b}\) in the backward SDE, consider the infinitesimal generator of the reversed process. Specifically, in the backward SDE:
\[ d\overleftarrow{Z}_\tau = \overleftarrow{b}(\tau, \overleftarrow{Z}_\tau) \, d\tau + \sigma(T - \tau) \, d\overleftarrow{W}_\tau, \]
we need to determine \(\overleftarrow{b}(\tau, x)\).
Using the expression for the reverse transition probability and expanding it for small \(\Delta t\), we can identify the effective drift.
Using the Gaussian approximation and Bayes’ theorem:
\[ p(t_k, x \mid t_{k+1}, x') \approx \frac{\mathcal{N}\left(x' \, \bigg| \, x + b(t_k, x) \Delta t, \, \sigma(t_k) \sigma(t_k)^T \Delta t \right) p(t_k, x)}{p(t_{k+1}, x')}. \]
Taking the logarithm to simplify expressions:
\[ \log p(t_k, x \mid t_{k+1}, x') \approx \log \mathcal{N}\left(x' \, \bigg| \, x + b(t_k, x) \Delta t, \, \sigma(t_k) \sigma(t_k)^T \Delta t \right) + \log p(t_k, x) - \log p(t_{k+1}, x'). \]
To find the drift \(\overleftarrow{b}\), we consider the score function, which is the gradient of the log-density:
\[ \nabla_{x'} \log p(t_k, x \mid t_{k+1}, x'). \]
Getting an LLM to grind through the calculations, we find that the reverse drift \(\overleftarrow{b}\) consists of two parts:
- Negative Forward Drift: \(-b(T - \tau, x)\).
- Fokker-Planck Adjustment: \(\sigma(T - \tau) \sigma(T - \tau)^T \nabla_x \log p(T - \tau, x)\). The total reverse drift is then
\[ \overleftarrow{b}(\tau, x) = -b(T - \tau, x) + \sigma(T - \tau) \sigma(T - \tau)^T \nabla_x \log p(T - \tau, x). \]
With the reverse drift identified, the backward SDE becomes:
\[ d\overleftarrow{Z}_\tau = \left[ -b(T - \tau, \overleftarrow{Z}_\tau) + \sigma(T - \tau) \sigma(T - \tau)^T \nabla_x \log p(T - \tau, \overleftarrow{Z}_\tau) \right] d\tau + \sigma(T - \tau) \, d\overleftarrow{W}_\tau. \]
This should look familiar as it matches the form of the backward SDE we derived earlier.
10.1 Let’s try that in continuous time
While the above derivation uses discrete time steps, we can extend the intuition to continuous time by considering infinitesimal operators.
Forward Infinitesimal Generator \(\mathcal{L}\):
For the forward SDE,
\[ \mathcal{L} f(x) = b(t, x) \cdot \nabla_x f(x) + \frac{1}{2} \text{Tr} \left[ \sigma(t) \sigma(t)^T \nabla_x^2 f(x) \right], \]
where \(f\) is a sufficiently smooth test function.
Reverse Infinitesimal Generator \(\overleftarrow{\mathcal{L}}\):
For the reversed SDE,
\[ \overleftarrow{\mathcal{L}} f(x) = \overleftarrow{b}(\tau, x) \cdot \nabla_x f(x) + \frac{1}{2} \text{Tr} \left[ \sigma(T - \tau) \sigma(T - \tau)^T \nabla_x^2 f(x) \right]. \]
To ensure that the reversed process has the correct marginal distributions \(p(T - \tau, x)\), the generators must satisfy a duality condition that we copy-paste from the Fokker-Planck equation:
\[ \partial_\tau p(T - \tau, x) = \overleftarrow{\mathcal{L}}^* p(T - \tau, x), \]
where \(\overleftarrow{\mathcal{L}}^*\) is the adjoint (Fokker-Planck operator) of \(\overleftarrow{\mathcal{L}}\).
Simultaneously, the original forward generator satisfies:
\[ \partial_t p(t, x) = \mathcal{L}^* p(t, x). \]
By substituting \(t = T - \tau\), we relate the forward and reverse dynamics:
\[ - \partial_\tau p(T - \tau, x) = \mathcal{L}^* p(T - \tau, x). \]
Equating the two expressions for \(\partial_\tau p(T - \tau, x)\) leads to the identification of \(\overleftarrow{\mathcal{L}}\) in terms of \(\mathcal{L}\) and the density \(p\), ultimately resulting in the expression for the reverse drift \(\overleftarrow{b}\).

