Conjugate priors
June 26, 2024 — July 23, 2024
Assumed audience:
Data scientists who must pretend they can remember statistics
A conjugate prior is one that is closed under sampling given its matched likelihood function. I occasionally see people talk about this as if it usefully applies to non-exponential family likelihoods, but I am familiar with it only in the case of exponential families, so we restrict ourselves to that case here.
It seems to arise in the 60s (DeGroot 2005; Raiffa and Schlaifer 2000), and be re-interpreted in the 70s (Diaconis and Ylvisaker 1979). A pragmatic intro is Fink (1997). Robert (2007) chapter 3 is gentler.
Exponential families have tractable conjugate priors, which means that the posterior distribution is in the same family as the prior, and moreover, there is a simple formula for updating the parameters. This is deliciously easy, and also misleads one into thinking that Bayes inference is much easier than it actually is in the general case, because it is so easy in this one.
We are going to observe lots of i.i.d. realisations of some variate \(X\sim p(x|\theta)\) and would like a consistent procedure for updating our beliefs about \(\theta\).
1 Exponential family likelihood
Our observation \(X\) is assumed to arise from an exponential family likelihood. That is, given (vector) parameter \(\theta\), \(X\) has a density of the following form: \[ p(x \mid \theta) = h(x) \exp\left( \eta(\theta)^T T(x) - A(\theta) \right) \] Here:
- \(\eta(\theta)\) is the natural (canonical) parameter, which is some transform of the naive parameter \(\theta\). The natural parameter is the parameter of the distribution that is linear in the sufficient statistics. In fact, it is so much simpler to use the natural parameters, that we drop \(\theta\) and just work with \(\eta\) hereafter.
- \(T(x)\) is the sufficient statistic (may be vector-valued).
- \(A(\theta)\) is the log-partition function.
- \(h(x)\) is the base measure.
Rewriting in natural parameters, we have \[ p(x \mid \eta) = h(x) \exp\left( \eta^T T(x) - A(\eta) \right). \]
If we knew \(\eta\) we would now have a distribution for \(X\). In practice, we are not sure about \(\eta\), so we have a prior distribution for \(\eta\). Things will go well for us if we choose this prior to have a particular, and particularly convenient, form.
2 Conjugate prior
The conjugate prior for \(\eta\) is designed to ensure that the posterior distribution remains within the same family after a realization from that likelihood we just introduced.
A conjugate prior has to look like this: \[ p(\eta \mid \lambda, \nu) = f(\lambda, \nu) \exp\left( \eta^T \lambda - \nu A(\eta) \right) \] where \(\lambda\) means something like ‘accumulated sufficient statistics from prior knowledge’ and \(\nu\) the ‘weight’ of the prior or the ‘number of prior observations’. These are effectively hyperparameters encoding how certain we are. This looks like an exponential family distribution, (\(f(\lambda, \nu)\) is the \(h\)-like base measure), except for this weird scaling of the log-partition function by \(\nu\). It is in fact a tempered exponential family.
3 Prior predictive distribution
The prior predictive distribution for a new observation \(x\) is obtained by integrating the product of the likelihood and the prior over the natural parameter \(\eta\): \[ \begin{aligned} p(x) &= \int p(x \mid \eta) p(\eta \mid \lambda, \nu) \, d\eta\\ & = \int h(x) \exp(\eta^T T(x) - A(\eta)) f(\lambda, \nu) \exp(\eta^T \lambda - \nu A(\eta)) \, d\eta \\ &= h(x) f(\lambda, \nu) \int \exp\left(\eta^T (T(x) + \lambda) - (\nu + 1) A(\eta)\right) \, d\eta. \end{aligned} \]
This integral represents the normalization constant of an updated exponential family distribution with parameters updated to \(\lambda' = \lambda + T(x)\) and \(\nu' = \nu + 1\). Thus, the integral simplifies to \(1/f(\lambda', \nu')\), where \(f\) is the normalizing factor ensuring that the distribution integrates to 1. Hence, the prior predictive distribution becomes: \[ p(x) = \frac{h(x) f(\lambda, \nu)}{f(\lambda + T(x), \nu + 1)} \]
The prior predictive distribution \(p(x)\) essentially provides the likelihood of observing \(x\) before any actual data are observed, based solely on the prior parameters \(\lambda\) and \(\nu\). This distribution reflects how beliefs encoded in the prior (through \(\lambda\) and \(\nu\)) influence expectations about future data points, integrated over all possible values of the natural parameter \(\eta\).
4 Updating the prior
Let us suppose an observation \(x\sim X\) arrives. We would like a conjugate posterior update that incorporates the new information. The update to the conjugate prior’s parameters is:
- \(\lambda_{\text{posterior}} \gets \lambda + T(x)\), incorporating the new data’s sufficient statistic into the prior accumulated statistics, and
- \(\nu_{\text{posterior}} \gets \nu + 1\), an increment in the effective number of observations.
The posterior distribution of \(\eta\), after observing \(x\), in full, is thus: \[ p(\eta \mid \lambda + T(x), \nu + 1) = f(\lambda + T(x), \nu + 1) \exp\left( \eta^T (\lambda + T(x)) - (\nu + 1) A(\eta) \right) \]
5 Posterior predictive
As with the prior predictive, we need to integrate out the natural parameters of the likelihood. The posterior predictive distribution for a new observation \(x'\) given the observed data \(x\) is obtained by integrating over the posterior distribution of \(\eta\): \[ p(x' \mid x) = \int p(x' \mid \eta) p(\eta \mid x) \, d\eta \] Expanding this using the forms we derived above for \(p(x' \mid \eta)\) and \(p(\eta \mid x)\), we find: \[\begin{aligned} p(x' \mid x) &= \int h(x') \exp\left(\eta^T T(x') - A(\eta)\right) f(\lambda + T(x), \nu + 1) \exp\left(\eta^T (\lambda + T(x)) - (\nu + 1) A(\eta)\right) \, d\eta % & = h(x') f(\lambda + T(x), \nu + 1) \int \exp\left(\eta^T (T(x') + \lambda + T(x)) - (\nu + 2) A(\eta)\right) \, d\eta \end{aligned} \] This integral represents the normalizing constant of an updated exponential family distribution with parameters \(\lambda' = \lambda + T(x) + T(x')\) and \(\nu' = \nu + 2\). Thus, the integral simplifies to \(1/f(\lambda', \nu')\), where \(f\) is the normalizing constant. Hence, \[ p(x' \mid x) = \frac{h(x') f(\lambda + T(x), \nu + 1)}{f(\lambda + T(x) + T(x'), \nu + 2)} \]
6 Updates
The conjugate prior for \(\eta\) is designed to ensure that the posterior distribution remains within the same family. We partition the parameter of the prior into \(\scriptsize\begin{bmatrix}\lambda\ \nu \end{bmatrix}.\) A conjugate prior has the following form: \[ \begin{aligned} p(\eta \mid \lambda, \nu) = f(\lambda, \nu) \exp\left( \eta^T \lambda - \nu A(\eta) \right) \end{aligned} \] where \(\lambda\) (something like ‘accumulated sufficient statistics from prior knowledge’) and \(\nu\) (the ‘weight’ of the prior or the ‘number of prior observations’) are effectively hyperparameters encoding how certain we are.
The prior predictive distribution for a new observation \(x\) is obtained by integrating the product of the likelihood and the prior over the natural parameter \(\eta\): \[ \begin{aligned} p(x) &= \int p(x \mid \eta) p(\eta \mid \lambda, \nu) \, d\eta\\ &= \int h(x) \exp(\eta^T T(x) - A(\eta)) f(\lambda, \nu) \exp(\eta^T \lambda - \nu A(\eta)) \, d\eta\\ &= h(x) f(\lambda, \nu) \int \exp\left(\eta^T (T(x) + \lambda) - (\nu + 1) A(\eta)\right) \, d\eta\\ &= \frac{h(x) f(\lambda, \nu)}{f(\lambda + T(x), \nu + 1)} \end{aligned} \] The last line follows from the observation that the integral represents the normalization constant of an exponential family distribution with parameters updated to \(\lambda' = \lambda + T(x)\) and \(\nu' = \nu + 1\). Thus, the integral simplifies to \(1/f(\lambda', \nu')\), where \(f\) is the normalizing factor.
Let us suppose an observation \(x\sim X\) arrives. We would like a conjugate posterior update that incorporates the new information. The update to the conjugate prior’s parameters is:
The posterior distribution of \(\eta\), after observing \(x\), continues to belong to the same exponential family and is given by: \[ p(\eta \mid \lambda + T(x), \nu + 1) = f(\lambda + T(x), \nu + 1) \exp\left( \eta^T (\lambda + T(x)) - (\nu + 1) A(\eta) \right) \] i.e.~The updated parameters are \(\lambda_{\text{new}} = \lambda + T(x)\) and \(\nu_{\text{new}} = \nu + 1\).
The posterior predictive distribution for a new observation \(x'\) given the observed data \(x\) is obtained by integrating over the posterior distribution of \(\eta\) \[ \begin{aligned} &p(x' \mid x) \\ &= \int p(x' \mid \eta) p(\eta \mid x) \, d\eta\\ &= \int h(x') \exp\left(\eta^T T(x') - A(\eta)\right) f(\lambda + T(x), \nu + 1) \exp\left(\eta^T (\lambda + T(x)) - (\nu + 1) A(\eta)\right) \, d\eta\\ & = h(x') f(\lambda + T(x), \nu + 1) \int \exp\left(\eta^T (T(x') + \lambda + T(x)) - (\nu + 2) A(\eta)\right) \, d\eta\\ &= \frac{h(x') f(\lambda + T(x), \nu + 1)}{f(\lambda + T(x) + T(x'), \nu + 2)} \end{aligned} \]
The last line follows from the observation that the integral term represents the normalizing constant of an updated exponential family distribution with parameters \(\lambda' = \lambda + T(x) + T(x')\) and \(\nu' = \nu + 2\). Thus, the integral simplifies to \(1/f(\lambda', \nu')\).
7 Mixtures
The under-rated bit of the conjugate prior thing is that, while the priors are themselves, not that flexible, there are some very interesting priors that can be constructed by mixtures of conjugate priors.
TBC. See Dalal and Hall (1983), O’Hagan (2010),…
Farrow’s tutorial introduction:
Consider what happens when we update our beliefs using Bayes’ theorem. Suppose we have a prior density \(f_j^{(0)}(\theta)\) for a parameter \(\theta\) and suppose the likelihood is \(L(\theta)\). Then our posterior density is \[ f_j^{(1)}(\theta)=\frac{f_j^{(0)}(\theta) L(\theta)}{C_j} \] where \[ C_j=\int_{-\infty}^{\infty} f_j^{(0)}(\theta) L(\theta) d \theta \]
Now let our prior density for a parameter \(\theta\) be \[ f^{(0)}(\theta)=\sum_{j=1}^J k_j^{(0)} f_j^{(0)}(\theta) . \] Our posterior density is \[ \begin{aligned} f^{(1)}(\theta) & =\frac{\sum_{j=1}^J k_j^{(0)} f_j^{(0)}(\theta) L(\theta)}{C} \\ & =\frac{\sum_{j=1}^J k_j^{(0)} C_j f_j^{(0)}(\theta) L(\theta) / C_j}{C} \\ & =\frac{\sum_{j=1}^J k_j^{(0)} C_j f_j^{(1)}(\theta)}{C} \end{aligned} \]
Hence we require \[ \frac{\sum_{j=1}^J k_j^{(0)} C_j}{C}=1 \] so \[ C=\sum_{j=1}^J k_j^{(0)} C_j \] and the posterior density is \[ f^{(1)}(\theta)=\sum_{j=1}^J k_j^{(1)} f_j^{(1)}(\theta) \] where \[ k_j^{(1)}=\frac{k_j^{(0)} C_j}{\sum_{i=1}^J k_i^{(0)} C_i} . \]
8 In nonparameterics
See (Broderick, Wilson, and Jordan 2018; Orbanz 2011).## Incoming