pyhgf.plots.plot_nodes#

pyhgf.plots.plot_nodes(network: Network, node_idxs: int | List[int], ci: bool = True, show_surprise: bool = True, show_observations: bool = False, show_current_state: bool = False, figsize: Tuple[int, int] = (12, 5), color: Tuple | str | None = None, axs: List | Axes | None = None)[source]#

Plot the trajectory of expected sufficient statistics of a set of nodes.

This function will plot the expected mean and precision (converted into standard deviation) before observation, and the Gaussian surprise after observation. If children_inputs is True, will also plot the children input (mean for value coupling and precision for volatility coupling).

Parameters:
network

An instance of main Network class.

node_idxs

The index(es) of the probabilistic node(s) that should be plotted. If multiple indexes are provided, multiple rows will be appended to the figure, one for each node.

ci

Whether to show the uncertainty around the values estimates (using the standard deviation \(\sqrt{\frac{1}{\hat{\pi}}}\)).

show_surprise

If True the surprise, defined as the negative log probability of the observation given the expectation, is plotted in the backgroud of the figure as grey shadded area.

show_observations

If True, show the observations received from the child node(s). In the situation of value coupled nodes, plot the expected mean of the child node(s). This feature is not supported in the situation of volatility coupling. Defaults to False.

show_current_state

If True, plot the current states (mean and precision) on the top of expected states (mean and precision). Defaults to False.

figsize

The width and height of the figure. Defaults to (18, 9) for a two-level model, or to (18, 12) for a three-level model.

color

The color of the main curve showing the beliefs trajectory.

axs

A list of Matplotlib axes instances where to draw the trajectories. This should correspond to the number of nodes in the structure. The default is None (create a new figure).

Returns:
axs

The Matplotlib axes instances where to plot the trajectories.

Examples

Visualization of nodes’ trajectories from a three-level continuous HGF model.

from pyhgf import load_data
from pyhgf.model import HGF

# Set up standard 3-level HGF for continuous inputs
hgf = HGF(
    n_levels=3,
    model_type="continuous",
    initial_mean={"1": 1.04, "2": 1.0, "3": 1.0},
    initial_precision={"1": 1e4, "2": 1e1, "3": 1e1},
    tonic_volatility={"1": -13.0, "2": -2.0, "3": -2.0},
    tonic_drift={"1": 0.0, "2": 0.0, "3": 0.0},
    volatility_coupling={"1": 1.0, "2": 1.0},
)

# Read USD-CHF data
timeserie = load_data("continuous")

# Feed input
hgf.input_data(input_data=timeserie)

# Plot
hgf.plot_nodes(node_idxs=1)
../../_images/pyhgf-plots-plot_nodes-1.png