pyhgf.model.HGF#
- class pyhgf.model.HGF(n_levels: int | None = 2, model_type: str = 'continuous', initial_mean: Dict = {'1': 0.0, '2': 0.0, '3': 0.0}, initial_precision: Dict = {'1': 1.0, '2': 1.0, '3': 1.0}, continuous_precision: float | ndarray | Array | bool_ | number | bool | int | complex = 10000.0, tonic_volatility: Dict = {'1': -3.0, '2': -3.0, '3': -3.0}, volatility_coupling: Dict = {'1': 1.0, '2': 1.0}, eta0: float | ndarray | Array | bool_ | number | bool | int | complex = 0.0, eta1: float | ndarray | Array | bool_ | number | bool | int | complex = 1.0, binary_precision: float | ndarray | Array | bool_ | number | bool | int | complex = inf, tonic_drift: Dict = {'1': 0.0, '2': 0.0, '3': 0.0})[source]#
The two-level and three-level Hierarchical Gaussian Filters (HGF).
This class uses pre-made node structures that correspond to the most widely used HGF. The inputs can be continuous or binary.
- Attributes:
- model_type
The model implemented (can be “continuous”, “binary” or “custom”).
- n_levels
The number of hierarchies in the model, including the input vector. It cannot be less than 2.
- .. note::
The parameter structure also incorporates the value and volatility coupling strength with children and parents (i.e. “value_coupling_parents”, “value_coupling_children”, “volatility_coupling_parents”, “volatility_coupling_children”).
- __init__(n_levels: int | None = 2, model_type: str = 'continuous', initial_mean: Dict = {'1': 0.0, '2': 0.0, '3': 0.0}, initial_precision: Dict = {'1': 1.0, '2': 1.0, '3': 1.0}, continuous_precision: float | ndarray | Array | bool_ | number | bool | int | complex = 10000.0, tonic_volatility: Dict = {'1': -3.0, '2': -3.0, '3': -3.0}, volatility_coupling: Dict = {'1': 1.0, '2': 1.0}, eta0: float | ndarray | Array | bool_ | number | bool | int | complex = 0.0, eta1: float | ndarray | Array | bool_ | number | bool | int | complex = 1.0, binary_precision: float | ndarray | Array | bool_ | number | bool | int | complex = inf, tonic_drift: Dict = {'1': 0.0, '2': 0.0, '3': 0.0}) None [source]#
Parameterization of the HGF model.
- Parameters:
- n_levels
The number of hierarchies in the perceptual model (can be 2 or 3). If None, the nodes hierarchy is not created and might be provided afterwards. Defaults to 2 for a 2-level HGF.
- model_typestr
The model type to use (can be “continuous” or “binary”).
- initial_mean
A dictionary containing the initial values for the initial mean at different levels of the hierarchy. Defaults set to 0.0.
- initial_precision
A dictionary containing the initial values for the initial precision at different levels of the hierarchy. Defaults set to 1.0.
- continuous_precision
The expected precision of the continuous input node. Default to 1e4. Only relevant if model_type=”continuous”.
- tonic_volatility
A dictionary containing the initial values for the tonic volatility at different levels of the hierarchy. This represents the tonic part of the variance (the part that is not affected by the parent node). Defaults are set to -3.0.
- volatility_coupling
A dictionary containing the initial values for the volatility coupling at different levels of the hierarchy. This represents the phasic part of the variance (the part that is affected by the parent nodes) and will define the strength of the connection between the node and the parent node. Defaults set to 1.0.
- eta0
The first categorical value of the binary node. Defaults to 0.0. Only relevant if model_type=”binary”.
- eta1
The second categorical value of the binary node. Defaults to 0.0. Only relevant if model_type=”binary”.
- binary_precision
The precision of the binary input node. Default to jnp.inf. Only relevant if model_type=”binary”.
- tonic_drift
A dictionary containing the initial values for the tonic drift at different levels of the hierarchy. This represents the drift of the random walk. Defaults set all entries to 0.0 (no drift).
Methods
__init__
([n_levels, model_type, ...])Parameterization of the HGF model.
add_edges
([kind, parent_idxs, ...])Add a value or volatility coupling link between a set of nodes.
add_nodes
([kind, n_nodes, node_parameters, ...])Add new input/state node(s) to the neural network.
create_belief_propagation_fn
([overwrite, ...])Create the belief propagation function.
get_network
()Return the attributes, edges and update sequence defining the network.
input_custom_sequence
(update_branches, ...)Add new observations with custom update sequences.
input_data
(input_data[, time_steps, ...])Add new observations.
plot_correlations
()Plot the heatmap of cross-trajectories correlation.
plot_network
()Visualization of node network using GraphViz.
plot_nodes
(node_idxs, **kwargs)Plot the node(s) beliefs trajectories.
plot_trajectories
(**kwargs)Plot the parameters trajectories.
surprise
([response_function, ...])Surprise of the model conditioned by the response function.
to_pandas
()Export the nodes trajectories and surprise as a Pandas data frame.
Attributes
input_idxs
Idexes of state nodes that can observe new data points by default.