pyhgf.model.HGF#

class pyhgf.model.HGF(n_levels: int | None = 2, model_type: str = 'continuous', initial_mean: Dict = {'1': 0.0, '2': 0.0, '3': 0.0}, initial_precision: Dict = {'1': 1.0, '2': 1.0, '3': 1.0}, continuous_precision: float | ndarray | Array | bool_ | number | bool | int | complex = 10000.0, tonic_volatility: Dict = {'1': -3.0, '2': -3.0, '3': -3.0}, volatility_coupling: Dict = {'1': 1.0, '2': 1.0}, eta0: float | ndarray | Array | bool_ | number | bool | int | complex = 0.0, eta1: float | ndarray | Array | bool_ | number | bool | int | complex = 1.0, binary_precision: float | ndarray | Array | bool_ | number | bool | int | complex = inf, tonic_drift: Dict = {'1': 0.0, '2': 0.0, '3': 0.0})[source]#

The two-level and three-level Hierarchical Gaussian Filters (HGF).

This class uses pre-made node structures that correspond to the most widely used HGF. The inputs can be continuous or binary.

Attributes:
model_type

The model implemented (can be “continuous”, “binary” or “custom”).

n_levels

The number of hierarchies in the model, including the input vector. It cannot be less than 2.

.. note::

The parameter structure also incorporates the value and volatility coupling strength with children and parents (i.e. “value_coupling_parents”, “value_coupling_children”, “volatility_coupling_parents”, “volatility_coupling_children”).

__init__(n_levels: int | None = 2, model_type: str = 'continuous', initial_mean: Dict = {'1': 0.0, '2': 0.0, '3': 0.0}, initial_precision: Dict = {'1': 1.0, '2': 1.0, '3': 1.0}, continuous_precision: float | ndarray | Array | bool_ | number | bool | int | complex = 10000.0, tonic_volatility: Dict = {'1': -3.0, '2': -3.0, '3': -3.0}, volatility_coupling: Dict = {'1': 1.0, '2': 1.0}, eta0: float | ndarray | Array | bool_ | number | bool | int | complex = 0.0, eta1: float | ndarray | Array | bool_ | number | bool | int | complex = 1.0, binary_precision: float | ndarray | Array | bool_ | number | bool | int | complex = inf, tonic_drift: Dict = {'1': 0.0, '2': 0.0, '3': 0.0}) None[source]#

Parameterization of the HGF model.

Parameters:
n_levels

The number of hierarchies in the perceptual model (can be 2 or 3). If None, the nodes hierarchy is not created and might be provided afterwards. Defaults to 2 for a 2-level HGF.

model_typestr

The model type to use (can be “continuous” or “binary”).

initial_mean

A dictionary containing the initial values for the initial mean at different levels of the hierarchy. Defaults set to 0.0.

initial_precision

A dictionary containing the initial values for the initial precision at different levels of the hierarchy. Defaults set to 1.0.

continuous_precision

The expected precision of the continuous input node. Default to 1e4. Only relevant if model_type=”continuous”.

tonic_volatility

A dictionary containing the initial values for the tonic volatility at different levels of the hierarchy. This represents the tonic part of the variance (the part that is not affected by the parent node). Defaults are set to -3.0.

volatility_coupling

A dictionary containing the initial values for the volatility coupling at different levels of the hierarchy. This represents the phasic part of the variance (the part that is affected by the parent nodes) and will define the strength of the connection between the node and the parent node. Defaults set to 1.0.

eta0

The first categorical value of the binary node. Defaults to 0.0. Only relevant if model_type=”binary”.

eta1

The second categorical value of the binary node. Defaults to 0.0. Only relevant if model_type=”binary”.

binary_precision

The precision of the binary input node. Default to jnp.inf. Only relevant if model_type=”binary”.

tonic_drift

A dictionary containing the initial values for the tonic drift at different levels of the hierarchy. This represents the drift of the random walk. Defaults set all entries to 0.0 (no drift).

Methods

__init__([n_levels, model_type, ...])

Parameterization of the HGF model.

add_edges([kind, parent_idxs, ...])

Add a value or volatility coupling link between a set of nodes.

add_nodes([kind, n_nodes, node_parameters, ...])

Add new input/state node(s) to the neural network.

cache_belief_propagation_fn()

Blank call to the belief propagation function.

create_belief_propagation_fn([overwrite])

Create the belief propagation function.

get_network()

Return the attributes, structure and update sequence defining the network.

input_custom_sequence(update_branches, ...)

Add new observations with custom update sequences.

input_data(input_data[, time_steps, observed])

Add new observations.

plot_correlations()

Plot the heatmap of cross-trajectories correlation.

plot_network()

Visualization of node network using GraphViz.

plot_nodes(node_idxs, **kwargs)

Plot the node(s) beliefs trajectories.

plot_trajectories(**kwargs)

Plot the parameters trajectories.

set_update_sequence([update_type])

Generate an update sequence from the network's structure.

surprise([response_function, ...])

Surprise of the model conditioned by the response function.

to_pandas()

Export the nodes trajectories and surprise as a Pandas data frame.