import jax
import jax.numpy as jnp
from jax import random
def cg_solver(b, mv_prod_fn, tol=1e-5, max_iter=100):
"""
Conjugate Gradient method for solving Σv = b.
"""
x = jnp.zeros_like(b)
r = b.copy()
p = r.copy()
rs_old = jnp.dot(r, r)
for i in range(max_iter):
Ap = mv_prod_fn(p)
alpha = rs_old / jnp.dot(p, Ap)
x = x + alpha * p
r = r - alpha * Ap
rs_new = jnp.dot(r, r)
if jnp.sqrt(rs_new) < tol:
break
p = r + (rs_new / rs_old) * p
rs_old = rs_new
return x
When dealing with high-dimensional Gaussian distributions, sampling can become computationally expensive, especially when the covariance matrix is large. Traditional methods like the Cholesky decomposition become impractical. However, if we can efficiently compute the product of the covariance matrix with arbitrary vectors, we can leverage Langevin dynamics to sample from the distribution without forming the full covariance matrix.
I have been doing this recently in the setting where
Spoiler alert: While this is an interesting exercise, at the end we discover that it would have been easier to do something even simpler, making this whole essay into an impractical educational exercise.
1 Problem Setting
We aim to sample from a multivariate Gaussian distribution:
where:
is the known mean vector. is the notional known covariance matrix, which might be too large to actually compute, let alone factorise for sampling in the usual way.
2 Langevin Dynamics for Sampling
Langevin dynamics provide a way to sample from a target distribution by simulating a stochastic differential equation (SDE) whose stationary distribution is the desired distribution. For a Gaussian distribution, the SDE simplifies due to the properties of the normal distribution (i.e. Gaussians all the way down).
The continuous-time Langevin equation is
where
is the potential function related to the target distribution via . represents the increment of a Wiener process (standard Brownian motion).
For our Gaussian distribution, the potential function is:
We discretize the Langevin equation using the Euler-Maruyama method with time step
where
Instead of computing
3 How to solve that linear equation
To solve
Given
- Initialize
, , . - For
:- If
, stop.
The CG method converges in at most
4 Plug the bits together
We have the following algorithm:
- Start with
or any arbitrary vector. - For
:- Compute
. - Solve
using CG to get . - Update
where .
- Compute
5 Python Implementation
Here is an implementation in python that I got an LLM to construct for me from the above algorithm.
5.1 Conjugate Gradient Solver
A teensy CG solver, which looks ok:
5.2 Langevin Dynamics Sampler
def sample_mvn_langevin(mu, num_samples=1000, epsilon=1e-2, burn_in=1000):
"""
Samples from N(mu, Σ) using Langevin dynamics.
Parameters:
- mu: Mean vector (JAX array of shape [D])
- num_samples: Number of samples to collect after burn-in
- epsilon: Time step size
- burn_in: Number of initial iterations to discard
TODO: set num_iter based on condition number of Σ
"""
D = mu.shape[0]
key = random.PRNGKey(0)
# Random initialization
key, subkey = random.split(key)
x = mu + random.normal(subkey, shape=(D,)) * 0.1 # Small perturbation around mu
samples = []
total_steps = num_samples + burn_in
for n in range(total_steps):
# Generate starting v
key, subkey = random.split(key)
noise = random.normal(subkey, shape=(D,))
# Compute gradient: v = Σ^{-1} (x - μ)
r = x - mu
v = cg_solver(r, mv_prod_fn, tol=1e-5, max_iter=100)
# Langevin update
x_new = x - epsilon * v + jnp.sqrt(2 * epsilon) * noise
x = x_new
if n >= burn_in:
samples.append(x.copy())
return jnp.stack(samples) if samples else None
5.3 Example
First, we need a function to compute
In the low-rank setting, would we actually need a Langevin sampler? or is there something better?
# Basic noisy sinewave
D = 100
mu = jnp.cos(jnp.linspace(0, 2 * jnp.pi, D)) * 5
rank = 10
sig2 = 0.1
def get_matrix_A(D, rank, scale=1.0, mean=None):
"""
Example: generate a random low-rank matrix.
Replace this with the kernel or actual structure of Σ.
"""
key = random.PRNGKey(42)
A = random.normal(key, shape=(D, rank)) * scale
if mean is not None:
A += mean.reshape(-1,1)
return A
A = get_matrix_A(D, rank, mean=mu, scale=2.0)
L = A - A.mean(1).reshape(-1,1)
# centre and normalize A to generate the empirical covariance
L = L / jnp.sqrt(rank)
def mv_prod_fn(v):
"""
Efficient computation of Σv using a known kernel or factorization.
For demonstration, we assume Σ = L L^T + U where L is a low-rank matrix and U=𝞼^2 I is diagonal
"""
return L @ (L.T @ v) + sig2 * v
Now we sample from the Gaussian distribution using Langevin dynamics.
6 Validation
After sampling, it’s wise to verify that the samples approximate the target distribution. We can do some basic diagnostics of that by seeing how well we’ve matched the marginal distributions.
import matplotlib.pyplot as plt
import numpy as np
from livingthing.matplotlib_style import set_livingthing_style, reset_default_style
set_livingthing_style()
def plot_validation(
samples,
mu,
sigma_diag,
reference_samples=None,
max_traces=rank,
base_alpha=0.5,
title="",
):
"""
Plot validation for the sampler against the ideal mean and variance,
and optionally include reference samples.
Parameters:
- samples: Samples generated from the sampler (numpy array of shape [num_samples, D]).
- mu: Ideal mean vector (numpy array of shape [D]).
- sigma_diag: Ideal diagonal of the covariance matrix (numpy array of shape [D]).
- reference_samples: Reference samples to plot (numpy array of shape [num_ref_samples, D]).
- max_traces: Maximum number of samples to plot (int).
- base_alpha: Base transparency value (float, for a small number ofs).
"""
D = mu.shape[0]
num_samples = samples.shape[0]
# Randomly select a subset of samples (up to max_traces)
subset_size = min(max_traces, num_samples)
rng = np.random.default_rng()
subset_indices = rng.choice(num_samples, subset_size, replace=False)
sample_subset = samples[subset_indices]
# Adjust alpha to decrease with the square root of the number ofs
trace_alpha = base_alpha / np.sqrt(subset_size)
# Generate x-axis for dimensions
x = np.arange(D)
# Plot the ideal mean and standard deviation bands
# plt.figure(figsize=(12, 7))
plt.plot(x, mu, label="Ideal Mean", color="gray", linestyle="--")
plt.fill_between(
x,
mu - 2 * sigma_diag,
mu + 2 * sigma_diag,
color="gray",
alpha=0.2,
label="Ideal ±2 Std. Dev.",
)
# Plot individual samples with adjusted transparency
for sample in sample_subset:
plt.plot(x, sample, color="red", alpha=trace_alpha)
# If reference samples are provided, plot them
if reference_samples is not None:
ref_subset_size = min(max_traces, reference_samples.shape[0])
ref_subset_indices = rng.choice(
reference_samples.shape[0], ref_subset_size, replace=False
)
ref_sample_subset = reference_samples[ref_subset_indices]
for ref_sample in ref_sample_subset:
plt.plot(x, ref_sample, color="green", alpha=trace_alpha, linestyle="-")
# Create proxy artists for the legend
from matplotlib.lines import Line2D
sample_proxy = Line2D([0], [0], color="red", alpha=1.0, label="Samples")
ref_proxy = Line2D(
[0], [0], color="green", alpha=1.0, linestyle="-", label="Reference Samples"
)
# Add labels and legend
handles = [
plt.Line2D([], [], color="grey", linestyle="--", label="Ideal Mean"),
plt.Line2D([], [], color="grey", alpha=0.2, label="Ideal ±1 Std. Dev."),
sample_proxy,
]
if reference_samples is not None:
handles.append(ref_proxy)
plt.legend(handles=handles)
plt.title(title)
plt.xlabel("Dimension")
plt.ylabel("Value")
plt.grid(True, linestyle="--", alpha=0.5)
plt.tight_layout()
plt.show()
# Example usage:
# Compute the diagonal standard deviation of the ideal covariance matrix
ideal_std_diag = np.sqrt((L * L).sum(1) + sig2)
# Call the plot function
plot_validation(
samples=np.array(samples),
mu=np.array(mu),
sigma_diag=ideal_std_diag,
reference_samples=A.T,
title="Simple Langevin samples",
)
plt.show()
Matplotlib style set to use Alegreya font: Alegreya
That’s … mediocre. Our sampling is in the right ballpark but not playing the same game as the ideal distribution.
7 Metropolis Adjustment
When using Langevin dynamics for sampling from a probability distribution, discretization errors can lead to sampling from an incorrect distribution. We can fix this in a couple of ways. Since we know that the Gaussian target distribution is actually log-concave, we could use the implicit Langevin trick (Hodgkinson, Salomone, and Roosta 2019).
But let us try a more generic and slightly easier (but not as elegant) trick: a Metropolis-Hastings rejection step, resulting in the Metropolis-Adjusted Langevin Algorithm (MALA). This adjustment ensures that the Markov chain has the desired target distribution as its stationary distribution.
The reason we might need to do this is that when we discretized time using the Euler-Maruyama method, we introduced errors that can accumulate, causing the chain to sample from an incorrect distribution. In MALA, the proposal distribution is based on the discretized Langevin dynamics. The acceptance probability is chosen to ensure detailed balance and correct for the discretization error (thought: are these actually the same thing?).
The proposal
The acceptance probability
This ratio adjusts for the asymmetry introduced by the drift term
We might be worried about finding the
Code
import jax
import jax.numpy as jnp
from jax import random
def log_q(x_new, x_old, grad_old, epsilon):
"""
Compute the log of the proposal density q(y | x).
Parameters:
- x_new: Proposed state (JAX array).
- x_old: Current state (JAX array).
- grad_old: Gradient of U at x_old (JAX array).
- epsilon: Time step size (float).
Returns:
- log_q: Log of the proposal density (float).
"""
diff = x_new - x_old + epsilon * grad_old
return -0.25 / epsilon * jnp.dot(diff, diff)
def metropolis_adjustment(x, y, U_x, U_y, grad_x, grad_y, epsilon, key):
"""
Perform Metropolis-Hastings acceptance step.
Parameters:
- x: Current state (JAX array).
- y: Proposed state (JAX array).
- U_x: Potential at current state (float).
- U_y: Potential at proposed state (float).
- grad_x: Gradient of U at x (JAX array).
- grad_y: Gradient of U at y (JAX array).
- epsilon: Time step size (float).
- key: JAX PRNGKey.
Returns:
- x_new: Accepted or retained state (JAX array).
"""
log_q_y_given_x = log_q(y, x, grad_x, epsilon)
log_q_x_given_y = log_q(x, y, grad_y, epsilon)
log_alpha = -U_y + U_x + log_q_x_given_y - log_q_y_given_x
log_alpha = jnp.minimum(log_alpha, 0.0) # Ensure log_alpha <= 0
u = random.uniform(key) # u ~ Uniform(0, 1)
accept = u < jnp.exp(log_alpha)
if accept:
return y
else:
return x
def mala_step(x, mu, L, sig2, epsilon, key):
"""
Perform one step of the Metropolis-Adjusted Langevin Algorithm (MALA).
Parameters:
- x: Current state (JAX array).
- mu: Mean vector (JAX array).
- L: Low-rank matrix defining Σ = L L^T + sig2 * I.
- sig2: Variance term in Σ.
- epsilon: Time step size (float).
- key: JAX PRNGKey.
Returns:
- x_new: New state (JAX array).
- key: Updated PRNGKey.
"""
# Split key
key, subkey1, subkey2 = random.split(key, num=3)
# Compute gradient and potential at x
r_x = x - mu
grad_x = cg_solver(r_x, mv_prod_fn, tol=1e-5, max_iter=100)
U_x = 0.5 * jnp.dot(r_x, grad_x)
# Propose new state y
noise = random.normal(subkey1, shape=x.shape)
y = x - epsilon * grad_x + jnp.sqrt(2 * epsilon) * noise
# Compute gradient and potential at y
r_y = y - mu
grad_y = cg_solver(r_y, mv_prod_fn, tol=1e-5, max_iter=100)
U_y = 0.5 * jnp.dot(r_y, grad_y)
# Perform Metropolis adjustment
x_new = metropolis_adjustment(x, y, U_x, U_y, grad_x, grad_y, epsilon, subkey2)
return x_new, key
def sample_mvn_mala(mu, L, sig2, num_samples=1000, epsilon=1e-3, burn_in=100):
"""
Samples from N(mu, Σ) using the Metropolis-Adjusted Langevin Algorithm (MALA).
Parameters:
- mu: Mean vector (JAX array of shape [D]).
- L: Low-rank matrix defining Σ = L L^T + sig2 * I.
- sig2: Variance term in Σ.
- num_samples: Number of samples to collect after burn-in.
- epsilon: Time step size.
- burn_in: Number of initial iterations to discard.
Returns:
- samples: Samples from the target distribution (JAX array of shape [num_samples, D]).
"""
D = mu.shape[0]
key = random.PRNGKey(0)
# Random initialization
key, subkey = random.split(key)
# Small perturbation around mu
x = mu + random.normal(subkey, shape=(D,)) * 0.1
samples = []
total_steps = burn_in + num_samples
for n in range(total_steps):
x, key = mala_step(x, mu, L, sig2, epsilon, key)
# Collect samples after burn-in
if n >= burn_in:
samples.append(x)
return jnp.stack(samples)
OK, now let us generate some samples using this shiny new MALA sampler.
Code
We could investigate the convergence of the MALA sampler by looking at the trace of the samples and think about trade-offs of step size and burn-in etc. I’m not inclined to do that, for a reason that I’ll expand next, which is that this is a silly way of sampling from a Gaussian distribution.
We should ask ourselves when is this method is worth trying at all. This can sample powerfully from tricky Gaussians when we can compute
- Sample
, . is a sample from .
To see this, just calculate the mean and variance of
So this Langevin example per se is useless (but fun, I hope). But the question might be: are there other occasions when we can calculate
That is not to say Langevin-type dynamics are not useful; but I think we need to be cleverer; probably score diffusions, which have a very sophisticated way of sampling from punishing (and not even Gaussian) distributions are the interesting case.