Draft

Score diffusion

August 17, 2023 — December 10, 2024

Bayes
estimator distribution
functional analysis
Markov processes
Monte Carlo
neural nets
optimization
probabilistic algorithms
probability
SDEs
stochastic processes
Figure 1

Langevin dynamics plus score functions to sample from tricky distributions via clever use of stochastic differential equations. Made famous by neural diffusions.

Short version: Classic Langevin MCMC samples from a target distribution by creating an indefinitely long pseudo-time axis, \(\tau\in[0,\infty)\), and a stationary Markov transition kernel that is reversible with respect to the target distribution. As this virtual time grows large, the process samples from the desired target distribution with the invocation of some ergodic theory.

Score diffusions, by contrast, set up a slightly different equation, on a finite time interval \(\tau\in[0,T]\). The transition kernel is not stationary, but configured so that, at pseudo-time \(\tau=T\), it is some “easy” distribution, and at pseudo-time \(\tau=0\) it is some challenging distribution of interest, typically a posterior, or a sample from some nasty distribution etc. Sampling is a more involved process, involving walking backwards and forwards through pseudo-time.

In both cases, we inject noise in a clever way and perturb it with the score function.

1 Introduction

The idea is explained in small enough words for me to understand in (Sohl-Dickstein et al. 2015). See also popularised articles for ML people,

All of those introductions include more work than we need to do in this notebook, because they want to get to the hyped NN bit. Here we are given the score function, unlike in the above references, where it must be learned. The fundamental insight of Neural diffusions is that we can learn an approximate score function by score matching, which is an amazing trick, but out of scope for this notebook.

In many statistical and machine learning applications, we aim to infer the posterior distribution \(p_{X|Y}(x|y)\) of a latent variable \(X\) given observations \(Y\). Consider an observation model of the form:

\[ Y = g(X) + \varepsilon \]

  • \(g(\cdot)\): A known, differentiable function mapping \(X\) to the observation space.
  • \(\varepsilon\): Observation noise, typically assumed to follow a Gaussian distribution \(\varepsilon \sim \mathcal{N}(0, \sigma^2 I)\).

Given a prior distribution \(p_X(x)\) over \(X\), our objective is to update this prior with the observed data \(Y = y\) to obtain the posterior distribution \(p_{X|Y}(x|y)\).

There are many ways we can sample or estimate the posterior with the score function (Langevin dynamics, HMC, maximum likelihood, …) but this one has some unusual and helpful properties.

The score function is a useful concept in statistics, defined as the gradient of the log-probability density function:

\[ S(x) = \nabla_x \log p_X(x) \]

For the posterior distribution, the score function is

\[ S_{\text{posterior}}(x) = \nabla_x \log p_{X|Y}(x|y) = S_{\text{prior}}(x) + \nabla_x \log p_{Y|X}(y|x) \]

where

  • \(S_{\text{prior}}(x)\): Score function of the prior distribution.
  • \(\nabla_x \log p_{Y|X}(y|x)\): Gradient of the log-likelihood, incorporating the observation model.

This decomposition allows us to update the prior score with information from the likelihood to obtain the posterior score.

Terminology

The score function as used here is sometimes called the Stein score, and distinguished from the Fisher score, \[ S(\theta) = \nabla_{\theta} \ln f(X; \boldsymbol{\theta}). \]

In the frequentist context, they involve different assumptions about what is random (\(\mathbf{X}\)) and what is not (\(\boldsymbol{\theta}\)). I am not convinced that these are so distinct in the Bayesian context, where everything is potentially random, i.e. I do not care if it is the “date” or the “parameter” that is random. You could persuade me this attitude wrong-headed though.

The example under score function gradient estimators maybe shows an edge case, because it usually gets used to approximate an unknown random distribution with a notionally deterministic one, and it is rather popular in variational Bayes. YMMV.

2 The Backward SDE and the Target Distribution

at \(\tau = 0\)

2.1 The Backward SDE and Its Drift Term

The backward SDE is given by:

\[ dZ_{\tau} = \left[ b(\tau) Z_{\tau} - \sigma^2(\tau) S_{\text{posterior}}(Z_{\tau}, \tau) \right] d\tau + \sigma(\tau) d\overleftarrow{W}_{\tau} \]

  • \(S_{\text{posterior}}(Z_{\tau}, \tau) = \nabla_z \log q_{Z_{\tau}}(Z_{\tau})\) is the score function of the distribution at time \(\tau\).
  • The drift term is \(\mu(z, \tau) = b(\tau) z - \sigma^2(\tau) S_{\text{posterior}}(z, \tau)\).

