# (Kernelized) Stein variational gradient descent

KSD, SVGD, other computational Stein discrepancy methods

November 2, 2022 — May 30, 2024

approximation
Bayes
functional analysis
Markov processes
measure
metrics
Monte Carlo
optimization
probabilistic algorithms
probability
statistics

Stein’s method meets variational inference via kernels and probability measures. The result is method of inference which maintains an ensemble of particles which notionally collectively sample from some target distribution. I should learn about this, as one of the methods I might use for low-assumption Bayes inference. This seems to have been invented in Q. Liu, Lee, and Jordan (2016) and Chwialkowski, Strathmann, and Gretton (2016), weaponized in Q. Liu (2016b).

There seems to be a standard way of introducing the tools, which I find very confusing. Here I work through that standard with laborious worked examples, so that I can internalise the necessary intuitions for all this.

For a more comprehensive introduction (albeit brusquer), see Anastasiou et al. (2023), which combines a whole bunch of recent developments with consistent notation.

For what it is worth, I found Chwialkowski, Strathmann, and Gretton (2016) to be the easiest read, although none of them was pedagogically ideal, which is why I wrote this note.

Let us introduce the bits we need.

## 1 Stein operators

We start with the classic Stein’s identity which turns out to be a useful trick for quantifying how well we have approximated some density.

Spoiler: later on it turns out that we can even use this as a target loss in order to improve how well we have approximated some density.

We care about a target density $$p$$ and another density $$q$$ (which will end up approximating it). $$p$$ needs to be differentiable for this to work. They are both densities over, by assumption $$\mathcal{X}\subseteq\mathbb{R}^d$$. We also introduce a family of $$\mathbb{R}^d$$ to $$\mathbb{R}^d$$ test functions $$\mathcal{F}$$. We require that $$\lim_{x \to \pm \infty} p(xb)f(xb) = 0$$ for $$\|b\|=1$$, and some other stuff which we get to in a moment.

Should say more about the generic class $$\mathcal{F}$$. Q. Liu, Lee, and Jordan (2016) does.

Next, we choose a Stein operator $$\mathcal{A}_{x}: \mathcal{F}\to \mathcal{G}$$. $$\mathcal{F}$$ and $$\mathcal{G}$$ are spaces of functions from $$\mathbb{R}^d$$ to $$\mathbb{R}$$. I gave them different names because it is not clear to me that they are necessarily the same space, but we can probably ignore that detail for now. We do not go into the details of the requirements of the spaces, but they should be smooth functions that are square-integrable (with respect to the target distribution $$p$$? or Lebesgue measure? something else?) whose derivatives also go to zero at infinity. For an operator $$\mathcal{A}_{x}$$ to be a Stein operator for a target distribution $$p(x)$$, it must satisfy the following key property:

