Pyro

Approximate maximum in the density of probabilistic programming effort

October 2, 2019 — November 25, 2021

Bayes
generative
how do science
Monte Carlo
sciml
statistics

A probabilistic programming language. pytorch + Bayesian inference = pyro (Pradhan et al. 2018).

Figure 1: Typical posterior density landscape

1 Vanilla Pyro

For rationale, see the pyro launch announcement:

We believe the critical ideas to solve AI will come from a joint effort among a worldwide community of people pursuing diverse approaches. By open sourcing Pyro, we hope to encourage the scientific world to collaborate on making AI tools more flexible, open, and easy-to-use. We expect the current (alpha!) version of Pyro will be of most interest to probabilistic modelers who want to leverage large data sets and deep networks, PyTorch users who want easy-to-use Bayesian computation, and data scientists ready to explore the ragged edge of new technology.

As a friendly, well-documented, consistent framework less of the designed-during-interdepartmental-turf-war feel of the tensorflow frameworks, this is where much of the effort going into probabilistitic programming seems to be going.

Framework documentation asserts that if you can understand one file, pyro/minipyro.py, you can understand the whole system.

2 Distributed

Pytorch lightning is not compatible with pyro. I’m not sure how to get around this; The docs give an example of distributed training via Horovod.

3 Numpyro

Numpyro is an alternative version of pyro which uses jax for as a backend instead of pytorch. In line with the general jax aesthetic it is elegant, fast, badly increasingly well documented, and missing some conveniences. The API is not identical with pyro, but they rhyme.

UPDATE: Numpyro is really coming along. It has all kinds of features now, e.g. Automatic PGM diagrams. Also it turns out that the jax backend is frequently less confusing (IMO) than pytorch.

Fun tip: the render_model method will happily draw graphical model diagrams for you.

4 Tutorials and textbooks

5 Tips, gotchas

5.1 Regression

Regression was not (for me) obvious, and the various ways you can set it up are illustrative of how to set up stuff in pyro generall.

We define the model as follows,a (Linear) regression model capturing a predictor variables (Africa or not, Terrain roughness) and a response variable (GDP), and an interaction term.

Suppose we want to solve a posterior inference problem of the following form:

\[\begin{aligned} \text{GDP}_i &\sim \mathcal{N}(\mu, \sigma)\\ \mu &= a + b_a \cdot \operatorname{InAfrica}_i + b_r \cdot \operatorname{Ruggedness}_i + b_{ar} \cdot \operatorname{InAfrica}_i \cdot \operatorname{Ruggedness}_i \\ a &\sim \mathcal{N}(0, 10)\\ b_a &\sim \mathcal{N}(0, 1)\\ b_r &\sim \mathcal{N}(0, 1)\\ b_{ar} &\sim \mathcal{N}(0, 1)\\ \sigma &\sim \operatorname{Gamma}(1, \frac12) \end{aligned}\]

pyro.clear_param_store()
def model():
    a = pyro.sample("a", dist.Normal(0., 10.))
    b_a = pyro.sample("bA", dist.Normal(0., 1.))
    b_r = pyro.sample("bR", dist.Normal(0., 1.))
    b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Gamma(1.0, 0.5))
    is_cont_africa = pyro.sample(
        "is_cont_africa", dist.Bernoulli(0.5))  # <- overridden
    ruggedness = pyro.sample(
        "ruggedness", dist.Normal(1.0, 0.5))    # <- overridden
    mean = a + (b_a * is_cont_africa) \
        + (b_r * ruggedness) \
        + (b_ar * is_cont_africa * ruggedness)
    s = pyro.sample(
        "log_gdp", dist.Normal(mean, sigma))    # <- overridden
    return s

Note the trick here, that we gave distributions even to regression inputs. This is how we need to do it, even if that distribution will never by used. And indeed, during inference we always override the values at those sites with data.

Inference proceeds by conditioning the model on the observed data, giving us updated estimates for the unknowns. In the MCMC setting we approximate those posterior distributions by with samples:

