Mixture density networks

Deep Learning

Introduction

A mixture density network is a deep feedforward network designed to output the probability density function for a multimodal regression problem. The underlying multimodal model is a Gaussian mixture model, characterized by mixing proportions, means, and covariances of the mixture components. The deep feedforward network learns to output these properties to enable the calculation of the probability density function.

Prerequisites

To understand mixture density networks, we recommend familiarity with the concepts in

Follow the above links to first get acquainted with the corresponding concepts.

A GMM recap

We have studied in our comprehensive article on Gaussian mixture models (GMM). The GMM models the data to have been generated from a mixture of components, each being a Gaussian distribution. In this context, the probability of an example \( \vx \) is calculated as

\begin{equation} p(\vx) = \sum_{k=1}^K P(c=k) \Gauss(\vx;\vmu_k,\mSigma_k) \label{eqn:gmm-pdf} \end{equation}

where, \( c \in \set{C_1, \ldots, C_K} \) is the collection of mixture components, each Gaussian such that the \( k\)-th component has mean \( \vmu_k \) and covariance \( \mSigma_k \).

Problem setting

In multimodal regression, the goal is to predict real-valued outputs from a conditional distribution that can have several peaks for the target variable for the same input variable.

Consider observations of the form \( \vx \in \real^\ndim \) — vectors consisting of \( \ndim \) features, \(\vx = [x_1, x_2, \ldots, x_\ndim] \).

For multimodal regression, we wish to predict \( \vy \in \real^\nclass \) from the conditional distribution \( p(\vy | \vx) \) when the same value of \( \vx \) can lead to a multiple peaks for the distribution of \( \vy \).

Multimodal regression

GMMs are suitable to modeling multimodal data — data with multiple peaks in their probability density. Since the same value of input can lead to multiple peaks in the distribution \( p(\vy|\vx) \), we can infer a GMM at each point in the input space. To enable this, the \( k \)-th Gaussian component at the input \( \vx \) will have the mean \( \vmu_k(\vx) \) and the covariance \( \mSigma_k(\vx) \).

Thus, the conditional distribution \( p(\vy|\vx) \) using such a GMM is

\begin{equation} p(\vy|\vx) = \sum_{k=1}^K P(c=k | \vx) \Gauss(\vy;\vmu_k(\vx),\mSigma_k(\vx)) \label{eqn:mdn-pdf} \end{equation}

Note that each input example \( \vx \) has its own custom GMM, since the GMM is conditional on the input.

To compute the conditional distribution in Equation \eqref{eqn:mdn-pdf}, we clearly need three quantities — \( p(c=k|\vx),~ \vmu_k(\vx),~ \mSigma_k(\vx) \) for all mixture components \( k = 1, \ldots, K \).

Mixture density networks provide as outputs these quantities to support the calculation of the conditional distribution.

Mixture proportions \( p(c=k|\vx) \)

As outlined in the article on GMM, the mixture proportions form a multinomial distribution over the \( K \) components so that

$$ \sum_{k=1}^K p(c=1|\vx) = 1 $$

In the context of deep feedforward network, such an output can be obtained using a softmax over a \(K\)-dimensional output vector. The softmax will ensure that each output is positive and that the elements of the vector sum to 1, thereby correctly modeling a multinomial distribution.

Means: \( \vmu_k(\vx) \)

Each component \( k \) will have a mean \( \vmu_k (\vx) \) that is a function of the input. The mean will have the same dimensionality as the target variable. This means, if \( \vy \in \real^\nclass \), then \( \vmu_k (\vx) \in \real^\nclass \). This means, collectively over all the components, we will need to output \( K \times \nclass \) outputs for the mean variables. These will be modeled in a mixture density network as an output matrix.

Covariances: \( \mSigma_k (\vx) \)

For simplicity, each of the covariances is assumed to be a diagonal matrix since full covariances might lead to additional challenges such as the need to ensure positive semidefiniteness. This means, the covariance of each component is represented as an \(\nclass\)-dimensional diagonal vector. Collectively, we will need to output \( K \times \nclass \) outputs for the covariances of the \( K \) components of the mixture.

Training

Training a mixture density network is similar to training a deep feedforward network. The loss function of choice is negative log-likelihood under the mixture model. Then, the training proceeds by the way of gradient-based optimization, using backpropagation for computing the gradients.

That being said, there are some challenges to be dealt with. Not just because of the number of outputs, but due to the nature of the outputs. The true mixture component of an example is unknown. It is a latent variable. This means, the learning process has to correctly attribute the loss (and the gradient) to the appropriate component.

Another challenge is that of numerical stability of the learning process. The loss function has division by the variance terms. If one of the variances gets very small, there could be numerical stability issues. A common way of dealing with these is to clip gradients. With gradient clipping, one sets the gradients to the thresholds of pre-specified desirable ranges that are unlikely to lead to problems, in the cases when the gradients do exceed such thresholds.

Please support us

Help us create more engaging and effective content and keep it free of paywalls and advertisements!

Subscribe for article updates

Stay up to date with new material for free.