Stein’s Operators: $$\mathcal{A}_{x}$$ is a Stein operator with respect to a suitable class of test functions $$\mathcal{F}$$, and target distribution $$p(x)$$ if the expectation of $$\mathcal{A}_{x} f(x)$$ under $$p$$is zero, i.e. for all those $$f$$ in $$\mathcal{F}$$, $\mathbb{E}_{X\sim p}[\mathcal{A}_{X} f(X)] = 0. \tag{1}$ For $$\mathcal{F}$$ which include a non-trivial linear subspace, we can see that $$\mathcal{A}_{x}$$ must be linear, because expectation is linear, and otherwise we could make linear changes to an $$f$$ and end up violating the equality. A popular choice is to set the Stein Operator to $\mathcal{A}_{x} f(x):=f(x) \nabla_x \cdot \log p(x)+\nabla_x \cdot f(x). \tag{2}$ Anastasiou et al. (2023) call this the Langevin Stein Operator, and it seems that if you do not otherwise specify, this is the one you get. The Langevin Stein operator makes this into a score-based method — See C. Liu et al. (2019) for some deep theory about that. Example time! Let us make this more concrete, by choosing a specific $$p$$ which is not trivial but not baffling either. I reckon a 2d Gaussian with standard deviation 1 and mean 0 with do the trick. Let us give it a correlation $$\rho$$, which we leave unspecified, to keep things spicy. This implies mean $$\mu = (0, 0)$$ and covariance $$\Sigma = \left[\begin{smallmatrix} 1 & \rho \\ \rho & 1 \end{smallmatrix}\right]$$, and thence inverse covariance $$\Sigma^{-1} = \tfrac{1}{1-\rho^2}\left[\begin{smallmatrix} 1 & -\rho \\ -\rho & 1 \end{smallmatrix}\right]$$. The pdf for this distribution is \begin{aligned} p(\boldsymbol{x}) &= \frac{1}{2\pi \sqrt{|\Sigma|}} \exp\left(-\frac{1}{2} (x-\mu)^{\top} \Sigma^{-1} (x-\mu)\right)\\ &= \frac{1}{2\pi \sqrt{1-\rho^2}} \exp\left(-\frac{1}{2} \begin{bmatrix} x_1 & x_2 \end{bmatrix} \frac{1}{1-\rho^2} \begin{bmatrix} 1 & -\rho \\ -\rho & 1 \end{bmatrix} \begin{bmatrix} x_1 \\ x_2 \end{bmatrix}\right)\\ &= \frac{1}{2\pi \sqrt{1-\rho^2}} \exp\left(-\frac{1}{2(1-\rho^2)}\left(x_1^2 - 2\rho x_1 x_2 + x_2^2\right)\right). \end{aligned} We can simplify the Langevin Stein operator for this choice of $$p$$, since \begin{aligned} \nabla_x \log p(x) &= \nabla_x \log \left(\frac{1}{2\pi \sqrt{1-\rho^2}} \exp\left(-\frac{1}{2(1-\rho^2)}\left(x_1^2 - 2\rho x_1 x_2 + x_2^2\right)\right)\right)\\ &= \nabla_x \left( -\frac{1}{2(1-\rho^2)} \left( x_1^2 - 2\rho x_1 x_2 + x_2^2\right)\right)\\ &= -\frac{1}{2(1-\rho^2)}\nabla_x\left(x_1^2 - 2\rho x_1 x_2 + x_2^2\right)\\ &= \frac{1}{1-\rho^2}\begin{bmatrix} x_1 - \rho x_2\\ x_2 - \rho x_1 \end{bmatrix} \end{aligned} Equation 1 then comes out to \begin{aligned} \mathbb{E}_{X\sim p}[\mathcal{A}_{x} f(X)] &= \mathbb{E}_{X\sim p}[f(x) \nabla_x \cdot \log p(x)+\nabla_x \cdot f(x)]\\ &= \mathbb{E}_{X\sim p}\left[ \frac{f(X)}{1-\rho^2} (X_1 - \rho X_2 + X_2 - \rho X_1) + \partial_{X_1}f(X) + \partial_{X_2}f(X) \right]\\ &= \mathbb{E}_{X\sim p}\left[ \frac{f(X)(1-\rho)}{1-\rho^2} (X_1 + X_2 ) + \partial_{X_1}f(X) + \partial_{X_2}f(X) \right]\\ &= \mathbb{E}_{X\sim p}\left[ f(X)\frac{X_1 + X_2}{1+\rho} + \partial_{X_1}f(X) + \partial_{X_2}f(X) \right] \end{aligned} We can choose some simple $$\mathcal{F}$$ for the purposes intuition building, e.g. the linear set$$\mathcal{F} := \{ f(x) = a x_1 + b x_2 + c; a,b, c\in \mathbb{R}\}$$ — we would normally use something a bit more interesting. The expectation of the Stein operator for our bivariate Gaussian for this function class is then

\begin{aligned} \mathbb{E}_{X\sim p}[\mathcal{A}_{x} f(X)] &=\mathbb{E}_{X\sim p}\left[ f(X)\frac{X_1 + X_2}{1+\rho} + \partial_{X_1}f(X) + \partial_{X_2}f(X) \right]\\ &=\mathbb{E}_{X\sim p}\left[ (a X_1 + b X_2)\frac{X_1 + X_2}{1+\rho} + \partial_{X_1}(a X_1 + b X_2) + \partial_{X_2}(a X_1 + b X_2) \right]\\ &=\mathbb{E}_{X\sim p}\left[ (a X_1 + b X_2)\frac{X_1 + X_2}{1+\rho} + a + b \right]\\ &= \mathbb{E}_{X\sim p}\left[ \frac{(X_1 + X_2)(a X_1 + b X_2)}{1+\rho} + a + b \right]\\ \end{aligned}