\[\begin{aligned} &p (a, b_a, b_{ar}, b_r,\sigma \mid \operatorname{GDP}, \operatorname{Ruggedness},\operatorname{InAfrica} )\\ &\quad \propto \prod_i p (\operatorname{GDP}_i \mid \operatorname{Ruggedness}_i,\operatorname{InAfrica}_i ,a, b_a, b_{ar}, b_r,\sigma)\\ & \qquad \cdot p (a, b_a, b_{ar}, b_r,\sigma) \end{aligned}\]


observed_model = poutine.condition(model, data={
    "log_gdp": log_gdp, "ruggedness": ruggedness, "is_cont_africa": is_cont_africa})
nuts_kernel = NUTS(observed_model)

mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
mcmc.run()

To actually make predictions we need to use the Predictive class, which is IMO not well explained in the docs, but you can work it out from their example. An only-slightly-confusing explanation is here. tl;dr to predict the GDP of a country NOT in africa with a Ruggedness of 3, we would do this:

Predictive(poutine.condition(model, data={
    "ruggedness": torch.tensor(3.0),
    "is_cont_africa": torch.tensor(0.)}),
    posterior_samples=mcmc.get_samples())()['log_gdp']

TODO: re-do this example with

  1. default arguments.
  2. factory functions.

5.2 Complex numbers

Currently do not work.

6 Algebraic effects

TBD

Sounds like this lands not too far from message passing ideas?

7 Funsors

I’ve seen funsors mentioned in this context. I gether they are some kind of graphical model-inference abstraction in the algebraic effect vein. What do they do exactly? Obermeyer et al. (2020) attempts to explain it although I do not feel like I got it:

It is a significant challenge to design probabilistic programming systems that can accommodate a wide variety of inference strategies within a unified framework. Noting that the versatility of modern automatic differentiation frameworks is based in large part on the unifying concept of tensors, we describe a software abstraction for integration —functional tensors— that captures many of the benefits of tensors, while also being able to describe continuous probability distributions. Moreover, functional tensors are a natural candidate for generalized variable elimination and parallel-scan filtering algorithms that enable parallel exact inference for a large family of tractable modeling motifs.

…This property is extensively exploited by the Pyro probabilistic programming language (Pradhan et al. 2018) and its implementation of tensor variable elimination for exact inference in discrete latent variable models, in which each random variable in a model is associated with a distinct tensor dimension and broadcasting is used to compile a probabilistic program into a discrete factor graph. Functional tensors (hereafter “funsors”) both formalize and extend this seemingly idiosyncratic but highly successful approach to probabilistic program compilation by generalizing tensors and broadcasting to allow free variables of non-integer types that appear in probabilistic models, such as real number, real-valued vector, or real-valued matrix. Building on this, we describe a simple language of lazy funsor expressions that can serve as a unified intermediate representation for a wide variety of probabilistic programs and inference algorithms. While in general there is no finite representation of functions of real variables, we provide a funsor interface for restricted classes of functions,including lazy algebraic expressions, non-normalized Gaussian functions, and Dirac delta distributions.

Sounds like this lands not so far from message passing?

8 Incoming

9 References

Baudart, Burroni, Hirzel, et al. 2021. Compiling Stan to Generative Probabilistic Languages and Extension to Deep Probabilistic Programming.” arXiv:1810.00873 [Cs, Stat].
Moore, and Gorinova. 2018. Effect Handling for Composable Program Transformations in Edward2.” arXiv:1811.06150 [Cs, Stat].
Obermeyer, Bingham, Jankowiak, et al. 2020. Functional Tensors for Probabilistic Programming.” arXiv:1910.10775 [Cs, Stat].
Pradhan, Chen, Jankowiak, et al. 2018. Pyro: Deep Universal Probabilistic Programming.” arXiv:1810.09538 [Cs, Stat].
Ritter, and Karaletsos. 2022. TyXe: Pyro-Based Bayesian Neural Nets for Pytorch.” Proceedings of Machine Learning and Systems.