Source code for blackjax.smc.ess
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All things related to SMC effective sample size"""
from typing import Callable
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from blackjax.types import Array, ArrayLikeTree
[docs]def ess(log_weights: Array) -> float:
return jnp.exp(log_ess(log_weights))
[docs]def log_ess(log_weights: Array) -> float:
"""Compute the effective sample size.
Parameters
----------
log_weights: np.ndarray
log-weights of the sample
Returns
-------
log_ess: float
The logarithm of the effective sample size
"""
return 2 * jsp.special.logsumexp(log_weights) - jsp.special.logsumexp(
2 * log_weights
)
[docs]def ess_solver(
logdensity_fn: Callable,
particles: ArrayLikeTree,
target_ess: float,
max_delta: float,
root_solver: Callable,
):
"""Build a Tempered SMC step.
Parameters
----------
logdensity_fn: Callable
The log probability function we wish to sample from.
smc_state: SMCState
Current state of the tempered SMC algorithm
target_ess: float
The relative ESS targeted for the next increment of SMC tempering
max_delta: float
Max acceptable delta increment
root_solver: Callable, optional
A solver to find the root of a function, takes a function `f`, a starting point `delta0`,
a min value `min_delta`, and a max value `max_delta`.
Default is `BFGS` minimization of `f ** 2` and ignores `min_delta` and `max_delta`.
Returns
-------
delta: float
The increment that solves for the target ESS
"""
n_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
logprob = logdensity_fn(particles)
target_val = jnp.log(n_particles * target_ess)
def fun_to_solve(delta):
log_weights = jnp.nan_to_num(-delta * logprob)
ess_val = log_ess(log_weights)
return ess_val - target_val
estimated_delta = root_solver(fun_to_solve, 0.0, 0.0, max_delta)
return estimated_delta