Variational inference

Machine Learning

Introduction

Variational inference is an approach to approximate conditional densities of latent variables given observed variables, especially useful in the case of complicated distributions. It is widely used to approximate posterior densities for Bayesian models, as an alternative to Markov chain Monte Carlo (MCMC) sampling methods. In the machine learning domain, it has been used extensively in the inference over topic models such as latent Dirichlet allocation (LDA) due to its greater speed and easier scalability, compared to MCMC approaches.

In this article, we will explain the foundations of variational inference and then provide an example walk-through for inferring densities over latent variables in Gaussian mixture models using variational inference.

Prerequisites

To understand variational inference, we recommend familiarity with the concepts in

Follow the above links to first get acquainted with the corresponding concepts.

Problem setting

Variational inference is used in latent variable models — a bipartite distribution with hidden (latent) variables that govern the generation of the observable variables.

Suppose the observable variables are denoted as a \( \ndim \)-dimensional multivariate random variable \( \vx \) and the hidden variables are denoted as an \( \nclass \)-dimensional multivariate random variable \( \vz \). The joint density of \( \vx \) and \( \vz \) is

$$ p(\vx,\vz) = p(\vz) p(\vx | \vz) $$

Intuitively, this formulation amounts to the following steps in data generation.

  1. Draw the latent variables \( \vz \) from the prior distribution \( p(\vz) \).
  2. Sample \( \vx \) from the likelihood of the observable variables given \( \vz \).

During inference on such Bayesian models, we need to infer the posterior distribution conditioned on the observed data — \( p(\vz | \vx) \). Since \( \vz \) are latent, this is often challenging, especially for complicated distributions. In machine learning, this is where variational inference approaches typically help with — the inference of \( p(\vz | \vx) \).

Variational inference: The main idea

One way of address inference in complicated Bayesian models with latent variables is to use MCMC sampling. Although quite accurate, it is slow to converge, and typically does not scale well.

Instead of the more traditional way of MCMC sampling, why not use an optimization approach to perform inference? The main idea behind variational inference is just that — solve inference as an optimization problem.

Intuitively, we assume that the conditional distribution \( p(\vz | \vx) \) must appear in some family of distributions \( \mathcal{Q} \). Then, we find the distribution \( q(\vz) \) in the search space \( \mathcal{Q} \) that is most similar to \( p(\vz | \vx) \). For the similarity score, variational inferences uses Kullback-Leibler divergence (KL-divergence) as the distance metric for comparing the distributions.

We formalize this intuition as

\begin{equation} \star{q}(\vz) = \argmin_{q(\vz) \in \mathcal{Q}} D_{\text{KL}}\left( q(\vz) || p(\vz|\vx) \right) \end{equation}

where, \( D_{\text{KL}} (q || p) \) is the KL-divergence between the distributions \( q \) and \( p \).

In the next couple of sections, we will understand the reasons behinds the choice of KL-divergence as a loss function to be minimized for variational inference.

From likelihood to KL-divergence

Typically, for a generative model, a good model should have a high log-likelihood of having generated the observed data, the so-called evidence. In other words, for a data point \( \vx \), it is desirable to have high value of \( \log p(\vx) \). Expanding this in terms of the latent variables we get,

\begin{aligned} \log p(\vx) &= \log \left(\int_{\vz} p(\vx,\vz) d\vz \right) \\\\ &= \log \left(\int_{\vz} p(\vx,\vz) \frac{q(\vz)}{q(\vz)} d\vz \right) \\\\ &= \log \left(\int_{\vz} \frac{p(\vx,\vz)}{q(\vz)} q(\vz) d\vz \right) \\\\ &= \log \left(\expect{\vz \sim q(\vz)}{\frac{p(\vx,\vz)}{q(\vz)}} \right) \\\\ \end{aligned}

In these steps, we first exploded the marginal into an integral over joint probability. If \( \vz \) follows a discrete probability distribution, then imagine a summation instead of the integral. Nevertheless, the remaining analysis remains the same with integrals replaced by summations.

Then, in the second step, we retained equality by multiplying the numerator and denominator with an unrelated distribution \( q(\vz) \). Finally, through some simple mathematical manipulation, we identified this quantity to be an expectation over the distribution \( q(\vz) \), because \( \expect{a}{f(b,a)} = \int_{a} f(b,a) p(a) da \).

Now, note that logarithm is a concave function. So, by Jensen's inequality, \( \log \expect{}{a} \ge \expect{}{\log a} \).

\begin{aligned} \log p(\vx) &= \log \left(\expect{\vz \sim q(\vz)}{\frac{p(\vx,\vz)}{q(\vz)}} \right) \\\\ &\ge \expect{\vz \sim q(\vz)}{\log \left(\frac{p(\vx,\vz)}{q(\vz)} \right)} \\\\
\end{aligned}

