Download
Download this notebook: glm_for_classification.ipynb!
Fit GLM for Classification#
The ClassifierGLM models categorical or discrete outcomes such as behavioral choices.
Key differences from standard GLM:
predictreturns predicted class labelspredict_probareturns (log-)probabilities for each classset_classesset class labels. By Default, classes are assumed to be{0, ..., n_classes - 1}however, this assumption can be overridden by callingset_classespassing an array of class labels.The
observation_modelparameter cannot be set at model initialization, sinceCategoricalObservationsis the only compatible observation model.
Generate Synthetic Data#
In this example, we simulate categorical choice data to demonstrate the classifier. With real data, you would replace this section by loading your experimental observations.
import jax
import numpy as np
import nemos as nmo
np.random.seed(200)
n_samples, n_features, n_classes = 1000, 5, 3
X = np.random.randn(n_samples, n_features)
# simulate categorical choices using known coefficients
true_coef = 2 * np.random.randn(n_features, n_classes)
true_intercept = np.zeros(n_classes)
model = nmo.glm.ClassifierGLM(n_classes)
model.coef_ = true_coef
model.intercept_ = true_intercept
# set classes
model.set_classes(np.arange(n_classes))
true_choice, _ = model.simulate(jax.random.PRNGKey(124), X)
Fit the Model and Predict Choices#
model = nmo.glm.ClassifierGLM(n_classes)
train_samples = 500
model.fit(X[:train_samples], true_choice[:train_samples])
# predict
predicted_choice = model.predict(X)
# get class probabilities
probs = model.predict_proba(X)
print(f"Probability shape: {probs.shape}") # (n_samples, n_classes)
Probability shape: (1000, 3)
Visualize Results#
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
cm = confusion_matrix(true_choice[train_samples:], predicted_choice[train_samples:])
disp = ConfusionMatrixDisplay(cm)
disp.plot(text_kw=dict(fontsize=15))
plt.title("Confusion Matrix", fontsize=20)
plt.xlabel(disp.ax_.get_xlabel(), fontsize=15)
plt.ylabel(disp.ax_.get_ylabel(), fontsize=15)
plt.show()