in Machine Learning

Deriving the variational lower bound

Basic properties of the variational lower bound, a.k.a. ELBO (evidence lower bound).

Often in probabilistic modelling, we are interested in maximising the probability of some observed data given the model, by tuning the model parameters \theta to maximise \prod_i  p_\theta(x_i) where x_i are the observed data. The fact that we are maximising the product of the p_\theta(x_i) corresponds to an assumption that each x_i is drawn i.i.d. from the true data distribution p_{\text{data}}(x).

In practice, it’s mathematically and computationally much more convenient to consider the logarithm of the product, so that our objective to maximise with respect to \theta is:

\sum_i \log p_\theta(x_i)

In the rest of this post we’ll simplify things by just considering \log p_\theta(x) for a single data point x.

Latent Variable Models

In a Latent Variable Model (LVM), as is the case for Variational Autoencoders, our model distribution is obtained by combining a simple distribution p(z) with a parametrised family of conditional distributions p_\theta(x|z), so that out objective can be written

\log p_\theta(x) = \log \left( \int p_\theta(x|z) p(z) dz \right).

Although p(z) and p_\theta(x|z) will generally be simple by choice, it may be impossible to compute \log p_\theta(x) analytically due to the need to solve the integral inside the logarithm. In many practical situations (e.g. anything involving neural networks), we’d not only like to be able to evaluate \log p_\theta(x) but also differentiate it with respect to \theta if we are to fit the model.

Variational Inference

The magic of variational inference hinges on the following two key observations.

First, we can choose any distribution q(z), multiply the inside of the integral by \frac{q(z)}{q(z)} and rearrange without changing its value. (This has a strong connection to Importance Sampling, see below.) Thus we can rewrite our objective as

\log p_\theta(x) = \log \left( \int p_\theta(x|z) \frac{p(z)}{q(z)} q(z) dz \right).

Second, since \log is concave and the integral can be written as an expectation, we can use Jensen’s inequality to swap the \log and \mathbb{E}. This results in a (variational) lower bound consisting of terms we can evaluate, provided we have chosen p_\theta(x|z), p(z) and q(z) suitably:

\begin{aligned} \log p_\theta(x) &= \log \left( \mathbb{E}_{q(z)} p_\theta(x|z) \frac{p(z)}{q(z)} \right) \\&\geq \mathbb{E}_{q(z)} \left[ \log p_\theta(x|z) + \log p(z) - \log q(z) \right] \\&=\mathbb{E}_{q(z)} \left[ \log p_\theta(x|z) \right] - \text{KL}\left[q(z) || p(z) \right] \end{aligned}

Recall that the above inequality holds for any q(z). Since we are probably interested in fitting the model to multiple data points, we can substitute q(z) with q_\phi(z|x), depending on x and a parameter \phi. This is the notation you’ll often see in the literature, (e.g. the original VAE paper, equation (3))

\begin{aligned} \log p_\theta(x) &\geq \mathbb{E}_{q_\phi(z|x)} \left[ \log p_\theta(x|z) \right] - \text{KL}\left[q_\phi(z|x) || p(z) \right] =: \mathcal{L}(x, \theta, \phi) \end{aligned}

Note that the terms variational lower bound, evidence lower bound and ELBO are used interchangeably in the literature.

How tight is the variational lower bound?

By properties of the logarithm and one application of Bayes’ rule, it’s straightforward to calculate the tightness of this bound.

\begin{aligned} &\log p_\theta(x) - \mathcal{L}(x, \theta, \phi)\\&=\log p_\theta(x) - \mathbb{E}_{q_\phi(z|x)} \left[ \log p_\theta(x|z) \right] + \text{KL}\left[q_\phi(z|x) || p(z) \right] \\&= \mathbb{E}_{q_\phi(z|x)} \left[\log p_\theta(x) - \log p_\theta(x|z) + \log q_\phi(z|x) - \log p(z) \right] \\&= \mathbb{E}_{q_\phi(z|x)} \left[\log p_\theta(x) - \log \frac{p_\theta(z|x) {p_\theta(x)}}{p(z)} + \log q_\phi(z|x) - \log p(z) \right] \\&= \mathbb{E}_{q_\phi(z|x)} \left[\log q_\phi(z|x) - \log p_\theta(z|x) \right] \\&= \text{KL}\left[q_\phi(z|x) || p_\theta(z|x) \right]\end{aligned}