Thus, using Jensen's inequality, we have upper bounded the log-likelihood.

Trudging along, with some basic mathematical manipulation, we get,

\begin{aligned} \log p(\vx) &\ge \expect{\vz \sim q(\vz)}{\log \left(\frac{p(\vx,\vz)}{q(\vz)} \right)} \\\\
&\ge \expect{\vz \sim q(\vz)}{\log \left(\frac{p(\vx)p(\vz|\vx)}{q(\vz)} \right)} \\\\
&\ge \expect{\vz \sim q(\vz)}{\log p(\vx)} + \expect{\vz \sim q(\vz)}{\log \frac{p(\vz|\vx)}{q(\vz)}}\\\\ &\ge \expect{\vz \sim q(\vz)}{\log p(\vx)} - D_{\text{KL}}\left( q(\vz) || p(\vz|\vx) \right) \\\\ \label{eqn:lik-to-kldiv} \end{aligned}

where, we have replaced \( \expect{\vz \sim q(\vz)}{\log \frac{p(\vz|\vx)}{q(\vz)}} \) with the KL-divergence term \( D_{\text{KL}}\left( q(\vz) || p(\vz|\vx) \right) \) because, \( D_{\text{KL}}(q || p) = \expect{a \sim q}{\log \frac{q(a)}{p(a)}} \).

The first term, \( \expect{\vz \sim q(\vz)}{\log p(\vx)} \), in the last equation is a constant with respect to the distribution \( q(\vz) \). Thus, to maximize the overall sum, the evidence log-likelihood \( \log p(\vx) \), we need to minimize the second term, the KL-divergence.

In fact, if we are able to minimize the KL-divergence term to zero, the equality will hold, as it is easy to check.

There, minimizing the KL-divergence of the approximate distribution \( q(\vz) \) to the true conditional density \( p(\vz|\vx) \) is the key to maximizing the log-likelihood of the evidence.

Evidence lower bound (ELBO)

Now that we know that the KL-divergence term needs to be minimized, we may have another challenge! Primarily that \( p(\vz|\vx) \) may be difficult to model accurately, given the complicated relationships among elements of \( \vz \) and \( \vx \). So, let's simplify the last equation by converting the conditional density \( p(\vz|\vx) \) into a joint density as \( \frac{p(\vz,\vx)}{p(\vx)} \) (kind of working backwards through the previous equations!).

\begin{aligned} \log p(\vx) &\ge \expect{\vz \sim q(\vz)}{\log p(\vx)} - D_{\text{KL}}\left( q(\vz) || p(\vz|\vx) \right) \\\\ &\ge \expect{\vz \sim q(\vz)}{\log p(\vx)} + \expect{\vz \sim q(\vz)}{\log \frac{p(\vz|\vx)}{q(\vz)}}\\\\ &\ge \expect{\vz \sim q(\vz)}{\log p(\vx)} + \expect{\vz \sim q(\vz)}{\log \frac{p(\vz,\vx)}{p(\vx)q(\vz)}}\\\\ &\ge \expect{\vz \sim q(\vz)}{\log p(\vx)} + \expect{\vz \sim q(\vz)}{\log p(\vz,\vx)} - \expect{\vz \sim q(\vz)}{\log p(\vx)} - \expect{\vz \sim q(\vz)}{\log q(\vz)}\\\\ &\ge \expect{\vz \sim q(\vz)}{\log p(\vz,\vx)} - \expect{\vz \sim q(\vz)}{\log q(\vz)}\\\\ \label{eqn:elbo} \end{aligned}

This last bound that we derived in Equation \eqref{eqn:elbo} is known as the variational lower bound. It is also known as evidence lower bound (ELBO) because it is a lower bound on the evidence \( \log p(\vx) \). Maximizing the likelihood is equivalent to maximizing the value of this lower bound.

We simplify notation by writing the ELBO in terms of a function of \( q \)

$$ \text{ELBO}(q) = \expect{\vz \sim q(\vz)}{\log p(\vz,\vx)} - \expect{\vz \sim q(\vz)}{\log q(\vz)} $$

Therefore, although the motivation is to minimize the KL-divergence \( D_{\text{KL}}\left( q(\vz)||p(\vz|\vx) \right) \), when it is infeasible to compute \( p(\vz|\vx) \), variational inference approaches optimize to maximize the ELBO for practical reasons.

\begin{equation} \star{q} = \argmax_{q \in \mathcal{Q}} \text{ELBO}(q)
\end{equation}

Factorized distributions

Now that we have set up the inference problem in terms of optimizing for \( q(\vz) \), we need to understand suitable choices for the search space \( \mathcal{Q} \). With many latent variables, with complex interdependencies, \( \mathcal{Q} \) can become a complicated and infeasible search space, because \( q(\vz) \) is jointly modeling all those latent variables.

