An introduction to variational autoencoders

November 04, 2023

We are all latent variable models

Here's one way of looking at learning. We interact with the world through observing (hearing, seeing) and acting (speaking, doing). We encode our observations about the world into some representation in our brain – and refine it as we observe more. Our actions reflect this representation.

Encoding & decoding

Imitation is an effective way to learn that engages both observation and action. For example, babies repeat the words of their parents. As they make mistakes and get corrected, they hone their internal representation of the words they hear (the encoder) as well as the way they create their own words from that representation (the decoder).

The baby tries to reconstruct the input via its internal representation. In this case, he incorrectly reconstructs "Dog" as "Dah".

Crudely casting this in machine learning terms, the representation is a vector z\mathbf{z} called a latent variable, which lives in the latent space. The baby is a latent variable model engaged in a task called reconstruction.

A note on notation: when talking about probability, I find it helpful to make explicit whether something is fixed or a variable in a distribution by making fixed things bold. For example, z=[0.12,0.25,0.05,0.33,0.02]\mathbf{z} = [0.12, -0.25, -0.05, 0.33, 0.02] is a fixed vector, p(xz)p(x|\mathbf{z}) is a conditional distribution over possible values of xx. p(x)p(\mathbf{x}) is a number between 00 and 11 (a probability) while p(x)p(x) is a distribution, i.e. a function of xx.

Given observation x\mathbf{x}, the encoder is a distribution q(zx)q(z|\mathbf{x}) over the latent space; knowing x=“Dog"\mathbf{x} = \text{``Dog"}, the encoder tells us which latent variables are probable. To obtain some z\mathbf{z}, we sample from q(zx)q(z|\mathbf{x}).

Similarly, given some latent variable z\mathbf{z}, the decoder is a distribution p(xz)p(x|\mathbf{z}). When sampled from, the decoder produces a reconstructed x~\mathbf{\tilde{x}}.

The latent variable is a vector z\mathbf{z}. The encoder and decoder are both conditional distributions.

The variational autoencoder

When neural networks are used as both the encoder and the decoder, the latent variable model is called a variational autoencoder (VAE).

Variational autoencoders are a type of encoder-decoder model. Figure from this blog post.

The latent space has fewer dimensions than the inputs, so encoding can be viewed as a form of data compression. The baby doesn't retain all the details of each syllable heard – the intricate patterns of each sound wave – only their compressed, salient features.

Evaluation reconstruction

A good model at reconstruction often gets it exactly right: x~=x\mathbf{\tilde{x}} = \mathbf{x}. Given some input x\mathbf{x}, let's pick some random zrand\mathbf{z_{rand}} and look at p(xzrand)p(\mathbf{x}|\mathbf{z_{rand}}): the probability of reconstructing the input perfectly. We want this number to be big.

But that's not really fair: what if we picked a zrand\mathbf{z_{rand}} that the encoder would never choose? After all, the decoder only sees the latent variables produced by the encoder. Ideally, we want to assign more weight to z\mathbf{z}'s that the encoder is more likely to produce:

zlatent spaceq(zx)p(xz)\sum_{\mathbf{z} \in \text{latent space}} q(\mathbf{z}|\mathbf{x}) p(\mathbf{x} | \mathbf{z})

The weighted average is also known as an expectation over q(zx)q(z|\mathbf{x}), written as Ezq(zx)\mathbb E_{\mathbf{z} \sim q(z|\mathbf{x})} :

Pperfect reconstruction(x)=Ezq(zx)[logp(xz)]P_{\text{perfect reconstruction}}(\mathbf{x}) = \mathbb E_{\mathbf{z} \sim q(z|\mathbf{x})}[\log p(\mathbf{x} | \mathbf{z})]

If Pperfect reconstruction(x)P_{\text{perfect reconstruction}}(\mathbf{x}) is high, we can tell our model that it did a good job.

Regularization

Neural networks tend to overfit. Imagine if our encoder learns to give each input it sees during training its unique corner in the latent space, and the decoder cooperates on this obvious signal.

