blackjax.vi.meanfield_vi#

Module Contents#

Classes#

MFVIState

Typed version of namedtuple.

MFVIInfo

Typed version of namedtuple.

meanfield_vi

High-level implementation of Mean-Field Variational Inference.

Functions#

step(→ Tuple[MFVIState, MFVIInfo])

Approximate the target density using the mean-field approximation.

sample(rng_key, state[, num_samples])

Sample from the mean-field approximation.

generate_meanfield_logdensity(mu, rho)

class MFVIState[source]#

Typed version of namedtuple.

Usage in Python versions >= 3.6:

class Employee(NamedTuple):
    name: str
    id: int

This is equivalent to:

Employee = collections.namedtuple('Employee', ['name', 'id'])

The resulting class has an extra __annotations__ attribute, giving a dict that maps field names to types. (The field names are also in the _fields attribute, which is part of the namedtuple API.) Alternative equivalent keyword syntax is also accepted:

Employee = NamedTuple('Employee', name=str, id=int)

In Python versions <= 3.5 use:

Employee = NamedTuple('Employee', [('name', str), ('id', int)])
mu: blackjax.types.ArrayTree[source]#
rho: blackjax.types.ArrayTree[source]#
opt_state: optax.OptState[source]#
class MFVIInfo[source]#

Typed version of namedtuple.

Usage in Python versions >= 3.6:

class Employee(NamedTuple):
    name: str
    id: int

This is equivalent to:

Employee = collections.namedtuple('Employee', ['name', 'id'])

The resulting class has an extra __annotations__ attribute, giving a dict that maps field names to types. (The field names are also in the _fields attribute, which is part of the namedtuple API.) Alternative equivalent keyword syntax is also accepted:

Employee = NamedTuple('Employee', name=str, id=int)

In Python versions <= 3.5 use:

Employee = NamedTuple('Employee', [('name', str), ('id', int)])
elbo: float[source]#
step(rng_key: blackjax.types.PRNGKey, state: MFVIState, logdensity_fn: Callable, optimizer: optax.GradientTransformation, num_samples: int = 5, stl_estimator: bool = True) Tuple[MFVIState, MFVIInfo][source]#

Approximate the target density using the mean-field approximation.

Parameters:
  • rng_key – Key for JAX’s pseudo-random number generator.

  • init_state – Initial state of the mean-field approximation.

  • logdensity_fn – Function that represents the target log-density to approximate.

  • optimizer – Optax GradientTransformation to be used for optimization.

  • num_samples – The number of samples that are taken from the approximation at each step to compute the Kullback-Leibler divergence between the approximation and the target log-density.

  • stl_estimator – Whether to use stick-the-landing (STL) gradient estimator [RWD17] for gradient estimation. The STL estimator has lower gradient variance by removing the score function term from the gradient. It is suggested by [ASD20] to always keep it in order for better results.

sample(rng_key: blackjax.types.PRNGKey, state: MFVIState, num_samples: int = 1)[source]#

Sample from the mean-field approximation.

class meanfield_vi[source]#

High-level implementation of Mean-Field Variational Inference.

Parameters:
  • logdensity_fn – A function that represents the log-density function associated with the distribution we want to sample from.

  • optimizer – Optax optimizer to use to optimize the ELBO.

  • num_samples – Number of samples to take at each step to optimize the ELBO.

Return type:

A VIAlgorithm.

init[source]#
step[source]#
sample[source]#
generate_meanfield_logdensity(mu, rho)[source]#