Use a logdensity function that is not compatible with JAX’s primitives#
We obviously recommend to use Blackjax with log-probability functions that are compatible with JAX’s primitives. These can be built manually or with Aesara, Numpyro, Oryx, PyMC, TensorFlow-Probability.
Nevertheless, you may have a good reason to use a function that is incompatible with JAX’s primitives, whether it is for performance reasons or for compatiblity with an already-implemented model. Who are we to judge?
In this example we will show you how this can be done using JAX’s experimental host_callback API, and hint at a faster solution.
Aesara model compiled to Numba#
The following example builds a logdensity function with Aesara, compiles it with Numba and uses Blackjax to sample from the posterior distribution of the model.
import aesara.tensor as at
import numpy as np
srng = at.random.RandomStream(0)
loc = np.array([-2, 0, 3.2, 2.5])
scale = np.array([1.2, 1, 5, 2.8])
weights = np.array([0.2, 0.3, 0.1, 0.4])
N_rv = srng.normal(loc, scale, name="N")
I_rv = srng.categorical(weights, name="I")
Y_rv = N_rv[I_rv]
WARNING (aesara.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
We can sample from the prior predictive distribution to make sure the model is correctly implemented:
import aesara
sampling_fn = aesara.function((), Y_rv)
print(sampling_fn())
print(sampling_fn())
2.516455713134264
0.1609480326942554
We do not care about the posterior distribution of the indicator variable I_rv so we marginalize it out, and subsequently build the logdensity’s graph:
from aeppl import joint_logprob
y_vv = Y_rv.clone()
i_vv = I_rv.clone()
logdensity = []
for i in range(4):
i_vv = at.as_tensor(i, dtype="int64")
component_logdensity, _ = joint_logprob(realized={Y_rv: y_vv, I_rv: i_vv})
logdensity.append(component_logdensity)
logdensity = at.stack(logdensity, axis=0)
total_logdensity = at.logsumexp(at.log(weights) + logdensity)
We are now ready to compile the logdensity to Numba:
logdensity_fn = aesara.function((y_vv,), total_logdensity, mode="NUMBA")
logdensity_fn(1.)
array(-3.15039347)
As is we cannot use these functions within jit-compiled functions written with JAX, or apply jax.grad to get the function’s gradients:
try:
jax.jit(logdensity_fn)(1.)
except Exception:
print("JAX raised an exception while jit-compiling!")
try:
jax.grad(logdensity_fn)(1.)
except Exception:
print("JAX raised an exception while differentiating!")
JAX raised an exception while jit-compiling!
JAX raised an exception while differentiating!
Indeed, a function written with Numba is incompatible with JAX’s primitives. Luckily Aesara can build the model’s gradient graph and compile it to Numba as well:
total_logdensity_grad = at.grad(total_logdensity, y_vv)
logdensity_grad_fn = aesara.function((y_vv,), total_logdensity_grad, mode="NUMBA")
logdensity_grad_fn(1.)
array(-0.44711512)
Use jax.experimental.host_callback to call Numba functions#
In order to be able to call logdensity_fn within JAX, we need to define a function that will call it via JAX’s host_callback. Yet, this wrapper function is not differentiable with JAX, and so we will also need to define this functions’ custom_vjp, and use host_callback to call the gradient-computing function as well:
import jax
import jax.experimental.host_callback as hcb
@jax.custom_vjp
def numba_logpdf(arg):
return hcb.call(lambda x: logdensity_fn(x).item(), arg, result_shape=arg)
def call_grad(arg):
return hcb.call(lambda x: logdensity_grad_fn(x).item(), arg, result_shape=arg)
def vjp_fwd(arg):
return numba_logpdf(arg), call_grad(arg)
def vjp_bwd(grad_x, y_bar):
return (grad_x * y_bar,)
numba_logpdf.defvjp(vjp_fwd, vjp_bwd)
And we can now call the function from a jitted function and apply jax.grad without JAX complaining:
jax.jit(numba_logpdf)(1.), jax.grad(numba_logpdf)(1.)
(Array(-3.1503935, dtype=float32), Array(-0.44711512, dtype=float32))
And use Blackjax’s NUTS sampler to sample from the model’s posterior distribution:
import blackjax
inverse_mass_matrix = np.ones(1)
step_size=1e-3
nuts = blackjax.nuts(numba_logpdf, step_size, inverse_mass_matrix)
init = nuts.init(0.)
rng_key = jax.random.PRNGKey(0)
state, info = nuts.step(rng_key, init)
for _ in range(10):
rng_key, nuts_key = jax.random.split(rng_key)
state, _ = nuts.step(nuts_key, state)
print(state)
HMCState(position=Array(2.909098, dtype=float32, weak_type=True), logdensity=Array(-3.7348793, dtype=float32), logdensity_grad=Array(-0.11291742, dtype=float32))
If you run this on your machine you will notice that this runs quite slowly compared to a pure-JAX equivalent, that’s because host_callback implied a lot of back-and-forth with Python. To see this let’s compare execution times between pure Numba on the one hand:
%%time
for _ in range(100_000):
logdensity_fn(100)
CPU times: user 4.03 s, sys: 7.98 ms, total: 4.04 s
Wall time: 4.04 s
And JAX on the other hand, with 100 times less iterations:
%%time
for _ in range(1_000):
numba_logpdf(100.)
CPU times: user 12.6 s, sys: 128 ms, total: 12.8 s
Wall time: 12.7 s
That’s a lot of overhead!
So while the implementation is simple considering what we’re trying to achieve, it is only recommended for workloads where most of the time is spent evaluating the logdensity and its gradient, and where this overhead becomes irrelevant.
Use custom XLA calls to call Numba functions faster#
To avoid this kind overhead we can use an XLA custom call to execute Numba functions so there is no callback to Python in loops. Writing a function that performs such custom calls given a Numba function is a bit out of scope for this tutorial, but you can get inspiration from jax-triton to implement a custom call to a Numba function. You will also need to register a custom vjp, but you already know how to do that.