x=“Dog"encoderz=[1,0,0,0,0]decoderx~=“Dog"\mathbf{x} = \text{``Dog"} \xrightarrow{encoder} \mathbf{z} = [1, 0, 0, 0, 0] \xrightarrow{decoder} \mathbf{\tilde{x}} = \text{``Dog"}
x=“Doggy"encoderz=[0,1,0,0,0]decoderx~=“Doggy"\mathbf{x} = \text{``Doggy"} \xrightarrow{encoder} \mathbf{z} = [0, 1, 0, 0, 0] \xrightarrow{decoder} \mathbf{\tilde{x}} = \text{``Doggy"}

We would get perfect reconstruction! But we don't want this. The model failed to capture the close relationship between "Dog" and "Doggy". A good, generalizable model should treat them similarly by assigning them similar latent variables. In other words, we don't want our model to merely memorize and regurgitate the inputs.

While a baby's brain is exceptionally good at dealing with this problem, neural networks need a helping hand. One approach is to guide the distribution of the latent variable to be something simple and nice, like the standard normal:

p(z)=Normal(0,1)p(z) = Normal(0, 1)

We talked previously about KL divergence, a similarity measure between probability distributions; DKL(q(zx)p(z))D_{KL}(q(z | \mathbf{x}) || p(z)) tells us how far the encoder has strayed from the standard normal.

The loss function

Putting everything together, let's write down the intuition that we want the model to 1) reconstruct well and 2) have an encoder distribution close to the standard normal:

ELBO(x)=Ezq(zx)[logp(xz)]DKL(q(zx)p(z))ELBO(\mathbf{x}) = \mathbb E_{\mathbf{z} \sim q(z|\mathbf{x})}[\log p(\mathbf{x} | \mathbf{z})] - D_{KL}(q(\mathbf{z} | \mathbf{x}) || p(\mathbf{z}))

This is the Evidence Lower BOund (ELBO) – we'll explain the name later! – a quantity we want to maximize. The expectation captures our strive for perfect reconstruction, while the KL divergence term acts as a penalty for complex, nonstandard encoder distributions. This technique to prevent overfitting is called regularization.

In machine learning, we're used to minimizing things, so let's define a loss function whose minimization is equivalent to maximizing ELBO:

Loss(x)=ELBO(x)Loss(\mathbf{x}) = - ELBO(\mathbf{x})

Some notes

Forcing p(z)p(z) to be standard normal might seem strange. Don't we want the distribution of zz to be something informative learned by the model? I think about it like this: the encoder and decoder are complex functions with many parameters (they're neural networks!) and they have all the power. Under a sufficiently complex function, p(z)=Normal(0,1)p(z) = Normal(0,1) can be transformed into anything you want. The art is in this transformation.

On the left are samples from a standard normal distribution. On the right are those samples mapped through the function g(z)=z/10+z/zg(z) = z/10 + z/ \lVert z \rVert. VAEs work in a similar way: they learn functions like gg that create arbitrary complex distributions. Figure from 1.

So far, we talked about variational autoencoders purely through the lens of machine learning. Some of the formulations might feel unnatural, e.g. why do we regularize in this weird way?

Variational autoencoders are actually deeply rooted in a field of statistics called variational inference – the first principles behind these decisions. That is the subject of the next section.

Variational Inference

Here's another way to look at the reconstruction problem. The baby has some internal distribution p(z)p(z) over the latent space: his mental model of the world. Every time he hears and repeats a word, he makes some update to this distribution. Learning is nothing but a series of these updates.

Given some word x=Dog"\mathbf{x} = ``Dog", the baby performs the update:

p(z)p(zx)p(z) \leftarrow p(z | \mathbf{x})

p(z)p(z) is the prior distribution (before the update) and p(zx)p(z | \mathbf{x}) is the posterior distribution (after the update). With each observation, the baby computes the posterior and uses it as the prior for the next observation. This approach is called Bayesian inference because to compute the posterior, we use Bayes rule:

p(zx)=p(xz)p(z)p(x)p(z | \mathbf{x}) = \frac{p(\mathbf{x} | z) p(z)}{p(\mathbf{x})}

