pyhgf.plots.plot_nodes#
- pyhgf.plots.plot_nodes(network: Network, node_idxs: int | List[int], ci: bool = True, show_surprise: bool = True, show_posterior: bool = False, figsize: Tuple[int, int] = (12, 5), color: Tuple | str | None = None, axs: List | Axes | None = None)[source]#
Plot the trajectory of expected sufficient statistics of a set of nodes.
This function will plot the expected mean and precision (converted into standard deviation) before observation, and the Gaussian surprise after observation. If children_inputs is True, will also plot the children input (mean for value coupling and precision for volatility coupling).
- Parameters:
- network
An instance of main Network class.
- node_idxs
The index(es) of the probabilistic node(s) that should be plotted. If multiple indexes are provided, multiple rows will be appended to the figure, one for each node.
- ci
Whether to show the uncertainty around the values estimates (using the standard deviation \(\sqrt{\frac{1}{\hat{\pi}}}\)).
- show_surprise
If True the surprise, defined as the negative log probability of the observation given the expectation, is plotted in the backgroud of the figure as grey shadded area.
- show_posterior
If True, plot the posterior mean and precision on the top of expected mean and precision. Defaults to False.
- figsize
The width and height of the figure. Defaults to (18, 9) for a two-level model, or to (18, 12) for a three-level model.
- color
The color of the main curve showing the beliefs trajectory.
- axs
A list of Matplotlib axes instances where to draw the trajectories. This should correspond to the number of nodes in the structure. The default is None (create a new figure).
- Returns:
- axs
The Matplotlib axes instances where to plot the trajectories.
Examples
Visualization of nodes’ trajectories from a three-level continuous HGF model.
from pyhgf import load_data from pyhgf.model import HGF # Set up standard 3-level HGF for continuous inputs hgf = HGF( n_levels=3, model_type="continuous", initial_mean={"1": 1.04, "2": 1.0, "3": 1.0}, initial_precision={"1": 1e4, "2": 1e1, "3": 1e1}, tonic_volatility={"1": -13.0, "2": -2.0, "3": -2.0}, tonic_drift={"1": 0.0, "2": 0.0, "3": 0.0}, volatility_coupling={"1": 1.0, "2": 1.0}, ) # Read USD-CHF data timeserie = load_data("continuous") # Feed input hgf.input_data(input_data=timeserie) # Plot hgf.plot_nodes(node_idxs=1)