Source code for blackjax.smc.base

# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple, Optional, Tuple

import jax
import jax.numpy as jnp

from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey


[docs]class SMCState(NamedTuple): """State of the SMC sampler"""
[docs] particles: ArrayTree
[docs] weights: Array
[docs]class SMCInfo(NamedTuple): """Additional information on the tempered SMC step. proposals: PyTree The particles that were proposed by the MCMC pass. ancestors: Array The index of the particles proposed by the MCMC pass that were selected by the resampling step. log_likelihood_increment: float The log-likelihood increment due to the current step of the SMC algorithm. """
[docs] ancestors: Array
[docs] log_likelihood_increment: float
[docs] update_info: NamedTuple
[docs]def init(particles: ArrayLikeTree): # Infer the number of particles from the size of the leading dimension of # the first leaf of the inputted PyTree. num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] weights = jnp.ones(num_particles) / num_particles return SMCState(particles, weights)
[docs]def step( rng_key: PRNGKey, state: SMCState, update_fn: Callable, weigh_fn: Callable, resample_fn: Callable, num_resampled: Optional[int] = None, ) -> Tuple[SMCState, SMCInfo]: """General SMC sampling step. `update_fn` here corresponds to the Markov kernel $M_{t+1}$, and `weigh_fn` corresponds to the potential function $G_t$. We first use `update_fn` to generate new particles from the current ones, weigh these particles using `weigh_fn` and resample them with `resample_fn`. The `update_fn` and `weigh_fn` functions must be batched by the called either using `jax.vmap` or `jax.pmap`. In Feynman-Kac terms, the algorithm goes roughly as follows: .. code:: M_t: update_fn G_t: weigh_fn R_t: resample_fn idx = R_t(weights) x_t = x_tm1[idx] x_{t+1} = M_t(x_t) weights = G_t(x_{t+1}) Parameters ---------- rng_key Key used to generate pseudo-random numbers. state Current state of the SMC sampler: particles and their respective log-weights update_fn Function that takes an array of keys and particles and returns new particles. weigh_fn Function that assigns a weight to the particles. resample_fn Function that resamples the particles. num_resampled The number of particles to resample. This can be used to implement Waste-Free SMC :cite:p:`dau2020waste`, in which case we resample a number :math:`M<N` of particles, and the update function is in charge of returning :math:`N` samples. Returns ------- new_particles An array that contains the new particles generated by this SMC step. info An `SMCInfo` object that contains extra information about the SMC transition. """ updating_key, resampling_key = jax.random.split(rng_key, 2) num_particles = state.weights.shape[0] if num_resampled is None: num_resampled = num_particles resampling_idx = resample_fn(resampling_key, state.weights, num_resampled) particles = jax.tree_map(lambda x: x[resampling_idx], state.particles) keys = jax.random.split(updating_key, num_resampled) particles, update_info = update_fn(keys, particles) log_weights = weigh_fn(particles) logsum_weights = jax.scipy.special.logsumexp(log_weights) normalizing_constant = logsum_weights - jnp.log(num_particles) weights = jnp.exp(log_weights - logsum_weights) return SMCState(particles, weights), SMCInfo( resampling_idx, normalizing_constant, update_info )