Summary

Writing the above equations in a slightly compressed form, we have

\log p_\theta(x) = \mathcal{L}(x, \theta, \phi) + \text{KL}\left[q_\phi(z|x) || p_\theta(z|x)\right] \geq \mathcal{L}(x, \theta, \phi)

To repeat this in words: the Jensen gap of the variational lower bound is the KL divergence between q_\phi(z|x) and the true posterior p_\theta(z|x). For a fixed \theta, maximising \mathcal{L}(x, \theta, \phi) with respect to \phi is equivalent to minimising \text{KL}\left[q_\phi(z|x) || p_\theta(z|x)\right]. This is why q_\phi(z|x) is often called the approximate posterior.


Bonus material: connection to Importance Sampling

In principle, you could think about trying to numerically approximate the integral \int p_\theta(x|z) p(z) dz by Monte Carlo sampling: draw a bunch of samples z_1, \ldots, z_k \sim p(z) and estimate the integral as

\int p_\theta(x|z) p(z) dz = \mathbb{E}_{p(z)}p_\theta(x|z) \approx \frac{1}{k}\sum_{i=1}^k p_\theta(x|z_i) .

Of course, this probably wouldn’t help for fitting the model, as performing Monte Carlo integration as part of an inner optimization loop would be painfully slow. But there’s a second reason that this is a sub-optimal course of action.

p_\theta(x|z) is the probability of the particular data point x given z. Let’s suppose that for each z, p_\theta(x|z) only puts a significant amount of probability mass on a small set of x, and that this set differs as we vary z. (Note: this will be the case with Gaussian decoders with concentrated covarainces for most non-trivial datasets.) Then for a fixed x, p_\theta(x|z) will be very small for most values of z and massive for a tiny set of values. In other words, our estimator \frac{1}{k}\sum_{i=1}^k p_\theta(x|z_i) will have extremely large variance.

We can improve things by using a trick called Importance Sampling, which really amounts to the observation that for any distribution q(z), multiplying the integrand by \frac{q(z)}{q(z)} and rearranging doesn’t change the value of the integral.

\begin{aligned}\int p_\theta(x|z) p(z) dz &= \int p_\theta(x|z) \frac{p(z)}{q(z)} q(z) dz \\ &= \mathbb{E}_{q(z)}p_\theta(x|z) \frac{p(z)}{q(z)} \\&\approx \frac{1}{k}\sum_{i=1}^k p_\theta(x|z_i)\frac{p(z_i)}{q(z_i)} \qquad z_1, \ldots, z_k \sim q(z) \end{aligned}

The idea here is that if q(z) is chosen to put more mass on values of z for which p_\theta(x|z) is large, the variance of the importance sampling estimator will have lower variance than the naive one. In fact, if we could choose q(z) = p_\theta(z|x) — the posterior distribution over z — our estimator would have variance zero! This means it would be possible to perfectly estimate the integral with only one sample. To see this, observe that by Bayes’ rule,

\begin{aligned} p_\theta(x|z)\frac{p(z)}{p_\theta(z|x)} &= p_\theta(x|z)\frac{p(z)p_\theta(x)}{p_\theta(x|z)p(z)} \\ &=p_\theta(x)\end{aligned}

So regardless of which z\sim p_\theta(z|x) we would draw, our one-sample Monte Carlo estimator would give the correct answer. Unfortunately, calculating p_\theta(z|x) itself requires knowing the value of p_\theta(x), so this insight doesn’t give us a trick to quickly calculate p_\theta(x)! It does, however, give us a connection to the Jensen gap of the variational bound. Since p_\theta(x|z)\frac{p(z)}{p_\theta(z|x)} is constant in z, \mathbb{E}_{p_\theta(z|x)}p_\theta(x|z) \frac{p(z)}{p_\theta(z|x)} is the expectation of a constant function and thus

\log p_\theta(x) = \log\left(\mathbb{E}_{p_\theta(z|x)}p_\theta(x|z) \frac{p(z)}{p_\theta(z|x)}\right) = \mathbb{E}_{p_\theta(z|x)} \log\left( p_\theta(x|z) \frac{p(z)}{p_\theta(z|x)}\right)

The right hand side is the variational lower bound with q_\phi(z|x) = p_\theta(z|x). This equation says that this bound is tight when the approximate posterior is equal to the true posterior, which we already learned above.