Neural Bayes posteriors
Training a network to directly estimate a posterior quantity, meta-learning Bayes
2022-11-24 — 2025-07-10
Suspiciously similar content
Explicitly training NNs to predict posteriors based on input data, i.e. In-Context Learning that gives us an explicit ( approximation to the) Bayesian result. There is a close connection between this and the implicit Bayesian inference that transformers seem to do. Interesting to pair with predictive Bayes.
1 Neural point estimators
NeuralEstimators facilitates the user-friendly development of neural point estimators, which are neural networks that transform data into parameter point estimates. They are likelihood-free, substantially faster than classical methods, and can be designed to be approximate Bayes estimators. The package caters for any model for which simulation is feasible.
Permutation-invariant neural estimators (Sainsbury-Dale, Zammit-Mangion, and Huser 2022, 2024) which lean on deep sets.
Note that deep sets are a specialisation of the attention architecture which leads us to wonder whether transformers can be trained to do Bayes inference even better. Yes. Read on.
2 Train a transformer to estimate a posterior in-context
The PFN (Müller et al. 2021) architecture has a lot of popular variants that can be elegantly shown to do Bayes Stuff (Hollmann et al. 2023; Dooley et al. 2023).
Müller et al. (2021):
We present Prior-Data Fitted Networks (PFNs). PFNs leverage in-context learning in large-scale machine learning techniques to approximate a large set of posteriors. The only requirement for PFNs to work is the ability to sample from a prior distribution over supervised learning tasks (or functions). Our method restates the objective of posterior approximation as a supervised classification problem with a set-valued input: it repeatedly draws a task (or function) from the prior, draws a set of data points and their labels from it, masks one of the labels and learns to make probabilistic predictions for it based on the set-valued input of the rest of the data points. Presented with a set of samples from a new supervised learning task as input, PFNs make probabilistic predictions for arbitrary other data points in a single forward propagation, having learned to approximate Bayesian inference. We demonstrate that PFNs can near-perfectly mimic Gaussian processes and also enable efficient Bayesian inference for intractable problems, with over 200-fold speedups in multiple setups compared to current methods.
They have had particular success on tabular data.
3 Neural processes
See neural processes.