This formula seems obvious from the manipulation of math-symbols , but I've always found it hard to understand what it actually means. In the rest of this section, I will try to provide an intuitive explanation.

The evidence

One quick aside before we dive in. p(x)p(\mathbf{x}), called the evidence, is a weighted average of probabilities conditional on all possible latent variables z\mathbf{z}:

p(x)=zlatent spacep(z)p(xz)p(\mathbf{x}) = \sum_{\mathbf{z} \in \text{latent space}} p(\mathbf{z})p(\mathbf{x} | \mathbf{z})

p(x)p(\mathbf{x}) is an averaged opinion across all z\mathbf{z}'s that represents our best guess at how probable x\mathbf{x} is.

When the latent space is massive, as in our case, p(x)p(\mathbf{x}) is infeasible to compute.

Bayesian updates

Let's look at Bayes rule purely through the lens of the distribution update: p(z)p(zx)p(z) \leftarrow p(z | \mathbf{x}).

  1. I have some preconception (prior), p(z)p(z)
  2. I see some x\mathbf{x} (e.g. "Dog")
  3. Now I have some updated mental model (posterior), p(zx)p(z | \mathbf{x})

How should the new observation x\mathbf{x} influence my mental model? At the very least, we should increase p(x)p(\mathbf{x}), the probability we assign to observing x\mathbf{x}, since we literally just observed it!

Under the hood, we have a long vector p(z)p(z) with a probability value for each possible z\mathbf{z} in the latent space. With each observation, we update every value in p(z)p(z).

x\mathbf{x}\longrightarrow

Click the update button to adjust p(z)p(z) based on some observed x\mathbf{x}. At each step, the probability associated with each zz is updated. The probabilities are made up.

We can think of these bars (probabilities) as knobs we can tweak to adjust our mental model to better fit each new observation (without losing sight of previous ones).

Understanding the fraction

Let's take some random z\mathbf{z}. Suppose z\mathbf{z} leads me to think that x\mathbf{x} is likely, say 60% (p(xz)=0.6p(\mathbf{x} | \mathbf{z}) = 0.6), while the averaged opinion is only 20% (p(x)=0.2p(\mathbf{x}) = 0.2). Given that we just observed x\mathbf{x}, z\mathbf{z} did better than average. Let's promote it by bumping its assigned probability by:

p(xz)p(x)=0.60.2=3\frac{p(\mathbf{x}|\mathbf{z})}{p(\mathbf{x})} = \frac{0.6}{0.2} = 3

The posterior is:

p(zx)=3p(z)p(\mathbf{z} | \mathbf{x}) = 3 * p(\mathbf{z})

Conversely, if z\mathbf{z} leads me to think that x\mathbf{x} is unlikely, say 20% (p(xz)=0.2p(\mathbf{x} | \mathbf{z}) = 0.2), while the averaged opinion is 60% (p(x)=0.6p(\mathbf{x}) = 0.6), then z\mathbf{z} did worse than the average. Let's decrease its assigned probability:

p(xz)p(x)=0.20.6=1/3    p(zx)=1/3p(z)\frac{p(\mathbf{x}|\mathbf{z})}{p(\mathbf{x})} = \frac{0.2}{0.6} = 1/3 \implies p(\mathbf{z} | \mathbf{x}) = 1/3 * p(\mathbf{z})

Either by promoting an advocate of x\mathbf{x} or demoting a naysayer, we 1) adjust the latent distribution p(z)p(z) to better fit x\mathbf{x} and 2) bring up the average opinion, p(x)p(\mathbf{x}).

That's the essence of the update rule: it's all controlled by the fraction p(xz)p(x)\frac{p(\mathbf{x}|\mathbf{z})}{p(\mathbf{x})}.

Approximating the posterior

As we mentioned, the evidence p(x)p(\mathbf{x}) is impossible to compute because it is a sum over all possible latent variables. Since p(x)p(\mathbf{x}) is the denominator of the Bayesian update, this means that we can't actually compute the posterior distribution – we need to approximate it.