Phew! OK that worked. I itch to plot these functions; I think there are two qualities of interest; the first is the function $$f$$ itself, and the second is the Stein operator applied to $$f$$ weighted by the density $$p$$.

Code
import jax.numpy as jnp
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio

pio.templates.default = "none"

# Plot params
n = 61

# Define the scalar function f
def f(x, a, b, c):
return a * x[..., 0] + b * x[..., 1] + c

# Define the log-density of the generic probability density function p
def log_p(x, rho):
return -0.5 * (
x[..., 0]**2 + x[..., 1]**2
- 2 * rho * x[..., 0] * x[..., 1]
) / (
1 - rho**2
) - jnp.log(
2 * jnp.pi * jnp.sqrt(1 - rho**2)
)

# Define the Stein operator applied to some f and p
def A_x_f(x, f, log_p):

# Fix specific values for rho and the parameters of f
rho = 0.3
a, b, c = 0.4, 0.25, -0.5
x1min, x1max = -3, 3
x2min, x2max = -3, 3

f_specific = lambda x: f(x, a, b, c)
log_p_specific = lambda x: log_p(x, rho)

# Create a grid of points
x1, x2 = np.meshgrid(
np.linspace(x1min, x1max , n, endpoint=True),
np.linspace(x2min, x2max, n, endpoint=True)
)
x = np.stack([x1, x2], axis=-1).reshape(-1, 2)

# Compute the function f at each point in x
f_x = f_specific(x).reshape(x1.shape)
p_x = np.exp(log_p_specific(x)).reshape(x1.shape)

# Compute the Stein operator for f at each point in x
A_x_f_x = A_x_f(x, f_specific, log_p_specific).reshape(x1.shape)
p_A_x_f_x = A_x_f_x * p_x

# Determine the z range with a margin
z_min = np.min(p_A_x_f_x) - 0.1
z_max = np.max(p_A_x_f_x) + 0.1

# Create the 3D surface plot
fig = go.Figure()

# Add the surface plot for the Stein operator colored by the density p_x
go.Surface(
z=p_A_x_f_x,
x=x1,
y=x2,
surfacecolor=p_x,
colorscale='Viridis',
showscale=False,  # Remove the color bar
opacity=0.9,  # slightly transparent
name='<i>p A<sub>x</sub> f</i>'  # Add name for legend
)
)

# Add the contour plot for f on the same axes, with a different color scheme and semi-transparent
go.Surface(
z=f_x,
x=x1,
y=x2,
colorscale='Cividis',
showscale=False,
opacity=0.5,  # make this semi-transparent
name='f',  # Add name for legend
contours={
"z": {
"show": True,
"start": np.min(f_x),
"end": np.max(f_x),
"size": (np.max(f_x) - np.min(f_x)) / 10,
"color":"white",
}
}
)
)

# Set the layout with an initial camera view closer to the z=0 plane
fig.update_layout(
title='<i>p A<sub>x</sub> f</i>&nbsp;and&nbsp;<i>f</i>',
scene=dict(
xaxis=dict(title='x<sub>1</sub>'),
yaxis=dict(title='x<sub>2</sub>'),
zaxis=dict(
# title='p A<sub>x</sub> f, f',
range=[z_min, z_max]),
camera=dict(
eye=dict(x=1.25, y=-1.25, z=0.5)  # Lower down closer to the z=0 plane
)
),
width=800,
height=800,
font=dict(family="Alegreya, serif"),
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)',
## legends don't work on 3d contours
# showlegend=True,  # Show legend
# legend=dict(
#     x=0.02,  # Position the legend on the left
#     y=0.98,
#     bgcolor='rgba(255,255,255,0.7)',  # Semi-transparent background for better visibility
#     bordercolor='Black',
#     borderwidth=1
# )
)
# Show the plot
fig.show()

Did that help us? Well, kinda. It is not really clear to me that I should trust that the second figure should actually integrate to 0. Did it?

