pyhgf.model.Network#

class pyhgf.model.Network[source]#

A predictive coding neural network.

This is the core class to define and manipulate neural networks, that consists in 1. attributes, 2. structure and 3. update sequences.

Attributes:
attributes

The attributes of the probabilistic nodes.

edges

The edges of the probabilistic nodes as a tuple of pyhgf.typing.AdjacencyLists. The tuple has the same length as the node number. For each node, the index lists the value/volatility parents/children.

inputs

Information on the input nodes.

node_trajectories

The dynamic of the node’s beliefs after updating.

update_sequence

The sequence of update functions that are applied during the belief propagation step.

scan_fn

The function that is passed to jax.lax.scan(). This is a pre- parametrized version of pyhgf.networks.beliefs_propagation().

__init__() None[source]#

Initialize an empty neural network.

Methods

__init__()

Initialize an empty neural network.

add_edges([kind, parent_idxs, ...])

Add a value or volatility coupling link between a set of nodes.

add_nodes([kind, n_nodes, node_parameters, ...])

Add new input/state node(s) to the neural network.

create_belief_propagation_fn([overwrite, ...])

Create the belief propagation function.

get_network()

Return the attributes, edges and update sequence defining the network.

input_custom_sequence(update_branches, ...)

Add new observations with custom update sequences.

input_data(input_data[, time_steps, ...])

Add new observations.

plot_correlations()

Plot the heatmap of cross-trajectories correlation.

plot_network()

Visualization of node network using GraphViz.

plot_nodes(node_idxs, **kwargs)

Plot the node(s) beliefs trajectories.

plot_trajectories(**kwargs)

Plot the parameters trajectories.

surprise(response_function[, ...])

Surprise of the model conditioned by the response function.

to_pandas()

Export the nodes trajectories and surprise as a Pandas data frame.

Attributes

input_idxs

Idexes of state nodes that can observe new data points by default.