Instead, the practice in variational inference is to use simpler distributions. This is typically achieved by the concept of factorized distributions — the idea of splitting (factorizing) a joint distribution into a product of easily modeled disjoint groups. For example, imagine that we split the vector \( \vz \) into \( K \) such groups — \( \vz_1,\ldots,\vz_K \) — such that each group can be modeled easily. In this case, the factorized joint distribution is

\begin{equation} q(\vz) = \prod_{k=1}^K q_k(\vz_k) \label{eqn:factorized-distribution} \end{equation}

Note the subscript on factorized distributions \( q_k \). This is intentional to denote that these group-level distributions are specific to that group. Moreover, the simple product of these factors implies that these groups are independent of each other.

This factorized form commonly used in variational inference corresponds to an approximation framework used in physics known as mean field theory.

Equipped with an optimization problem in Equation \eqref{eqn:vi-opt} and simplifying factorization such as that in Equation \eqref{eqn:factorized-distribution}, we have significantly reduced the challenge of modeling complicated conditional densities. We can choose any simplifying factorization, and then use the loss function to help us calculate the conditional density by minimizing the KL-divergence. For groups of continuous latent variables, we can use a continuous factorization such as a multivariate Gaussian distribution and for groups of discrete latent variables, we can use some categorical distribution for the factorized distribution. Choosing them for simplifying the problem. Such is the power of variational inference.

Optimum factors

Given that we can adopt any simple factors for our factorized joint distribution of latent variables, we are now ready to derive a generic solution for optimal value of each factor using the evidence lower bound.

Recapping some previous equations

$$ \star{q} = \argmax_{q \in \mathcal{Q}} \text{ELBO}(q) $$

where,

$$ \text{ELBO}(q) = \expect{\vz \sim q(\vz)}{\log p(\vz,\vx)} - \expect{\vz \sim q(\vz)}{\log q(\vz)} $$

Suppose we are using the factorization, \( q(\vz) = \prod_{k=1}^K q_k(\vz_k) \). We can now write the ELBO in terms of these factors.

To optimize for the factor over the group \( \vz_k \), we can treat other variables \( \vz_{-k} = \set{\vz_1,\ldots,\vz_{k-1},\vz_{k+1},\ldots,\vz_{K}} \) as observable and rewrite the ELBO as

