Use with Oryx models#

Oryx is a probabilistic programming library written in JAX, it is thus natively compatible with Blackjax. In this notebook we will show how we can use Oryx as a modeling language together with Blackjax as an inference library.

We reproduce the example in Oryx’s documentation and train a Bayesian Neural Network (BNN) on the iris dataset:

from sklearn import datasets

iris = datasets.load_iris()
features, labels = iris['data'], iris['target']
num_features = features.shape[-1]
num_classes = len(iris.target_names)
Hide code cell source
print(f"Number of features: {num_features}")
print(f"Number of classes: {num_classes}")
print(f"Number of data points: {features.shape[0]}")
Number of features: 4
Number of classes: 3
Number of data points: 150

Oryx’s approach, like Aesara’s, is to implement probabilistic models as generative models and then apply transformations to get the log-probability density function. We begin with implementing a dense layer with normal prior probability on the weights and use the function random_variable to define random variables:

import jax
from oryx.core.ppl import random_variable


from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions


def dense(dim_out, activation=jax.nn.relu):

    def forward(key, x):
        dim_in = x.shape[-1]
        w_key, b_key = jax.random.split(key)
        w = random_variable(
            tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out, dim_in)),
            name='w'
        )(w_key)
        b = random_variable(
            tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out,)),
            name='b'
        )(b_key)

        return activation(jnp.dot(w, x) + b)

    return forward
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[3], line 2
      1 import jax
----> 2 from oryx.core.ppl import random_variable
      5 from tensorflow_probability.substrates import jax as tfp
      6 tfd = tfp.distributions

File ~/checkouts/readthedocs.org/user_builds/blackjax/envs/latest/lib/python3.8/site-packages/oryx/__init__.py:16
      1 # Copyright 2020 The TensorFlow Probability Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ============================================================================
     15 """Oryx is a neural network mini-library built on top of Jax."""
---> 16 from oryx import bijectors
     17 from oryx import core
     18 from oryx import distributions

File ~/checkouts/readthedocs.org/user_builds/blackjax/envs/latest/lib/python3.8/site-packages/oryx/bijectors/__init__.py:18
     15 """Module for probability bijectors and related functions."""
     16 import inspect
---> 18 from tensorflow_probability.python.experimental.substrates import jax as tfp
     19 from oryx.bijectors import bijector_extensions
     21 tfb = tfp.bijectors

File ~/checkouts/readthedocs.org/user_builds/blackjax/envs/latest/lib/python3.8/site-packages/tensorflow_probability/python/experimental/__init__.py:31
     15 """TensorFlow Probability API-unstable package.
     16 
     17 This package contains potentially useful code which is under active development
   (...)
     27 You are welcome to try any of this out (and tell us how well it works for you!).
     28 """
     30 from tensorflow_probability.python.experimental import auto_batching
---> 31 from tensorflow_probability.python.experimental import bayesopt
     32 from tensorflow_probability.python.experimental import bijectors
     33 from tensorflow_probability.python.experimental import distribute

File ~/checkouts/readthedocs.org/user_builds/blackjax/envs/latest/lib/python3.8/site-packages/tensorflow_probability/python/experimental/bayesopt/__init__.py:17
      1 # Copyright 2023 The TensorFlow Probability Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ============================================================================
     15 """TensorFlow Probability experimental Bayesopt package."""
---> 17 from tensorflow_probability.python.experimental.bayesopt import acquisition
     18 from tensorflow_probability.python.internal import all_util
     20 _allowed_symbols = [
     21     'acquisition',
     22 ]

File ~/checkouts/readthedocs.org/user_builds/blackjax/envs/latest/lib/python3.8/site-packages/tensorflow_probability/python/experimental/bayesopt/acquisition/__init__.py:17
      1 # Copyright 2023 The TensorFlow Probability Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ============================================================================
     15 """Acquisition Functions."""
---> 17 from tensorflow_probability.python.experimental.bayesopt.acquisition.acquisition_function import AcquisitionFunction
     18 from tensorflow_probability.python.experimental.bayesopt.acquisition.acquisition_function import MCMCReducer
     19 from tensorflow_probability.python.experimental.bayesopt.acquisition.expected_improvement import GaussianProcessExpectedImprovement

