Disentangled representation learning aims to factor a data point’s latent encoding so each dimension (or chunk) aligns with one underlying generative factor—like object pose, scale, or lighting. Further, we want it to be decoupled from the others, meaning it is statistically independent or orthogonal and (roughly) insensitive to all others. What exactly we require depends on the architecture, AFAICT.
The upshot is the same. Imagine encoding face images: one latent should “turn the head,” another “brighten the cheek,” and another “open the mouth.” When we achieve this, downstream tasks become easier, since we can e.g. turn someone’s head without changing their expression.
Justifications for this include:
- Interpretability: disentangled representations are easier to understand and visualise.
- Control: disentangled representations allow for more precise control over the generated samples.
- Generalization: disentangled representations can improve the generalization of models to unseen data. Or at least that’s what people claim. I’m sceptical of this as a blanket statement but think it might be interesting in causal settings.
Disentangling was big business in early generative AI when we weren’t sure how to condition GANs or VAEs on specific features. We use other tools to condition Diffusions. Maybe still relevant today for interpretability / robustness.
1 Grandaddy example: β-VAE
The β-VAE augments the standard variational autoencoder objective
by weighting the KL term with a factor β > 1:
This stronger bottleneck (larger β) encourages each latent dimension to carry only the minimal information needed, which pushes them toward independence—and often yields interpretable axes like “rotate” or “zoom.”
2 Fancier: total-correlation in β-TCVAE
β-TCVAE decomposes the VAE’s KL into three parts—mutual information, dimension-wise KL, and total correlation (TC)—and then penalizes TC more heavily. Concretely:
and the objective becomes
By dialling up β on the TC term, we more directly push the joint q(z) toward a product of its marginals, which sharpens disentanglement.
3 InfoGAN
InfoGAN (Chen et al. 2016) is a GAN-based disentangling method that lives in that same family. It augments the usual GAN min–max with a mutual-information term to coax one part of the latent code to line up with a single factor:
where
is the standard GAN loss, is noise, is a “code” you hope will become interpretable (e.g. rotation, thickness), and encourages to actually control something in the output.
By maximizing
4 Other examples
In general, it seems you pick a backbone (VAE, GAN, diffusion, …), add one of these disentangling penalties. You have succeeded if traversing a single latent smoothly transforms just one aspect of the output. A handful of dimensions are then devoted to “rotation,” “colour,” or “thickness” etc, and the rest of the model behaves normally, doing whatever you were doing before, except now our face generator has a colour knob.
5 Causal
Interesting, and why I got involved in this field in the first place, after seeing Yang et al. (2021). See Causal abstraction for more.
6 Questions
- How many factors can we disentangle?
- How much can this be made unsupervised?
- Are there underexploited tools in this toolkit for Mechanistic interpretability for more.
- Are there underexploited tools in Developmental interpretability for more.