A soft introduction to VAEs
Variational Auto-encoders (VAEs) by Kingma and Welling are generative models from back when in 2014 (paper).
In the ML domain that’s ancient history. But VAEs are still useful and arguably underpinning many of the latest generative modelling achievements.
This post will cover:
- A light primer to VAEs, no maffs
- Good applications
- VAEs today
- Lowlights
Hopefully on the way we will elicit some new understanding and appreciations, or even love, of VAEs.
Primer
When you first google/prompt/ask about VAEs you will randomly be assigned to one of two ways, the easy way or the hard way.
The hard way, or textbook way, starts with some patter about variational Bayesian methods, progresses to a definition of the Evidence Lower Bound (ELBO), before finally degenerating into alternative decompositions of the resulting loss function. This is a difficult path that is best trodden later.
The Auto-encoder
The easy way starts with a description of the lesser Auto-encoder (AE), a simple design that compresses data into a latent bottleneck. Data is passed into the model, squeezing it down via an encoder into a compressed latent vector (the bottleneck), then un-squeezing via a decoder back to the original data. The encoder and decoder can be somewhat arbitrary neural nets.
The AE is then trained for this compression and de-compression task by minimising the difference between the data input and output. This difference is called the reconstruction loss. The closer the better. If you’re thinking this sounds like some kind of compression algorithm, then yes, that’s fine.
Consider a dataset of images of dogs. During training, each image is passed into the encoder, compressed down to a smaller-dimensional vector, then un-compressed back to the original image (approximately) via the decoder. Once trained the encoder provides a mapping from images into the latent representation, typically called \(Z\), and the decoder provides a mapping from \(Z\) back to the (approximate) original image.
┌─────────────┐ ┌─────────────┐
│ \ │ ─ ─ │ │
│ (. .)\ │ ── ┌────-─┐ ── │ /(. .)\ / │
│ (*)____ │ ─── | z0 | ─── │ (o)___/ │
│ / |\ │ ── | z1 | ── │ / | │
│ / |--\ | \│-Encoder─> | z2 | -Decoder─>│ / |--\ | │
│ (_)(_) (_) │ ── | .. | ── │ (_)(_) (_) │
│ │ ─── | zn | ─── │ │
│ │ ── └─────-┘ ── │ │
│ │ ─ ─ │ │
└─────────────┘ └─────────────┘
By squeezing images, which have many dimensions, through compressed latent dimensions, we force the AE to try to make efficient use of this latent. A consequence of this is that similar images are encoded close to each other in the latent vector. Consider an MNIST example, i.e. passing images of handwritten digits (zero to nine) through a 2D vector \((z_0, z_1)\). The digits might end up distributed across the latent space something like this:
z1|
| 00 111
| 000 1111
| 0000 6666 111
| 6666
| 99 444
| 99999 4444
| 9 88 33 44
| 88888 333
| 88 333 77
| 555 77
| 555 22222 77
| 222
|________________________________
z0
Similar digits are clustered together. For example the ones are close to other ones. But also maybe sevens that look a bit like fours are close to fours that look like sevens and so on.
This general idea of finding and using compressed representations of data is pretty fundamental to both (i) the core usefulness of VAEs, and (ii) their use in state of the art generative approaches used today. So keep in mind.
Our AE is pretty cool but it’s not a very good generative model. It’s more like an elaborate way of clustering or compressing. We could take some random samples from the latent space and decode them into images, and some of them might look realistic, but most not. The best we could probably manage is to interpolate between some of the training data latent values, ideally two that are close together, in which case hopefully the decoder will spit out something similar.
For example we could sample between where we know the 6s and 8s are and get a digit that looks kind of like a 6 and kind of like an 8:
z1|
| 00 111
| 000 1111
| 0000 6666 111
| 6666
| 99 | 444
| 99999 X 4444
| 9 | 33 44
| 88888 333
| 88 333 77
| 555 77
| 555 22222 77
| 222
|________________________________
z0
Similarly, back to our doggy AE, we could sample between an image of a poodle and a labrador to get a labradoodle:
z1 ─── | | _ _ _ |
| ─── | | ( ) ( )( ) |
| P ─── | | | \__| | / |
| \ ──-Decoder-> | | __ | |
| X ─── | | |/ |(*)| |
| \ ─── | | \ /(. .)\ |
| L ─── | | |
└──────────z0 └─────────────┘
But more likely we’d get nonsense because the model hasn’t been trained to use the latent space this way. This general idea of using the whole latent space and maybe even making it meaningful — more up could be a happier dog, more right could be a smaller dog for example — is good intuition and a nice segue into VAEs.
More formally we can think of the decoder as providing some conditional probability distribution \(p(x \mid z)\), where \(p(x)\) is our target distribution, and \(p(z)\) the latent distribution. The reason why the decoder is not providing a good generative model is because \(p(z)\) isn’t defined, and worse, is probably a nightmare landscape of sparse and discontinuous probability densities. So that if we try to sample from it, we usually get nonsense from the decoder and certainly we are unable to recover the target distribution \(p(x)\).
The Variational Auto-encoder
The VAE imposes a prior distribution on the latent space. Traditionally a (multivariate) normal distribution, \(p(z) \sim \mathcal{N}(0, I)\). So that the latent space becomes denser, smoother and more usefully sampled from.
This prior is imposed during training by adding an additional component to the loss function that encourages the latent to be normally distributed. So now our loss function has two parts; (i) the reconstruction loss, and (ii) a regularisation loss that tries to keep the latent distribution normally distributed. This regularisation is typically Kullback–Leibler divergence (\(D_{KL}\)).
\[\mathcal{L} = \mathcal{L}_{\text{reconstruction}} + D_{KL}\]If training goes well the decoder learns \(p(x \mid z)\), i.e. our target distribution \(p(x)\) conditional on the latent \(z\), but now with \(z\) nicely distributed as a normal distribution. So we now sample new \(z\) values from \(\mathcal{N}(0, I)\), pass them into the decoder, and out comes a new sample from \(p(x)\).
Applications
Density representation
VAEs are good at density estimation
But what is density estimation versus sample quality and why does it matter?
Well you should really read about evaluating generative models in my previous post. But in a nutshell, generative models tend not to be able to have their cake and also eat it. They can be very good at making realistic images (high sample quality) but not at generating the full distribution of all possible images (density estimation), for example. In the case of image generation, this is no problem. Sample quality dominates - does this image of a person have the correct number of fingers?
But for my applications, where I want to represent realistic distributions of populations of people, density estimation is key.
VAEs explicitly consider the likelihood of all training samples as being distributed across the latent space. By keeping the latent space smooth and dense, a VAE is forced to encode the full breadth of data, not just the easy parts. This makes VAEs well suited to applications where diversity matters — where you need your model to represent the full distribution, not just its greatest hits.
Efficiency
VAEs are great (the best I think) for when you don’t have much data or compute
The smooth latent space also makes VAEs very data efficient compared to alternatives. Because the encoder and decoder share a structured, continuous latent, the model generalises well from limited data. Where a GAN or auto-regressive approach (think LLM) might need to see millions of examples to learn a good distribution, the VAE’s regularisation provides a strong inductive bias that does some of that work for free.
We can abstractly think of a VAE as spreading its training data as points across a normal distribution. We can then sample infinitely from the gaps between the training data and approximately maintain the original distribution of data. It might be helpful to think of a VAE as learning to interpolate sensibly between data.
This simplification comes at some cost. Firstly, we need to be careful during training to actually make sure the data is correctly distributed across the latent. Secondly, those spaces between the data are not necessarily well mapped by the decoder back to nice (high quality) samples. So the outputs from VAEs are typically considered as low quality. In the case of image generation, they are a bit fuzzy, for example.
VAEs today
The VAE is arguably more relevant now than it was in 2014, just not always under its own name.
VAE-style encoders appear in many models, when an efficient and structured latent representation is useful. The latent diffusion model (Rombach et al., 2022) — the architecture behind Stable Diffusion and many of its descendants — uses a VAE to compress images into a compact latent space before the diffusion process operates on them. The diffusion model never touches pixels directly; it works in the compressed latent space from the VAE. This is computationally far cheaper and, empirically, produces better results. The VAE here is not the generative model, but it is load-bearing infrastructure.
The VAE is perhaps best understood not as a generative model in its own right, but as a very good way of building latent spaces. That turns out to be useful almost everywhere.
The nasty bits of VAEs
Tuning a VAE
VAEs require a balance between the reconstruction and regularisation (\(D_{KL}\)) losses, controlled by a hyperparameter \(\beta\) formalised as the \(\beta\)-VAE by Higgins et al. (2017). High \(\beta\) prioritises regularisation, pushing the latent hard towards the prior (smooth and dense and correctly distributed), but at the cost of reconstruction quality. Low \(\beta\) prioritises reconstructions but produces a messier latent that is harder to sample from.
This forms a trade-off between two failure mechanisms; (i) poor reconstruction by the decoder, and (ii) poor regularisation of the latent. Finding a good trade-off typically requires plenty of search and evaluating generated outputs.
It’s worth noting that \(D_{KL}\) is computed per sample, i.e. it is a local measure of how well each individual encoding matches the prior. A model can have a low average \(D_{KL}\) while still having pockets of the latent that are poorly regularised and will produce nonsense when sampled. The Wasserstein Auto-encoder (WAE) (Tolstikhin et al., 2018) addresses this by replacing the per-sample \(D_{KL}\) with a global distributional penalty, which is more directly aligned with the actual generative objective.
Conditionality is hard
Conditional VAEs (Sohn et al., 2015) — where you want to generate some output conditional on some input (in addition to the random sample of \(z\)) — are tricky. Structurally they are straightforward. Pass \(y\) into both encoder and decoder, and the model learns to generate \(x\) given \(y\) in addition to getting random variation from \(z\).
But in practice, VAEs have a habit of ignoring \(y\), especially if \(y\) is weak or confused. This is the posterior collapse problem, and it is a specific case of a more general issue: the VAE will always find the laziest path to minimising its loss, and that path often involves not using information you wanted it to use.
Various fixes exist — \(D_{KL}\) annealing, free bits, weakening the decoder — but they are all patches on a fundamental tension. The reconstruction loss wants to use every available signal. The \(D_{KL}\) wants to forget everything that isn’t in the prior. Threading between these two pressures is the core difficulty of training VAEs, and conditionality makes it harder.