p_A_x_f_x.sum().item()*(x1max-x1min)*(x2max-x2min)/(p_A_x_f_x.size)
0.01856179405239057

Hm, not convincingly exactly 0, but not so far off that we cannot persuade ourselves that it is simply a truncation problem.

## 2 Stein discrepancy

We make Equation 1 into a quantity that depends on two, potentially-different densities by taking the expectation over a different density $$q$$ than the one that generated the operator $$\mathcal{A}_{x},$$ and seeing if that does something useful:

$\mathbb{E}_{x\sim q}[\mathcal{A}_{x} f(X)]=0 \tag{3}$

Spoiler: it turns out that this does do something useful.

\begin{aligned} \mathbb{E}_{x\sim q}[\mathcal{A}_{x} f(X)] &=\mathbb{E}_{x\sim q}[\mathcal{A}_{x} f(X)] - \overbrace{\mathbb{E}_{x\sim q}[\mathcal{A}_q f(X)]}^{=0}\\ &=\mathbb{E}_{x\sim q}\big[f(x) \cdot \nabla_x \log p(x)+\nabla_x \cdot f(x)\\ &\qquad-f(x) \cdot \nabla_x \log p(x)-\nabla_x \cdot f(x)\big]\\ &=\mathbb{E}_{x\sim q}\left[f(x) \cdot \nabla_x \log p(x) -f(x) \cdot \nabla_x \log q(x)\right]\\ &=\mathbb{E}_{x\sim q}\left[f(x) \delta_{p,q}(x) \right] \end{aligned} where $$\delta_{p,q}(x):= \nabla_x \log p(x) -\nabla_x \log q(x)$$ is the difference in score function between $$p$$ and $$q$$.

By choosing a $$f$$ from some sufficiently rich $$\mathcal{F}$$ we can make this non-zero unless $$p=q$$ a.e., so this equation tells us something about how distinct are two densities $$p$$ and $$q$$, in this slightly weird but credible-seeming sense where we care about the difference in their score functions. i.e. this is some kind of score matching method.

This looks neat. How can we calculate it in practice? Obstacle: we have not specified $$f$$. We could fix some $$f$$ and use it to measure how different are $$p$$ and $$q$$ in some sense. Or we could choose some stochastic process which generates some random $$f$$s and estimate it over many $$f$$s, I guess? I assume that has been done.

The Stein Discrepancy takes a strong approach to controlling those $$f$$s: We control the supremum of that difference over all $$f$$ in some function class $$\mathcal{F}$$, so that we know that this difference $$p$$ and $$q$$ is not too bad for any $$f$$, since if we have found this Stein discrepancy, we have found how bad it is over the worst $$f$$:

$\sqrt{S(q, p)} = \sup_{f \in \mathcal{F}} \left| \mathbb{E}_{x \sim q}[\operatorname{trace}(\mathcal{A}_{x} f(X))] \right|$

Notice we snuck in a trace there as well to make it a scalar? This ended up being the most confusing thing for me; how many dimensions even is anything in this equation?

Let us consider how we might find this ‘worst’ $$f$$ which gives us this most powerful guarantee of the difference between $$p$$ and $$q$$. There are a few steps.

First, we use the linearity of that Stein operator $$\mathcal{A}_{x}$$, mentioned earlier. Suppose that $$f$$ can be represented as a finite linear combination $$f(x)=\sum_i w_i f_i(x)$$ of a set of basis functions $$f_i(x)$$ for some coefficients $$w_i$$ s.t. $$\|w\| \leq 1$$. Then we can define the ‘violation of Stein-ness’ by $\mathbb{E}_q\left[\mathcal{A}_{x} f\right]=\mathbb{E}_q\left[\mathcal{A}_{x} \sum_i w_i f_i(x)\right]=\sum_i w_i \beta_i,$ where $\beta_i=\mathbb{E}_{x \sim q}\left[\mathcal{A}_{x} f_i(x)\right] .$

