pyhgf.plots.plot_trajectories#

pyhgf.plots.plot_trajectories(network: Network, ci: bool = True, show_surprise: bool = True, show_posterior: bool = False, show_total_surprise: bool = False, figsize: Tuple[int, int] = (18, 9), axs: List | Axes | None = None) Axes[source]#

Plot the trajectories of the nodes’ sufficient statistics and surprise.

This function will plot the expected mean and precision (converted into standard deviation) and the surprise at each level of the node structure.

Parameters:
network

An instance of the main Network class.

ci

Show the uncertainty around the values estimates (standard deviation).

show_surprise

If True plot each node’s surprise together with sufficient statistics. If False, only the input node’s surprise is depicted.

show_posterior

If True, plot the posterior mean and precision on the top of expected mean and precision. Defaults to False.

show_total_surprise

If True, plot the sum of surprises across all nodes in the bottom panel. Defaults to False.

figsize

The width and height of the figure. Defaults to (18, 9) for a two-level model, or to (18, 12) for a three-level model.

axs

A list of Matplotlib axes instances where to draw the trajectories. This should correspond to the number of nodes in the structure. The default is None (create a new figure).

Returns:
axs

The Matplotlib axes instances where to plot the trajectories.

Examples

Visualization of nodes’ trajectories from a three-level continuous HGF model.

from pyhgf import load_data
from pyhgf.model import HGF

# Set up standard 3-level HGF for continuous inputs
hgf = HGF(
    n_levels=3,
    model_type="continuous",
    initial_mean={"1": 1.04, "2": 1.0, "3": 1.0},
    initial_precision={"1": 1e4, "2": 1e1, "3": 1e1},
    tonic_volatility={"1": -13.0, "2": -2.0, "3": -2.0},
    tonic_drift={"1": 0.0, "2": 0.0, "3": 0.0},
    volatility_coupling={"1": 1.0, "2": 1.0},
)

# Read USD-CHF data
timeserie = load_data("continuous")

# Feed input
hgf.input_data(input_data=timeserie)

# Plot
hgf.plot_trajectories();
../../_images/pyhgf-plots-plot_trajectories-1.png

Visualization of nodes’ trajectories from a three-level binary HGF model.

from pyhgf import load_data
from pyhgf.model import HGF
import jax.numpy as jnp

# Read binary input
u, _ = load_data("binary")

three_levels_hgf = HGF(
    n_levels=3,
    model_type="binary",
    initial_mean={"1": .0, "2": .5, "3": 0.},
    initial_precision={"1": .0, "2": 1e4, "3": 1e1},
    tonic_volatility={"1": None, "2": -6.0, "3": -2.0},
    tonic_drift={"1": None, "2": 0.0, "3": 0.0},
    volatility_coupling={"1": None, "2": 1.0},
    eta0=0.0,
    eta1=1.0,
    binary_precision = jnp.inf,
)

# Feed input
three_levels_hgf = three_levels_hgf.input_data(u)

# Plot
three_levels_hgf.plot_trajectories();
../../_images/pyhgf-plots-plot_trajectories-2.png