Variational Bayes NNs and Graphical Models

Bitter Lesson Propagation

January 18, 2024 — October 31, 2024

algebra
graphical models
hidden variables
hierarchical models
how do science
machine learning
networks
neural nets
probability
statistics
Figure 1

Bayes NNs in the light of Graphical models for ML. What is the best approximation we can make, given our data and our compute budget? This is a very popular question right now; see e.g. Bayes-by-backprop. My own interest is that I would like it to work with inverse problems, not just the input-conditional posterior predictive, which makes things tougher.

Figure 2: caption 2

Let us suppose we want to use these methods for NNs. Two things are commonly difficult (relative to classical inference) in this setting:

  1. The data is very big, in the sense that it contains many observations maybe big enough that we are happy to “waste” data with “suboptimal” approximation methods, trading statistical efficiency against computational efficiency.
  2. The dimensions of some variables are big even in isolation (e.g. images) so we need to think about that.

1 When to optimise

2 How much rank

Gradient versus ensemble steps.

3 Variational belief propagation (speed run)

We say we can solve one of these problems if we can propagate information through the nodes of the graph and find the target marginal distribution, in some sense. This might be the target marginal in the sense that we can sample from it (a generative model) or in the sense that we can compute some density or expectation (a discriminative model). For now, we are agnostic about that.

Even with that relaxed notion of “marginalisation” we do not necessarily have the power to solve these. Propagating information through these nodes is not generally tractable. Some special cases are tractable, but actually, let us leap into the void and just seek approximate solutions.

Normally at this point we laboriously introduce many different inference methods, junction trees and identifiability criteria and so on. Then we say “OH NO BUT THAT DOES NOT ACTUALLY WORK ON MY PROBLEM!!!1!”.

There are two classes of difficulties with classical algorithms that seem ubiquitous:

  1. Our favoured algorithm is only exact on tree graphs, but we have a graph with cycles.
  2. Our favoured algorithm only works on nodes whose joint distributions at each induce exponential families at each edge.

Few interesting problems satisfy both of these conditions, or indeed either.

So let us speed run through the hand-wringing bit where we worry that we are only solving integrals approximately, and just go to some approximate inference methods that do pretty well in practical problems; we can sweat the details later.

I refer to this as…

4 YOLO belief propagation

Let us partition a node and work out what we can do with those partitions. Suppose we have some factorisation of the density over it, so that the joint over all variates is

\[ p(X)=\frac{1}{Z} \prod_a f_a\left(X_a\right)=\frac{1}{Z} \prod_a e^{-E\left(X_a\right)} . \]

We will start out doing this heuristically, to build intuitions.

5 Noise outsourcing formulation

TBD. See noise outsourcing.

6 Variational representations