\begin{aligned} \text{ELBO}(\star{q}_k) &= \expect{\vz_k \sim q_k(\vz_k)}{\expect{\vz_{-k}}{\log p(\vz_k, \vz_{-k}, \vx)}} - \expect{\vz_k \sim q_k(\vz)}{\expect{\vz_{-k}}{\log \left(q_k(\vz_k) \prod_{\vz_j \in \vz_{-k}} q_j(\vz_j) \right)}} \\\\ &= \expect{\vz_k \sim q_k(\vz_k)}{\expect{\vz_{-k} \sim q(\vz_{-k}}{\log p(\vz_k, \vz_{-k}, \vx)}} - \expect{\vz_k \sim q_k(\vz_k)}{\expect{\vz_{-k} \sim }{\log q_k(\vz_k)}} - \expect{\vz_k \sim q_k(\vz_k)}{\expect{\vz_{-k}}{\log \left(\prod_{\vz_j \in \vz_{-k}} q_j(\vz_j) \right)}} \\\\ &= \expect{\vz_k \sim q_k(\vz_k)}{\expect{\vz_{-k}}{\log p(\vz_k, \vz_{-k}, \vx)}} - \expect{\vz_k \sim q_k(\vz_k)}{\expect{\vz_{-k}}{\log q_k(\vz_k)}} + \text{constant} \\\\ &= \expect{\vz_k \sim q_k(\vz_k)}{\log a} - \expect{\vz_k \sim q_k(\vz_k)}{\log q_k(\vz_k)} \\\\ &= - \left( \expect{\vz_k \sim q_k(\vz_k)}{\log \frac{q_k(\vz_k)}{a}} \right) \\\\ &= - \left( D_{\text{KL}}(q_k(\vz_k) || a) \right)\\\\ \end{aligned}

In the third step, we folded all the terms that are constant with respect to the expectation over \( q_k(\vz_k) \) into a \( \text{constant} \).

Then, in the fourth step, to simplify notation, we used a substitute variable \( a \), such that

$$ a = \alpha \textexp{ \expect{\vz_{-k}}{\log p(\vz_k, \vz_{-k}, \vx)}} $$

where, the constant \( \alpha \) consumes all the constant terms from the previous step.

With this substitution, we find in the last step, that the ELBO of the variable group \( q_k(\vz_k) \) is actually equal to the negative of KL divergence between \( a \) and \( q_k(\vz_k) \) plus some constant term. Thus, maximizing the KL divergence is equivalent to minimizing the KL-divergence term. The KL-divergence between two distributions will be minimized when the two distributions are equal. That is, when \( q_k(\vz_k) \) is equal to \( a\).

In other words, the ELBO is maximized for the optimum factor

\begin{equation} \star{q}_k(\vz_k) \propto \textexp{ \expect{\vz_{-k}}{\log p(\vz_k, \vz_{-k}, \vx)}} \label{eqn:optimum-factors} \end{equation}

where, we have subsumed the constant \( \alpha \) into the direct proportionality. In practice, the constant is figured out by inspection, as will be clear in subsequent discussion.

Coordinate ascent variational inference (CAVI)

In the Equation \eqref{eqn:optimum-factors} for finding optimum factors, we find that \( q_k(\vz_k) \) depends on the expectation over all other latent variables \( \vz_{-k} \), which themselves may be unknown initially.

This suggests an iterative strategy for arriving at the optimum factors for all variables.

  1. Initialize the factors \( q_k(\vz_k) \) for all factors
  2. For all \( k = 1, \ldots, K \), iteratively set \( q_k \propto \textexp{ \expect{\vz_{-k}}{\log p(\vz_k, \vz_{-k}, \vx)}} \)
  3. Compute \( \text{ELBO}(q) \). If ELBO has not converged from its previous computation, go back to step 2.

This coordinate-wise iterative algorithm is known as coordinate ascent variational inference (CAVI).

Recipe for variational inference

Given all that we have covered so far, the recipe for variation inference is quite straightforward.

  1. Identify if the model you are dealing with has two groups of variables — latent and observable.
  2. Factorize the joint distribution of the latent variables into easier and well-modeled densities of groups of latent variables, the so-called factors.
  3. Establish the evidence lower bound (ELBO)
  4. Calculate the optimum factors based on maximizing the ELBO for the factors, based on the CAVI procedure.
  5. That's it!

A simple example: Univariate Gaussian with priors over mean and precision

Given a univariate dataset \( \dataset = \set{x_1,\ldots,x_\ndata} \), our goal is to infer the posterior distribution for the mean \( \mu \) and precision \( \tau \) of the univariate Gaussian that generated the dataset.

The likelihood function is

$$ p(\dataset | \mu, \tau) = \left( \frac{\tau}{2\pi} \right)^{\ndata/2} \textexp{-\frac{\tau}{2} \sum_{\ndatasmall=1}^{\ndata} (x_\ndatasmall - \mu)^2} $$

For Bayesian modeling, we will introduce conjugate priors for the mean and precision. For the precision, \( \tau \), we will introduce the conjugate Gamma prior

$$ p(\tau) = \text{Gamma}(\tau | a_0, b_0) $$

We will model the mean, from the conjugate Gaussian prior,

$$ p(\mu) = \Gauss{\mu | \mu_0, (\lambda_0\tau)^{-1}} $$

Given that our latent variables are \( \mu \) and \( \tau \), we will use the simple factorized distribution as follows

$$ q(\mu,\tau) = q_\mu(\mu) q_\tau(\tau) $$

Now, let's find \( \star{q}_\mu \) using ELBO.

The analysis for finding \( \star{q}_\tau \) is similar and we recommend you to try it out on your own.

Example: Finding \( \star{q}_\mu(\mu) \)

We know from the section on optimum factors that

\begin{aligned} \star{q}_\mu(\mu) &\propto \textexp{\expect{\tau \sim q_\tau(\tau)}{\log p(\mu, \tau, \dataset)}} \\\\ &\propto \textexp{\expect{\tau \sim q_\tau(\tau)}{\log p(\dataset | \mu, \tau) + \log p(\mu | \tau) + \text{constant}}} \\\\ &\propto \textexp{-\frac{\expect{\tau \sim q_\tau(\tau)}{\tau}}{2} \left(\lambda_0(\mu - \mu_0)^2 + \sum_{\ndatasmall=1}^{\ndata} (x_\ndatasmall - \mu)^2 \right) + \text{constant}} \end{aligned}

On close inspection of the final equation, if we complete the square over \( \mu \), we find that \( \star{q}_\mu(\mu) \) is a Gaussian

$$ \star{q}_\mu(\mu) = \Gauss(\mu | \mu_\ndata, \lambda_\ndata^{-1}) $$

where,

$$ \mu_\ndata = \frac{\lambda_0\mu_0 + \ndata \bar{x}}{\lambda_0 + \ndata} $$

and

$$ \lambda_\ndata = (\lambda_0 + \ndata) \expect{\tau \sim q_\tau(\tau)}{\tau} $$

Although we did not assume any functional form for \( q_\mu(\mu) \), we discovered it to be a Gaussian functioned on the observations in the available data \( \dataset \) through the mean \( \bar{x} \).

Please support us

Help us create more engaging and effective content and keep it free of paywalls and advertisements!

Subscribe for article updates

Stay up to date with new material for free.