import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from livingthing.matplotlib_style import set_livingthing_style, reset_default_style
set_livingthing_style()
rng = np.random.default_rng(11)
# Dimensions
d, m, N = 60, 10, 300
# Prior: squared-exponential on a 1D grid
def se_cov(n, ell=12.0, sigma=1.0, jitter=1e-8):
x = np.arange(n)[:, None]
S = sigma**2 * np.exp(-0.5 * (x - x.T)**2 / ell**2)
return S + jitter*np.eye(n)
m0 = np.zeros(d)
Sigma = se_cov(d, ell=12.0, sigma=1.0)
L = np.linalg.cholesky(Sigma)
# Observation operator (m point picks) and noise
idx = np.sort(rng.choice(d, size=m, replace=False))
H = np.zeros((m, d)); H[np.arange(m), idx] = 1.0
rho = 0.15
R = (rho**2) * np.eye(m)
# One truth and its noisy observations
x_true = m0 + L @ rng.standard_normal(d)
y_star = H @ x_true + rho * rng.standard_normal(m)
# Prior ensemble
X = m0[:, None] + L @ rng.standard_normal((d, N))
# Choose EnKF flavour
stochastic = True # True: perturbed obs; False: deterministic/square-root
if stochastic:
Eps = rho * rng.standard_normal((m, N))
Y = H @ X + Eps
gamma_diag = 1e-9 * np.ones(m) # only tiny ridge (R already inside sample cov through Y)
else:
Y = H @ X
gamma_diag = np.diag(R) + 1e-9
# Centered ensembles
Xc = X - X.mean(axis=1, keepdims=True) # d x N
Yc = Y - Y.mean(axis=1, keepdims=True) # m x N
# Woodbury factors
sqrtN1 = np.sqrt(N - 1.0)
U = Yc / sqrtN1 # m x N
Gamma_inv = 1.0 / gamma_diag # length-m
# residuals in obs space
Res = (y_star[:, None] - Y) # m x N
# Build A = U^T Γ^{-1} U (N x N) and B = U^T Γ^{-1} Res (N x N) efficiently
U_over_Gamma = U * Gamma_inv[:, None] # m x N (diag multiply)
A = U.T @ U_over_Gamma # N x N
B = U.T @ (Res * Gamma_inv[:, None]) # N x N
# Solve (I + A) Q = B for Q, then ΔX = (1/sqrt(N-1)) Xc @ Q
I_N = np.eye(N)
M = np.linalg.inv(I_N + A)
Q = M @ B # N x N
dX = (Xc @ Q) / sqrtN1 # d x N
X_post_enkf = X + dX
# Analytic posterior ensemble for comparison
S = H @ Sigma @ H.T + R
K = Sigma @ H.T @ np.linalg.inv(S)
mu_post = m0 + K @ (y_star - H @ m0)
Sigma_post = Sigma - K @ H @ Sigma
L_post = np.linalg.cholesky(Sigma_post + 1e-9*np.eye(d))
X_post_analytic = mu_post[:, None] + L_post @ rng.standard_normal((d, N))
# Plotting (single axes)
alpha = float(N)**(-0.75)
alpha = max(min(alpha, 0.25), 0.003)
x_axis = np.arange(d)
c_enkf = 'tab:blue'
c_analytic = 'tab:orange'
c_obs = 'black'
y_min = min(X_post_enkf.min(), X_post_analytic.min(), y_star.min())
y_max = max(X_post_enkf.max(), X_post_analytic.max(), y_star.max())
pad = 0.05*(y_max - y_min + 1e-12)
ylims = (y_min - pad, y_max + pad)
plt.figure(figsize=(9, 4.8), dpi=120)
for i in range(N):
plt.plot(x_axis, X_post_enkf[:, i], color=c_enkf, alpha=alpha, linewidth=1.0)
plt.plot(x_axis, X_post_analytic[:, i], color=c_analytic, alpha=alpha, linewidth=1.0)
plt.scatter(idx, y_star, s=32, color=c_obs, marker='o', alpha=0.9, zorder=3)
proxy_enkf = Line2D([0], [0], color=c_enkf, lw=2, alpha=0.9, label='Empirical Matheron / EnKF (Woodbury, ΔX)')
proxy_an = Line2D([0], [0], color=c_analytic, lw=2, alpha=0.9, label='Analytic Gaussian')
proxy_obs = Line2D([0], [0], color=c_obs, marker='o', lw=0, label='Observations (m=10)')
plt.legend(handles=[proxy_enkf, proxy_an, proxy_obs], loc='upper right', frameon=False)
plt.title(f'Posterior samples (N={N}, {"stochastic" if stochastic else "deterministic"} EnKF)')
plt.xlabel('state index')
plt.ylabel('value')
plt.ylim(*ylims)
plt.xlim(0, d-1)
plt.tight_layout()
plt.show()
# Quick numeric check
mean_rel_err = np.linalg.norm(X_post_enkf.mean(axis=1) - mu_post) / (np.linalg.norm(mu_post) + 1e-12)
cov_rel_err = np.linalg.norm(np.cov(X_post_enkf, bias=False) - Sigma_post, 'fro') / (np.linalg.norm(Sigma_post, 'fro') + 1e-12)
print("Relative mean error (EnKF vs analytic):", f"{mean_rel_err:.3e}")
print("Relative cov error (EnKF vs analytic):", f"{cov_rel_err:.3e}")