## Code

```
import jax.numpy as jnp
import numpy as np
from jax import grad, vmap
import plotly.graph_objects as go
import plotly.io as pio
pio.templates.default = "none"
# Plot params
n = 61
# Define the scalar function f
def f(x, a, b, c):
return a * x[..., 0] + b * x[..., 1] + c
# Define the log-density of the generic probability density function p
def log_p(x, rho):
return -0.5 * (
x[..., 0]**2 + x[..., 1]**2
- 2 * rho * x[..., 0] * x[..., 1]
) / (
1 - rho**2
) - jnp.log(
2 * jnp.pi * jnp.sqrt(1 - rho**2)
)
# Define the Stein operator applied to some f and p
def A_x_f(x, f, log_p):
grad_log_p = vmap(grad(log_p))(x)
grad_f = vmap(grad(f))(x)
return grad_log_p.sum(axis=-1) * f(x) + grad_f.sum(axis=-1)
# Fix specific values for rho and the parameters of f
rho = 0.3
a, b, c = 0.4, 0.25, -0.5
x1min, x1max = -3, 3
x2min, x2max = -3, 3
f_specific = lambda x: f(x, a, b, c)
log_p_specific = lambda x: log_p(x, rho)
# Create a grid of points
x1, x2 = np.meshgrid(
np.linspace(x1min, x1max , n, endpoint=True),
np.linspace(x2min, x2max, n, endpoint=True)
)
x = np.stack([x1, x2], axis=-1).reshape(-1, 2)
# Compute the function f at each point in x
f_x = f_specific(x).reshape(x1.shape)
p_x = np.exp(log_p_specific(x)).reshape(x1.shape)
# Compute the Stein operator for f at each point in x
A_x_f_x = A_x_f(x, f_specific, log_p_specific).reshape(x1.shape)
p_A_x_f_x = A_x_f_x * p_x
# Determine the z range with a margin
z_min = np.min(p_A_x_f_x) - 0.1
z_max = np.max(p_A_x_f_x) + 0.1
# Create the 3D surface plot
fig = go.Figure()
# Add the surface plot for the Stein operator colored by the density p_x
fig.add_trace(
go.Surface(
z=p_A_x_f_x,
x=x1,
y=x2,
surfacecolor=p_x,
colorscale='Viridis',
showscale=False, # Remove the color bar
opacity=0.9, # slightly transparent
name='<i>p A<sub>x</sub> f</i>' # Add name for legend
)
)
# Add the contour plot for f on the same axes, with a different color scheme and semi-transparent
fig.add_trace(
go.Surface(
z=f_x,
x=x1,
y=x2,
colorscale='Cividis',
showscale=False,
opacity=0.5, # make this semi-transparent
name='f', # Add name for legend
contours={
"z": {
"show": True,
"start": np.min(f_x),
"end": np.max(f_x),
"size": (np.max(f_x) - np.min(f_x)) / 10,
"color":"white",
}
}
)
)
# Set the layout with an initial camera view closer to the z=0 plane
fig.update_layout(
title='<i>p A<sub>x</sub> f</i> and <i>f</i>',
scene=dict(
xaxis=dict(title='x<sub>1</sub>'),
yaxis=dict(title='x<sub>2</sub>'),
zaxis=dict(
# title='p A<sub>x</sub> f, f',
range=[z_min, z_max]),
camera=dict(
eye=dict(x=1.25, y=-1.25, z=0.5) # Lower down closer to the z=0 plane
)
),
width=800,
height=800,
font=dict(family="Alegreya, serif"),
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)',
## legends don't work on 3d contours
# showlegend=True, # Show legend
# legend=dict(
# x=0.02, # Position the legend on the left
# y=0.98,
# bgcolor='rgba(255,255,255,0.7)', # Semi-transparent background for better visibility
# bordercolor='Black',
# borderwidth=1
# )
)
# Show the plot
fig.show()
```