Use with PyMC models#
BlackJAX can take any log-probability function as long as it is compatible with JAX’s primitives. In this notebook we show how we can use PyMC as a modeling language and BlackJAX as an inference library.
Before you start
You will need PyMC to run this example. Please follow the installation instructions on PyMC’s repository.
We will reproduce the Eight School example from the TFP documentation. Follow the link for a description of the problem and the model that is used.
Show code cell content
import numpy as np
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
We implement the non-centered version of the hierarchical model:
import pymc as pm
with pm.Model() as model:
mu = pm.Normal("mu", mu=0.0, sigma=10.0)
tau = pm.HalfCauchy("tau", 5.0)
theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
theta_1 = mu + tau * theta
obs = pm.Normal("obs", mu=theta_1, sigma=sigma, shape=J, observed=y)
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
We need to translate the model into a log-probability function that will be used by Blackjax to perform inference. For that we use the get_jaxified_logp function in PyMC’s internals.
from pymc.sampling_jax import get_jaxified_logp
rvs = [rv.name for rv in model.value_vars]
logdensity_fn = get_jaxified_logp(model)
/home/docs/checkouts/readthedocs.org/user_builds/blackjax/envs/latest/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
We can now run the window adaptation for the NUTS sampler:
import blackjax
import jax
# Get the initial position from PyMC
init_position_dict = model.initial_point()
init_position = [init_position_dict[rv] for rv in rvs]
rng_key = jax.random.PRNGKey(1234)
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
(last_state, parameters), _ = adapt.run(rng_key, init_position, 1000)
kernel = blackjax.nuts(logdensity_fn, **parameters).step
Let us now perform inference with the tuned kernel:
Show code cell content
def inference_loop(rng_key, kernel, initial_state, num_samples):
def one_step(state, rng_key):
state, info = kernel(rng_key, state)
return state, (state, info)
keys = jax.random.split(rng_key, num_samples)
_, (states, infos) = jax.lax.scan(one_step, initial_state, keys)
return states, infos
states, infos = inference_loop(rng_key, kernel, last_state, 50_000)