Show code cell source
%matplotlib inline
import warnings
# Ignore the first specific warning
warnings.filterwarnings(
"ignore",
message="plotting functions contained within `_documentation_utils` are intended for nemos's documentation.",
category=UserWarning,
)
# Ignore the second specific warning
warnings.filterwarnings(
"ignore",
message="Ignoring cached namespace 'core'",
category=UserWarning,
)
warnings.filterwarnings(
"ignore",
message=(
"invalid value encountered in div "
),
category=RuntimeWarning,
)
FeaturePytree example#
This small example notebook shows how to use our custom FeaturePytree objects instead of arrays to represent the design matrix. It will show that these two representations are equivalent.
This demo will fit the Poisson-GLM to some synthetic data. We will first show the simple case, with a single neuron receiving some input. We will then show a two-neuron system, to demonstrate how FeaturePytree can make it easier to separate examine separate types of inputs.
First, however, let’s briefly discuss FeaturePytrees
.
import jax
import jax.numpy as jnp
import numpy as np
import nemos as nmo
np.random.seed(111)
FeaturePytrees#
A FeaturePytree is a custom NeMoS object used to represent design matrices, GLM coefficients, and other similar variables. It is a simple pytree, a dictionary with strings as keys and arrays as values. These arrays must all have the same number of elements along the first dimension, which represents the time points, but can have different numbers of elements along the other dimensions (and even different numbers of dimensions).
example_pytree = nmo.pytrees.FeaturePytree(feature_0=np.random.normal(size=(100, 1, 2)),
feature_1=np.random.normal(size=(100, 2)),
feature_2=np.random.normal(size=(100, 5)))
example_pytree
feature_0: shape (100, 1, 2), dtype float64
feature_1: shape (100, 2), dtype float64
feature_2: shape (100, 5), dtype float64
FeaturePytrees can be indexed into like dictionary, so we can grab a single one of their features:
example_pytree['feature_0'].shape
(100, 1, 2)
We can grab the number of time points by getting the length or using the
shape
attribute
print(len(example_pytree))
print(example_pytree.shape)
100
(100,)
We can also jointly index into the FeaturePytree’s leaves:
example_pytree[:10]
feature_0: shape (10, 1, 2), dtype float64
feature_1: shape (10, 2), dtype float64
feature_2: shape (10, 5), dtype float64
We can add new features after initialization, as long as they have the same number of time points.
example_pytree['feature_3'] = np.zeros((100, 2, 4))
However, if we try to add a new feature with the wrong number of time points, we’ll get an exception:
try:
example_pytree['feature_4'] = np.zeros((99, 2, 4))
except ValueError as e:
print(e)
All arrays must have same number of time points, 100
Similarly, if we try to add a feature that’s not an array:
try:
example_pytree['feature_4'] = "Strings are very predictive"
except ValueError as e:
print(e)
All values must be arrays of at least 1 dimension!
FeaturePytrees are intended to be used with jax.tree_util.tree_map, a useful function for performing computations on arbitrary pytrees, preserving their structure.
We can map lambda functions:
mapped = jax.tree_util.tree_map(lambda x: x**2, example_pytree)
print(mapped)
mapped['feature_1']
feature_0: shape (100, 1, 2), dtype float64
feature_1: shape (100, 2), dtype float64
feature_2: shape (100, 5), dtype float64
feature_3: shape (100, 2, 4), dtype float64
array([[5.45183002e+00, 1.15606483e+00],
[3.54531739e+00, 1.55395816e+00],
[2.02448321e-01, 1.03200557e+00],
[5.96191117e-01, 1.48684966e+00],
[6.10952753e-01, 3.04129026e-01],
[2.36076038e+00, 2.66317025e-01],
[5.59943626e-01, 1.90938078e+00],
[5.05106186e-01, 6.84611471e-03],
[2.56259393e-01, 2.41880785e+00],
[1.72284458e+00, 3.02600901e-01],
[1.03796246e+00, 6.18155410e-03],
[4.80876143e-01, 3.55152103e-01],
[1.45874142e-02, 7.38642441e-02],
[1.25311157e+00, 7.41861244e-01],
[8.11542614e-01, 8.13309176e-01],
[1.28633116e+00, 9.93578902e-02],
[3.45997654e-03, 1.22878629e-01],
[2.36001666e+00, 3.11904966e-02],
[2.15261893e+00, 2.33474397e+00],
[6.62998961e-01, 3.30439030e+00],
[6.69486087e-03, 8.04013247e-03],
[3.61061077e-02, 9.85351550e-04],
[4.51987003e-01, 1.14217081e+00],
[1.52720281e-01, 1.33155885e-01],
[2.81435381e+00, 3.56236727e-02],
[5.12260166e+00, 7.55248484e-01],
[3.56086895e+00, 9.67230545e-01],
[3.27970828e+00, 1.29043215e-01],
[1.61946543e-02, 2.08260854e-01],
[5.09534456e-03, 2.58362331e+00],
[1.45801179e-01, 4.87087146e-01],
[3.15617764e-03, 3.30445072e-01],
[3.41234180e+00, 1.40705753e+00],
[3.95771584e-01, 5.80184990e-05],
[2.04865618e-01, 6.48252939e-02],
[4.38347939e-03, 1.27892735e-01],
[4.73308144e-02, 1.87656627e-01],
[1.42017591e+00, 5.75539014e-02],
[4.91673517e-01, 4.45696191e-02],
[3.78509065e-01, 1.47044724e+00],
[1.38674014e-01, 1.90322132e+00],
[3.85825658e-01, 3.25337246e-01],
[1.59406146e-02, 2.07207541e-03],
[6.85010939e-01, 2.56408357e-02],
[1.75284457e+00, 1.16418291e+00],
[2.20364349e+00, 3.17292157e-03],
[2.38381565e-01, 2.03698029e-02],
[3.73722775e-01, 1.41370260e-01],
[3.21101442e-01, 1.16330709e-01],
[3.87783673e+00, 4.65788745e-01],
[1.40069075e+00, 3.59942720e-02],
[1.82969334e+00, 1.18764313e-01],
[7.54854859e-03, 1.97889909e+00],
[1.83973372e-01, 9.71110480e-03],
[1.97530629e-01, 8.57962728e-01],
[3.85370313e+00, 1.01884023e-01],
[3.21434145e-03, 1.78591423e+00],
[1.09249406e+00, 4.99992884e-01],
[1.09415779e-01, 3.42741401e-02],
[7.08140528e-03, 5.40930142e-02],
[1.44745594e+00, 9.81139281e-01],
[1.86987979e-01, 4.61590646e-04],
[1.10109916e+00, 1.26030420e+00],
[1.15222704e+00, 7.64796089e-01],
[5.83946680e-01, 6.71690866e+00],
[1.52963209e+00, 3.14285317e-01],
[1.32872586e-01, 2.91653213e-01],
[5.26652787e-01, 8.21635743e-05],
[9.48471483e-01, 1.93025655e+00],
[3.12346476e+00, 5.96762432e+00],
[2.80093035e+00, 3.69544883e-02],
[2.47971319e+00, 1.00992879e+00],
[8.56730019e-02, 1.00640253e+00],
[4.47497682e+00, 6.82710821e-01],
[1.40868807e-03, 1.39808960e+00],
[4.37884824e+00, 1.10675124e+00],
[1.55902889e+00, 4.78874515e-01],
[3.13648564e-01, 2.16344654e+00],
[2.65269372e-01, 1.06181600e+00],
[1.17902638e-01, 1.84137767e-01],
[7.57807154e-03, 4.43745213e-02],
[1.47401168e-03, 1.40032894e+00],
[2.73849655e-01, 5.40856687e+00],
[3.92893202e-02, 2.91925413e-01],
[1.40287820e+00, 2.46812577e+00],
[4.26857463e-01, 1.33300314e+00],
[4.30655331e-03, 1.12071288e+00],
[3.26776277e+00, 9.95189833e+00],
[6.58230551e-01, 1.99168858e-01],
[7.28313356e-01, 1.04781500e+00],
[2.43742610e-01, 5.11278378e-01],
[8.21537461e+00, 8.79373531e-01],
[9.78346048e-01, 3.19456819e+00],
[3.94546914e+00, 9.95753401e-01],
[1.05538878e+00, 1.64826127e-01],
[8.97150940e-01, 3.78314870e+00],
[8.27701146e-02, 4.64614023e-02],
[1.19659435e-01, 3.30443656e+00],
[6.68192017e-01, 8.40939003e-03],
[2.52590088e+00, 8.33983949e-02]])
Or functions from jax or numpy that operate on arrays:
mapped = jax.tree_util.tree_map(jnp.exp, example_pytree)
print(mapped)
mapped['feature_1']
feature_0: shape (100, 1, 2), dtype float32
feature_1: shape (100, 2), dtype float32
feature_2: shape (100, 5), dtype float32
feature_3: shape (100, 2, 4), dtype float32
Array([[ 0.09681867, 0.34122792],
[ 0.15214804, 3.4784214 ],
[ 1.5682222 , 0.3620848 ],
[ 2.1643803 , 0.29541788],
[ 0.45765728, 1.7358183 ],
[ 0.21513778, 0.5968681 ],
[ 0.47317317, 3.9820764 ],
[ 2.0354323 , 0.9205893 ],
[ 1.6590095 , 0.21113622],
[ 3.7157173 , 0.5768969 ],
[ 0.36102632, 1.0817963 ],
[ 0.49984744, 1.8147476 ],
[ 1.1283748 , 1.3122979 ],
[ 0.32646757, 0.42260653],
[ 2.461711 , 0.40582365],
[ 3.108578 , 0.729635 ],
[ 0.94287497, 0.7043073 ],
[ 0.21518984, 0.838108 ],
[ 0.23057465, 0.21697202],
[ 2.2574763 , 6.1582847 ],
[ 0.92143583, 0.9142358 ],
[ 0.82694584, 0.96909726],
[ 0.5105331 , 0.3434465 ],
[ 1.4781547 , 1.4403776 ],
[ 0.18682113, 1.2077297 ],
[ 0.10400496, 0.4193496 ],
[ 6.5997148 , 0.37400776],
[ 6.11654 , 1.4322202 ],
[ 1.1357102 , 0.6335882 ],
[ 0.93110645, 4.989648 ],
[ 0.6826048 , 0.49762112],
[ 0.94536906, 0.5627929 ],
[ 0.15766977, 0.30538118],
[ 0.5330693 , 1.0076461 ],
[ 0.6359592 , 1.289956 ],
[ 0.93593633, 1.4299235 ],
[ 1.2430356 , 1.5421746 ],
[ 3.2927113 , 1.271127 ],
[ 0.49599257, 1.2350546 ],
[ 0.5405161 , 3.3622823 ],
[ 1.4511982 , 3.9732041 ],
[ 1.8610646 , 1.7689452 ],
[ 1.1345727 , 0.9555004 ],
[ 0.4370735 , 0.8520351 ],
[ 0.26608208, 2.941657 ],
[ 4.412615 , 0.9452284 ],
[ 1.6294513 , 0.8669944 ],
[ 0.54262936, 0.6866076 ],
[ 0.5674185 , 1.4064558 ],
[ 7.165102 , 0.5053585 ],
[ 0.30620277, 1.208913 ],
[ 3.867706 , 0.7084881 ],
[ 1.0907683 , 0.24494208],
[ 1.5356 , 1.103564 ],
[ 1.5596231 , 2.525055 ],
[ 0.14042453, 0.7267355 ],
[ 1.0583332 , 0.26279497],
[ 2.8440366 , 0.49307117],
[ 1.3920543 , 0.83099395],
[ 1.0877932 , 1.2618501 ],
[ 3.3304338 , 0.37138176],
[ 0.6489361 , 1.0217171 ],
[ 0.35017133, 0.32542193],
[ 0.341838 , 0.41705957],
[ 2.1471987 , 0.07489263],
[ 0.2903167 , 0.57085985],
[ 0.6945321 , 0.58271956],
[ 0.48398155, 1.0091056 ],
[ 2.6482394 , 0.24924058],
[ 0.17078793, 11.506039 ],
[ 5.3313155 , 1.2119559 ],
[ 4.8293433 , 0.36606216],
[ 1.3400401 , 2.7269835 ],
[ 0.12058334, 0.43768176],
[ 1.0382457 , 0.3065396 ],
[ 0.12336969, 2.8634365 ],
[ 3.4854972 , 1.9977221 ],
[ 0.57118434, 0.22972646],
[ 1.6737106 , 0.3568483 ],
[ 0.709376 , 0.6510865 ],
[ 0.9166293 , 0.81005543],
[ 0.9623348 , 0.30624956],
[ 1.6875985 , 0.09772161],
[ 0.82019323, 0.58257276],
[ 3.2688282 , 4.811587 ],
[ 0.5203033 , 0.31519696],
[ 0.9364826 , 0.34692827],
[ 0.16403173, 0.04265278],
[ 2.250864 , 0.64000237],
[ 0.42595875, 2.7832751 ],
[ 0.61036205, 2.0442631 ],
[17.570957 , 0.39150804],
[ 2.6888504 , 0.16740564],
[ 0.13719933, 2.71251 ],
[ 2.7935724 , 1.5007843 ],
[ 0.38783297, 6.993854 ],
[ 0.7499881 , 1.2405429 ],
[ 1.413287 , 0.16238083],
[ 0.44156495, 0.91237634],
[ 4.900358 , 1.3348085 ]], dtype=float32)
We can change the dimensionality of our pytree:
mapped = jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=-1), example_pytree)
print(mapped)
mapped['feature_1']
feature_0: shape (100, 1), dtype float32
feature_1: shape (100,), dtype float32
feature_2: shape (100,), dtype float32
feature_3: shape (100, 2), dtype float32
Array([-1.70506001e+00, -3.18161368e-01, -2.82967091e-01, -2.23615140e-01,
-1.15077883e-01, -1.02626789e+00, 3.16754788e-01, 3.13983470e-01,
-5.24515510e-01, 3.81240010e-01, -4.70090777e-01, -4.87529933e-02,
1.96279079e-01, -9.90369201e-01, -4.89979982e-04, 4.09477264e-01,
-2.04681024e-01, -8.56421471e-01, -1.49758375e+00, 1.31602287e+00,
-8.57444555e-02, -1.10703193e-01, -8.70511889e-01, 3.77849877e-01,
-7.44430661e-01, -1.56618345e+00, 4.51773822e-01, 1.08511114e+00,
-1.64548904e-01, 7.67991841e-01, -5.39877772e-01, -3.15511703e-01,
-1.51672351e+00, -3.10743392e-01, -9.90063548e-02, 1.45706534e-01,
3.25374991e-01, 7.15807617e-01, -2.45039582e-01, 2.98694551e-01,
8.75981212e-01, 5.95766068e-01, 4.03680503e-02, -4.93890733e-01,
-1.22488678e-01, 7.14069366e-01, 1.72760263e-01, -4.93660539e-01,
-1.12792626e-01, 6.43367529e-01, -4.96893108e-01, 5.04019797e-01,
-6.59925580e-01, 2.63733059e-01, 6.85353518e-01, -1.14113891e+00,
-6.39842987e-01, 1.69061333e-01, 7.28239045e-02, 1.58365071e-01,
1.06288910e-01, -2.05468193e-01, -1.08598280e+00, -9.73972321e-01,
-9.13767934e-01, -8.98697257e-01, -4.52283084e-01, -3.58322024e-01,
-2.07720846e-01, 3.37769687e-01, 9.32916760e-01, 2.84879208e-01,
6.47947788e-01, -1.47083867e+00, -5.72437942e-01, -5.20273685e-01,
9.70309138e-01, -1.01545465e+00, -2.57700711e-01, -3.86241138e-01,
-1.48852363e-01, -6.10873938e-01, -9.01163042e-01, -3.69258285e-01,
1.37772918e+00, -9.03950453e-01, -5.62130809e-01, -2.48117924e+00,
1.82515368e-01, 8.51078033e-02, 1.10667199e-01, 9.64249134e-01,
-3.99110883e-01, -4.94223028e-01, 7.16654539e-01, 4.98925626e-01,
-3.60744521e-02, -7.35946298e-01, -4.54566419e-01, 9.39048052e-01], dtype=float32)
Or the number of time points:
mapped = jax.tree_util.tree_map(lambda x: x[::10], example_pytree)
print(mapped)
mapped['feature_1']
feature_0: shape (10, 1, 2), dtype float64
feature_1: shape (10, 2), dtype float64
feature_2: shape (10, 5), dtype float64
feature_3: shape (10, 2, 4), dtype float64
array([[-2.33491542, -1.07520455],
[-1.01880443, 0.07862286],
[-0.08182213, -0.08966679],
[-0.38183921, -0.69791629],
[ 0.3723896 , 1.37957288],
[-1.18350782, 0.18972156],
[ 1.20310263, -0.99052475],
[ 1.67359803, 0.1922355 ],
[-0.08705212, -0.21065261],
[-0.49370296, 0.71503733]])
If we map something whose output cannot be a FeaturePytree (because its values are scalars or non-arrays), we return a dictionary of arrays instead:
print(jax.tree_util.tree_map(jnp.mean, example_pytree))
print(jax.tree_util.tree_map(lambda x: x.shape, example_pytree))
import matplotlib.pyplot as plt
import pynapple as nap
nap.nap_config.suppress_conversion_warnings = True
{'feature_0': Array(0.01142166, dtype=float32), 'feature_1': Array(-0.16257457, dtype=float32), 'feature_2': Array(0.00767873, dtype=float32), 'feature_3': Array(0., dtype=float32)}
{'feature_0': (100, 1, 2), 'feature_1': (100, 2), 'feature_2': (100, 5), 'feature_3': (100, 2, 4)}
FeaturePytrees and GLM#
These properties make FeaturePytrees useful for representing design matrices
and similar objects for the GLM
.
First, let’s get our dataset and do some initial exploration of it. To do so, we’ll use pynapple to stream data from the DANDI archive.
Attention
We need some additional packages for this portion, which you can install
with pip install dandi pynapple
io = nmo.fetch.download_dandi_data(
"000582",
"sub-11265/sub-11265_ses-07020602_behavior+ecephys.nwb",
)
nwb = nap.NWBFile(io.read(), lazy_loading=False)
print(nwb)
07020602
┍━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━┑
│ Keys │ Type │
┝━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━┥
│ units │ TsGroup │
│ ElectricalSeriesLFP │ Tsd │
│ SpatialSeriesLED2 │ TsdFrame │
│ SpatialSeriesLED1 │ TsdFrame │
│ ElectricalSeries │ Tsd │
┕━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━┙
This data set has cells that are tuned for head direction and 2d position. Let’s compute some simple tuning curves to see if we can find a cell that looks tuned for both.
tc, binsxy = nap.compute_2d_tuning_curves(nwb['units'], nwb['SpatialSeriesLED1'].dropna(), 20)
fig, axes = plt.subplots(3, 3, figsize=(9, 9))
for i, ax in zip(tc.keys(), axes.flatten()):
ax.imshow(tc[i], origin="lower", aspect="auto")
ax.set_title("Unit {}".format(i))
axes[-1,-1].remove()
plt.tight_layout()
# compute head direction.
diff = nwb['SpatialSeriesLED1'].values-nwb['SpatialSeriesLED2'].values
head_dir = np.arctan2(*diff.T)
head_dir = nap.Tsd(nwb['SpatialSeriesLED1'].index, head_dir)
tune_head = nap.compute_1d_tuning_curves(nwb['units'], head_dir.dropna(), 30)
fig, axes = plt.subplots(3, 3, figsize=(9, 9), subplot_kw={'projection': 'polar'})
for i, ax in zip(tune_head.columns, axes.flatten()):
ax.plot(tune_head.index, tune_head[i])
ax.set_title("Unit {}".format(i))
axes[-1,-1].remove()
/home/docs/checkouts/readthedocs.org/user_builds/nemos/envs/stable/lib/python3.11/site-packages/pynapple/process/tuning_curves.py:269: RuntimeWarning: invalid value encountered in divide
count = count / occupancy


