Example 1: Bayesian filtering of cardiac volatility#

Open In Colab

Hide code cell content
import sys
from IPython.utils import io
if 'google.colab' in sys.modules:

  with io.capture_output() as captured:
      ! pip install pyhgf watermark systole
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
from systole import import_dataset1
from systole.detection import ecg_peaks
from systole.plots import plot_raw
from systole.utils import input_conversion

from pyhgf.distribution import HGFDistribution
from pyhgf.model import HGF
from pyhgf.response import total_gaussian_surprise

plt.rcParams["figure.constrained_layout.use"] = True
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.

The nodalized version of the Hierarchical Gaussian Filter that is implemented in pyhgf opens the possibility to create filters with multiple inputs. Here, we illustrate how we can use this feature to create an agent that is filtering their physiological signals in real-time. We use a two-level Hierarchical Gaussian Filter to predict the dynamics of the instantaneous heart rate (the RR interval measured at each heartbeat). We then extract the trajectory of surprise at each predictive node to relate it with the cognitive task performed by the participant while the signal is being recorded.

Loading and preprocessing physiological recording#

We use the physiological dataset included in Systole as an example. This recording contains electrocardiography (ECG) and respiration recording.

# Import PPG recording as pandas data frame
physio_df = import_dataset1(modalities=['ECG', 'Respiration'])

# Only use the first 60 seconds for demonstration
ecg = physio_df.ecg
  0%|          | 0/2 [00:00<?, ?it/s]
Downloading ECG channel:   0%|          | 0/2 [00:00<?, ?it/s]
Downloading ECG channel:  50%|█████     | 1/2 [00:00<00:00,  4.53it/s]
Downloading Respiration channel:  50%|█████     | 1/2 [00:00<00:00,  4.53it/s]
Downloading Respiration channel: 100%|██████████| 2/2 [00:00<00:00,  5.18it/s]
Downloading Respiration channel: 100%|██████████| 2/2 [00:00<00:00,  5.06it/s]

Plot the signal with instantaneous heart rate derivations#

plot_raw(ecg, modality='ecg', sfreq=1000, show_heart_rate=True);
../_images/b2277651d82fbd6c2de948e8799deeaeebfaf620074745809c9940b6d9677a09.png

Preprocessing#

# detect R peaks using Pan-Tomkins algorithm
_, peaks = ecg_peaks(physio_df.ecg)

# convert the peaks into a RR time series
rr = input_conversion(x=peaks, input_type="peaks", output_type="rr_s")

Model#

Note

Here we use the total Gaussian surprise (pyhgf.response.total_gaussian_surprise()) as a response function. This response function deviates from the default behaviour for the continuous HGF in that it returns the sum of the surprise for all the probabilistic nodes in the network, whereas the default (pyhgf.response.first_level_gaussian_surprise()) only computes the surprise at the first level (i.e. the value parent of the continuous input node). We explicitly specify this parameter here to indicate that we want our model to minimise its prediction errors over all variables, and not only at the observation level. In this case, however, the results are expected to be very similar between the two methods.

hgf_logp_op = HGFDistribution(
    n_levels=2,
    model_type="continuous",
    input_data=rr[np.newaxis, :],
    response_function=total_gaussian_surprise,
)
with pm.Model() as three_level_hgf:

    # omegas priors
    tonic_volatility_2 = pm.Normal("tonic_volatility_2", -2.0, 2.0)

    # HGF distribution
    pm.Potential("hgf_loglike", hgf_logp_op(tonic_volatility_1=-4.0, tonic_volatility_2=tonic_volatility_2))
pm.model_to_graphviz(three_level_hgf)
../_images/8fc9c210ab62f5dd9a3be3ce533d3c8c7ee05351b66a0f94da4f8a26be10ae39.svg
with three_level_hgf:
    idata = pm.sample(chains=2, cores=1)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (2 chains in 1 job)
NUTS: [tonic_volatility_2]


Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 13 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
az.plot_trace(idata);
../_images/076cd4f6b7af51070021fc0d795122482c81f6eb7d55d9947e264f7a66b34e8a.png
# retrieve the best fir for omega_2
tonic_volatility_2 = az.summary(idata)["mean"]["tonic_volatility_2"]
hgf = HGF(
    n_levels=2,
    model_type="continuous",
    initial_mean={"1": rr[0], "2": -4.0},
    initial_precision={"1": 1e4, "2": 1e1},
    tonic_volatility={"1": -4.0, "2": tonic_volatility_2},
    tonic_drift={"1": 0.0, "2": 0.0},
    volatility_coupling={"1": 1.0}).input_data(input_data=rr)
hgf.plot_trajectories();
../_images/83ca2eaea2fbcfc3a6744651f7a0054d4328232ed11944c259a2a3f4dbfd880b.png

System configuration#

%load_ext watermark
%watermark -n -u -v -iv -w -p pyhgf,jax,jaxlib
Last updated: Sun Oct 13 2024

Python implementation: CPython
Python version       : 3.12.7
IPython version      : 8.28.0

pyhgf : 0.2.0
jax   : 0.4.31
jaxlib: 0.4.31

pymc      : 5.17.0
systole   : 0.3.0
sys       : 3.12.7 (main, Oct  1 2024, 15:17:55) [GCC 11.4.0]
numpy     : 1.26.0
arviz     : 0.20.0
matplotlib: 3.9.2
IPython   : 8.28.0
pyhgf     : 0.2.0

Watermark: 2.5.0