pyhgf.plots.plot_nodes#

pyhgf.plots.plot_nodes(network: Network, node_idxs: int | List[int], ci: bool = True, show_surprise: bool = True, show_posterior: bool = False, figsize: Tuple[int, int] = (12, 5), color: Tuple | str | None = None, axs: List | Axes | None = None)[source]#

Plot the trajectory of expected sufficient statistics of a set of nodes.

This function will plot the expected mean and precision (converted into standard deviation) before observation, and the Gaussian surprise after observation. If children_inputs is True, will also plot the children input (mean for value coupling and precision for volatility coupling).

Parameters:
network

An instance of main Network class.

node_idxs

The index(es) of the probabilistic node(s) that should be plotted. If multiple indexes are provided, multiple rows will be appended to the figure, one for each node.

ci

Whether to show the uncertainty around the values estimates (using the standard deviation \(\sqrt{\frac{1}{\hat{\pi}}}\)).

show_surprise

If True the surprise, defined as the negative log probability of the observation given the expectation, is plotted in the backgroud of the figure as grey shadded area.

show_posterior

If True, plot the posterior mean and precision on the top of expected mean and precision. Defaults to False.

figsize

The width and height of the figure. Defaults to (18, 9) for a two-level model, or to (18, 12) for a three-level model.

color

The color of the main curve showing the beliefs trajectory.

axs

A list of Matplotlib axes instances where to draw the trajectories. This should correspond to the number of nodes in the structure. The default is None (create a new figure).

Returns:
axs

The Matplotlib axes instances where to plot the trajectories.

Examples

Visualization of nodes’ trajectories from a three-level continuous HGF model.

from pyhgf import load_data
from pyhgf.model import HGF

# Set up standard 3-level HGF for continuous inputs
hgf = HGF(
    n_levels=3,
    model_type="continuous",
    initial_mean={"1": 1.04, "2": 1.0, "3": 1.0},
    initial_precision={"1": 1e4, "2": 1e1, "3": 1e1},
    tonic_volatility={"1": -13.0, "2": -2.0, "3": -2.0},
    tonic_drift={"1": 0.0, "2": 0.0, "3": 0.0},
    volatility_coupling={"1": 1.0, "2": 1.0},
)

# Read USD-CHF data
timeserie = load_data("continuous")

# Feed input
hgf.input_data(input_data=timeserie)

# Plot
hgf.plot_nodes(node_idxs=1)
../../_images/pyhgf-plots-plot_nodes-1.png