The Score Function Estimator

a.k.a. REINFORCE; a gradient estimator for expectations

September 13, 2024 — September 13, 2024

Bayes
calculus
density
estimator distribution
Monte Carlo
probabilistic algorithms
probability
risk
uncertainty
Figure 1

The score function estimator, a.k.a. log-derivative trick, a.k.a. REINFORCE (all-caps, for some reason?), is a generic method that works on various types of variables; it has notoriously high variance if done naïvely. Credited to (Williams 1992), it must be older than that.

This is the fundamental insight:

\[ \begin{aligned} g(\ell) &= \frac{\partial}{\partial \theta} \mathbb{E}_{\mathsf{x}\sim p(\mathsf{x};\theta)} \ell(\mathsf{x}) \\ &= \frac{\partial}{\partial \theta} \int \ell(x) p(x;\theta) \mathrm{d} x\\ &= \int \ell(x) \frac{\partial}{\partial \theta} p(x;\theta) \mathrm{d} x\\ &= \mathbb{E}_{\mathsf{x}\sim p(\mathsf{x};\theta)} \ell(\mathsf{x}) \frac{\partial}{\partial \theta}\log p(\mathsf{x};\theta) \end{aligned} \]

This suggests a simple and obvious Monte Carlo estimate of the gradient by choosing sample \(x_i\sim p(x;\theta)\):

\[ \begin{aligned} \hat{g}_{\text{REINFORCE}}(\ell) &= \sum_i \ell(x_i) \frac{\partial}{\partial \theta}\log p(x_i;\theta) \end{aligned} \]

For unifying overviews, see (Mohamed et al. 2020; Schulman et al. 2015; van Krieken, Tomczak, and Teije 2021) and the Storchastic docs.

It is annoyingly hard to find a clear example of this method online, despite its simplicity; all the code examples I see wrap it up with reinforcement learning or some other unnecessarily specific complexity.

Laurence Davies and I put together this demo, in which we try to find the parameters that minimize the difference between the categorical distribution we sample from and some target distribution.

import torch

# True target distribution probabilities
true_probs = torch.tensor([0.1, 0.6, 0.3])

# Optimization parameters
n_batch = 1000
n_iter = 3000
lr = 0.01

def loss(x):
    """
    The target loss, a negative log-likelihood for a
    categorical distribution with the given probabilities.
    """
    return -torch.distributions.Multinomial(
        total_count=1, probs=true_probs).log_prob(x)


# Set the seed for reproducibility
torch.manual_seed(42)

# Initialize the parameter estimates
theta_hat = torch.nn.Parameter(torch.tensor([0., 0., 0.]))
optimizer = torch.optim.Adam([theta_hat], lr=lr)

for epoch in range(n_iter):
    optimizer.zero_grad()
    # Sample from the estimated distribution
    x_sample = torch.distributions.Multinomial(
        1, logits=theta_hat).sample((n_batch,))
    # exaluate log density at the sample points
    log_p_theta_x = torch.distributions.Multinomial(
        1, logits=theta_hat).log_prob(x_sample)
    # Evaluate the target function at the sample points
    f_hat = loss(x_sample)

    # Compute the gradient of the log density wrt parameters.
    # The `grad_outputs` multiply the `f_hat` by gradient directly.
    grad_log_p_theta_x = torch.autograd.grad(
        outputs=log_p_theta_x,
        inputs=theta_hat,
        grad_outputs=torch.ones_like(log_p_theta_x),
        create_graph=True)[0]

    # The final gradients are weighted over the sample points
    final_gradients = (
        f_hat.detach().unsqueeze(1)
        * grad_log_p_theta_x
      ).mean(dim=0)
    theta_hat.grad = final_gradients

    optimizer.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Estimated Probs:"
          f"{torch.softmax(theta_hat, dim=0).detach().numpy()}")

# Display the final estimated probabilities
estimated_final_probs = torch.softmax(theta_hat, dim=0)
print("Final Estimated Probabilities: "
  f" {estimated_final_probs.detach().numpy()}"
  f" (True Probabilities: {true_probs.detach().numpy()}")

Note that the batch size there is very large. If we set it to be smaller, the variance of the estimator is too high to be useful.

Classically we might address such problems with a diminishing learning rate as per SGD, but I have lazily not done that here.

1 Rao-Blackwellization

Rao-Blackwellization (Casella and Robert 1996) seems like a natural extension to reduce the variance. How would it work? Liu et al. (2019) is a contemporary example; I have a vague feeling that I saw something similar in Rubinstein and Kroese (2016). TODO: follow up.

2 References

Casella, and Robert. 1996. Rao-Blackwellisation of Sampling Schemes.” Biometrika.
Fu. 2005. Stochastic Gradient Estimation.”
Grathwohl, Choi, Wu, et al. 2018. Backpropagation Through the Void: Optimizing Control Variates for Black-Box Gradient Estimation.” In Proceedings of ICLR.
Hyvärinen. 2005. Estimation of Non-Normalized Statistical Models by Score Matching.” The Journal of Machine Learning Research.
Kool, Hoof, and Welling. 2019. Estimating Gradients for Discrete Random Variables by Sampling Without Replacement.” In.
Kool, van Hoof, and Welling. 2019. Buy 4 Reinforce Samples, Get a Baseline for Free!
Liu, Regier, Tripuraneni, et al. 2019. Rao-Blackwellized Stochastic Gradients for Discrete Distributions.” In.
Mohamed, Rosca, Figurnov, et al. 2020. Monte Carlo Gradient Estimation in Machine Learning.” Journal of Machine Learning Research.
Richter, Boustati, Nüsken, et al. 2020. VarGrad: A Low-Variance Gradient Estimator for Variational Inference.”
Rubinstein, and Kroese. 2016. Simulation and the Monte Carlo Method. Wiley series in probability and statistics.
Ruiz, Titsias, and Blei. 2016. The Generalized Reparameterization Gradient.” In Advances In Neural Information Processing Systems.
Schulman, Heess, Weber, et al. 2015. Gradient Estimation Using Stochastic Computation Graphs.” In Proceedings of the 28th International Conference on Neural Information Processing Systems - Volume 2. NIPS’15.
Stoker. 1986. Consistent Estimation of Scaled Coefficients.” Econometrica.
Tucker, Mnih, Maddison, et al. 2017. REBAR: Low-Variance, Unbiased Gradient Estimates for Discrete Latent Variable Models.” In Proceedings of the 31st International Conference on Neural Information Processing Systems. NIPS’17.
van Krieken, Tomczak, and Teije. 2021. Storchastic: A Framework for General Stochastic Automatic Differentiation.” In arXiv:2104.00428 [Cs, Stat].
Weber, Heess, Buesing, et al. 2019. Credit Assignment Techniques in Stochastic Computation Graphs.”
Williams. 1992. Simple Statistical Gradient-Following Algorithms for Connectionist Reinforcement Learning.” Machine Learning.
Xu, Quiroz, Kohn, et al. 2019. Variance Reduction Properties of the Reparameterization Trick.” In Proceedings of the Twenty-Second International Conference on Artificial Intelligence and Statistics.
Yin, and Zhou. 2018. ARM: Augment-REINFORCE-Merge Gradient for Stochastic Binary Networks.”