Pyro
Approximate maximum in the density of probabilistic programming effort
October 2, 2019 — May 3, 2024
A probabilistic programming language using modern NN frameworks, (pytorch) and jax, and implementing many fashionable algorithms from the probabilistic programming literature.
1 Classic Pyro
pytorch + Bayesian inference = pyro (Pradhan et al. 2018).
For rationale, see the pyro launch announcement:
We believe the critical ideas to solve AI will come from a joint effort among a worldwide community of people pursuing diverse approaches. By open sourcing Pyro, we hope to encourage the scientific world to collaborate on making AI tools more flexible, open, and easy-to-use. We expect the current (alpha!) version of Pyro will be of most interest to probabilistic modellers who want to leverage large data sets and deep networks, PyTorch users who want easy-to-use Bayesian computation, and data scientists ready to explore the ragged edge of new technology.
As a friendly, well-documented, consistent framework less of the designed-during-interdepartmental-turf-war feel of the tensorflow frameworks, this is where much of the effort going into probabilistic programming seems to be going.
Framework documentation asserts that if you can understand one file, pyro/minipyro.py, you can understand the whole system.
2 Numpyro
Numpyro is an alternative version of pyro which uses jax for as a backend instead of pytorch. In line with the general jax aesthetic it is elegant, fast, badly increasingly well documented, and missing some conveniences. The API is not identical with pyro, but they rhyme.
UPDATE: Numpyro is really coming along. It has all kinds of features now, e.g. Automatic PGM diagrams. Also, it turns out that the jax backend is frequently less confusing (IMO) than pytorch.
Fun tip: the render_model
method will automatically draw graphical model diagrams.
3 Tutorials and textbooks
- With Tom Blau I recently wrote a tutorial introduction to pyro, csiro-mlai/hackfest-ppl
- Rob Salomone’s course is excellent and starts with great examples.
- Bayes for Hackers
- see also generic tutorials on probabilistic programming
4 Tips, gotchas
4.1 Distributed
MultiGPU Pyro is not necessarily obvious, since many of the implied inference methods are not just plain SGD, so they do not parallelize in the same way as a simple neural network might. The docs give an example of distributed training via Horovod.
4.2 Regression
Regression was not (for me) obvious, and the various ways you can set it up are illustrative of how to set up stuff in pyro general.
We define the model as follows,a (Linear) regression model capturing a predictor variables (Africa or not, Terrain roughness) and a response variable (GDP), and an interaction term.
Suppose we want to solve a posterior inference problem of the following form:
\[\begin{aligned} \text{GDP}_i &\sim \mathcal{N}(\mu, \sigma)\\ \mu &= a + b_a \cdot \operatorname{InAfrica}_i + b_r \cdot \operatorname{Ruggedness}_i + b_{ar} \cdot \operatorname{InAfrica}_i \cdot \operatorname{Ruggedness}_i \\ a &\sim \mathcal{N}(0, 10)\\ b_a &\sim \mathcal{N}(0, 1)\\ b_r &\sim \mathcal{N}(0, 1)\\ b_{ar} &\sim \mathcal{N}(0, 1)\\ \sigma &\sim \operatorname{Gamma}(1, \frac12) \end{aligned}\]
pyro.clear_param_store()
def model():
a = pyro.sample("a", dist.Normal(0., 10.))
b_a = pyro.sample("bA", dist.Normal(0., 1.))
b_r = pyro.sample("bR", dist.Normal(0., 1.))
b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
sigma = pyro.sample("sigma", dist.Gamma(1.0, 0.5))
is_cont_africa = pyro.sample(
"is_cont_africa", dist.Bernoulli(0.5)) # <- overridden
ruggedness = pyro.sample(
"ruggedness", dist.Normal(1.0, 0.5)) # <- overridden
mean = a + (b_a * is_cont_africa) \
+ (b_r * ruggedness) \
+ (b_ar * is_cont_africa * ruggedness)
s = pyro.sample(
"log_gdp", dist.Normal(mean, sigma)) # <- overridden
return s
Note the trick here, that we gave distributions even to regression inputs. This is how we need to do it, even if that distribution will never by used. And indeed, during inference we always override the values at those sites with data.
Inference proceeds by conditioning the model on the observed data, giving us updated estimates for the unknowns. In the MCMC setting we approximate those posterior distributions by with samples:
\[\begin{aligned} &p (a, b_a, b_{ar}, b_r,\sigma \mid \operatorname{GDP}, \operatorname{Ruggedness},\operatorname{InAfrica} )\\ &\quad \propto \prod_i p (\operatorname{GDP}_i \mid \operatorname{Ruggedness}_i,\operatorname{InAfrica}_i ,a, b_a, b_{ar}, b_r,\sigma)\\ & \qquad \cdot p (a, b_a, b_{ar}, b_r,\sigma) \end{aligned}\]
observed_model = poutine.condition(model, data={
"log_gdp": log_gdp, "ruggedness": ruggedness, "is_cont_africa": is_cont_africa})
nuts_kernel = NUTS(observed_model)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
mcmc.run()
To actually make predictions we need to use the Predictive
class, which is IMO not well explained in the docs, but you can work it out from their example. An only-slightly-confusing explanation is here. tl;dr to predict the GDP of a country NOT in africa with a Ruggedness of 3, we would do this:
Predictive(poutine.condition(model, data={
"ruggedness": torch.tensor(3.0),
"is_cont_africa": torch.tensor(0.)}),
posterior_samples=mcmc.get_samples())()['log_gdp']
TODO: re-do this example with
- default arguments.
- factory functions.
4.3 Complex numbers
Currently do not work.
5 Algebraic effects
TBD
Sounds like this lands not too far from message passing ideas?
6 Funsors
I’ve seen funsors mentioned in this context. I gether they are some kind of graphical model-inference abstraction in the algebraic effect vein. What do they do exactly? Obermeyer et al. (2020) attempts to explain it although I do not feel like I got it:
It is a significant challenge to design probabilistic programming systems that can accommodate a wide variety of inference strategies within a unified framework. Noting that the versatility of modern automatic differentiation frameworks is based in large part on the unifying concept of tensors, we describe a software abstraction for integration —functional tensors— that captures many of the benefits of tensors, while also being able to describe continuous probability distributions. Moreover, functional tensors are a natural candidate for generalized variable elimination and parallel-scan filtering algorithms that enable parallel exact inference for a large family of tractable modeling motifs.
…This property is extensively exploited by the Pyro probabilistic programming language (Pradhan et al. 2018) and its implementation of tensor variable elimination for exact inference in discrete latent variable models, in which each random variable in a model is associated with a distinct tensor dimension and broadcasting is used to compile a probabilistic program into a discrete factor graph. Functional tensors (hereafter “funsors”) both formalize and extend this seemingly idiosyncratic but highly successful approach to probabilistic program compilation by generalizing tensors and broadcasting to allow free variables of non-integer types that appear in probabilistic models, such as real number, real-valued vector, or real-valued matrix. Building on this, we describe a simple language of lazy funsor expressions that can serve as a unified intermediate representation for a wide variety of probabilistic programs and inference algorithms. While in general there is no finite representation of functions of real variables, we provide a funsor interface for restricted classes of functions,including lazy algebraic expressions, non-normalized Gaussian functions, and Dirac delta distributions.
Sounds like this lands not so far from message passing?