Variational Bayes Neural Nets and Graphical Models
Bitter Lesson Propagation
January 18, 2024 — April 16, 2025
Suspiciously similar content
Bayes NNs in the light of Graphical models for ML. What’s the best approximation we can make, given our data and our compute budget?
This is a very popular question right now; see e.g. Bayes-by-backprop. My own interest is that I’d like it to work with inverse problems, not just the input-conditional posterior predictive, which makes things tougher.
NB: I’m very slammed for time right now and I don’t know when I’ll have a moment to return to this.
Let’s suppose we want to use these methods for NNs. Two things are commonly difficult (relative to classical inference) in this setting:
- The data is very big, in the sense that it contains many observations maybe big enough that we’re happy to “waste” data with “suboptimal” approximation methods, trading statistical efficiency against computational efficiency.
- The dimensions of some variables are big even in isolation (e.g. images) so we need to think about that.
1 When to optimise
2 How much rank
Gradient versus ensemble steps.
3 Variational belief propagation (speed run)
We say we can solve one of these problems if we can propagate information through the nodes of the graph and find the target marginal distribution, in some sense. This might be the target marginal in the sense that we can sample from it (a generative model) or in the sense that we can compute some density or expectation (a discriminative model). For now, we’re agnostic about that.
Even with that relaxed notion of “marginalisation” we don’t necessarily have the power to solve these. Propagating information through these nodes is not generally tractable. Some special cases are tractable, but actually, let’s leap into the void and just seek approximate solutions.
Normally at this point we laboriously introduce many different inference methods, junction trees and identifiability criteria and so on. Then we say “OH NO BUT THAT DOES NOT ACTUALLY WORK ON MY PROBLEM!!!1!”.
There are two classes of difficulties with classical algorithms that seem ubiquitous:
- Our favoured algorithm is only exact on tree graphs, but we have a graph with cycles.
- Our favoured algorithm only works on nodes whose joint distributions at each induce exponential families at each edge.
Few interesting problems satisfy both of these conditions, or indeed either.
So let’s speed run through the hand-wringing bit where we worry that we are only solving integrals approximately, and just go to some approximate inference methods that do pretty well in practical problems; we can sweat the details later.
I refer to this as…
4 YOLO belief propagation
Let’s partition a node and work out what we can do with those partitions. Suppose we have some factorisation of the density over it, so that the joint over all variates is
\[ p(X)=\frac{1}{Z} \prod_a f_a\left(X_a\right)=\frac{1}{Z} \prod_a e^{-E\left(X_a\right)} . \]
We’ll start out doing this heuristically, to build intuitions.
5 Noise outsourcing formulation
TBD. See noise outsourcing.