Source code for blackjax.smc.base
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple, Optional, Tuple
import jax
import jax.numpy as jnp
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey
[docs]class SMCState(NamedTuple):
"""State of the SMC sampler"""
[docs] particles: ArrayTree
[docs]class SMCInfo(NamedTuple):
"""Additional information on the tempered SMC step.
proposals: PyTree
The particles that were proposed by the MCMC pass.
ancestors: Array
The index of the particles proposed by the MCMC pass that were selected
by the resampling step.
log_likelihood_increment: float
The log-likelihood increment due to the current step of the SMC algorithm.
"""
[docs] log_likelihood_increment: float
[docs] update_info: NamedTuple
[docs]def init(particles: ArrayLikeTree):
# Infer the number of particles from the size of the leading dimension of
# the first leaf of the inputted PyTree.
num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
weights = jnp.ones(num_particles) / num_particles
return SMCState(particles, weights)
[docs]def step(
rng_key: PRNGKey,
state: SMCState,
update_fn: Callable,
weigh_fn: Callable,
resample_fn: Callable,
num_resampled: Optional[int] = None,
) -> Tuple[SMCState, SMCInfo]:
"""General SMC sampling step.
`update_fn` here corresponds to the Markov kernel $M_{t+1}$, and `weigh_fn`
corresponds to the potential function $G_t$. We first use `update_fn` to
generate new particles from the current ones, weigh these particles using
`weigh_fn` and resample them with `resample_fn`.
The `update_fn` and `weigh_fn` functions must be batched by the called either
using `jax.vmap` or `jax.pmap`.
In Feynman-Kac terms, the algorithm goes roughly as follows:
.. code::
M_t: update_fn
G_t: weigh_fn
R_t: resample_fn
idx = R_t(weights)
x_t = x_tm1[idx]
x_{t+1} = M_t(x_t)
weights = G_t(x_{t+1})
Parameters
----------
rng_key
Key used to generate pseudo-random numbers.
state
Current state of the SMC sampler: particles and their respective
log-weights
update_fn
Function that takes an array of keys and particles and returns
new particles.
weigh_fn
Function that assigns a weight to the particles.
resample_fn
Function that resamples the particles.
num_resampled
The number of particles to resample. This can be used to implement
Waste-Free SMC :cite:p:`dau2020waste`, in which case we resample a number :math:`M<N`
of particles, and the update function is in charge of returning
:math:`N` samples.
Returns
-------
new_particles
An array that contains the new particles generated by this SMC step.
info
An `SMCInfo` object that contains extra information about the SMC
transition.
"""
updating_key, resampling_key = jax.random.split(rng_key, 2)
num_particles = state.weights.shape[0]
if num_resampled is None:
num_resampled = num_particles
resampling_idx = resample_fn(resampling_key, state.weights, num_resampled)
particles = jax.tree_map(lambda x: x[resampling_idx], state.particles)
keys = jax.random.split(updating_key, num_resampled)
particles, update_info = update_fn(keys, particles)
log_weights = weigh_fn(particles)
logsum_weights = jax.scipy.special.logsumexp(log_weights)
normalizing_constant = logsum_weights - jnp.log(num_particles)
weights = jnp.exp(log_weights - logsum_weights)
return SMCState(particles, weights), SMCInfo(
resampling_idx, normalizing_constant, update_info
)