This only works for univariate densities, so far. To make the discrepancy be a scalar even for multivariate problems (in the sense of densities over multidimensional spaces) we define the violation as $\mathbb{E}_{X\sim p}\left[\operatorname{trace}\left(\mathcal{A}_q \boldsymbol{f}(x)\right)\right]=\mathbb{E}_{X\sim p}\left[\left(\boldsymbol{s}_q(x)-\boldsymbol{s}_p(x)\right)^{\top} \boldsymbol{f}(x)\right]$

The optimal (i.e. greatest _mis_match) coefficients $$w_i$$ would then b $\max _w \sum_i w_i \beta_i, \quad \text { s.t. } \quad\|w\| \leq 1$

OK, so this is notionally an optimisation problem we can solve, choosing the $$w_i$$ values to be as terrible as possible, and then seeing how bad the most-terrible values are.

The distances arising from these are apparently integral probability metric according to .

However, it looks like a nested optimisation problem, which can be tedious. Can we do better?

## 3 Kernelized Stein Discrepancy

When we see a challenge of this kind — where we wish we could use ‘more tricks’ in our function space — it typically suggests that the trick we are looking for might be the kernel trick. This entail choosing the tricky function class $$\mathcal{F}$$ to be a reproducing kernel Hilbert space (“RKHS”) and seeing what that does to the problem, which we call i kernelizing. Frequently that makes life easier. Spoiler: it helps here too.

