nemos.simulation.simulate_recurrent#
- nemos.simulation.simulate_recurrent(coupling_coef, feedforward_coef, intercepts, random_key, feedforward_input, coupling_basis_matrix, init_y, inverse_link_function=<PjitFunction of <function softplus>>)[source]#
Simulate neural activity using the GLM as a recurrent network.
This function projects neural activity into the future, employing the fitted parameters of the GLM. It is capable of simulating activity based on a combination of historical activity and external feedforward inputs like convolved currents, light intensities, etc.
- Parameters:
coupling_coef (NDArray) – Coefficients for the coupling (recurrent connections) between neurons. Expected shape: (n_neurons (receiver), n_neurons (sender), n_basis_coupling).
feedforward_coef (NDArray) – Coefficients for the feedforward inputs to each neuron. Expected shape:
(n_neurons, n_basis_input)
.intercepts (NDArray) – Bias term for each neuron. Expected shape:
(n_neurons,)
.random_key (
Array
) – jax.random.key for seeding the simulation.feedforward_input (
Union
[NDArray,Array
]) – External input matrix to the model, representing factors like convolved currents, light intensities, etc. When not provided, the simulation is done with coupling-only. Expected shape:(n_time_bins, n_neurons, n_basis_input)
.init_y (
Union
[NDArray,Array
]) – Initial observation (spike counts for PoissonGLM) matrix that kickstarts the simulation. Expected shape:(window_size, n_neurons)
.coupling_basis_matrix (
Union
[NDArray,Array
]) – Basis matrix for coupling, representing between-neuron couplings and auto-correlations. Expected shape:(window_size, n_basis_coupling)
.inverse_link_function (
Callable
) – The inverse link function for the observation model.
- Returns:
simulated_activity – Simulated activity (spike counts for PoissonGLMs) for each neuron over time. Shape,
(n_time_bins, n_neurons)
.firing_rates – Simulated rates for each neuron over time. Shape,
(n_time_bins, n_neurons,)
.
- Raises:
ValueError – If there’s an inconsistency between the number of neurons in model parameters.
ValueError – If the number of neurons in input arguments doesn’t match with model parameters.
Examples
>>> import numpy as np >>> import jax >>> import matplotlib.pyplot as plt >>> from nemos.simulation import simulate_recurrent >>> >>> n_neurons = 2 >>> coupling_duration = 100 >>> feedforward_input = np.random.normal(size=(1000, n_neurons, 1)) >>> coupling_basis = np.random.normal(size=(coupling_duration, 10)) >>> coupling_coef = np.random.normal(size=(n_neurons, n_neurons, 10)) >>> intercept = -9 * np.ones(n_neurons) >>> init_spikes = np.zeros((coupling_duration, n_neurons)) >>> random_key = jax.random.key(123) >>> spikes, rates = simulate_recurrent( ... coupling_coef=coupling_coef, ... feedforward_coef=np.ones((n_neurons, 1)), ... intercepts=intercept, ... random_key=random_key, ... feedforward_input=feedforward_input, ... coupling_basis_matrix=coupling_basis, ... init_y=init_spikes ... ) >>> _ = plt.figure() >>> _ = plt.plot(rates[:, 0], label="Neuron 0 rate") >>> _ = plt.plot(rates[:, 1], label="Neuron 1 rate") >>> _ = plt.legend() >>> _ = plt.title("Simulated firing rates") >>> _ = plt.show()