blackjax.vi.svgd#

Module Contents#

Classes#

svgd

Implements the (basic) user interface for the svgd algorithm.

Functions#

rbf_kernel(x, y[, length_scale])

update_median_heuristic(→ SVGDState)

Median heuristic for setting the bandwidth of RBF kernels.

rbf_kernel(x, y, length_scale=1)[source]#
update_median_heuristic(state: SVGDState) SVGDState[source]#

Median heuristic for setting the bandwidth of RBF kernels.

A reasonable middle-ground for choosing the length_scale of the RBF kernel is to pick the empirical median of the squared distance between particles. This strategy is called the median heuristic.

class svgd[source]#

Implements the (basic) user interface for the svgd algorithm.

Parameters:
  • grad_logdensity_fn – gradient, or an estimate, of the target log density function to samples approximately from

  • optimizer – Optax compatible optimizer, which conforms to the optax.GradientTransformation protocol

  • kernel – positive semi definite kernel

  • update_kernel_parameters – function that updates the kernel parameters given the current state of the particles

Return type:

A SamplingAlgorithm.

init[source]#
build_kernel[source]#