Our goal is to show that with this drift term, the distribution \(q_{Z_{0}}(z)\) at \(\tau = 0\) is the desired posterior distribution \(p_{X|Y}(x|y)\).

2.2 The Fokker-Planck Equation and Density Evolution

The Fokker-Planck equation describes how the probability density \(q_{Z_{\tau}}(z)\) evolves over time under the influence of the SDE:

\[ \frac{\partial q_{Z_{\tau}}(z)}{\partial \tau} = -\nabla_z \cdot \left[ \mu(z, \tau) q_{Z_{\tau}}(z) \right] + \frac{1}{2} \sigma^2(\tau) \Delta_z q_{Z_{\tau}}(z) \]

  • \(\nabla_z \cdot\) denotes the divergence operator.
  • \(\Delta_z\) is the Laplacian operator with respect to \(z\).

2.3 Incorporating the Posterior Score Function

The posterior score function can be expressed as:

\[ S_{\text{posterior}}(z, \tau) = S_{\text{prior}}(z, \tau) + h(\tau) \nabla_z \log p_{Y|X}(y | z) \]

  • \(S_{\text{prior}}(z, \tau)\) is the prior score function.
  • \(h(\tau)\) is a damping function, typically chosen as \(h(\tau) = 1 - \tau\).
  • \(\nabla_z \log p_{Y|X}(y | z)\) is the gradient of the log-likelihood.

By setting the drift term using \(S_{\text{posterior}}(z, \tau)\), we adjust the evolution of \(q_{Z_{\tau}}(z)\) to incorporate the influence of the observation \(y\).

2.4 Evolution Towards the Posterior Distribution

At \(\tau = 1\):

  • The distribution \(q_{Z_{1}}(z)\) is known (e.g., standard normal distribution).

For \(\tau \in [0, 1]\):

  • The backward SDE evolves the distribution \(q_{Z_{\tau}}(z)\) backward in time.
  • The drift term includes the posterior score function, guiding the density towards the posterior distribution.

As \(\tau \rightarrow 0\):

  • The damping function \(h(\tau) \rightarrow 1\), so the influence of the likelihood becomes fully incorporated.
  • The distribution \(q_{Z_{0}}(z)\) approaches the posterior distribution \(p_{X|Y}(x|y)\).

2.5 Mathematical Justification Using the Fokker-Planck Equation

2.5.1 Substituting the Drift Term

The drift term in the Fokker-Planck equation becomes:

\[ \mu(z, \tau) = b(\tau) z - \sigma^2(\tau) \left[ S_{\text{prior}}(z, \tau) + h(\tau) \nabla_z \log p_{Y|X}(y | z) \right] \]

Substituting this into the Fokker-Planck equation:

\[ \frac{\partial q_{Z_{\tau}}(z)}{\partial \tau} = -\nabla_z \cdot \left\{ \left[ b(\tau) z - \sigma^2(\tau) \left( S_{\text{prior}} + h(\tau) \nabla_z \log p_{Y|X}(y | z) \right) \right] q_{Z_{\tau}}(z) \right\} + \frac{1}{2} \sigma^2(\tau) \Delta_z q_{Z_{\tau}}(z) \]

2.5.2 Rearranging Terms

Recognizing that \(S_{\text{prior}} = \nabla_z \log q_{\text{prior}, Z_{\tau}}(z)\):

\[ \frac{\partial q_{Z_{\tau}}(z)}{\partial \tau} = -\nabla_z \cdot \left\{ \left[ b(\tau) z - \sigma^2(\tau) \nabla_z \log \left( q_{\text{prior}, Z_{\tau}}(z) p_{Y|X}(y | z)^{h(\tau)} \right) \right] q_{Z_{\tau}}(z) \right\} + \frac{1}{2} \sigma^2(\tau) \Delta_z q_{Z_{\tau}}(z) \]

2.5.3 Desired Density Evolution

We want \(q_{Z_{0}}(z) \propto p_{X}(z) p_{Y|X}(y | z)\), i.e., the posterior distribution.

By appropriately choosing \(h(\tau)\) and ensuring that the evolution of \(q_{Z_{\tau}}(z)\) follows the modified Fokker-Planck equation, we guide the density towards the posterior.

