(wrapping_jax_function)=

How to wrap a JAX function for use in PyMC

:::{post} Mar 24, 2022 :tags: PyTensor, hidden markov model, JAX :category: advanced, how-to :author: Ricardo Vieira :::

:::{include} ../extra_installs.md :::

Intro: PyTensor and its backends

PyMC uses the {doc}PyTensor <pytensor:index> library to create and manipulate probabilistic graphs. PyTensor is backend-agnostic, meaning it can make use of functions written in different languages or frameworks, including pure Python, NumPy, C, Cython, Numba, and JAX.

All that is needed is to encapsulate such function in a PyTensor {class}~pytensor.graph.op.Op, which enforces a specific API regarding how inputs and outputs of pure "operations" should be handled. It also implements methods for optional extra functionality like symbolic shape inference and automatic differentiation. This is well covered in the PyTensor {ref}Op documentation <pytensor:op_contract> and in our {ref}blackbox_external_likelihood_numpy pymc-example.

More recently, PyTensor became capable of compiling directly to some of these languages/frameworks, meaning that we can convert a complete PyTensor graph into a JAX or NUMBA jitted function, whereas traditionally they could only be converted to Python or C.

This has some interesting uses, such as sampling models defined in PyMC with pure JAX samplers, like those implemented in NumPyro or BlackJax.

This notebook illustrates how we can implement a new PyTensor {class}~pytensor.graph.op.Op that wraps a JAX function.

Outline

  1. We start in a similar path as that taken in the {ref}blackbox_external_likelihood_numpy, which wraps a NumPy function in a PyTensor {class}~pytensor.graph.op.Op, this time wrapping a JAX jitted function instead.
  2. We then enable PyTensor to "unwrap" the just wrapped JAX function, so that the whole graph can be compiled to JAX. We make use of this to sample our PyMC model via the JAX NumPyro NUTS sampler.

A motivating example: marginal HMM

For illustration purposes, we will simulate data following a simple Hidden Markov Model (HMM), with 3 possible latent states $S \in {0, 1, 2}$ and normal emission likelihood.

YNormal((S+1)signal,noise)Y \sim \text{Normal}((S + 1) \cdot \text{signal}, \text{noise})

Our HMM will have a fixed Categorical probability $P$ of switching across states, which depends only on the last state

St+1Categorical(PSt)S_{t+1} \sim \text{Categorical}(P_{S_t})

To complete our model, we assume a fixed probability $P_{t0}$ for each possible initial state $S_{t0}$,

St0Categorical(Pt0)S_{t0} \sim \text{Categorical}(P_{t0})

Simulating data

Let's generate data according to this model! The first step is to set some values for the parameters in our model

We write a helper function to generate a single HMM process and create our simulated data

The figure above shows the hidden state and respective observed emission of our simulated data. Later, we will use this data to perform posterior inferences about the true model parameters.

Computing the marginal HMM likelihood using JAX

We will write a JAX function to compute the likelihood of our HMM model, marginalizing over the hidden states. This allows for more efficient sampling of the remaining model parameters. To achieve this, we will use the well known forward algorithm, working on the log scale for numerical stability.

We will take advantage of JAX scan to obtain an efficient and differentiable log-likelihood, and the handy vmap to automatically vectorize this log-likelihood across multiple observed processes.

Our core JAX function computes the marginal log-likelihood of a single HMM process

Let's test it with the true parameters and the first simulated HMM process

We now use vmap to vectorize the core function across multiple observations.

Passing a row matrix with only the first simulated HMM process should return the same result

Our goal is, however, to compute the joint log-likelihood for all the simulated data

We will also ask JAX to give us the function of the gradients with respect to each input. This will come in handy later.

Let's print out the gradient with respect to emission_signal. We will check this value is unchanged after we wrap our function in PyTensor.

Wrapping the JAX function in PyTensor

Now we are ready to wrap our JAX jitted function in a PyTensor {class}~pytensor.graph.op.Op, that we can then use in our PyMC models. We recommend you check PyTensor's official {ref}Op documentation <pytensor:op_contract> if you want to understand it in more detail.

In brief, we will inherit from {class}~pytensor.graph.op.Op and define the following methods:

  1. make_node: Creates an {class}~pytensor.graph.basic.Apply node that holds together the symbolic inputs and outputs of our operation
  2. perform: Python code that returns the evaluation of our operation, given concrete input values
  3. grad: Returns a PyTensor symbolic graph that represents the gradient expression of an output cost wrt to its inputs

For the grad we will create a second {class}~pytensor.graph.op.Op that wraps our jitted grad version from above

We recommend using the debug helper eval method to confirm we specified everything correctly. We should get the same outputs as before:

It's also useful to check the gradient of our {class}~pytensor.graph.op.Op can be requested via the PyTensor grad interface:

Sampling with PyMC

We are now ready to make inferences about our HMM model with PyMC. We will define priors for each model parameter and use {class}~pymc.Potential to add the joint log-likelihood term to our model.

Before we start sampling, we check the logp of each variable at the model initial point. Bugs tend to manifest themselves in the form of nan or -inf for the initial probabilities.

We are now ready to sample!

