Predicting the Unpredictable 🔮. The Magic of Mixture Density Networks… | by Miguel Dias, PhD | May, 2024

Editor
6 Min Read


MDNs take your boring old neural network and turn it into a prediction powerhouse. Why settle for one prediction when you can have an entire buffet of potential outcomes?

If life throws complex, unpredictable scenarios your way, MDNs are ready with a probability-laden safety net.

The Core Idea

In a MDN, the probability density of the target variable t given the input x is represented as a linear combination of kernel functions, typically Gaussian functions, though not limited to. In math speak:

Where 𝛼ᵢ(x) are the mixing coefficients, and who doesn’t love a good mix, am I right? 🎛️ These determine how much weight each component 𝜙ᵢ(t|x) — each Gaussian in our case — holds in the model.

Brewing the Gaussians ☕

Each Gaussian component 𝜙ᵢ(t|x) has its own mean 𝜇ᵢ(x) and variance 𝜎².

Mixing It Up 🎧 with Coefficients

The mixing coefficients 𝛼 are crucial as they balance the influence of each Gaussian component, governed by a softmax function to ensure they sum up to 1:

Magical Parameters ✨ Means & Variances

Means 𝜇 and variances 𝜎² define each Gaussian. And guess what? Variances have to be positive! We achieve this by using the exponential of the network outputs:

Alright, so how do we train this beast? Well, it’s all about maximizing the likelihood of our observed data. Fancy terms, I know. Let’s see it in action.

The Log-Likelihood Spell ✨

The likelihood of our data under the MDN model is the product of the probabilities assigned to each data point. In math speak:

This basically says, “Hey, what’s the chance we got this data given our model?”. But products can get messy, so we take the log (because math loves logs), which turns our product into a sum:

Now, here’s the kicker: we actually want to minimize the negative log likelihood because our optimization algorithms like to minimize things. So, plugging in the definition of p(t|x), the error function we actually minimize is:

This formula might look intimidating, but it’s just saying we sum up the log probabilities across all data points, then throw in a negative sign because minimization is our jam.

Now here’s how to translate our wizardry into Python, and you can find the full code here:

The Loss Function

def mdn_loss(alpha, sigma, mu, target, eps=1e-8):
target = target.unsqueeze(1).expand_as(mu)
m = torch.distributions.Normal(loc=mu, scale=sigma)
log_prob = m.log_prob(target)
log_prob = log_prob.sum(dim=2)
log_alpha = torch.log(alpha + eps) # Avoid log(0) disaster
loss = -torch.logsumexp(log_alpha + log_prob, dim=1)
return loss.mean()

Here’s the breakdown:

  1. target = target.unsqueeze(1).expand_as(mu): Expand the target to match the shape of mu.
  2. m = torch.distributions.Normal(loc=mu, scale=sigma): Create a normal distribution.
  3. log_prob = m.log_prob(target): Calculate the log probability.
  4. log_prob = log_prob.sum(dim=2): Sum log probabilities.
  5. log_alpha = torch.log(alpha + eps): Calculate log of mixing coefficients.
  6. loss = -torch.logsumexp(log_alpha + log_prob, dim=1): Combine and log-sum-exp the probabilities.
  7. return loss.mean(): Return the average loss.

The Neural Network

Let’s create a neural network that’s all set to handle the wizardry:

class MDN(nn.Module):
def __init__(self, input_dim, output_dim, num_hidden, num_mixtures):
super(MDN, self).__init__()
self.hidden = nn.Sequential(
nn.Linear(input_dim, num_hidden),
nn.Tanh(),
nn.Linear(num_hidden, num_hidden),
nn.Tanh(),
)
self.z_alpha = nn.Linear(num_hidden, num_mixtures)
self.z_sigma = nn.Linear(num_hidden, num_mixtures * output_dim)
self.z_mu = nn.Linear(num_hidden, num_mixtures * output_dim)
self.num_mixtures = num_mixtures
self.output_dim = output_dim

def forward(self, x):
hidden = self.hidden(x)
alpha = F.softmax(self.z_alpha(hidden), dim=-1)
sigma = torch.exp(self.z_sigma(hidden)).view(-1, self.num_mixtures, self.output_dim)
mu = self.z_mu(hidden).view(-1, self.num_mixtures, self.output_dim)
return alpha, sigma, mu

Notice the softmax being applied to 𝛼alpha = F.softmax(self.z_alpha(hidden), dim=-1), so they sum up to 1, and the exponential to 𝜎 sigma = torch.exp(self.z_sigma(hidden)).view(-1, self.num_mixtures, self.output_dim), to ensure they remain positive, as explained earlier.

The Prediction

Getting predictions from MDNs is a bit of a trick. Here’s how you sample from the mixture model:

def get_sample_preds(alpha, sigma, mu, samples=10):
N, K, T = mu.shape
sampled_preds = torch.zeros(N, samples, T)
uniform_samples = torch.rand(N, samples)
cum_alpha = alpha.cumsum(dim=1)
for i, j in itertools.product(range(N), range(samples)):
u = uniform_samples[i, j]
k = torch.searchsorted(cum_alpha[i], u).item()
sampled_preds[i, j] = torch.normal(mu[i, k], sigma[i, k])
return sampled_preds

Here’s the breakdown:

  1. N, K, T = mu.shape: Get the number of data points, mixture components, and output dimensions.
  2. sampled_preds = torch.zeros(N, samples, T): Initialize the tensor to store sampled predictions.
  3. uniform_samples = torch.rand(N, samples): Generate uniform random numbers for sampling.
  4. cum_alpha = alpha.cumsum(dim=1): Compute the cumulative sum of mixture weights.
  5. for i, j in itertools.product(range(N), range(samples)): Loop over each combination of data points and samples.
  6. u = uniform_samples[i, j]: Get a random number for the current sample.
  7. k = torch.searchsorted(cum_alpha[i], u).item(): Find the mixture component index.
  8. sampled_preds[i, j] = torch.normal(mu[i, k], sigma[i, k]): Sample from the selected Gaussian component.
  9. return sampled_preds: Return the tensor of sampled predictions.

Let’s apply MDNs to predict ‘Apparent Temperature’ using a simple Weather Dataset. I trained an MDN with a 50-hidden-layer network, and guess what? It rocks! 🎸

Find the full code here. Here are some results:

Share this Article
Please enter CoinGecko Free Api Key to get this plugin works.