pyhgf.distribution.hgf_logp#
- pyhgf.distribution.hgf_logp(mean_1: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | bool | int | float | complex = 0.0, mean_2: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | bool | int | float | complex = 0.0, mean_3: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | bool | int | float | complex = 0.0, precision_1: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | bool | int | float | complex = 1.0, precision_2: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | bool | int | float | complex = 1.0, precision_3: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | bool | int | float | complex = 1.0, tonic_volatility_1: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | bool | int | float | complex = -3.0, tonic_volatility_2: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | bool | int | float | complex = -3.0, tonic_volatility_3: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | bool | int | float | complex = -3.0, tonic_drift_1: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | bool | int | float | complex = 0.0, tonic_drift_2: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | bool | int | float | complex = 0.0, tonic_drift_3: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | bool | int | float | complex = 0.0, volatility_coupling_1: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | bool | int | float | complex = 1.0, volatility_coupling_2: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | bool | int | float | complex = 1.0, input_precision: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | bool | int | float | complex = inf, response_function_parameters: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | bool | int | float | complex = array([1.]), vectorized_logp: ~typing.Callable = <PjitFunction of <function logp>>, input_data: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | bool | int | float | complex = nan, response_function_inputs: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | bool | int | float | complex = nan, time_steps: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | bool | int | float | complex = nan) Array [source]#
Compute log-probabilities of a batch of Hierarchical Gaussian Filters.
Hint
This function supports broadcasting along the first axis, which means that it can fit multiple models when input data are provided. When a network parameter is a float, this value will be used on all models. When a network parameter is an array, the size should match the number of input data, and different values will be used accordingly.
- Parameters:
- mean_1
The mean at the first level of the HGF.
- mean_2
The mean at the second level of the HGF.
- mean_3
The mean at the third level of the HGF. The value of this parameter will be ignored when using a two-level HGF (n_levels=2).
- precision_1
The precision at the first level of the HGF.
- precision_2
The precision at the second level of the HGF.
- precision_3
The precision at the third level of the HGF. The value of this parameter will be ignored when using a two-level HGF (n_levels=2).
- tonic_volatility_1
The tonic volatility at the first level of the HGF. This parameter represents the tonic part of the variance (the part that is not inherited from parent nodes).
- tonic_volatility_2
The tonic volatility at the second level of the HGF. This parameter represents the tonic part of the variance (the part that is not inherited from parent nodes).
- tonic_volatility_3
The tonic volatility at the third level of the HGF. This parameter represents the tonic part of the variance (the part that is not inherited from parent nodes). The value of this parameter will be ignored when using a two-level HGF (n_levels=2).
- tonic_drift_1
The tonic drift at the first level of the HGF. This parameter represents the drift of the random walk.
- tonic_drift_2
The tonic drift at the second level of the HGF. This parameter represents the drift of the random walk.
- tonic_drift_3
The tonic drift at the first level of the HGF. This parameter represents the drift of the random walk. The value of this parameter will be ignored when using a two-level HGF (n_levels=2).
- volatility_coupling_1
The volatility coupling between the first and second levels of the HGF. This represents the phasic part of the variance (the part affected by the parent nodes). Defaults to 1.0.
- volatility_coupling_2
The volatility coupling between the second and third levels of the HGF. This represents the phasic part of the variance (the part affected by the parent nodes). Defaults to 1.0. The value of this parameter will be ignored when using a two-level HGF (n_levels=2).
- input_precision
The expected precision associated with the continuous or binary input, depending on the model type. The default is np.inf.
- response_function_parameters
An array list of additional parameters that will be passed to the response function. This can include values over which inference is performed in a PyMC model (e.g. the inverse temperature of a binary softmax).
- vectorized_logp
A vectorized log probability function for a two or three-layered HGF.
- input_data
An array of input time series where the first dimension is the number of models to fit in parallel.
- response_function_inputs
An array of behavioural input passed to the response function where the first dimension is the number of models to fit in parallel.
- time_steps
An array of input time steps where the first dimension is the number of models to fit in parallel.
- Returns:
- log_prob
The sum of the log probabilities (negative surprise).