nemos.simulation.regress_filter#
- nemos.simulation.regress_filter(coupling_filters, eval_basis)[source]#
Approximate scipy.stats.gamma based filters with basis function.
Find the Ordinary Least Squares weights for representing the filters in terms of basis functions.
- Parameters:
coupling_filters (NDArray) – The coupling filters. Shape
(window_size, n_neurons_receiver, n_neurons_sender)
eval_basis (NDArray) – The evaluated basis function, shape
(window_size, n_basis_funcs)
- Returns:
The weights for each neuron. Shape
(n_basis_funcs, n_neurons_receiver, n_neurons_sender)
- Return type:
weights
- Raises:
ValueError – If eval_basis is not two-dimensional.
ValueError – If coupling_filters is not three-dimensional.
ValueError – If window_size differs between eval_basis and coupling_filters.
Examples
>>> import numpy as np >>> import matplotlib.pyplot as plt >>> from nemos.simulation import regress_filter, difference_of_gammas >>> from nemos.basis import RaisedCosineLogEval >>> filter_duration = 100 >>> n_basis_funcs = 20 >>> filter_bank = difference_of_gammas(filter_duration).reshape(filter_duration, 1, 1) >>> _, basis = RaisedCosineLogEval(10).evaluate_on_grid(filter_duration) >>> weights = regress_filter(filter_bank, basis)[0, 0] >>> print("Weights shape:", weights.shape) Weights shape: (10,) >>> _ = plt.plot(filter_bank[:, 0, 0], label=f"True filter") >>> _ = plt.plot(basis.dot(weights), "--", label=f"Approx. filter") >>> _ = plt.legend() >>> _ = plt.title("True vs. Approximated Filters") >>> _ = plt.show()