nemos.simulation.simulate_recurrent#

nemos.simulation.simulate_recurrent(coupling_coef, feedforward_coef, intercepts, random_key, feedforward_input, coupling_basis_matrix, init_y, inverse_link_function=<PjitFunction of <function softplus>>)[source]#

Simulate neural activity using the GLM as a recurrent network.

This function projects neural activity into the future, employing the fitted parameters of the GLM. It is capable of simulating activity based on a combination of historical activity and external feedforward inputs like convolved currents, light intensities, etc.

Parameters:
  • coupling_coef (NDArray) – Coefficients for the coupling (recurrent connections) between neurons. Expected shape: (n_neurons (receiver), n_neurons (sender), n_basis_coupling).

  • feedforward_coef (NDArray) – Coefficients for the feedforward inputs to each neuron. Expected shape: (n_neurons, n_basis_input).

  • intercepts (NDArray) – Bias term for each neuron. Expected shape: (n_neurons,).

  • random_key (Array) – jax.random.key for seeding the simulation.

  • feedforward_input (Union[NDArray, Array]) – External input matrix to the model, representing factors like convolved currents, light intensities, etc. When not provided, the simulation is done with coupling-only. Expected shape: (n_time_bins, n_neurons, n_basis_input).

  • init_y (Union[NDArray, Array]) – Initial observation (spike counts for PoissonGLM) matrix that kickstarts the simulation. Expected shape: (window_size, n_neurons).

  • coupling_basis_matrix (Union[NDArray, Array]) – Basis matrix for coupling, representing between-neuron couplings and auto-correlations. Expected shape: (window_size, n_basis_coupling).

  • inverse_link_function (Callable) – The inverse link function for the observation model.

Returns:

  • simulated_activity – Simulated activity (spike counts for PoissonGLMs) for each neuron over time. Shape, (n_time_bins, n_neurons).

  • firing_rates – Simulated rates for each neuron over time. Shape, (n_time_bins, n_neurons,).

Raises:
  • ValueError – If there’s an inconsistency between the number of neurons in model parameters.

  • ValueError – If the number of neurons in input arguments doesn’t match with model parameters.

Examples

>>> import numpy as np
>>> import jax
>>> import matplotlib.pyplot as plt
>>> from nemos.simulation import simulate_recurrent
>>>
>>> n_neurons = 2
>>> coupling_duration = 100
>>> feedforward_input = np.random.normal(size=(1000, n_neurons, 1))
>>> coupling_basis = np.random.normal(size=(coupling_duration, 10))
>>> coupling_coef = np.random.normal(size=(n_neurons, n_neurons, 10))
>>> intercept = -9 * np.ones(n_neurons)
>>> init_spikes = np.zeros((coupling_duration, n_neurons))
>>> random_key = jax.random.key(123)
>>> spikes, rates = simulate_recurrent(
...     coupling_coef=coupling_coef,
...     feedforward_coef=np.ones((n_neurons, 1)),
...     intercepts=intercept,
...     random_key=random_key,
...     feedforward_input=feedforward_input,
...     coupling_basis_matrix=coupling_basis,
...     init_y=init_spikes
... )
>>> _ = plt.figure()
>>> _ = plt.plot(rates[:, 0], label="Neuron 0 rate")
>>> _ = plt.plot(rates[:, 1], label="Neuron 1 rate")
>>> _ = plt.legend()
>>> _ = plt.title("Simulated firing rates")
>>> _ = plt.show()