blackjax.smc.tempered#

Module Contents#

Classes#

TemperedSMCState

Current state for the tempered SMC algorithm.

Functions#

init(particles)

build_kernel(→ Callable)

Build the base Tempered SMC kernel.

class TemperedSMCState[source]#

Current state for the tempered SMC algorithm.

particles: PyTree

The particles’ positions.

lmbda: float

Current value of the tempering parameter.

particles: blackjax.types.ArrayTree[source]#
weights: blackjax.types.Array[source]#
lmbda: float[source]#
init(particles: blackjax.types.ArrayLikeTree)[source]#
build_kernel(logprior_fn: Callable, loglikelihood_fn: Callable, mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable) Callable[source]#

Build the base Tempered SMC kernel.

Tempered SMC uses tempering to sample from a distribution given by

\[p(x) \propto p_0(x) \exp(-V(x)) \mathrm{d}x\]

where \(p_0\) is the prior distribution, typically easy to sample from and for which the density is easy to compute, and \(\exp(-V(x))\) is an unnormalized likelihood term for which \(V(x)\) is easy to compute pointwise.

Parameters:
  • logprior_fn – A function that computes the log density of the prior distribution

  • loglikelihood_fn – A function that returns the probability at a given position.

  • mcmc_step_fn – A function that creates a mcmc kernel from a log-probability density function.

  • mcmc_init_fn (Callable) – A function that creates a new mcmc state from a position and a log-probability density function.

  • resampling_fn – A random function that resamples generated particles based of weights

  • num_mcmc_iterations – Number of iterations in the MCMC chain.

Returns:

  • A callable that takes a rng_key and a TemperedSMCState that contains the current state

  • of the chain and that returns a new state of the chain along with

  • information about the transition.