## Code

```
import jax.numpy as jnp
import numpy as np
from jax import grad, vmap, jacfwd, jacrev
import plotly.figure_factory as ff
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
from _plotly_styles import textbook
pio.templates.default = "none"
# Define shared parameters for scale
A_scale_factor = 1.5
f_scale_factor = 0.1
arrow_scale_factor = 0.3
n = 25
# Define the generic function f
def f(x, a1, b1, a2, b2):
return jnp.stack([
a1 * x[..., 0] + b1 * x[..., 1],
a2 * x[..., 0] + b2 * x[..., 1]],
axis=-1)
# Define the log-density of the generic probability density function p
def log_p(x, rho):
return -0.5 * (
x[..., 0]**2 + x[..., 1]**2
- 2 * rho * x[..., 0] * x[..., 1]
) / (
1 - rho**2
) - jnp.log(
2 * jnp.pi * jnp.sqrt(1 - rho**2)
)
# Fix specific values for rho and the parameters of f
rho = 0.5
a1, b1, a2, b2 = 1, 0.25, -0.75, -1
f_specific = lambda x: f(x, a1, b1, a2, b2)
log_p_specific = lambda x: log_p(x, rho)
# Define the Stein operator applied to a specific f and p
def A_p_f(x, f, log_p):
grad_log_p = vmap(grad(log_p))(x)
jac_f = vmap(jacfwd(f))(x)
return grad_log_p[:, None, :] * f(x)[:, :, None] + jac_f
# Create a grid of points
x1, x2 = np.meshgrid(
np.linspace(-3, 3, n, endpoint=True),
np.linspace(-3, 3, n, endpoint=True)
)
x = np.stack([x1, x2], axis=-1).reshape(-1, 2)
# Compute the function f at each point in x
f_x = f_specific(x)
p_x = np.exp(log_p_specific(x))
# Compute the Stein operator for f at each point in x
A_f_p_x = A_p_f(x, f_specific, log_p_specific)
p_A_f_p_x = A_f_p_x * p_x[:, None, None]
# Create a subplot with 2 rows and 1 column
fig = make_subplots(
rows=2, cols=1,
subplot_titles=(
'<i>f</i>',
'<i>p A<sub>p</sub> f</i>'),
vertical_spacing=0.1,
row_heights=[0.5, 0.5])
textbook(fig)
# Add the quiver plot for the function f to the first subplot
fig.add_trace(
ff.create_quiver(
x1, x2,
f_x[:, 0].reshape(x1.shape), f_x[:, 1].reshape(x2.shape),
scale=f_scale_factor,
arrow_scale=arrow_scale_factor,
name='f', line_width=1).data[0],
row=1, col=1
)
# Add the quiver plot for the first component of the Stein operator to the second subplot
fig.add_trace(
ff.create_quiver(
x1, x2,
p_A_f_p_x[:, 0, 0].reshape(x1.shape),
p_A_f_p_x[:, 1, 0].reshape(x2.shape),
scale=A_scale_factor,
arrow_scale=arrow_scale_factor,
name='Component 1', line_width=1,
line_color='blue').data[0],
row=2, col=1
)
# Add the quiver plot for the second component of the Stein operator to the second subplot
fig.add_trace(
ff.create_quiver(
x1, x2,
p_A_f_p_x[:, 0, 1].reshape(x1.shape),
p_A_f_p_x[:, 1, 1].reshape(x2.shape),
scale=A_scale_factor,
arrow_scale=arrow_scale_factor,
name='Component 2', line_width=1,
line_color='red').data[0],
row=2, col=1
)
# Set the layout
fig.update_layout(
width=400,
height=800,
font=dict(family="Alegreya, serif"),
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)',
showlegend=False
)
# Update axis ranges and aspect ratio for both subplots
axis_range = [-3, 3]
for i in range(1, 3):
fig.update_xaxes(
range=axis_range, row=i, col=1)
fig.update_yaxes(
range=axis_range, row=i, col=1,
scaleanchor=f'x{i}', scaleratio=1)
# Show the plot
fig.show()
```