So, how kernelized Stein discrepancy works is as follows: $$\mathcal{H}$$ is the RKHS with associated kernel $$k:\mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}$$. We require that $$k(x, x'): \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}$$ be positive definite kernel. The RKHS $$\mathcal{H}$$ with kernel $$k$$ includes functions of form $$f(x)=\sum_i w_i k\left(x, x_i\right)$$, equipped with RKHS inner product $$\langle f, g\rangle_{\mathcal{H}}=\sum_{i j} w_i v_j k\left(x_i, x_j\right)$$ for $$g=\sum_j v_j k\left(x, x_j\right)$$ and RKHS norm $$\|f\|_{\mathcal{H}}^2=\sum_{i j} w_i w_j k\left(x_i, x_j\right)$$.

Now we have some extra structure: \begin{aligned} f(x)&=\langle f(\cdot), k(x, \cdot)\rangle_{\mathcal{H}} && \text{reproducing property}\\ \nabla_x f(x)&=\left\langle f(\cdot), \nabla_x k(x, \cdot)\right\rangle_{\mathcal{H}} && \text{gradient property} \end{aligned}

This is one of two kernels that we invoke; in fact this one gets “Steinalized” using our operator $$\mathcal{A}_{x}$$ to give us another kernel:

\begin{aligned} k_{\mathcal{A}}(x, x') &:=\operatorname{trace}[\mathcal{A}_{x}\mathcal{A}_{x'} k(x, x')]. \end{aligned}

If we plug in Equation 2, we get a special form,

\begin{aligned} k_{\mathcal{A}}(x, x') =&\nabla_{x}\cdot \nabla_{x'}k(x,x')\\ &+ \nabla_{x}k(x,x')\cdot \nabla_{x'}\log p(x')\\ &+ \nabla_{x'}k(x,x')\cdot \nabla_{x}\log p(x)\\ &+ k(x,x')(\nabla_{x}\log p(x))\cdot (\nabla_{x'}\log p(x')) \end{aligned} Woof! look at those score functions everywhere!

Moreover, \begin{aligned} \mathbb{E}_{x \sim q}\left[\operatorname{trace}\mathcal{A}_{x} \boldsymbol{f}(x)\right] &=\sum_{i=1}^{d} \left\langle f_i(\cdot), \mathbb{E}_{x \sim q}\left[\mathcal{A}_{x} k_i(\cdot, x)\right]\right\rangle_{\mathcal{H}}\\ &=\left\langle \boldsymbol{f}(\cdot), \mathbb{E}_{x \sim q}\left[\mathcal{A}_{x} k(\cdot, x)\right]\right\rangle_{\mathcal{H}^d} \end{aligned}

We take $$\mathcal{F}$$ to be the unit ball in that RKHS, i.e. $$\mathcal{F}:=\{\boldsymbol{f};\|\boldsymbol{f}\|_{\mathcal{H}^d} \leq 1 \}$$.

Then we can write

$\sqrt{S(q, p)}=\sup _{f \in {\color{red}\mathcal{H}},\|f\|_{\color{red}\mathcal{H}^d} \leq 1}\left\{\mathbb{E}_{x \sim q}\left[\operatorname{trace}\mathcal{A}_{x} f(x)\right]\right\} .$ i.e. it is just the same, but we have restricted the function class to be an RKHS.

Define $\beta_{q, p}(\cdot)=\mathbb{E}_{x' \sim q} \mathcal{A}_{x} k(\cdot, x').$

Finding that supremum is then equivalent to solving $\sup _f\left\langle f, \beta_{q, p}\right\rangle_{\mathcal{H}}, \quad \text { s.t. }\|f\|_{\mathcal{H}} \leq 1 .$

From this we get $$\phi(x)=\phi_{q, p}^*(x) /\left\|\phi_{q, p}^*\right\|_{\mathcal{H}^d}$$, where $\phi_{q, p}^*(\cdot)=\mathbb{E}_{x \sim q}\left[\mathcal{A}_{x} k(x, \cdot)\right], \quad \text { for which we have } \quad \mathbb{S}(q, p)=\left\|\phi_{q, p}^*\right\|_{\mathcal{H}^d}^2$

We maximise this, I assert, if we set $$f=\beta_{q, p} /\left\|\beta_{q, p}\right\|_{\mathcal{H}}$$, normalising it to be on the unit ball (question: why can it not be on the interior?) at the point that maximises the expectation. Thus \begin{align} S(q, p) &=\left\|\beta_{q, p}\right\|_{\mathcal{H}^d}^2\\ &=\mathbb{E}_{x, x' \sim q}\left[\kappa_p\left(x, x'\right)\right] \end{align} where $\kappa_p\left(x, x'\right):=\mathcal{A}_{x}^x \mathcal{A}_{x}^{x'} k\left(x, x'\right) .$ Here we defined $$\mathcal{A}_{x}^x$$ and $$\mathcal{A}_{x}^{x'}$$ represents the Stein operator w.r.t. variable $$x$$ and $$x'$$, respectively. $$\kappa_p\left(x, x'\right)$$ is the “Steinalized” kernel obtained by applying Stein operator on $$k\left(x, x'\right)$$ twice.

$S(p, q)=\mathbb{E}_{x, x' \sim p}\left[\boldsymbol{\delta}_{q, p}(x)^{\top} k\left(x, x'\right) \boldsymbol{\delta}_{q, p}\left(x'\right)\right],$ where $$\boldsymbol{\delta}_{q, p}(x)=s_q(x)-s_p(x)$$ is the score difference between $$p$$ and $$q$$, and $$x, x'$$ are i.i.d. draws from $$p(x)$$.

It is a mess to write out in full though.

## 4 Stein Variational Gradient Descent

The next bit comes from Q. Liu and Wang (2019). It turns out that we can use this Stein trick to sample from some interest int distributions, by using the Stein discrepancy as a loss function. Interestingly, this works on posterior distributions in particular.

We manufacture an empirical $$q$$ by using a set of particles $$\{x_i\}_{i=1}^n$$.

The gradient descent here is not SGD where we assimilate gradient steps by looking at examples; it is rather a gradient descent in parameter space which converges in a towards a good approximation of the posterior.

A worked example will sort this out.

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax.scipy.stats import norm
import plotly.graph_objects as go
import plotly.io as pio

pio.templates.default = "none"

# Define the target distribution (standard normal in this case)
def log_p(x, rho):
return -0.5 * (
x[..., 0]**2 + x[..., 1]**2
- 2 * rho * x[..., 0] * x[..., 1]
) / (
1 - rho**2
) - jnp.log(
2 * jnp.pi * jnp.sqrt(1 - rho**2)
)
# Define the RBF kernel
def rbf_kernel(x, y, h):
return jnp.exp(-jnp.sum((x - y)**2) / (2 * h**2))

# Compute the Stein kernel
@jit
def stein_kernel(x, x_prime, h, rho):
k = rbf_kernel(x, x_prime, h)

return term1 + term2 + term3 + term4

# Define the SVGD update
@jit
def svgd_update(particles, h, rho, lr=0.001):
n_particles = particles.shape[0]

x_i = particles[i]

x_j = particles[j]
k_stein = stein_kernel(x_i, x_j, h, rho)
return k_stein

phi = jnp.mean(k_stein_values, axis=0)

return particles + lr * updates

# Initialize particles
key = jax.random.PRNGKey(0)
n_particles = 20
particles = jax.random.normal(key, (n_particles, 2))  # 2D particles

# Set kernel bandwidth
h = 0.1
rho = 0.8  # Correlation parameter

# Set the number of iterations
n_iterations = 1000

# Store particle locations at different stages
initial_particles = particles.copy()
mid_particles = None
final_particles = None

# Run SVGD for a few iterations
for i in range(n_iterations):
particles = svgd_update(particles, h, rho)

# Detect NaN values
if jnp.isnan(particles).any():
num_nan_particles = jnp.isnan(particles).any(axis=1).sum()
raise ValueError(f"Detected {num_nan_particles} NaN values at iteration {i + 1}. Diagnostic Info: particles shape {particles.shape}, step number {i + 1}")

if i == n_iterations // 2:
mid_particles = particles.copy()

final_particles = particles.copy()

# Create a mesh grid for plotting the density
# Create a mesh grid for plotting the density
x = np.linspace(-3, 3, 100)
y = np.linspace(-3, 3, 100)
X, Y = np.meshgrid(x, y)
XY = np.stack([X.ravel(), Y.ravel()], axis=-1)

# Compute the log-density for the mesh grid
rho = 0.5  # Example correlation value
Z = np.exp(jax.vmap(lambda xy: log_p(xy, rho))(XY).reshape(X.shape))

fig = go.Figure()

x=x,
y=y,
z=Z,
colorscale='Viridis',
opacity=0.5
))

# Add the scatter plot of initial particles
x=initial_particles[:, 0],
y=initial_particles[:, 1],
mode='markers',
marker=dict(size=5, color='red', opacity=0.8),
name='Initial Particles'
))