Show code cell source
# save image for thumbnail
from pathlib import Path
import os
root = os.environ.get("READTHEDOCS_OUTPUT")
if root:
path = Path(root) / "html/_static/thumbnails/how_to_guide"
# if local store in ../_build/html/...
else:
path = Path("../_build/html/_static/thumbnails/how_to_guide")
# make sure the folder exists if run from build
if root or Path("../assets/stylesheets").exists():
path.mkdir(parents=True, exist_ok=True)
if path.exists():
fig.savefig(path / "plot_07_glm_pytree.svg")
Okay, let’s use unit number 7.
Now let’s set up our design matrix. First, let’s fit the head direction by
itself. Head direction is a circular variable (pi and -pi are adjacent to
each other), so we need to use a basis that has this property as well.
CyclicBSplineEval
is one such basis.
Let’s create our basis and then arrange our data properly.
unit_no = 7
spikes = nwb['units'][unit_no]
basis = nmo.basis.CyclicBSplineEval(10, order=5)
x = np.linspace(-np.pi, np.pi, 100)
plt.figure()
plt.plot(x, basis.compute_features(x))
# Find the interval on which head_dir has no NaNs
head_dir = head_dir.dropna()
# Grab the second (of two), since the first one is really short
valid_data= head_dir.time_support.loc[[1]]
head_dir = head_dir.restrict(valid_data)
# Count spikes at the same rate as head direction, over the same epoch
spikes = spikes.count(bin_size=1/head_dir.rate, ep=valid_data)
# the time points for spike are in the middle of these bins (whereas for
# head_dir they're at the ends), so use interpolate to shift head_dir to the
# center.
head_dir = head_dir.interpolate(spikes)
X = nmo.pytrees.FeaturePytree(head_direction=basis.compute_features(head_dir))

