Example 2: Estimating the mean and precision of a time-varying Gaussian distributions#

Open In Colab

Hide code cell content
import sys

from IPython.utils import io

if 'google.colab' in sys.modules:

  with io.capture_output() as captured:
      ! pip install pyhgf watermark
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from scipy.stats import norm

from pyhgf.model import Network

Where the standard continuous HGF assumes a known precision in the input node (usually set to something high), this assumption can be relaxed and the filter can also try to estimate this quantity from the data. In this notebook, we demonstrate how we can infer the value of the mean, of the precision, or both value at the same time, using the appropriate value and volatility coupling parents.

Unkown mean, known precision#

Hint

The The continuous Hierarchical Gaussian Filter is an example of a model assuming a continuous input with known precision and unknown mean. It is further assumed that the mean is changing overtime, and we want the model to track this rate of change by adding a volatility node on the top of the value parent (two-level continuous HGF), and event track the rate of change of this rate of change by adding another volatility parent (three-level continuous HGF).

np.random.seed(123)
dist_mean, dist_std = 5, 1
input_data = np.random.normal(loc=dist_mean, scale=dist_std, size=1000)
mean_hgf = (
    Network()
    .add_nodes(precision=1.0, autoconnection_strength=0)
    .add_nodes(value_children=0, tonic_volatility=-8.0)
    .input_data(input_data)
)
mean_hgf.plot_network()
../_images/03cb3da637fa30b39e8a29f2999f9714378a48d111560bde52adeda5f5cc6209.svg

Note

We are setting the tonic volatility to something low for visualization purposes, but changing this value can make the model learn in fewer iterations.

Hide code cell source
# get the nodes trajectories
fig, ax = plt.subplots(figsize=(12, 5))

x = np.linspace(-10, 10, 1000)
for i, color in zip([0, 2, 5, 10, 50, 500], plt.cm.Greys(np.linspace(0.2, 1, 6))):

    # extract the sufficient statistics from the input node (and parents)
    mean = mean_hgf.node_trajectories[0]["expected_mean"][i]
    std = np.sqrt(1 / (mean_hgf.attributes[0]["expected_precision"]))

    # the model expectations
    ax.plot(x, norm(mean, std).pdf(x), color=color, label=i)

# the sampling distribution
ax.fill_between(x, norm(dist_mean, dist_std).pdf(x), color="#582766", alpha=0.2)

ax.legend(title="Iterations")
ax.set_xlabel("Input (u)")
ax.set_ylabel("Density")
plt.grid(linestyle=":")
sns.despine()
../_images/f40b8c16d26c6de68c903509b540674e12e4282620056a9ff16f72acfc1c8d8f.png

Kown mean, unknown precision#

Unkown mean, unknown precision#

np.random.seed(123)
dist_mean, dist_std = 5, 1
input_data = np.random.normal(loc=dist_mean, scale=dist_std, size=1000)
mean_precision_hgf = (
    Network()
    .add_nodes(precision=1e4)
    .add_nodes(value_children=0, tonic_volatility=-6.0)
    .add_nodes(volatility_children=0, mean=6.0, tonic_volatility=-8.0)
).input_data(input_data)
mean_precision_hgf.plot_network()
../_images/0ed648f290a9ce3dd179484111606aba91582e7494d9e92272d820b25cfb1b37.svg
Hide code cell source
fig, ax = plt.subplots(figsize=(12, 5))

x = np.linspace(-10, 10, 1000)
for i, color in zip(range(0, 250, 25), plt.cm.Greys(np.linspace(0.2, 1, 10))):

    # extract the sufficient statistics from the input node (and parents)
    mean = mean_precision_hgf.node_trajectories[0]["expected_mean"][i]
    std = np.sqrt(1 / mean_precision_hgf.node_trajectories[0]["expected_precision"][i])

    # the model expectations
    ax.plot(x, norm(mean, std).pdf(x), color=color, label=i)

# the sampling distribution
ax.fill_between(x, norm(dist_mean, dist_std).pdf(x), color="#582766", alpha=0.2)

ax.legend(title="Iterations")
ax.set_xlabel("Input (u)")
ax.set_ylabel("Density")
plt.grid(linestyle=":")
sns.despine()
../_images/3a0b4bc5af47f28f421e06ee64657de293a907af8d5f4d7d93f70e12273e3fd2.png

System configuration#

%load_ext watermark
%watermark -n -u -v -iv -w -p pyhgf,jax,jaxlib
Last updated: Sun Oct 13 2024

Python implementation: CPython
Python version       : 3.12.7
IPython version      : 8.28.0

pyhgf : 0.2.0
jax   : 0.4.31
jaxlib: 0.4.31

seaborn   : 0.13.2
sys       : 3.12.7 (main, Oct  1 2024, 15:17:55) [GCC 11.4.0]
pyhgf     : 0.2.0
matplotlib: 3.9.2
IPython   : 8.28.0
numpy     : 1.26.0

Watermark: 2.5.0