The posteriors look reasonably centered around the true values used to generate our data.

Unwrapping the wrapped JAX function

As mentioned in the beginning, PyTensor can compile an entire graph to JAX. To do this, it needs to know how each {class}~pytensor.graph.op.Op in the graph can be converted to a JAX function. This can be done by {term}dispatch <dispatching> with {func}pytensor.link.jax.dispatch.jax_funcify. Most of the default PyTensor {class}~pytensor.graph.op.Ops already have such a dispatch function, but we will need to add a new one for our custom HMMLogpOp, as PyTensor has never seen that before.

For that we need a function which returns (another) JAX function, that performs the same computation as in our perform method. Fortunately, we started exactly with such function, so this amounts to 3 short lines of code.

:::{note} We do not return the jitted function, so that the entire PyTensor graph can be jitted together after being converted to JAX. :::

For a better understanding of {class}~pytensor.graph.op.Op JAX conversions, we recommend reading PyTensor's {doc}Adding JAX and Numba support for Ops guide <pytensor:extending/creating_a_numba_jax_op>.

We can test that our conversion function is working properly by compiling a {func}pytensor.function with mode="JAX":

We can also compile a JAX function that computes the log probability of each variable in our PyMC model, similar to {meth}~pymc.Model.point_logps. We will use the helper method {meth}~pymc.Model.compile_fn.

Note that we could have added an equally simple function to convert our HMMLogpGradOp, in case we wanted to convert PyTensor gradient graphs to JAX. In our case, we don't need to do this because we will rely on JAX grad function (or more precisely, NumPyro will rely on it) to obtain these again from our compiled JAX function.

We include a {ref}short discussion <pytensor_vs_jax> at the end of this document, to help you better understand the trade-offs between working with PyTensor graphs vs JAX functions, and when you might want to use one or the other.

Sampling with NumPyro

Now that we know our model logp can be entirely compiled to JAX, we can use the handy {func}pymc.sampling_jax.sample_numpyro_nuts to sample our model using the pure JAX sampler implemented in NumPyro.

As expected, sampling results look pretty similar!

Depending on the model and computer architecture you are using, a pure JAX sampler can provide considerable speedups.

(pytensor_vs_jax)=

Some brief notes on using PyTensor vs JAX

When should you use JAX?

As we have seen, it is pretty straightforward to interface between PyTensor graphs and JAX functions.

This can be very handy when you want to combine previously implemented JAX function with PyMC models. We used a marginalized HMM log-likelihood in this example, but the same strategy could be used to do Bayesian inference with Deep Neural Networks or Differential Equations, or pretty much any other functions implemented in JAX that can be used in the context of a Bayesian model.

It can also be worth it, if you need to make use of JAX's unique features like vectorization, support for tree structures, or its fine-grained parallelization, and GPU and TPU capabilities.

When should you not use JAX?

Like JAX, PyTensor has the goal of mimicking the NumPy and Scipy APIs, so that writing code in PyTensor should feel very similar to how code is written in those libraries.

There are, however, some of advantages to working with PyTensor:

  1. PyTensor graphs are considerably easier to {ref}inspect and debug <pytensor:debug_faq> than JAX functions
  2. PyTensor has clever {ref}optimization and stabilization routines <pytensor:optimizations> that are not possible or implemented in JAX
  3. PyTensor graphs can be easily {ref}manipulated after creation <pytensor:graph_rewriting>

Point 2 means your graphs are likely to perform better if written in PyTensor. In general you don't have to worry about using specialized functions like log1p or logsumexp, as PyTensor will be able to detect the equivalent naive expressions and replace them by their specialized counterparts. Importantly, you still benefit from these optimizations when your graph is later compiled to JAX.

The catch is that PyTensor cannot reason about JAX functions, and by association {class}~pytensor.graph.op.Ops that wrap them. This means that the larger the portion of the graph is "hidden" inside a JAX function, the less a user will benefit from PyTensor's rewrite and debugging abilities.

Point 3 is more important for library developers. It is the main reason why PyMC developers opted to use PyTensor (and before that, its predecessor Theano) as its backend. Many of the user-facing utilities provided by PyMC rely on the ability to easily parse and manipulate PyTensor graphs.

Bonus: Using a single Op that can compute its own gradients

We had to create two {class}~pytensor.graph.op.Ops, one for the function we cared about and a separate one for its gradients. However, JAX provides a value_and_grad utility that can return both the value of a function and its gradients. We can do something similar and get away with a single {class}~pytensor.graph.op.Op if we are clever about it.

By doing this we can (potentially) save memory and reuse computation that is shared between the function and its gradients. This may be relevant when working with very large JAX functions.

Note that this is only useful if you are interested in taking gradients with respect to your {class}~pytensor.graph.op.Op using PyTensor. If your endgoal is to compile your graph to JAX, and only then take the gradients (as NumPyro does), then it's better to use the first approach. You don't even need to implement the grad method and associated {class}~pytensor.graph.op.Op in that case.

We check again that we can take the gradient using PyTensor grad interface

Authors

Authored by Ricardo Vieira in March 24, 2022 (pymc-examples#302)

Watermark

:::{include} ../page_footer.md :::