The two most popular methods for approximating complex distributions are Markov Chain Monte Carlo (MCMC) and variational inference. We talked about MCMC previously in various contexts. It uses a trial-and-error approach to generate samples from which we can then learn about the underlying complex distribution.

In contrast, variational inference looks at a family of distributions and tries to pick the best one. For illustration, we assume the observations follow a normal distribution and consider all distributions we get by varying the the mean and variance.

meanvariance

Try adjusting the the mean and variance of the normal distribution to fit the observations (blue dots). In essence, variational inference is all about doing these adjustments.

Variational inference is a principled way to vary these parameters of the distribution (hence the name!) and find a setting of them that best explains the observations. Of course, in practice the distributions are much more complex.

In our case, let's try to use some distribution q(zx)q(z | \mathbf{x}) to approximate p(zx)p(z | \mathbf{x}). We want q(zx)q(z | \mathbf{x}) to be as similar to p(zx)p(z | \mathbf{x}) as possible, which we can enforce by minimizing the KL divergence between them:

DKL(q(zx)p(zx))D_{KL}(q(z | \mathbf{x}) || p(z | \mathbf{x}))

If the KL divergence is 00, then q(zx)q(z | \mathbf{x}) perfectly approximates the posterior p(zx)p(z | \mathbf{x}).

The Evidence Lower Bound (ELBO)

If you're not interested in the mathematical details, this section can be skipped entirely. TLDR: expanding out DKL(q(zx)p(zx))D_{KL}(q(z | \mathbf{x}) || p(z | \mathbf{x})) yields the foundational equation of variational inference at the end of the section.

By definition of KL divergence and applying log rules:

DKL(q(zx)p(zx))=Ezq(zx)[logq(zx)p(zx)]=Ezq(zx)[logq(zx)logp(zx)]\begin{align*} D_{KL}(q(z | \mathbf{x}) || p(z | \mathbf{x})) &= \mathbb E_{\mathbf{z} \sim q(z|\mathbf{x})}\left[\log \frac{q(\mathbf{z} | \mathbf{x})}{p(\mathbf{z} | \mathbf{x})}\right]\\ &= \mathbb E_{\mathbf{z} \sim q(z|\mathbf{x})} \left[\log q(\mathbf{z} | \mathbf{x}) - \log p(\mathbf{z} | \mathbf{x}) \right] \end{align*}

Apply Bayes rule and log rules:

DKL(q(zx)p(zx))=Ezq(zx)[logq(zx)logp(xz)p(z)p(x)]=Ezq(zx)[logq(zx)(logp(xz)+logp(z)logp(x))]=Ezq(zx)[logq(zx)logp(xz)logp(z)+logp(x)]\begin{align*} D_{KL}(q(z | \mathbf{x}) || p(z | \mathbf{x})) &= \mathbb E_{\mathbf{z} \sim q(z|\mathbf{x})} \left[\log q(\mathbf{z} | \mathbf{x}) - \log \frac{p(\mathbf{x} | \mathbf{z})p(\mathbf{z})}{p(\mathbf{x})} \right] \\ &= \mathbb E_{\mathbf{z} \sim q(z|\mathbf{x})} \left[\log q(\mathbf{z} | \mathbf{x}) - (\log p(\mathbf{x} | \mathbf{z}) + \log p(\mathbf{z}) - \log p(\mathbf{x}))\right] \\ &= \mathbb E_{\mathbf{z} \sim q(z|\mathbf{x})} \left[\log q(\mathbf{z} | \mathbf{x}) - \log p(\mathbf{x} | \mathbf{z}) - \log p(\mathbf{z}) + \log p(\mathbf{x})\right] \\ \end{align*}

Move logp(x)\log p(\mathbf{x}) out of the expectation because it doesn't depend on z\mathbf{z}:

DKL(q(zx)p(zx))=Ezq(zx)[logq(zx)logp(xz)logp(z)]+logp(x)D_{KL}(q(z | \mathbf{x}) || p(z | \mathbf{x})) = \mathbb E_{\mathbf{z} \sim q(z|\mathbf{x})} \left[\log q(\mathbf{z} | \mathbf{x}) - \log p(\mathbf{x} | \mathbf{z}) - \log p(\mathbf{z})\right] + \log p(\mathbf{x})