3 Example

Let’s apply this to the specific example where \(g(x) = \sin(x)\) and \(\varepsilon \sim \mathcal{N}(0, \sigma^2)\).

3.1 The Likelihood and Its Gradient

  • Likelihood:

    \[ p_{Y|X}(y | x) = \frac{1}{\sqrt{2\pi \sigma^2}} \exp\left( -\frac{(y - \sin(x))^2}{2\sigma^2} \right) \]

  • Gradient of Log-Likelihood:

    \[ \nabla_x \log p_{Y|X}(y | x) = \frac{y - \sin(x)}{\sigma^2} \cos(x) \]

3.2 Posterior Score Function

The posterior score function is:

\[ S_{\text{posterior}}(x, \tau) = S_{\text{prior}}(x, \tau) + h(\tau) \frac{y - \sin(x)}{\sigma^2} \cos(x) \]

  • As \(\tau \rightarrow 0\), \(h(\tau) \rightarrow 1\), and the influence of the likelihood is fully incorporated.

3.3 Evolution of the Density

By including this posterior score function in the drift term of the backward SDE, we adjust the evolution of \(q_{Z_{\tau}}(z)\) to move towards the posterior distribution \(p_{X|Y}(x|y)\).

  • Drift Term in Backward SDE:

    \[ \mu(z, \tau) = b(\tau) z - \sigma^2(\tau) \left( S_{\text{prior}}(z, \tau) + h(\tau) \frac{y - \sin(z)}{\sigma^2} \cos(z) \right) \]

At \(\tau = 0\):

  • The evolution governed by the Fokker-Planck equation with the above drift term ensures that \(q_{Z_{0}}(z) = p_{X|Y}(x|y)\).
  • This means that the samples \(Z_{0}\) obtained from the backward SDE at \(\tau = 0\) are distributed according to the desired posterior distribution.

4 Implementation

Below is an example of how to numerically simulate the forward and backward SDEs using JAX.

We’ll implement a numerical experiment using JAX to simulate both the forward and backward SDEs with a non-trivial observation model. Specifically, we’ll define an observation model \(y = g(x) + \varepsilon\) where \(g(x) = \sin(x)\), making the score function more challenging.

4.1 Import Libraries

Code
import jax
import jax.numpy as jnp
import numpy as np
from jax import random, grad
import matplotlib.pyplot as plt
from livingthing.matplotlib_style import set_livingthing_style, reset_default_style

set_livingthing_style()
Matplotlib style set to use Alegreya font: Alegreya

4.2 Define SDE Coefficients

We set \(\alpha_{\tau} = 1 - \tau\) and \(\beta_{\tau}^2 = \tau\), leading to

Code
def b_tau(tau):
    return -1.0 / (1.0 - tau)

def sigma_tau(tau):
    return jnp.sqrt(1.0 + 2.0 * tau / (1.0 - tau))

4.3 Define the Observation Model and Score Function

Let \(g(x) = \sin(x)\) and assume Gaussian noise \(\varepsilon \sim \mathcal{N}(0, \sigma^2)\). The likelihood is:

\[ p(Y=y|X=x) = \mathcal{N}(y | \sin(x), \sigma^2) \]

Thus, the gradient of the log-likelihood is:

\[ \nabla_x \log p(Y=y|X=x) = \frac{y - \sin(x)}{\sigma^2} \cos(x) \]

Define the score function incorporating the posterior:

Code
def posterior_score(z, tau, y, sigma2):
    # Score of the prior (assuming q(Z_tau) is standard normal for simplicity)
    score_prior = -z / tau  # From S_prior = -z / tau for q(Z_tau) = N(0, tau)

    # Gradient of log-likelihood
    grad_log_likelihood = (y - jnp.sin(z)) / sigma2 * jnp.cos(z)

    # Damping function h(tau) = 1 - tau
    h_tau = 1.0 - tau

    # Posterior score
    return score_prior + h_tau * grad_log_likelihood

4.4 Forward SDE Simulation

Simulate the forward diffusion process, adding noise to the data:

Code
def simulate_forward_sde(Z0, key, N=1000):
    dtau = 1.0 / N
    tau_values = jnp.linspace(0.0, 1.0, N + 1)
    Z_tau = Z0
    Z_tau_list = [Z_tau]

    for i in range(N):
        tau = tau_values[i]
        b = b_tau(tau)
        sigma = sigma_tau(tau)
        key, subkey = random.split(key)
        dW = random.normal(subkey, shape=Z0.shape) * jnp.sqrt(dtau)
        Z_tau = Z_tau + b * Z_tau * dtau + sigma * dW
        Z_tau_list.append(Z_tau)

    return jnp.array(Z_tau_list)

