Neural Bayes posteriors
Training a network to directly estimate a posterior quantity, meta-learning Bayes
2022-11-24 — 2025-07-10
Wherein transformers are trained as Prior-Data Fitted Networks to approximate Bayesian posteriors in-context, are shown to mimic Gaussian processes and are reported to yield over two‑hundredfold speedups for tabular tasks.
We explicitly train NNs to predict posteriors from input data, i.e. in-context learning, that gives us an explicit (approximation to the) Bayesian result. There’s a close connection between that and the implicit Bayesian inference that transformers seem to do. It’s interesting to pair this with predictive Bayes.
1 Neural point estimators
NeuralEstimators facilitates the user-friendly development of neural point estimators, which are neural networks that transform data into parameter point estimates. They are likelihood-free, substantially faster than classical methods, and can be designed to be approximate Bayes estimators. The package caters for any model for which simulation is feasible.
Permutation-invariant neural estimators (Sainsbury-Dale, Zammit-Mangion, and Huser 2022, 2024) which lean on deep sets.
Note that deep sets are a specialization of the attention architecture, and that leads us to wonder whether transformers can be trained to do Bayesian inference even better. Yes — read on.
2 Train a transformer to estimate a posterior in-context
The PFN (Müller et al. 2021) architecture has many popular variants, and we can elegantly show they perform Bayesian inference (Hollmann et al. 2023; Dooley et al. 2023).
Müller et al. (2021):
We present Prior-Data Fitted Networks (PFNs). PFNs leverage in-context learning in large-scale machine learning techniques to approximate a large set of posteriors. The only requirement for PFNs to work is the ability to sample from a prior distribution over supervised learning tasks (or functions). Our method restates the objective of posterior approximation as a supervised classification problem with a set-valued input: it repeatedly draws a task (or function) from the prior, draws a set of data points and their labels from it, masks one of the labels and learns to make probabilistic predictions for it based on the set-valued input of the rest of the data points. When presented with a set of samples from a new supervised learning task, PFNs make probabilistic predictions for arbitrary data points in a single forward pass, because they’ve learned to approximate Bayesian inference. We demonstrate that PFNs can nearly perfectly mimic Gaussian processes and enable efficient Bayesian inference for intractable problems, achieving over 200-fold speedups in several setups compared to current methods.
They’ve had particular success on tabular data.
3 Neural processes
We discuss neural processes.