(marginalizing-models)=

Automatic marginalization of discrete variables

:::{post} Jan 20, 2024 :tags: mixture model :category: intermediate, how-to :author: Rob Zinkov :::

PyMC is very amendable to sampling models with discrete latent variables. But if you insist on using the NUTS sampler exclusively, you will need to get rid of your discrete variables somehow. The best way to do this is by marginalizing them out, as then you benefit from Rao-Blackwell's theorem and get a lower variance estimate of your parameters.

Formally the argument goes like this, samplers can be understood as approximating the expectation $\mathbb{E}_{p(x, z)}[f(x, z)]$ for some function $f$ with respect to a distribution $p(x, z)$. By law of total expectation we know that

Ep(x,z)[f(x,z)]=Ep(z)[Ep(xz)[f(x,z)]]\mathbb{E}_{p(x, z)}[f(x, z)] = \mathbb{E}_{p(z)}\left[\mathbb{E}_{p(x \mid z)}\left[f(x, z)\right]\right]

Letting $g(z) = \mathbb{E}_{p(x \mid z)}\left[f(x, z)\right]$, we know by law of total variance that

Vp(x,z)[f(x,z)]=Vp(z)[g(z)]+Ep(z)[Vp(xz)[f(x,z)]]\mathbb{V}_{p(x, z)}[f(x, z)] = \mathbb{V}_{p(z)}[g(z)] + \mathbb{E}_{p(z)}\left[\mathbb{V}_{p(x \mid z)}\left[f(x, z)\right]\right]

Because the expectation is over a variance it must always be positive, and thus we know

Vp(x,z)[f(x,z)]Vp(z)[g(z)]\mathbb{V}_{p(x, z)}[f(x, z)] \geq \mathbb{V}_{p(z)}[g(z)]

Intuitively, marginalizing variables in your model lets you use $g$ instead of $f$. This lower variance manifests most directly in lower Monte-Carlo standard error (mcse), and indirectly in a generally higher effective sample size (ESS).

Unfortunately, the computation to do this is often tedious and unintuitive. Luckily, pymc-experimental now supports a way to do this work automatically!

:::{include} ../extra_installs.md :::

As a motivating example, consider a gaussian mixture model

Gaussian Mixture model

There are two ways to specify the same model. One where the choice of mixture is explicit.

The other way is where we use the built-in {class}NormalMixture <pymc.NormalMixture> distribution. Here the mixture assignment is not an explicit variable in our model. There is nothing unique about the first model other than we initialize it with {class}pmx.MarginalModel <pymc_extras.MarginalModel> instead of {class}pm.Model <pymc.model.core.Model>. This different class is what will allow us to marginalize out variables later.

We can immediately see that the marginalized model has a higher ESS. Let's now marginalize out the choice and see what it changes in our model.

As we can see, the idx variable is gone now. We also were able to use the NUTS sampler, and the ESS has improved.

But {class}MarginalModel <pymc_extras.MarginalModel> has a distinct advantage. It still knows about the discrete variables that were marginalized out, and we can obtain estimates for the posterior of idx given the other variables. We do this using the {meth}recover_marginals <pymc_extras.MarginalModel.recover_marginals> method.

This idx variable lets us recover the mixture assignment variable after running the NUTS sampler! We can split out the samples of y by reading off the mixture label from the associated idx for each sample.

One important thing to notice is that this discrete variable has a lower ESS, and particularly so for the tail. This means idx might not be estimated well particularly for the tails. If this is important, I recommend using the lp_idx instead, which is the log-probability of idx given sample values on each iteration. The benefits of working with lp_idx will explored further in the next example.

Coal mining model

The same methods work for the {ref}Coal mining <pymc:pymc_overview#case-study-2-coal-mining-disasters> switchpoint model as well. The coal mining dataset records the number of coal mining disasters in the UK between 1851 and 1962. The time series dataset captures a time when mining safety regulations are being introduced, we try to estimate when this occurred using a discrete switchpoint variable.

We will sample the model both before and after we marginalize out the switchpoint variable

As before, the ESS improved massively

Finally, let us recover the switchpoint variable

While recover_marginals is able to sample the discrete variables that were marginalized out. The probabilities associated with each draw often offer a cleaner estimate of the discrete variable. Particularly for lower probability values. This is best illustrated by comparing the histogram of the sampled values with the plot of the log-probabilities.

By plotting a histogram of sampled values instead of working with the log-probabilities directly, we are left with noisier and more incomplete exploration of the underlying discrete distribution.

Authors

References

:::{bibliography} :filter: docname in docnames :::

Watermark

:::{include} ../page_footer.md :::