# Add the scatter plot of middle particles
x=mid_particles[:, 0],
y=mid_particles[:, 1],
mode='markers',
marker=dict(size=5, color='green', opacity=0.8),
name='Middle Particles'
))

# Add the scatter plot of final particles
x=final_particles[:, 0],
y=final_particles[:, 1],
mode='markers',
marker=dict(size=5, color='blue', opacity=0.8),
name='Final Particles'
))

fig.update_layout(
title='SVGD Particles at Different Stages with Target Density',
xaxis_title='x1',
yaxis_title='x2',
font=dict(family="Alegreya, serif"),
template=pio.templates.default,
legend=dict(
x=0,
y=1,
traceorder='normal',
font=dict(size=12),
bgcolor='rgba(255, 255, 255, 0.5)',
bordercolor='Black',
borderwidth=1
)
)

fig.show()

## 5 For mixtures

Mixtures in general are helpful in variational inference . See ELBO-within-Stein , Nonlinear Stein , Stein Mixtures

## 7 As moment matching

See Q. Liu and Wang (2018).

## 8 By message passing

Define a kernel over factors and now the Stein messages may be passed locally. Discovered simultaneously in 2018 by D. Wang, Zeng, and Liu (2018) and Zhuo et al. (2018).

To read: Zhou and Qiu (2023).

## 10 References