Separate terms into 2 expectations and group with log rules:

DKL(q(zx)p(zx))=Ezq(zx)[logq(zx)p(z)]Ezq(zx)[logp(xz)]+logp(x)D_{KL}(q(z | \mathbf{x}) || p(z | \mathbf{x})) = \mathbb E_{\mathbf{z} \sim q(z|\mathbf{x})} \left[ \log \frac{q(\mathbf{z} | \mathbf{x})}{p(\mathbf{z})} \right] - \mathbb E_{\mathbf{z} \sim q(z|\mathbf{x})} \left[\log p(\mathbf{x} | \mathbf{z})\right] + \log p(\mathbf{x})

The first expectation is a KL divergence: DKL(q(zx)p(z))D_{KL}(q(z | \mathbf{x}) || p(z)). Rewriting and rearranging:

logp(x)DKL(q(zx)p(zx))=Ezq(zx)[logp(xz)]DKL(q(zx)p(z))\log p(\mathbf{x}) - D_{KL}(q(z | \mathbf{x}) || p(z | \mathbf{x})) = \mathbb E_{\mathbf{z} \sim q(z|\mathbf{x})} \left[\log p(\mathbf{x} | \mathbf{z})\right] - D_{KL}(q(z | \mathbf{x}) || p(z))

This is the central equation in variational inference. The right hand side is exactly what we have called the evidence lower bound (ELBO).

Interpreting ELBO

From expanding DKL(q(zx)p(zx))D_{KL}(q(z | \mathbf{x}) || p(z | \mathbf{x})), we got:

logp(x)DKL(q(zx)p(zx))=ELBO(x)\log p(\mathbf{x}) - D_{KL}(q(z | \mathbf{x}) || p(z | \mathbf{x})) = ELBO(\mathbf{x})

Since DKL(q(zx)p(zx))D_{KL}(q(z | \mathbf{x}) || p(z | \mathbf{x})) cannot be negative , ELBO(x)ELBO(\mathbf{x}) is a lower bound on the (log-)evidence, logp(x)\log p(\mathbf{x}). That's why it's called the evidence lower bound!

ELBOlog(p(x))Divergence

Adjust the slider to mimic the process of maximizing ELBO, a lower bound on the (log-)evidence. Since DKL(q(zx)p(zx))D_{KL}(q(z | \mathbf{x}) || p(z | \mathbf{x})) is the "distance" between ELBO and log(p(x))\log(p(\mathbf{x})), our original goal of minimizing it brings ELBO closer to log(p(x))\log(p(\mathbf{x})).

Let's think about the left hand side of the equation. Maximizing ELBO has two desired effects:

  1. increase logp(x)\log p(\mathbf{x}). This is our basic requirement: since we just observed x\mathbf{x}, p(x)p(\mathbf{x}) should go up!

  2. minimize DKL(q(zx)p(zx))D_{KL}(q(z | \mathbf{x}) || p(z | \mathbf{x})), which satisfies our goal of approximating the posterior.

VAEs are neural networks that do variational inference

The machine learning motivations for VAEs we started with (encoder-decoder, reconstruction loss, regularization) are grounded in the statistics of variational inference (Bayesian updates, evidence maximization, posterior approximation). Let's explore the connections:

Variational Inference

VAEs (machine learning)

q(zx)q(z | \mathbf{x})

We couldn't directly compute the posterior p(zx)p(z | \mathbf{x}) in the Bayesian update, so we try to approximate it with q(zx)q(z | \mathbf{x}).

q(zx)q(z | \mathbf{x}) is the encoder. Using a neural network as the encoder gives us the flexibility to do this approximation well.

p(xz)p(x | \mathbf{z})

Ezq(zx)[logp(xz)]\mathbb E_{\mathbf{z} \sim q(z|\mathbf{x})} \left[\log p(\mathbf{x} | \mathbf{z})\right] fell out as a term in ELBO whose maximization accomplishes the dual goal of maximizing the intractable evidence, logp(x)\log p(\mathbf{x}), and bringing q(zx)q(z | \mathbf{x}) close to p(zx)p(z | \mathbf{x}).

