API#

Updates functions#

Update functions are the heart of probabilistic networks as they shape the propagation of beliefs in the neural hierarchy. The library implements the standard variational updates for value and volatility coupling, as described in Weber et al. (2023).

The updates module contains the update functions used during the belief propagation. Update functions are available through three sub-modules, organized according to their functional roles. We usually dissociate the first updates, triggered top-down (from the leaves to the roots of the networks), that are prediction steps and recover the current state of inference. The second updates are the prediction error, signalling the divergence between the prediction and the new observation (for input nodes), or state (for state nodes). Interleaved with these steps are posterior update steps, where a node receives prediction errors from the child nodes and estimates new statistics.

Posterior updates#

Update the sufficient statistics of a state node after receiving prediction errors from children nodes. The prediction errors from all the children below the node should be computed before calling the posterior update step.

Categorical nodes#

categorical_state_update(attributes, ...)

Update the categorical input node given an array of binary observations.

Continuous nodes#

posterior_update_mean_continuous_node(...)

Update the mean of a state node using the value prediction errors.

posterior_update_precision_continuous_node(...)

Update the precision of a state node using the volatility prediction errors.

continuous_node_posterior_update(attributes, ...)

Update the posterior of a continuous node using the standard HGF update.

continuous_node_posterior_update_ehgf(...)

Update the posterior of a continuous node using the eHGF update.

Exponential family#

posterior_update_exponential_family(...)

Update the parameters of an exponential family distribution.

Prediction steps#

Compute the expectation for future observation given the influence of parent nodes. The prediction step are executed for all nodes, top-down, before any observation.

Binary nodes#

binary_state_node_prediction(attributes, ...)

Get the new expected mean and precision of a binary state node.

Continuous nodes#

predict_mean(attributes, edges, node_idx)

Compute the expected mean of a continuous state node.

predict_precision(attributes, edges, node_idx)

Compute the expected precision of a continuous state node.

continuous_node_prediction(attributes, ...)

Update the expected mean and expected precision of a continuous node.

Dirichlet processes#

dirichlet_node_prediction(edges, attributes, ...)

Prediction of a Dirichlet process node.

Prediction error steps#

Compute the value and volatility prediction errors of a given node. The prediction error can only be computed after the posterior update (or observation) of a given node.

binary_state_node_prediction_error(...)

Compute the value prediction errors and predicted precision of a binary node.

binary_finite_state_node_prediction_error(...)

Update the posterior of a binary node given finite precision of the input.

categorical_state_prediction_error(...)

Prediction error from a categorical state node.

continuous_node_value_prediction_error(...)

Compute the value prediction error of a state node.

continuous_node_volatility_prediction_error(...)

Compute the volatility prediction error of a state node.

continuous_node_prediction_error(attributes, ...)

Store prediction errors in an input node.

dirichlet_node_prediction_error(edges, ...)

Prediction error and update the child networks of a Dirichlet process node.

update_cluster(operands, edges, node_idx)

Update an existing cluster.

create_cluster(operands, edges, node_idx)

Create a new cluster.

get_candidate(value, sensory_precision, ...)

Find the best cluster candidate given previous clusters and an input value.

likely_cluster_proposal(mean_mu_G0, ...[, ...])

Sample likely new belief distributions given pre-existing clusters.

clusters_likelihood(value, expected_mean, ...)

Likelihood of a parametrized candidate under the new observation.

Distribution#

The Hierarchical Gaussian Filter as a PyMC distribution. This distribution can be embedded in models using PyMC>=5.0.0.

logp

Compute the log-probability of a decision model under belief trajectories.

hgf_logp

Compute log-probabilities of a batch of Hierarchical Gaussian Filters.

HGFLogpGradOp

Gradient Op for the HGF distribution.

HGFDistribution

The HGF distribution PyMC >= 5.0 compatible.

HGFPointwise

The HGF distribution returning pointwise log probability.

Model#

The main class is used to create a standard Hierarchical Gaussian Filter for binary or continuous inputs, with two or three levels. This class wraps the previous JAX modules and creates a standard node structure for these models.

HGF

The two-level and three-level Hierarchical Gaussian Filters (HGF).

Network

A predictive coding neural network.

Plots#

Plotting functionalities to visualize parameters trajectories and correlations after observing new data.

plot_trajectories(network[, ci, ...])

Plot the trajectories of the nodes' sufficient statistics and surprise.

plot_correlations(network)

Plot the heatmap correlation of the sufficient statistics trajectories.

plot_network(network)

Visualization of node network using GraphViz.

plot_nodes(network, node_idxs[, ci, ...])

Plot the trajectory of expected sufficient statistics of a set of nodes.

Response#

A collection of response functions. A response function is simply a callable taking at least the HGF instance as input after observation and returning surprise.

first_level_gaussian_surprise(hgf[, ...])

Gaussian surprise at the first level of a probabilistic network.

total_gaussian_surprise(hgf[, ...])

Sum of the Gaussian surprise across the probabilistic network.

first_level_binary_surprise(hgf[, ...])

Time series of binary surprises for all binary state nodes.

binary_softmax(hgf[, ...])

Surprise under the binary sofmax model.

binary_softmax_inverse_temperature(hgf[, ...])

Surprise from a binary sofmax parametrized by the inverse temperature.

Utils#

Utilities for manipulating neural networks.

beliefs_propagation(attributes, inputs, ...)

Update the network's parameters after observing new data point(s).

list_branches(node_idxs, edges[, branch_list])

Return the branch of a network from a given set of root nodes.

fill_categorical_state_node(network, ...)

Generate a binary network implied by categorical state(-transition) nodes.

get_update_sequence(network, update_type)

Generate an update sequence from the network's structure.

to_pandas(network)

Export the nodes trajectories and surprise as a Pandas data frame.

add_edges(attributes, edges[, kind, ...])

Add a value or volatility coupling link between a set of nodes.

get_input_idxs(edges)

List all possible default inputs nodes.

Math#

Math functions and probability densities.

MultivariateNormal()

The multivariate normal as an exponential family distribution.

Normal()

The univariate normal as an exponential family distribution.

gaussian_predictive_distribution(x, xi, nu)

Density of the Gaussian-predictive distribution.

gaussian_density(x, mean, precision)

Gaussian density as defined by mean and precision.

sigmoid(x[, lower_bound, upper_bound])

Logistic sigmoid function.

binary_surprise(x, expected_mean)

Surprise at a binary outcome.

gaussian_surprise(x, expected_mean, ...)

Surprise at an outcome under a Gaussian prediction.

dirichlet_kullback_leibler(alpha_1, alpha_2)

Compute the Kullback-Leibler divergence between two Dirichlet distributions.

binary_surprise_finite_precision(value, ...)

Compute the binary surprise with finite precision.