API#

Updates functions#

Update functions are the heart of probabilistic networks as they shape the propagation of beliefs in the neural hierarchy. The library implements the standard variational updates for value and volatility coupling, as described in Weber et al. (2023).

The updates module contains the update functions used during the belief propagation. Update functions are available through three sub-modules, organized according to their functional roles. We usually dissociate the first updates, triggered top-down (from the leaves to the roots of the networks), that are prediction steps and recover the current state of inference. The second updates are the prediction error, signalling the divergence between the prediction and the new observation (for input nodes), or state (for state nodes). Interleaved with these steps are posterior update steps, where a node receives prediction errors from the child nodes and estimates new statistics.

Posterior updates#

Update the sufficient statistics of a state node after receiving prediction errors from children nodes. The prediction errors from all the children below the node should be computed before calling the posterior update step.

Binary nodes#

binary_node_update_infinite(attributes, ...)

Update the posterior of a binary node given infinite precision of the input.

binary_node_update_finite(attributes, ...)

Update the posterior of a binary node given finite precision of the input.

Categorical nodes#

categorical_input_update(attributes, ...)

Update the categorical input node given an array of binary observations.

Continuous nodes#

posterior_update_mean_continuous_node(...)

Update the mean of a state node using the value prediction errors.

posterior_update_precision_continuous_node(...)

Update the precision of a state node using the volatility prediction errors.

continuous_node_update(attributes, node_idx, ...)

Update the posterior of a continuous node using the standard HGF update.

continuous_node_update_ehgf(attributes, ...)

Update the posterior of a continuous node using the eHGF update.

Prediction steps#

Compute the expectation for future observation given the influence of parent nodes. The prediction step are executed for all nodes, top-down, before any observation.

Binary nodes#

binary_state_node_prediction(attributes, ...)

Get the new expected mean and precision of a binary state node.

Continuous nodes#

predict_mean(attributes, edges, time_step, ...)

Compute the expected mean of a continuous state node.

predict_precision(attributes, edges, ...)

Compute the expected precision of a continuous state node.

continuous_node_prediction(attributes, ...)

Update the expected mean and expected precision of a continuous node.

Prediction error steps#

Compute the value and volatility prediction errors of a given node. The prediction error can only be computed after the posterior update (or observation) of a given node.

Inputs#

Binary inputs#

binary_input_prediction_error_infinite_precision(...)

Compute the prediction error of a binary input assuming an infinite precision.

binary_input_prediction_error_finite_precision(...)

Compute the prediction error of a binary input assuming a finite precision.

Continuous inputs#

continuous_input_volatility_prediction_error(...)

Store noise prediction error from an input node.

continuous_input_value_prediction_error(...)

Store value prediction error and expected precision from an input node.

continuous_input_prediction_error(...)

Store prediction errors in an input node.

State nodes#

Binary state nodes#

binary_state_node_prediction_error(...)

Compute the value prediction errors and predicted precision of a binary node.

Continuous state nodes#

continuous_node_value_prediction_error(...)

Compute the value prediction error of a state node.

continuous_node_volatility_prediction_error(...)

Compute the volatility prediction error of a state node.

continuous_node_prediction_error(attributes, ...)

Store prediction errors in an input node.

Distribution#

The Hierarchical Gaussian Filter as a PyMC distribution. This distribution can be embedded in models using PyMC>=5.0.0.

hgf_logp([tonic_volatility_1, ...])

HGF log-probability given input data, response function and parameters.

HGFLogpGradOp([input_data, time_steps, ...])

Gradient Op for the HGF distribution.

HGFDistribution([input_data, time_steps, ...])

The HGF distribution PyMC >= 5.0 compatible.

Model#

The main class is used to create a standard Hierarchical Gaussian Filter for binary or continuous inputs, with two or three levels. This class wraps the previous JAX modules and creates a standard node structure for these models.

HGF([n_levels, model_type, update_type, ...])

The two-level and three-level Hierarchical Gaussian Filters (HGF).

Plots#

Plotting functionalities to visualize parameters trajectories and correlations after observing new data.

plot_trajectories(hgf[, ci, show_surprise, ...])

Plot the trajectories of the nodes' sufficient statistics and surprise.

plot_correlations(hgf)

Plot the heatmap correlation of the sufficient statistics trajectories.

plot_network(hgf)

Visualization of node network using GraphViz.

plot_nodes(hgf, node_idxs[, ci, ...])

Plot the sufficient statistics trajectories of a set of nodes.

Response#

A collection of response functions. A response function is simply a callable taking at least the HGF instance as input after observation and returning surprise.

first_level_gaussian_surprise(hgf[, ...])

Gaussian surprise at the first level of a probabilistic network.

total_gaussian_surprise(hgf[, ...])

Sum of the Gaussian surprise across the probabilistic network.

first_level_binary_surprise(hgf[, ...])

Sum of the binary surprise along the time series (binary HGF).

binary_softmax(hgf[, ...])

Surprise under the binary sofmax model.

binary_softmax_inverse_temperature(hgf[, ...])

Surprise from a binary sofmax parametrized by the inverse temperature.

Networks#

Utilities for manipulating networks of probabilistic nodes.

beliefs_propagation(attributes, input_data, ...)

Update the network's parameters after observing new data point(s).

trim_sequence(exclude_node_idxs, ...)

Remove steps from an update sequence that depends on a set of nodes.

list_branches(node_idxs, edges[, branch_list])

Return the branch of a network from a given set of root nodes.

fill_categorical_state_node(hgf, node_idx, ...)

Generate a binary network implied by categorical state(-transition) nodes.

get_update_sequence(hgf)

Generate an update sequence from the network's structure.

Math#

Math functions and probability densities.

MultivariateNormal()

The multivariate normal as an exponential family distribution.

Normal()

The univariate normal as an exponential family distribution.

gaussian_predictive_distribution(x, xi, nu)

Density of the Gaussian-predictive distribution.

gaussian_density(x, mean, precision)

Gaussian density as defined by mean and precision.

sigmoid(x[, lower_bound, upper_bound])

Logistic sigmoid function.

binary_surprise(x, expected_mean)

Surprise at a binary outcome.

gaussian_surprise(x, expected_mean, ...)

Surprise at an outcome under a Gaussian prediction.

dirichlet_kullback_leibler(alpha_1, alpha_2)

Compute the Kullback-Leibler divergence between two Dirichlet distributions.

binary_surprise_finite_precision(value, ...)

Compute the binary surprise with finite precision.