nemos.simulation.regress_filter#

nemos.simulation.regress_filter(coupling_filters, eval_basis)[source]#

Approximate scipy.stats.gamma based filters with basis function.

Find the Ordinary Least Squares weights for representing the filters in terms of basis functions.

Parameters:
  • coupling_filters (NDArray) – The coupling filters. Shape (window_size, n_neurons_receiver, n_neurons_sender)

  • eval_basis (NDArray) – The evaluated basis function, shape (window_size, n_basis_funcs)

Returns:

The weights for each neuron. Shape (n_basis_funcs, n_neurons_receiver, n_neurons_sender)

Return type:

weights

Raises:
  • ValueError – If eval_basis is not two-dimensional.

  • ValueError – If coupling_filters is not three-dimensional.

  • ValueError – If window_size differs between eval_basis and coupling_filters.

Examples

>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.simulation import regress_filter, difference_of_gammas
>>> from nemos.basis import RaisedCosineLogEval
>>> filter_duration = 100
>>> n_basis_funcs = 20
>>> filter_bank = difference_of_gammas(filter_duration).reshape(filter_duration, 1, 1)
>>> _, basis = RaisedCosineLogEval(10).evaluate_on_grid(filter_duration)
>>> weights = regress_filter(filter_bank, basis)[0, 0]
>>> print("Weights shape:", weights.shape)
Weights shape: (10,)
>>> _ = plt.plot(filter_bank[:, 0, 0], label=f"True filter")
>>> _ = plt.plot(basis.dot(weights), "--", label=f"Approx. filter")
>>> _ = plt.legend()
>>> _ = plt.title("True vs. Approximated Filters")
>>> _ = plt.show()