File ~/checkouts/readthedocs.org/user_builds/blackjax/envs/latest/lib/python3.8/site-packages/tensorflow_probability/python/experimental/bayesopt/acquisition/acquisition_function.py:22
     19 import tensorflow.compat.v2 as tf
     21 from tensorflow_probability.python.internal import dtype_util
---> 22 from tensorflow_probability.python.internal import prefer_static as ps
     23 from tensorflow_probability.python.internal import tensor_util
     26 class AcquisitionFunction(object, metaclass=abc.ABCMeta):

File ~/checkouts/readthedocs.org/user_builds/blackjax/envs/latest/lib/python3.8/site-packages/tensorflow_probability/python/internal/prefer_static.py:30
     28 from tensorflow.python.framework import ops  # pylint: disable=g-direct-tensorflow-import
     29 from tensorflow.python.framework import tensor_util  # pylint: disable=g-direct-tensorflow-import
---> 30 from tensorflow.python.ops import control_flow_case  # pylint: disable=g-direct-tensorflow-import
     31 from tensorflow.python.util import tf_inspect  # pylint: disable=g-direct-tensorflow-import
     33 JAX_MODE = False

ImportError: cannot import name 'control_flow_case' from 'tensorflow.python.ops' (/home/docs/checkouts/readthedocs.org/user_builds/blackjax/envs/latest/lib/python3.8/site-packages/tensorflow/python/ops/__init__.py)

We now use this layer to build a multi-layer perceptron. The nest function is used to create “scope tags” that allows in this context to re-use our dense layer multiple times without name collision in the dictionary that will contain the parameters:

from oryx.core.ppl import nest


def mlp(hidden_sizes, num_classes):
    num_hidden = len(hidden_sizes)

    def forward(key, x):
        keys = jax.random.split(key, num_hidden + 1)
        for i, (subkey, hidden_size) in enumerate(zip(keys[:-1], hidden_sizes)):
            x = nest(dense(hidden_size), scope=f'layer_{i + 1}')(subkey, x)
        logits = nest(dense(num_classes, activation=lambda x: x),
                        scope=f'layer_{num_hidden + 1}')(keys[-1], x)
        return logits

    return forward

Finally, we model the labels as categorical random variables:

import functools

def predict(mlp):
    def forward(key, xs):
        mlp_key, label_key = jax.random.split(key)
        logits = jax.vmap(functools.partial(mlp, mlp_key))(xs)
        return random_variable(
            tfd.Independent(tfd.Categorical(logits=logits), 1), name='y')(label_key)

    return forward

We can now build the BNN and sample an initial position for the inference algorithm using joint_sample:

import jax.numpy as jnp
from oryx.core.ppl import joint_sample


bnn = mlp([50, 50], num_classes)
initial_weights = joint_sample(bnn)(jax.random.PRNGKey(0), jnp.ones(num_features))

print(initial_weights.keys())
Hide code cell source
num_parameters = sum([layer.size for layer in jax.tree_util.tree_flatten(initial_weights)[0]])
print(f"Number of parameters in the model: {num_parameters}")

To sample from this model we will need to obtain its joint distribution log-probability using joint_log_prob:

from oryx.core.ppl import joint_log_prob

def logdensity_fn(weights):
  return joint_log_prob(predict(bnn))(dict(weights, y=labels), features)

We can now run the window adaptation to get good values for the parameters of the NUTS algorithm:

%%time
import blackjax

rng_key = jax.random.PRNGKey(0)
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
(last_state, parameters), _ = adapt.run(rng_key, initial_weights, 100)
kernel = blackjax.nuts(logdensity_fn, **parameters).step

and sample from the model’s posterior distribution:

Hide code cell content
def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, infos
%%time

states, infos = inference_loop(rng_key, kernel, last_state, 100)

We can now use our samples to take an estimate of the accuracy that is averaged over the posterior distribution. We use intervene to “inject” the posterior values of the weights instead of sampling from the prior distribution:

from oryx.core.ppl import intervene

posterior_weights = states.position

output_logits = jax.vmap(
    lambda weights: jax.vmap(lambda x: intervene(bnn, **weights)(
        jax.random.PRNGKey(0), x)
    )(features)
)(posterior_weights)

output_probs = jax.nn.softmax(output_logits)
Hide code cell source
print('Average sample accuracy:', (
    output_probs.argmax(axis=-1) == labels[None]).mean())

print('BMA accuracy:', (
    output_probs.mean(axis=0).argmax(axis=-1) == labels[None]).mean())