(faster_sampling_notebook)=

Faster Sampling with JAX and Numba

:::{post} July 11, 2023 :tags: hierarchical model, JAX, numba, scaling :category: reference, intermediate :author: Thomas Wiecki :::

PyMC can compile its models to various execution backends through PyTensor, including:

  • C
  • JAX
  • Numba

By default, PyMC is using the C backend which then gets called by the Python-based samplers.

However, by compiling to other backends, we can use samplers written in other languages than Python that call the PyMC model without any Python-overhead.

For the JAX backend there is the NumPyro and BlackJAX NUTS sampler available. To use these samplers, you have to install numpyro and blackjax. Both of them are available through conda/mamba: mamba install -c conda-forge numpyro blackjax.

For the Numba backend, there is the Nutpie sampler written in Rust. To use this sampler you need nutpie installed: mamba install -c conda-forge nutpie.

We will use a simple probabilistic PCA model as our example.

Sampling using Python NUTS sampler

Sampling using NumPyro JAX NUTS sampler

Sampling using BlackJAX NUTS sampler

Sampling using Nutpie Rust NUTS sampler

Authors

Authored by Thomas Wiecki in July 2023

:::{include} ../page_footer.md :::