p(xz)p(x | \mathbf{z}) is the decoder, also a neural network. Ezq(zx)[logp(xz)]\mathbb E_{\mathbf{z} \sim q(z|\mathbf{x})} \left[\log p(\mathbf{x} | \mathbf{z})\right] is the probability of perfect reconstruction. It makes sense to strive for perfect reconstruction and maximize this probability.

p(z)p(z)

p(z)p(z) is the prior we use before seeing any observations. p(z)Normal(0,1)p(z) \sim Normal(0, 1) is a reasonable choice. It's a starting point. It would take a lot of observations that disobey Normal(0,1)Normal(0, 1) to, via Bayesian updates, convince us of a drastically different latent distribution.

Our encoder and decoder are both neural networks. They're just black-box learners of complex distributions with no concept of priors. They can easily conjure up a wildly complex distribution – nothing like Normal(0,1)Normal(0, 1) – that merely memorizes the observations, a problem called overfitting.

To prevent this, we constantly nudge the encoder q(zx)q(z | \mathbf{x}) towards Normal(0,1)Normal(0, 1), as a reminder of where it would have started if we were using traditional Bayesian updates. When viewed this way, DKL(q(zx)p(z))D_{KL}(q(z | \mathbf{x}) || p(z)) is a regularization term.

Modeling protein sequences

Pair-wise models are limiting

In a previous post, we talked about ways to extract the information hidden in Multiple Sequence Alignments (MSAs): the co-evolutionary data of proteins. For example, amino acid positions that co-vary in the MSA tend to interact with each other in the folded structure, often via direct 3D contact.

MSA

An MSA contains different variants of a sequence. The structure sketches how the amino acid chain might fold in space (try dragging the nodes). Hover over each row in the MSA to see the corresponding amino acid in the folded structure. Hover over the blue link to highlight the contacting positions.

We talked about position-wise models that look at each position and pair-wise models that consider all possible pairs of positions. But what about the interactions between 3 positions? Or even more? Those higher-order interactions are commonplace in natural proteins but modelling them is unfortunately computationally infeasible.

Variational autoencoders for proteins

Let's imagine there being some latent variable vector z\mathbf{z} that explains all interactions – including higher-order ones.

Applying latent variable models like VAEs to MSAs. Figure from 2.

Like the mysterious representation hidden in the baby's brain, we don't need to understand exactly how it encodes these higher-order interactions; we let the neural networks, guided by the reconstruction task, figure it out.

In this work, researchers from the Marks lab did exactly this to create a VAE model called DeepSequence. I will do a deep dive on this model – and variants of it – in the next post!

Further reading

I am inspired by this blog post by Jaan Altosaar and this blog post by Lilian Weng, both of which are superb and go into more technical details.

Also, check out the cool paper from the Marks lab applying VAEs to protein sequences. You should have the theoretical tools to understand it well.

References

  1. Doersch, C. Tutorial on Variational Autoencoders. arXiv (2016).

  2. Riesselman, A.J. et al. Deep generative models of genetic variation capture the effects of mutations. Nat Methods 15, 816–822 (2018).

  1. Doersch, C. Tutorial on Variational Autoencoders. arXiv (2016).

  2. Riesselman, A.J. et al. Deep generative models of genetic variation capture the effects of mutations. Nat Methods 15, 816–822 (2018).

  1. We also use log probability for mathematical convenience.
  2. p(x,z)=p(zx)p(x)=p(xz)p(z)p(x,z) = p(z|x)p(x) = p(x|z)p(z), from the definition of conditional probability. To see Bayes rule, simply take p(zx)p(x)=p(xz)p(z)p(z|x)p(x) = p(x|z)p(z) and divide both sides by p(x)p(x).
  3. This is a property of KL divergence.


Profile picture

Written by Liam Bai who works on software at Ginkgo Bioworks and writes about math, AI, and biology. He's on LinkedIn and Twitter.