4.5 Backward SDE Simulation

Simulate the backward denoising process, incorporating the posterior score:

Code
def simulate_backward_sde(Z1, y, sigma2, key, N=1000):
    dtau = -1.0 / N  # Negative time step for backward integration
    tau_values = jnp.linspace(1.0, 0.0, N + 1)
    Z_tau = Z1
    Z_tau_list = [Z_tau]

    for i in range(N):
        tau = tau_values[i]
        b = b_tau(tau)
        sigma = sigma_tau(tau)
        S_posterior = posterior_score(Z_tau, tau, y, sigma2)
        drift = b * Z_tau - sigma**2 * S_posterior
        key, subkey = random.split(key)
        dW = random.normal(subkey, shape=Z1.shape) * jnp.sqrt(-dtau)
        Z_tau = Z_tau + drift * dtau + sigma * dW
        Z_tau_list.append(Z_tau)

    return jnp.array(Z_tau_list)

4.6 Example Usage

Simulate both processes with a non-trivial observation model and visualize the results.

Code
# Initialize random keys
key = random.PRNGKey(0)
key, subkey = random.split(key)

# Number of samples and time steps
num_samples = 10000
N = 1000

# Observation parameters
y = 0.5  # Observed value
sigma2 = 0.1  # Variance of observation noise

# Initial data distribution: X ~ N(5, 1)
Z0 = random.normal(subkey, shape=(num_samples,)) * 1.0 + 5.0

# Simulate forward SDE
key, subkey = random.split(key)
Z_forward = simulate_forward_sde(Z0, subkey, N=N)

# Simulate backward SDE starting from Z1
Z1 = Z_forward[-1]
key, subkey = random.split(key)
Z_backward = simulate_backward_sde(Z1, y, sigma2, subkey, N=N)

# Convert to numpy for plotting
Z0_np = np.array(Z0)
Z1_np = np.array(Z1)
Z_backward_np = np.array(Z_backward[-1])

# Plotting the distributions
plt.figure(figsize=(18, 5))

# Initial distribution
plt.subplot(1, 3, 1)
plt.hist(Z0_np, bins=100, density=True, alpha=0.6, color='blue', label='Initial Z0')
plt.title('Initial Distribution at τ=0')
plt.xlabel('Z0')
plt.ylabel('Density')
plt.legend()

# After Forward SDE
plt.subplot(1, 3, 2)
plt.hist(Z1_np, bins=100, density=True, alpha=0.6, color='green', label='After Forward SDE')
plt.title('Distribution at τ=1')
plt.xlabel('Z1')
plt.ylabel('Density')
plt.legend()

# # After Backward SDE
# plt.subplot(1, 3, 3)
# plt.hist(Z_backward_np, bins=100, density=True, alpha=0.6, color='red', label='Recovered Z0')
# plt.title('Recovered Distribution at τ=0')
# plt.xlabel('Z0 Recovered')
# plt.ylabel('Density')
# plt.legend()

plt.tight_layout()
plt.show()

5 References

Anderson. 1982. Reverse-Time Diffusion Equation Models.” Stochastic Processes and Their Applications.
Bao, Cao, Meir, et al. 2016. A First Order Scheme for Backward Doubly Stochastic Differential Equations.” SIAM/ASA Journal on Uncertainty Quantification.
Bao, Chipilski, Liang, et al. 2024. Nonlinear Ensemble Filtering with Diffusion Models: Application to the Surface Quasi-Geostrophic Dynamics.”
Bao, Zhang, and Zhang. 2024. An Ensemble Score Filter for Tracking High-Dimensional Nonlinear Dynamical Systems.”
Dockhorn, Vahdat, and Kreis. 2022. Score-Based Generative Modeling with Critically-Damped Langevin Diffusion.”
Holzschuh, Vegetti, and Thuerey. 2022. “Score Matching via Differentiable Physics.”
Rozet, and Louppe. 2023. Score-Based Data Assimilation.”
Sohl-Dickstein, Weiss, Maheswaranathan, et al. 2015. Deep Unsupervised Learning Using Nonequilibrium Thermodynamics.”