Abbasi-Yadkori, Pacchiano, and Phan. 2020. arXiv.org.
Alsup, Venturi, and Peherstorfer. 2022. In Proceedings of the 2nd Mathematical and Scientific Machine Learning Conference.
Ambrogioni, Güçlü, Güçlütürk, et al. 2018. In Proceedings of the 32Nd International Conference on Neural Information Processing Systems. NIPS’18.
Anastasiou, Barp, Briol, et al. 2023. Statistical Science.
Chakraborty, Bedi, Koppel, et al. 2023. In Proceedings of the 40th International Conference on Machine Learning.
Chen, and Ghattas. 2020. In Proceedings of the 34th International Conference on Neural Information Processing Systems. NIPS ’20.
Chu, Minami, and Fukumizu. 2022. In.
Chwialkowski, Strathmann, and Gretton. 2016. In Proceedings of the 33rd International Conference on International Conference on Machine Learning - Volume 48. ICML’16.
Detommaso, Cui, Spantini, et al. 2018. In Proceedings of the 32nd International Conference on Neural Information Processing Systems. NIPS’18.
Detommaso, Hoitzing, Cui, et al. 2019. arXiv:1901.07987 [Cs, Stat].
Feng, Wang, and Liu. 2017. In UAI 2017.
Gong, Peng, and Liu. 2019. In Proceedings of the 36th International Conference on Machine Learning.
Gorham, and Mackey. 2015. In Proceedings of the 28th International Conference on Neural Information Processing Systems - Volume 1. NIPS’15.
———. 2017. In Proceedings of the 34th International Conference on Machine Learning.
Gorham, Raj, and Mackey. 2020. arXiv:2007.02857 [Cs, Math, Stat].
Han, Ding, Liu, et al. 2020. In Proceedings of the Twenty Third International Conference on Artificial Intelligence and Statistics.
Han, and Liu. 2018. In Proceedings of the 35th International Conference on Machine Learning.
Huggins, Campbell, Kasprzak, et al. 2018. arXiv:1806.10234 [Cs, Stat].
Ley, Reinert, and Swan. 2017. Probability Surveys.
Li, Li, Liu, et al. 2020. Communications in Applied Mathematics and Computational Science.
Liu, Qiang. 2016a.
———. 2017.
Liu, Qiang, Lee, and Jordan. 2016. In Proceedings of The 33rd International Conference on Machine Learning.
Liu, Qiang, and Wang. 2018. In Proceedings of the 32nd International Conference on Neural Information Processing Systems. NIPS’18.
———. 2019. In Advances In Neural Information Processing Systems.
Liu, Chang, and Zhu. 2018. Proceedings of the AAAI Conference on Artificial Intelligence.
Liu, Chang, Zhuo, Cheng, et al. 2019. In Proceedings of the 36th International Conference on Machine Learning.
Liu, Xing, Zhu, Ton, et al. 2022. In Proceedings of The 25th International Conference on Artificial Intelligence and Statistics.
Markatou, Karlis, and Ding. 2021. Annual Review of Statistics and Its Application.
Matsubara, Knoblauch, Briol, et al. 2022. Journal of the Royal Statistical Society Series B: Statistical Methodology.
Nalisnick, and Smyth. 2017. In NIPS2017 (Workshop).
Oates, Girolami, and Chopin. 2017. Journal of the Royal Statistical Society Series B: Statistical Methodology.
Pielok, Bischl, and Rügamer. 2023. In.
Pulido, and van Leeuwen. 2019. Journal of Computational Physics.
Pulido, Van Leeuwen, and Posselt. 2019. In Computational Science – ICCS 2019. ICCS 2019. Lecture Notes in Computer Science.
Ranganath, Tran, and Blei. 2016. In PMLR.
Rønning. 2023. “A Probabilistic Approach to the Protein Fold- Ing Problem Using Stein-Based Variational Inference.”
Rønning, Al-Sibahi, Ley, et al. 2021.
Stordal, Moraes, Raanes, et al. 2021. Mathematical Geosciences.
Tamang, Ebtehaj, van Leeuwen, et al. 2021. Nonlinear Processes in Geophysics.
Wang, Dilin, and Liu. 2019. In Proceedings of the 36th International Conference on Machine Learning.
Wang, Ziyu, Ren, Zhu, et al. 2018. In.
Wang, Dilin, Tang, Bajaj, et al. 2019. In Proceedings of the 33rd International Conference on Neural Information Processing Systems.
Wang, Dilin, Zeng, and Liu. 2018.
Wen, and Li. 2022. Statistics and Computing.
Xu, and Matsuda. 2021. arXiv:2103.00895 [Stat].
Yang, Liu, Rao, et al. 2018. In Proceedings of the 35th International Conference on Machine Learning.
Zhang, Zhang, Carin, et al. 2020. In International Conference on Artificial Intelligence and Statistics.
Zhao, Wang, Zhu, et al. 2023. Information Sciences.
Zhou, and Qiu. 2023.
Zhuo, Liu, Shi, et al. 2018. In Proceedings of the 35th International Conference on Machine Learning.