(dirichlet_mixture_of_multinomials)=
:::{post} Jan 8, 2022 :tags: mixture model, :category: advanced :author: Byron J. Smith, Abhipsha Das, Oriol Abril-Pla :::
This example notebook demonstrates the use of a Dirichlet mixture of multinomials (a.k.a Dirichlet-multinomial or DM) to model categorical count data. Models like this one are important in a variety of areas, including natural language processing, ecology, bioinformatics, and more.
The Dirichlet-multinomial can be understood as draws from a Multinomial distribution where each sample has a slightly different probability vector, which is itself drawn from a common Dirichlet distribution. This contrasts with the Multinomial distribution, which assumes that all observations arise from a single fixed probability vector. This enables the Dirichlet-multinomial to accommodate more variable (a.k.a, over-dispersed) count data than the Multinomial.
Other examples of over-dispersed count distributions are the Beta-binomial (which can be thought of as a special case of the DM) or the Negative binomial distributions.
The DM is also an example of marginalizing a mixture distribution over its latent parameters. This notebook will demonstrate the performance benefits that come from taking that approach.
Let us simulate some over-dispersed, categorical count data for this example.
Here we are simulating from the DM distribution itself, so it is perhaps tautological to fit that model, but rest assured that data like these really do appear in the counts of different:
madsen2005modelingdirichlet,nowicka2016drimseq,goodhardt1984thedirichlet.Here we will discuss a community ecology example, pretending that we have observed counts of $k=5$ different tree species in $n=10$ different forests.
Our simulation will produce a two-dimensional matrix of integers (counts) where each row, (zero-)indexed by $i \in (0...n-1)$, is an observation (different forest), and each column $j \in (0...k-1)$ is a category (tree species). We'll parameterize this distribution with three things:
Here, and throughout this notebook, we've used a convenient reparameterization of the Dirichlet distribution from one to two parameters, $\alpha=\mathrm{conc} \times \mathrm{frac}$, as this fits our desired interpretation.
Each observation from the DM is simulated by:
Notice that each observation gets its own latent parameter $p_i$, simulated independently from a common Dirichlet distribution.
The first model that we will fit to these data is a plain multinomial model, where the only parameter is the expected fraction of each category, $\mathrm{frac}$, which we will give a Dirichlet prior. While the uniform prior ($\alpha_j=1$ for each $j$) works well, if we have independent beliefs about the fraction of each tree, we could encode this into our prior, e.g. increasing the value of $\alpha_j$ where we expect a higher fraction of species-$j$.
The trace plots look fairly good; visually, each parameter appears to be moving around the posterior well.
Likewise, diagnostics in the parameter summary table all look fine. Here we've added a column estimating the effective sample size per second of sampling.
Here we've drawn a forest-plot, showing the mean and 94% HDIs from our posterior approximation. Interestingly, because we know what the underlying frequencies are for each species (dashed lines), we can comment on the accuracy of our inferences. And now the issues with our model become apparent; notice that the 94% HDIs don't include the true values for tree species 0, 1, 3. We might have seen one HDI miss, but three???
...what's going on?
Let's troubleshoot this model using a posterior-predictive check, comparing our data to simulated data conditioned on our posterior estimates.
Here we're plotting histograms of the predicted counts against the observed counts for each species.
(Notice that the y-axis isn't full height and clips the distributions for species mahogany in purple.)
And now we can start to see why our posterior HDI deviates from the true parameters for three of five species (vertical lines).
See that for all of the species the observed counts are frequently quite far from the predictions
conditioned on the posterior distribution.
This is particularly obvious for (e.g.) oak where we have one observation of more than 30
trees of this species, despite the posterior predicitive mass being concentrated far below that.
This is overdispersion at work, and a clear sign that we need to adjust our model to accommodate it.
Posterior predictive checks are one of the best ways to diagnose model misspecification, and this example is no different.
Let's go ahead and model our data using the DM distribution.
For this model we'll keep the same prior on the expected frequencies of each species, $\mathrm{frac}$. We'll also add a strictly positive parameter, $\mathrm{conc}$, for the concentration.
In this iteration of our model we'll explicitly include the latent multinomial probability, $p_i$, modeling the $\mathrm{true_p}_i$ from our simulations (which we would not observe in the real world).
Compare this diagram to the first. Here the latent, Dirichlet distributed $p$ separates the multinomial from the expected frequencies, $\mathrm{frac}$, accounting for overdispersion of counts relative to the simple multinomial model.
Here we had to increase target_accept from 0.8 to 0.9 to not get drowned in divergences.
We also got a warning about the rhat statistic, although we'll ignore it for now.
More interesting is how much longer it took to sample this model than the first.
This is partly because our model has an additional ~$(n \times k)$ parameters,
but it seems like there are other geometric challenges for NUTS as well.
We'll see if we can fix these in the next model, but for now let's take a look at the traces.
The divergences seem to occur when the estimated fraction of the rare species (mahogany) is very close to zero.
On the other hand, since we know the ground-truth for $\mathrm{frac}$, we can congratulate ourselves that the HDIs include the true values for all of our species!
Modeling this mixture has made our inferences robust to the overdispersion of counts, while the plain multinomial is very sensitive. Notice that the HDI is much wider than before for each $\mathrm{frac}_i$. In this case that makes the difference between correct and incorrect inferences.
This is great, but we can do better.
The slightly too large $\hat{R}$ value for frac[mahogany] is a bit concerning, and it's surprising
that our $\mathrm{ESS} ; \mathrm{sec}^{-1}$ is quite small.
Happily, the Dirichlet distribution is conjugate to the multinomial and therefore there's a convenient, closed-form for the marginalized distribution, i.e. the Dirichlet-multinomial distribution, which was added to PyMC in 3.11.0.
Let's take advantage of this, marginalizing out the explicit latent parameter, $p_i$, replacing the combination of this node and the multinomial with the DM to make an equivalent model.
The plate diagram shows that we've collapsed what had been the latent Dirichlet and the multinomial nodes together into a single DM node.
It samples much more quickly and without any of the warnings from before!
Trace plots look fuzzy and KDEs are clean.
We see that $\hat{R}$ is close to $1$ everywhere and $\mathrm{ESS} ; \mathrm{sec}^{-1}$ is much higher. Our reparameterization (marginalization) has greatly improved the sampling! (And, thankfully, the HDIs look similar to the other model.)
This all looks very good, but what if we didn't have the ground-truth?
Posterior predictive checks to the rescue (again)!
(Notice, again, that the y-axis isn't full height, and clips the distributions for mahogany in purple.)
Compared to the multinomial (plots on the right), PPCs for the DM (left) show that the observed data is an entirely reasonable realization of our model. This is great news!
Let's go a step further and try to put a number on how much better our DM model is relative to the raw multinomial. We'll use leave-one-out cross validation to compare the out-of-sample predictive ability of the two.
Unsurprisingly, the DM outclasses the multinomial by a mile, assigning a weight of 100% to the over-dispersed model.
While the warning=True flag for the multinomial distribution indicates that the numerical value cannot be fully trusted, the large difference in elpd_loo is further confirmation that between the two, the DM should be greatly favored for prediction, parameter inference, etc.
Obviously the DM is not a perfect model in every case, but it is often a better choice than the multinomial, much more robust while taking on just one additional parameter.
There are a number of shortcomings to the DM that we should keep in mind when selecting a model. The biggest problem is that, while more flexible than the multinomial, the DM still ignores the possibility of underlying correlations between categories. If one of our tree species relies on another, for instance, the model we've used here will not effectively account for this. In that case, swapping the vanilla Dirichlet distribution for something fancier (e.g. the Generalized Dirichlet or Logistic-Multivariate Normal) may be worth considering.
:::{bibliography} :filter: docname in docnames :::
:::{include} page_footer.md :::