Now we’ll fit our GLM and then see what our head direction tuning looks like:
model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.001)
model.fit(X, spikes)
print(model.coef_['head_direction'])
bs_vis = basis.compute_features(x)
tuning = jnp.einsum('b, tb->t', model.coef_['head_direction'], bs_vis)
plt.figure()
plt.polar(x, tuning)
[-0.24226123 -0.6421874 -0.88985455 -1.005565 -0.69316936 -0.25309175
0.30488047 1.1731167 1.378537 0.6832005 ]
[<matplotlib.lines.Line2D at 0x7f23bdc53990>]

This looks like a smoothed version of our tuning curve, like we’d expect!
For a more direct comparison, we can plot the tuning function based on the model predicted firing rates with that estimated from the counts.
# predict rates and convert back to pynapple
rates_nap = nap.TsdFrame(t=head_dir.t, d=np.asarray(model.predict(X)))
# compute tuning function
tune_head_model = nap.compute_1d_tuning_curves_continuous(rates_nap, head_dir, 30)
# compare model prediction with data
fig, ax = plt.subplots(1, 1, subplot_kw={'projection': 'polar'})
ax.plot(tune_head[7], label="counts")
# multiply by the sampling rate for converting to spike/sec.
ax.plot(tune_head_model * rates_nap.rate, label="model")
# Let's compare this to using arrays, to see what it looks like:
model = nmo.glm.GLM()
model.fit(X['head_direction'], spikes)
model.coef_
Array([ 0.05909842, -0.38455936, -1.3328478 , -1.8809333 , -1.0995328 ,
0.02065016, 0.65151244, 1.3542147 , 1.2034482 , 1.0014157 ], dtype=float32)

