"""Utility functions for coupling filter definition."""
from __future__ import annotations
from typing import Callable, Tuple, Union
import jax
import jax.numpy as jnp
import lazy_loader as lazy
import numpy as np
from numpy.typing import NDArray
from . import validation
from .pytrees import FeaturePytree
# Lazy load to avoid importing scipy.stats at module level
scipy = lazy.load("scipy")
[docs]
def difference_of_gammas(
ws: int,
upper_percentile: float = 0.99,
inhib_a: float = 1.0,
excit_a: float = 2.0,
inhib_b: float = 1.0,
excit_b: float = 2.0,
) -> NDArray:
r"""
Generate coupling filter as a Gamma pdf difference.
Parameters
----------
ws:
The window size of the filter.
upper_percentile:
Upper bound of the gamma range as a percentile. The gamma function
will be evaluated over the range [0, ppf(upper_percentile)].
inhib_a:
The ``a`` constant for the gamma pdf of the inhibitory part of the filter.
excit_a:
The ``a`` constant for the gamma pdf of the excitatory part of the filter.
inhib_b:
The ``b`` constant for the gamma pdf of the inhibitory part of the filter.
excit_b:
The ``a`` constant for the gamma pdf of the excitatory part of the filter.
Notes
-----
The probability density function of a gamma distribution is parametrized as
follows [1]_ :,
.. math::
p(x;\; a, b) = \frac{b^a x^{a-1} e^{-x}}{\Gamma(a)},
where :math:`\Gamma(a)` refers to the gamma function, see [1]_.
Returns
-------
filter:
The coupling filter.
Raises
------
ValueError:
If any of the Gamma parameters is lesser or equal to 0.
ValueError:
If the upper_percentile is not in [0, 1).
References
----------
.. [1] SciPy Docs - :obj:`scipy.stats.gamma`
Examples
--------
.. plot::
:include-source: True
:caption: Difference of Gammas.
>>> import matplotlib.pyplot as plt
>>> from nemos.simulation import difference_of_gammas
>>> coupling_duration = 100
>>> inhib_a, inhib_b = 1.0, 1.0
>>> excit_a, excit_b = 2.0, 2.0
>>> coupling_filter = difference_of_gammas(
... ws=coupling_duration,
... inhib_a=inhib_a,
... inhib_b=inhib_b,
... excit_a=excit_a,
... excit_b=excit_b
... )
>>> _ = plt.plot(coupling_filter)
>>> _ = plt.title("Coupling filter from difference of gammas")
>>> _ = plt.show()
"""
# check that the gamma parameters are positive (scipy returns
# nans otherwise but no exception is raised)
variables = {
"excit_a": excit_a,
"inhib_a": inhib_a,
"excit_b": excit_b,
"inhib_b": inhib_b,
}
for name, value in variables.items():
if value <= 0:
raise ValueError(f"Gamma parameter {name} must be >0.")
# check for valid percentile
if upper_percentile < 0 or upper_percentile >= 1:
raise ValueError(
f"upper_percentile should lie in the [0, 1) interval. {upper_percentile} provided instead!"
)
sts = scipy.stats
gm_inhibition = sts.gamma(a=inhib_a, scale=1 / inhib_b)
gm_excitation = sts.gamma(a=excit_a, scale=1 / excit_b)
# calculate upper bound for the evaluation
xmax = max(gm_inhibition.ppf(upper_percentile), gm_excitation.ppf(upper_percentile))
# equi-spaced sample covering the range
x = np.linspace(0, xmax, ws)
# compute difference of gammas & normalize
gamma_diff = gm_excitation.pdf(x) - gm_inhibition.pdf(x)
gamma_diff = gamma_diff / np.linalg.norm(gamma_diff, ord=2)
return gamma_diff
[docs]
def regress_filter(coupling_filters: NDArray, eval_basis: NDArray) -> NDArray:
"""Approximate scipy.stats.gamma based filters with basis function.
Find the Ordinary Least Squares weights for representing the filters in terms of basis functions.
Parameters
----------
coupling_filters:
The coupling filters. Shape ``(window_size, n_neurons_receiver, n_neurons_sender)``
eval_basis:
The evaluated basis function, shape ``(window_size, n_basis_funcs)``
Returns
-------
weights:
The weights for each neuron. Shape ``(n_basis_funcs, n_neurons_receiver, n_neurons_sender)``
Raises
------
ValueError
If eval_basis is not two-dimensional.
ValueError
If coupling_filters is not three-dimensional.
ValueError
If window_size differs between eval_basis and coupling_filters.
Examples
--------
.. plot::
:include-source: True
:caption: Least-squares approximate filter.
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.simulation import regress_filter, difference_of_gammas
>>> from nemos.basis import RaisedCosineLogEval
>>> filter_duration = 100
>>> n_basis_funcs = 20
>>> filter_bank = difference_of_gammas(filter_duration).reshape(filter_duration, 1, 1)
>>> _, basis = RaisedCosineLogEval(10).evaluate_on_grid(filter_duration)
>>> weights = regress_filter(filter_bank, basis)[0, 0]
>>> print("Weights shape:", weights.shape)
Weights shape: (10,)
>>> _ = plt.plot(filter_bank[:, 0, 0], label=f"True filter")
>>> _ = plt.plot(basis.dot(weights), "--", label=f"Approx. filter")
>>> _ = plt.legend()
>>> _ = plt.title("True vs. Approximated Filters")
>>> _ = plt.show()
"""
# check shapes
if eval_basis.ndim != 2:
raise ValueError(
"eval_basis must be a 2 dimensional array, "
"shape (window_size, n_basis_funcs). "
f"{eval_basis.ndim} dimensional array provided instead!"
)
if coupling_filters.ndim != 3:
raise ValueError(
"coupling_filters must be a 3 dimensional array, "
"shape (window_size, n_neurons, n_neurons). "
f"{coupling_filters.ndim} dimensional array provided instead!"
)
ws, n_neurons_receiver, n_neurons_sender = coupling_filters.shape
# check that window size matches
if eval_basis.shape[0] != ws:
raise ValueError(
"window_size mismatch. The window size of coupling_filters and eval_basis "
f"does not match. coupling_filters has a window size of {ws}; "
f"eval_basis has a window size of {eval_basis.shape[0]}."
)
# Reshape the coupling_filters for vectorized least-squares
filters_reshaped = coupling_filters.reshape(ws, -1)
# Solve the least squares problem for all filters at once
# (vecotrizing the features)
weights = np.linalg.lstsq(eval_basis, filters_reshaped, rcond=None)[0]
# Reshape back to the original dimensions
weights = np.transpose(
weights.reshape(-1, n_neurons_receiver, n_neurons_sender), axes=(1, 2, 0)
)
return weights
[docs]
def simulate_recurrent(
coupling_coef: NDArray,
feedforward_coef: NDArray,
intercepts: NDArray,
random_key: jax.Array,
feedforward_input: Union[NDArray, jnp.ndarray],
coupling_basis_matrix: Union[NDArray, jnp.ndarray],
init_y: Union[NDArray, jnp.ndarray],
inverse_link_function: Callable = jax.nn.softplus,
):
"""
Simulate neural activity using the GLM as a recurrent network.
This function projects neural activity into the future, employing the fitted
parameters of the GLM. It is capable of simulating activity based on a combination
of historical activity and external feedforward inputs like convolved currents, light
intensities, etc.
Parameters
----------
coupling_coef :
Coefficients for the coupling (recurrent connections) between neurons.
Expected shape: (n_neurons (receiver), n_neurons (sender), n_basis_coupling).
feedforward_coef :
Coefficients for the feedforward inputs to each neuron.
Expected shape: ``(n_neurons, n_basis_input)``.
intercepts :
Bias term for each neuron. Expected shape: ``(n_neurons,)``.
random_key :
jax.random.key for seeding the simulation.
feedforward_input :
External input matrix to the model, representing factors like convolved currents,
light intensities, etc. When not provided, the simulation is done with coupling-only.
Expected shape: ``(n_time_bins, n_neurons, n_basis_input)``.
init_y :
Initial observation (spike counts for PoissonGLM) matrix that kickstarts the simulation.
Expected shape: ``(window_size, n_neurons)``.
coupling_basis_matrix :
Basis matrix for coupling, representing between-neuron couplings
and auto-correlations. Expected shape: ``(window_size, n_basis_coupling)``.
inverse_link_function :
The inverse link function for the observation model.
Returns
-------
simulated_activity :
Simulated activity (spike counts for PoissonGLMs) for each neuron over time.
Shape, ``(n_time_bins, n_neurons)``.
firing_rates :
Simulated rates for each neuron over time. Shape, ``(n_time_bins, n_neurons,)``.
Raises
------
ValueError
If there's an inconsistency between the number of neurons in model parameters.
ValueError
If the number of neurons in input arguments doesn't match with model parameters.
Examples
--------
.. plot::
:include-source: True
:caption: Recurrently connected GLM simulations.
>>> import numpy as np
>>> import jax
>>> import matplotlib.pyplot as plt
>>> from nemos.simulation import simulate_recurrent
>>> np.random.seed(42)
>>> n_neurons = 2
>>> coupling_duration = 100
>>> feedforward_input = np.random.normal(size=(1000, n_neurons, 1))
>>> coupling_basis = np.random.normal(size=(coupling_duration, 10))
>>> coupling_coef = 0.5*np.random.normal(size=(n_neurons, n_neurons, 10))
>>> intercept = -9 * np.ones(n_neurons)
>>> init_spikes = np.zeros((coupling_duration, n_neurons))
>>> random_key = jax.random.key(123)
>>> spikes, rates = simulate_recurrent(
... coupling_coef=coupling_coef,
... feedforward_coef=np.ones((n_neurons, 1)),
... intercepts=intercept,
... random_key=random_key,
... feedforward_input=feedforward_input,
... coupling_basis_matrix=coupling_basis,
... init_y=init_spikes
... )
>>> _ = plt.figure()
>>> _ = plt.plot(rates[:, 0], label="Neuron 0 rate")
>>> _ = plt.plot(rates[:, 1], label="Neuron 1 rate")
>>> _ = plt.legend()
>>> _ = plt.title("Simulated firing rates")
>>> _ = plt.show()
"""
if isinstance(feedforward_input, FeaturePytree):
raise ValueError("simulate_recurrent works only with arrays, not pytrees.")
# convert to jnp.ndarray of floats
coupling_basis_matrix = jnp.asarray(coupling_basis_matrix, dtype=float)
coupling_coef = jnp.asarray(coupling_coef, dtype=float)
feedforward_coef = jnp.asarray(feedforward_coef, dtype=float)
intercepts = jnp.asarray(intercepts, dtype=float)
feedforward_input = jax.tree_util.tree_map(
lambda x: jnp.asarray(x, dtype=float), feedforward_input
)
init_y = jnp.asarray(init_y, dtype=float)
# check that n_neurons is consistent
n_neurons = intercepts.shape[0]
if (
feedforward_input.shape[1] != n_neurons
or feedforward_coef.shape[0] != n_neurons
or init_y.shape[1] != n_neurons
or coupling_coef.shape[0] != n_neurons
or coupling_coef.shape[1] != n_neurons
):
raise ValueError(
"The number of neurons provided in the inputs is inconsistent!"
)
# checks the input size
validation.check_tree_leaves_dimensionality(
feedforward_input,
expected_dim=3,
err_message="`feedforward_input` must be three-dimensional, with shape "
"(n_timebins, n_neurons, n_features) or pytree of the same shape.",
)
validation.check_tree_axis_consistency(
feedforward_coef,
feedforward_input,
axis_1=1,
axis_2=2,
err_message="Inconsistent number of features. "
f"spike basis coefficients has {jax.tree_util.tree_map(lambda p: p.shape[0], feedforward_coef)} features, "
f"X has {jax.tree_util.tree_map(lambda x: x.shape[2], feedforward_input)} features instead!",
)
validation.error_invalid_entry(feedforward_input)
# validate y
validation.check_tree_leaves_dimensionality(
init_y,
expected_dim=2,
err_message="`init_y` must be two-dimensional, with shape (n_timebins, ).",
)
n_basis = coupling_coef.shape[-1]
coupling_coef = coupling_coef.reshape(n_neurons, -1)
if coupling_basis_matrix.shape[1] * n_neurons != coupling_coef.shape[1]:
raise ValueError(
f"Inconsistent number of features. `coupling_basis_matrix` assumes "
f"{coupling_basis_matrix.shape[1]} basis functions for the coupling filters, "
f"`coupling_coef` assumes {n_basis} basis functions instead."
)
if init_y.shape[0] != coupling_basis_matrix.shape[0]:
raise ValueError(
"`init_y` and `coupling_basis_matrix`"
" should have the same window size! "
f"`init_y` window size: {init_y.shape[0]}, "
f"`coupling_basis_matrix` window size: {coupling_basis_matrix.shape[0]}"
)
subkeys = jax.random.split(random_key, num=feedforward_input.shape[0])
# Pre-compute feedforward contribution: (n_samples, n_neurons)
feed_forward_contrib = jnp.einsum("ik,tik->ti", feedforward_coef, feedforward_input)
# Pre-flip the basis to match convolution behavior (jnp.convolve flips the kernel)
coupling_basis_flipped = coupling_basis_matrix[::-1]
def scan_fn(
activity: jnp.ndarray, inputs: Tuple[jax.Array, jnp.ndarray]
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
"""Optimized scan over time steps.
Improvements over original:
- Direct iteration over feedforward input (no dynamic_slice)
- Simple einsum for convolution (no nested scan)
"""
key, ff_input = inputs
# Simple einsum convolution: activity (window, n_neurons) @ basis_flipped (window, n_basis)
# Flipping matches jnp.convolve behavior used in _tensor_convolve
# Result: (n_neurons, n_basis) -> flattened to (n_neurons * n_basis,)
conv_act = jnp.einsum("wn,wb->nb", activity, coupling_basis_flipped).reshape(-1)
# Predict firing rate
firing_rate = inverse_link_function(
coupling_coef.dot(conv_act) + ff_input + intercepts
)
# Simulate activity
new_act = jax.random.poisson(key, firing_rate)
activity = jnp.vstack((activity[1:], new_act))
return activity, (new_act, firing_rate)
# Iterate over (subkeys, feed_forward_contrib) together
_, outputs = jax.lax.scan(scan_fn, init_y, (subkeys, feed_forward_contrib))
simulated_activity, firing_rates = outputs
return simulated_activity, firing_rates