We can see that the solution is identical, as is the way of interacting with the GLM object.
However, with a single type of feature, it’s unclear why exactly this is helpful. Let’s add a feature for the animal’s position in space. For this feature, we need a 2d basis. Let’s use some raised cosine bumps and organize our data similarly.
pos_basis = nmo.basis.RaisedCosineLinearEval(10) * nmo.basis.RaisedCosineLinearEval(10)
spatial_pos = nwb['SpatialSeriesLED1'].restrict(valid_data)
X['spatial_position'] = pos_basis.compute_features(*spatial_pos.values.T)
Running the GLM is identical to before, but we can see that our coef_ FeaturePytree now has two separate keys, one for each feature type.
model = nmo.glm.GLM(solver_name="LBFGS")
model.fit(X, spikes)
model.coef_
{'head_direction': Array([-0.09402025, -0.42544287, -1.3601326 , -2.2187333 , -1.0918001 ,
0.4141025 , 0.6911496 , 1.0892535 , 1.446903 , 0.54534346], dtype=float32),
'spatial_position': Array([ 1.00051284e-01, 1.00856972e+00, 6.81885242e-01, -3.09586376e-01,
-7.17706263e-01, 5.74160218e-02, 1.87587768e-01, -9.45991635e-01,
-1.29657471e+00, -1.03216279e+00, -5.22445261e-01, -9.29365605e-02,
-2.48249292e-01, -1.72176972e-01, 4.37599242e-01, 1.12689054e+00,
6.95358634e-01, -1.20578337e+00, -5.93989789e-01, 1.08557773e+00,
-9.13282752e-01, -3.77993882e-01, -1.15488447e-01, -2.82066226e-01,
5.06545424e-01, 6.60092711e-01, -1.07800938e-01, -1.30765152e+00,
-7.75339156e-02, 2.34187341e+00, -1.19679663e-02, 1.35469389e+00,
7.89580643e-01, -2.85468996e-01, -4.47471559e-01, -9.65955973e-01,
-1.13836372e+00, -1.26636922e+00, -7.48979211e-01, 6.96704984e-01,
4.93707880e-02, 4.03982520e-01, 3.99184138e-01, 7.14728773e-01,
-2.25699648e-01, -7.26205468e-01, 4.38113928e-01, 7.43483663e-01,
-3.19963396e-01, -4.08950686e-01, -1.82423159e-01, -8.19022596e-01,
8.89168009e-02, 5.65515220e-01, -7.43878186e-02, 7.43348002e-01,
1.43106616e+00, 8.81953478e-01, -7.00419009e-01, -9.95423317e-01,
9.42733943e-01, 3.29298526e-02, -6.00966573e-01, -7.87943304e-01,
-6.21314704e-01, 7.80552998e-02, 9.34794545e-04, -7.95975327e-01,
-1.39013267e+00, -8.67521346e-01, 1.43283892e+00, 1.86127007e-01,
-9.47542965e-01, 6.25953823e-02, 7.32932866e-01, -3.95648897e-01,
-8.98551702e-01, -6.18883252e-01, -7.80938148e-01, -7.19587743e-01,
5.49517214e-01, -4.59617853e-01, -5.19327939e-01, 1.41857791e+00,
1.49307621e+00, -8.48016813e-02, -6.07651711e-01, 5.51722765e-01,
6.18988872e-01, -7.40070522e-01, -1.60670206e-01, -6.20351851e-01,
-1.55294389e-01, 8.26539397e-01, -1.40935227e-01, -3.99333954e-01,
4.05770004e-01, 1.28821862e+00, 9.58565950e-01, -5.60608864e-01], dtype=float32)}
Let’s visualize our tuning. Head direction looks pretty much the same (though the values are slightly different, as we can see when printing out the coefficients).
bs_vis = basis.compute_features(x)
tuning = jnp.einsum('b,nb->n', model.coef_['head_direction'], bs_vis)
print(model.coef_['head_direction'])
plt.figure()
plt.polar(x, tuning.T)
[-0.09402025 -0.42544287 -1.3601326 -2.2187333 -1.0918001 0.4141025
0.6911496 1.0892535 1.446903 0.54534346]
[<matplotlib.lines.Line2D at 0x7f23c49ec5d0>]

And the spatial tuning again looks like a smoothed version of our earlier tuning curves.
_, _, pos_bs_vis = pos_basis.evaluate_on_grid(50, 50)
pos_tuning = jnp.einsum('b,ijb->ij', model.coef_['spatial_position'], pos_bs_vis)
plt.figure()
plt.imshow(pos_tuning)
<matplotlib.image.AxesImage at 0x7f23b3b5c510>

We could do all this with matrices as well, but we have to pay attention to indices in a way that is annoying:
X_mat = nmo.utils.pynapple_concatenate_jax([X['head_direction'], X['spatial_position']], -1)
model = nmo.glm.GLM()
model.fit(X_mat, spikes)
model.coef_[..., :basis.n_basis_funcs]
Array([-0.14071149, -1.1683542 , -0.874997 , -3.0199964 , -0.82603216,
0.01578746, 0.4781742 , 0.9998034 , 1.1875935 , 0.47495434], dtype=float32)