nemos.glm_hmm.GLMHMM#
- class nemos.glm_hmm.GLMHMM(n_states, observation_model='Bernoulli', inverse_link_function=None, regularizer='Ridge', regularizer_strength=1.0, dirichlet_initial_proba=None, dirichlet_transition_proba=None, solver_name=None, solver_kwargs=None, maxiter=1000, tol=1e-08, seed=Array([0, 123], dtype=uint32), hmm_initialization_funcs=None, model_initialization_funcs=None)[source]#
Bases:
BaseHMMGeneralized Linear Model with Hidden Markov Model (GLM-HMM).
This model combines a Generalized Linear Model (GLM) with a Hidden Markov Model (HMM) to capture state-dependent relationships between predictors and neural or behavioral responses. The GLM-HMM is suitable for modeling time series data where the relationship between inputs and outputs varies according to an underlying latent state that evolves over time following Markovian dynamics.
The model assumes that at each time step, the system is in one of
n_statesdiscrete hidden states. Each state has its own GLM parameters (coefficients and intercept), and transitions between states are governed by a transition probability matrix. The model is fitted using the Expectation-Maximization (EM) algorithm.Below is a table of the default inverse link function for the available observation models.
Observation Model
Default Inverse Link Function
Poisson
\(e^x\)
Gamma
\(1/x\)
Bernoulli
\(1 / (1 + e^{-x})\)
NegativeBinomial
\(e^x\)
Gaussian
\(x\)
Below is a table listing the default and available solvers for each regularizer.
Regularizer
Default Solver
Available Solvers
UnRegularized
LBFGS
GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient
Ridge
LBFGS
GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient
Lasso
ProximalGradient
ProximalGradient
ElasticNet
ProximalGradient
ProximalGradient
GroupLasso
ProximalGradient
ProximalGradient
- Parameters:
n_states (
int) – The number of hidden states in the HMM. Must be a positive integer.observation_model (
Union[Observations,Literal['Poisson','Gamma','Bernoulli','NegativeBinomial','Gaussian']]) – Observation model to use. The model describes the distribution of the response variable. Default is the Bernoulli model. Alternatives are “Poisson”, “Gamma”, “NegativeBinomial”, and “Gaussian”.inverse_link_function (
Optional[Callable[[Array],Array]]) – A function that maps the linear combination of predictors into a rate or probability. The default depends on the observation model, see the table above.regularizer (
Union[str,Regularizer]) – Regularization scheme used in the M-step for the per-state GLM coefficients. Default is"Ridge". Pass"UnRegularized"to disable regularization.regularizer_strength (
Any) – Strength of the regularization applied to the GLM coefficients. Default is1.0. Ignored whenregularizer="UnRegularized".dirichlet_initial_proba (
Optional[Array]) – Alpha parameters for the Dirichlet prior over the initial state probabilities. Shape(n_states,). If None, a flat (uninformative) prior is assumed.dirichlet_transition_proba (
Optional[Array]) – Alpha parameters for the Dirichlet prior over the transition probabilities. Shape(n_states, n_states). If None, a flat (uninformative) prior is assumed.solver_name (
str) – Solver used for the GLM M-step. The solver must be valid for the chosen regularizer (see table above). Default isNone, in which case the regularizer’s default solver is selected ("LBFGS"for Ridge / UnRegularized,"ProximalGradient"for Lasso / ElasticNet / GroupLasso).solver_kwargs (
Optional[dict]) – Optional dictionary for keyword arguments that are passed to the solver when instantiated. E.g., stepsize, tol, acceleration, etc.maxiter (
int) – Maximum number of EM iterations. Default is 1000.tol (
float) – Convergence tolerance for the EM algorithm. The algorithm stops when the absolute change in log-likelihood between consecutive iterations falls below this threshold. Default is 1e-8.seed – JAX PRNG key for random number generation during initialization. Default is
jax.random.PRNGKey(123).hmm_initialization_funcs (
Optional[dict[Literal['initial_proba_init','initial_proba_init_kwargs','initial_proba_init_custom','transition_proba_init','transition_proba_init_kwargs','transition_proba_init_custom'],InitFunctionHMM|dict[str,Any] |bool]]) – Dictionary of initialization functions for HMM probabilities (initial and transition). Included for scikit-learn compatibility; prefer configuring via thesetup()method after construction. IfNone, defaults fromDEFAULT_INIT_FUNCTIONSare used.model_initialization_funcs (
Optional[dict[Literal['glm_params_init','glm_params_init_kwargs','glm_params_init_custom','scale_init','scale_init_kwargs','scale_init_custom'],InitFunctionGLM|InitFunctionHMM|dict[str,Any] |bool]]) – Dictionary of initialization functions for the GLM-specific parameters (coefficients, intercept, and scale). Included for scikit-learn compatibility; prefer configuring via thesetup()method after construction. IfNone, defaults fromDEFAULT_INIT_FUNCTIONS_GLMHMMare used.
- transition_prob_#
Transition probability matrix of shape
(n_states, n_states). Entry[i, j]represents the probability of transitioning from stateito statej.
- initial_prob_#
Initial state probability vector of shape
(n_states,). Entry[i]represents the probability of starting in statei.
- coef_#
GLM coefficients for each state, shape
(n_features, n_states).
- intercept_#
GLM intercepts (bias terms) for each state, shape
(n_states,).
- solver_state_#
State of the solver after fitting. May include details like optimization error.
- scale_#
Scale parameter for the observation model, shape
(n_states,).
- dof_resid_#
Degrees of freedom for the residuals.
Notes
To bypass the initialization functions entirely and provide parameter arrays directly, pass them to the
fit()method:model.fit(X, y, init_params=my_params)
- Raises:
TypeError – If
n_statesis not a positive integer.TypeError – If provided
regularizerorobservation_modelare not valid.TypeError – If
seedis not a valid JAX PRNG key.KeyError – If
hmm_initialization_funcsormodel_initialization_funcscontains keys that are not valid for their respective default dictionary.ValueError – If any
*_kwargsentry in either initialization-funcs dictionary contains keyword arguments that don’t match the signature of the corresponding initialization function.ValueError – If
maxiteris not a positive integer.ValueError – If
tolis not a positive float.
- Parameters:
n_states (int)
observation_model (Observations | Literal['Poisson', 'Gamma', 'Bernoulli', 'NegativeBinomial', 'Gaussian'])
inverse_link_function (Optional[Callable[[jnp.ndarray], jnp.ndarray]])
regularizer (Union[str, Regularizer])
regularizer_strength (Any)
dirichlet_initial_proba (Union[jnp.ndarray, None])
dirichlet_transition_proba (Union[jnp.ndarray | None])
solver_name (str)
solver_kwargs (Optional[dict])
maxiter (int)
tol (float)
hmm_initialization_funcs (Optional[HMM_INITIALIZATION_FN_DICT])
model_initialization_funcs (Optional[GLMHMM_INITIALIZATION_FN_DICT])
Examples
Fit a GLM-HMM
Basic model fitting with the default Bernoulli observation model. The number of hidden states is the only required argument;
coef_carries one column per state, and the HMM transition matrix and initial distribution are exposed as fitted attributes.>>> import jax >>> import numpy as np >>> import nemos as nmo >>> np.random.seed(123) >>> X = np.random.normal(size=(200, 4)) >>> y = np.random.binomial(n=1, p=0.5, size=200) >>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X, y) >>> model.coef_.shape (4, 2) >>> model.transition_prob_.shape (2, 2) >>> model.initial_prob_.shape (2,)
Customize the Observation Model
Specify the observation model as a string:
>>> model = nmo.glm_hmm.GLMHMM(n_states=2, observation_model="Poisson") >>> model.observation_model PoissonObservations()
Or pass the observation model object directly:
>>> model = nmo.glm_hmm.GLMHMM( ... n_states=2, observation_model=nmo.observation_models.PoissonObservations() ... ) >>> model.observation_model PoissonObservations()
Customize the Inverse Link Function
Use a soft-plus inverse link function instead of the observation-model default:
>>> model = nmo.glm_hmm.GLMHMM(n_states=2, inverse_link_function=jax.nn.softplus) >>> model.inverse_link_function.__name__ 'softplus'
Change the Regularization
Regularization applies to the per-state GLM coefficients. The default is Ridge with strength
1.0. Tune the strength:>>> model = nmo.glm_hmm.GLMHMM(n_states=2, regularizer_strength=0.1).fit(X, y) >>> model.regularizer, float(model.regularizer_strength) (Ridge(), 0.1)
Or switch to Lasso for sparse per-state coefficients (Lasso requires a proximal solver):
>>> model = nmo.glm_hmm.GLMHMM( ... n_states=2, ... regularizer="Lasso", ... regularizer_strength=0.01, ... solver_name="ProximalGradient", ... ).fit(X, y) >>> model.regularizer Lasso()
Select a Solver
The solver is used for the M-step inside EM. Pick LBFGS for potentially faster convergence on smooth losses:
>>> model = nmo.glm_hmm.GLMHMM(n_states=2, solver_name="LBFGS").fit(X, y) >>> model.solver_name 'LBFGS'
Fit Across Multiple Sessions
Mark session boundaries with
session_startsso the HMM resets at each new session start instead of treating the data as a single chain. Pass either a boolean mask of shape(n_time_bins,)withTrueat each session start, or an integer array of session-start indices — the two are equivalent:>>> is_new_mask = np.zeros(200, dtype=bool) >>> is_new_mask[0] = True >>> is_new_mask[100] = True >>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X, y, session_starts=is_new_mask) >>> # Equivalent: pass the starts as integer indices. >>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X, y, session_starts=np.array([0, 100]))
Decode Hidden States
Recover the most-likely state sequence (Viterbi-style) or the smoothed posterior probabilities from the forward-backward pass:
>>> states = model.decode_state(X, y, session_starts=is_new_mask) >>> states.shape (200, 2) >>> post = model.smooth_proba(X, y, session_starts=is_new_mask) >>> post.shape (200, 2)
Simulate from the Fitted Model
Sample a hidden-state trajectory and observations conditioned on inputs:
>>> activity, rates, sim_states = model.simulate( ... jax.random.key(0), X, state_format="index" ... ) >>> activity.shape, sim_states.shape ((200,), (200,))
Use a Dict of Arrays as Input
Features can be passed as any JAX pytree of 2-D arrays; the fitted
coef_will share the same pytree structure, with the trailing axis indexing states:>>> X_dict = {"input_1": X[:, :2], "input_2": X[:, 2:]} >>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X_dict, y) >>> type(model.coef_) <class 'dict'>
Attributes
Alpha parameters of the Dirichlet prior over the initial probabilities of HMM states.
Alpha parameters of the Dirichlet prior over the initial probabilities of HMM states.
Dictionary of initialization functions for HMM parameters.
Inverse link function mapping the linear predictor to the emission space.
EM maximum number of iterations.
Dictionary of initialization functions for model parameters.
Number of hidden states of the HMM.
The observation model governing the emission distribution at each state.
Provides the initialization function for the optimizer state.
Provides the function to execute the optimization process.
Provides the function for updating the state during the optimization process.
Getter for the regularizer attribute.
Regularizer strength getter.
Random seed as a jax PRNG key.
Getter for the solver_kwargs attribute.
Getter for the solver_name attribute.
Getter for the solver specification.
Tolerance for the EM algorithm convergence criterion.
- __init__(n_states, observation_model='Bernoulli', inverse_link_function=None, regularizer='Ridge', regularizer_strength=1.0, dirichlet_initial_proba=None, dirichlet_transition_proba=None, solver_name=None, solver_kwargs=None, maxiter=1000, tol=1e-08, seed=Array([0, 123], dtype=uint32), hmm_initialization_funcs=None, model_initialization_funcs=None)[source]#
- Parameters:
n_states (int)
observation_model (Observations | Literal['Poisson', 'Gamma', 'Bernoulli', 'NegativeBinomial', 'Gaussian'])
regularizer (str | Regularizer)
regularizer_strength (Any)
dirichlet_initial_proba (Array | None)
dirichlet_transition_proba (Array | None)
solver_name (str)
solver_kwargs (dict | None)
maxiter (int)
tol (float)
hmm_initialization_funcs (dict[Literal['initial_proba_init', 'initial_proba_init_kwargs', 'initial_proba_init_custom', 'transition_proba_init', 'transition_proba_init_kwargs', 'transition_proba_init_custom'], ~nemos.hmm.initialize_parameters.InitFunctionHMM | dict[str, ~typing.Any] | bool] | None)
model_initialization_funcs (dict[Literal['glm_params_init', 'glm_params_init_kwargs', 'glm_params_init_custom', 'scale_init', 'scale_init_kwargs', 'scale_init_custom'], ~nemos.glm_hmm.initialize_parameters.InitFunctionGLM | ~nemos.hmm.initialize_parameters.InitFunctionHMM | dict[str, ~typing.Any] | bool] | None)
Methods
__init__(n_states[, observation_model, ...])compute_loss(params, X, y, *args, **kwargs)Compute the loss function for the model.
decode_state(X, y[, session_starts, ...])Compute the most likely hidden state sequence (Viterbi decoding).
filter_proba(X, y[, session_starts])Compute filtering posterior probabilities over hidden states.
fit(X, y[, init_params, session_starts])Fit the GLM-HMM via Expectation-Maximization.
get_params([deep])From scikit-learn, get parameters by inspecting init.
initialize_optimizer_and_state(init_params, X, y)Initialize the optimization routine and its state for running fit and update.
initialize_params(X, y)Initialize model parameters.
save_params(filename)Save GLM-HMM model parameters and fit state to a .npz file.
score(X, y[, session_starts])Marginal log-likelihood of the data under the fitted HMM.
set_params(**params)Manage warnings in case of multiple parameter settings.
setup([initial_proba_init, ...])Configure how
fit()initializes each model parameter.simulate(random_key, feedforward_input[, ...])Simulate neural activity and hidden states from the model.
smooth_proba(X, y[, session_starts])Compute smoothing posterior probabilities over hidden states.
update(params, opt_state, X, y, *args[, ...])Run a single EM iteration on the GLM-HMM.
- classmethod __init_subclass__(**kwargs)#
Set the
set_{method}_requestmethods.This uses PEP-487 [1] to set the
set_{method}_requestmethods. It looks for the information available in the set default values which are set using__metadata_request__*class attributes, or inferred from method signatures.The
__metadata_request__*class attributes are used when a method does not explicitly accept a metadata through its arguments or if the developer would like to specify a request value for those metadata which are different from the defaultNone.References
- __sklearn_tags__()#
Return regression model specific estimator tags.
- compute_loss(params, X, y, *args, **kwargs)#
Compute the loss function for the model.
This method validates inputs and converts user-provided parameters to the internal representation before computing the loss.
- Parameters:
params (UserProvidedParamsT) – Parameter tuple of (coefficients, intercept).
X (DESIGN_INPUT_TYPE) – Input data, array of shape
(n_time_bins, n_features)or pytree of same.y (jnp.ndarray) – Target data, array of shape
(n_time_bins,)for single neuron models or(n_time_bins, n_neurons)for population models.*args – Additional positional arguments passed to the model-specific loss function.
**kwargs – Additional keyword arguments passed to the model-specific loss function.
- Returns:
The loss value (negative log-likelihood).
- Return type:
jnp.ndarray
- Raises:
ValueError – If inputs or parameters have incompatible shapes or invalid values.
- decode_state(X, y, session_starts=None, state_format='one-hot')[source]#
Compute the most likely hidden state sequence (Viterbi decoding).
Finds the single most likely sequence of hidden states that best explains the observed data. Uses the Viterbi (max-sum) algorithm to compute the state sequence that maximizes the joint probability of states and observations.
Unlike
smooth_proba()andfilter_proba(), which return a probability distribution over states at each time bin, this method makes a hard assignment to the single globally optimal state path.The decoded states answer: “What is the most likely sequence of states that generated the observed data?”
- Parameters:
X (Union[DESIGN_INPUT_TYPE, ArrayLike]) – Predictors, shape
(n_time_bins, n_features). A pytree of 2-D arrays sharing the leading time axis is also accepted.y (ArrayLike) – Observations, shape
(n_time_bins,)for a single neuron or(n_time_bins, n_neurons)for a population model. A pynappleTsd/TsdFrameis accepted; session boundaries are then inferred fromtime_support.session_starts (Optional[ArrayLike]) –
Optional session boundaries. Accepts:
a boolean array of shape
(n_time_bins,)withTrueat each session start,an integer array of session-start indices,
a pynapple
IntervalSet(requiresXoryto be a pynapple object to supply timestamps).
If
None, the entire input is treated as a single session.state_format (Literal[‘one-hot’, ‘index’]) –
Format of the returned state sequence:
"one-hot"(default): binary array of shape(n_time_bins, n_states)with a single 1 per row."index": integer array of shape(n_time_bins,)with values in[0, n_states - 1].
- Returns:
Most likely state sequence. Shape and dtype depend on
state_format(see above). Returns a pynappleTsdFrame(columns"state_0","state_1", …) for"one-hot"format or a pynappleTsdfor"index"format when the inputs are pynapple objects; otherwise returns a JAX array.- Return type:
jnp.ndarray | nap.TsdFrame
- Raises:
ValueError – If the model has not been fitted (call
fit()first).ValueError – If
state_formatis not"one-hot"or"index".ValueError – If
Xorycontain NaN values in the interior of an epoch (boundary NaNs are allowed and removed before inference).ValueError – If
Xandyhave inconsistent shapes or feature counts.
See also
smooth_probaCompute smoothing posteriors (soft, probabilistic state assignments).
filter_probaCompute filtering posteriors (causal, conditioned on past observations).
Notes
Viterbi decoding finds the globally optimal state sequence, which can differ from the sequence of states that are individually most probable at each time bin (as returned by
smooth_proba()). For uncertainty estimates usesmooth_proba()instead. Session boundaries reset the Viterbi recursion so that no path crosses session borders.Examples
Decode the most likely state sequence as integer indices:
>>> import numpy as np >>> import nemos as nmo >>> np.random.seed(123) >>> X = np.random.randn(100, 5) >>> y = np.random.poisson(2, size=100) >>> model = nmo.glm_hmm.GLMHMM(n_states=3, observation_model="Poisson").fit(X, y) >>> states = model.decode_state(X, y, state_format="index") >>> states.shape (100,)
One-hot output (default):
>>> states_onehot = model.decode_state(X, y) >>> states_onehot.shape (100, 3) >>> bool(np.all(states_onehot.sum(axis=1) == 1)) True
- property dirichlet_initial_proba: Array | None#
Alpha parameters of the Dirichlet prior over the initial probabilities of HMM states.
If
None, a flat prior is assumed.
- property dirichlet_transition_proba: Array | None#
Alpha parameters of the Dirichlet prior over the initial probabilities of HMM states.
If
None, a flat prior is assumed.
- filter_proba(X, y, session_starts=None)[source]#
Compute filtering posterior probabilities over hidden states.
Computes the probability of being in each hidden state at each time bin, conditioned only on observations up to that time bin. Uses the forward pass of the forward-backward algorithm, providing causal (online) state estimates that rely solely on past and current observations.
The filtering posteriors answer: “Given observations up to time
t, what is the probability that the system is in statekat timet?”- Parameters:
X (Union[DESIGN_INPUT_TYPE, ArrayLike]) – Predictors, shape
(n_time_bins, n_features). A pytree of 2-D arrays sharing the leading time axis is also accepted.y (Union[NDArray, jnp.ndarray, nap.Tsd]) – Observations, shape
(n_time_bins,)for a single neuron or(n_time_bins, n_neurons)for a population model. A pynappleTsd/TsdFrameis accepted; session boundaries are then inferred fromtime_support.session_starts (Optional[ArrayLike]) –
Optional session boundaries. Accepts:
a boolean array of shape
(n_time_bins,)withTrueat each session start,an integer array of session-start indices,
a pynapple
IntervalSet(requiresXoryto be a pynapple object to supply timestamps).
If
None, the entire input is treated as a single session.
- Returns:
Filtering posterior probabilities, shape
(n_time_bins, n_states). Each row sums to 1. Returns a pynappleTsdFrame(with columns named"state_0","state_1", …) when the inputs are pynapple objects; otherwise returns a JAX array.- Return type:
jnp.ndarray | nap.TsdFrame
- Raises:
ValueError – If the model has not been fitted (call
fit()first).ValueError – If
Xorycontain NaN values in the interior of an epoch (boundary NaNs are allowed and removed before inference).ValueError – If
Xandyhave inconsistent shapes or feature counts.
See also
smooth_probaCompute smoothing posteriors (conditioned on all observations).
decode_stateCompute the most likely state sequence via Viterbi decoding.
Notes
Filtering is causal: each posterior at time
tuses only observations up tot, making it suitable for online or real-time applications. For retrospective analysis where all data are available,smooth_proba()gives better state estimates. Session boundaries reset the HMM chain so that no information crosses session borders.Examples
Fit a GLM-HMM and compute filtering posteriors (causal/online):
>>> import numpy as np >>> import nemos as nmo >>> np.random.seed(123) >>> X = np.random.randn(100, 5) >>> y = np.random.poisson(2, size=100) >>> model = nmo.glm_hmm.GLMHMM(n_states=3, observation_model="Poisson").fit(X, y) >>> filt = model.filter_proba(X, y) >>> filt.shape (100, 3) >>> bool(np.allclose(filt.sum(axis=1), 1.0)) True
With pynapple inputs the result is returned as a
TsdFrame:>>> import pynapple as nap >>> t = np.arange(100) * 0.01 >>> X_tsd = nap.TsdFrame(t=t, d=X) >>> y_tsd = nap.Tsd(t=t, d=y.astype(float)) >>> type(model.filter_proba(X_tsd, y_tsd)).__name__ 'TsdFrame'
- fit(X, y, init_params=None, session_starts=None)[source]#
Fit the GLM-HMM via Expectation-Maximization.
Runs the EM algorithm until the absolute change in log-likelihood between consecutive iterations falls below
tolormaxiteris reached. Fitted parameters are exposed on the instance ascoef_,intercept_,scale_,initial_prob_,transition_prob_, plussolver_state_(EM trace) anddof_resid_.How parameters are initialized:
If
init_paramsisNone(typical), the per-state GLM parameters and HMM probabilities are produced by the initializers configured viasetup()(or the package defaults whensetup()was never called).If
init_paramsis provided, it bypasses the initializers entirely. It must be a 5-tuple(coef, intercept, scale, initial_prob, transition_prob)whose shapes are consistent withX,y, andn_states.
- Parameters:
X (DESIGN_INPUT_TYPE) – Predictors, shape
(n_time_bins, n_features). A pytree of arrays sharing leading dimension is also accepted; the fittedcoef_mirrors the pytree structure (with a trailing state axis). A pynappleTsdFrameis accepted.y (Union[NDArray, jnp.ndarray, nap.Tsd]) – Observations, shape
(n_time_bins,)for single neuron or(n_time_bins, n_neurons)for population models. A pynappleTsd/TsdFrameis accepted.init_params (Optional[GLMHMMUserParams]) – Optional explicit initial parameters as a 5-tuple
(coef, intercept, scale, initial_prob, transition_prob). WhenNone(default), the initializers configured bysetup()(or the defaults) are used.session_starts (Optional[jnp.ndarray]) –
Optional session boundaries for the HMM. Accepts:
a boolean array of shape
(n_time_bins,)withTrueat each session start,an integer array of session-start indices,
a pynapple
IntervalSet(requiresXoryto be a pynapple object to supply timestamps).
If
Xoryis a pynapple object andsession_startsisNone, the (unique, enforced)time_supportof the pynapple input determines the session starts. With no pynapple input andsession_starts=None, the whole input is treated as a single session.
- Returns:
The fitted estimator.
- Return type:
GLMHMM
- Raises:
ValueError – If inputs fail dimensionality, shape, or consistency checks (e.g.
coeffeatures do not matchX.shape[1], or NaNs appear mid-epoch).TypeError – If
init_paramsis not a 5-tuple or has incompatible leaf types.
- Warns:
RuntimeWarning – Emitted when EM runs out of iterations without satisfying the
tolcriterion (solver_state_.iterations == maxiter). Consider enabling float64, raisingmaxiter, or looseningtol.
Examples
Basic fit with default Bernoulli observations:
>>> import numpy as np >>> import nemos as nmo >>> np.random.seed(0) >>> X = np.random.normal(size=(200, 4)) >>> y = np.random.binomial(n=1, p=0.5, size=200) >>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X, y) >>> model.coef_.shape, model.transition_prob_.shape ((4, 2), (2, 2))
Multiple sessions via explicit
session_starts:>>> session_starts = np.array([0, 100]) >>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X, y, session_starts=session_starts)
- get_metadata_routing()#
Get metadata routing of this object.
Please check User Guide on how the routing mechanism works.
- Returns:
routing – A
MetadataRequestencapsulating routing information.- Return type:
MetadataRequest
- get_params(deep=True)#
From scikit-learn, get parameters by inspecting init.
- Parameters:
deep – If True, will return the parameters for this estimator and contained subobjects that are estimators.
- Return type:
- Returns:
A dictionary containing the parameters. Key is the parameter name, value is the parameter value.
- property hmm_initialization_funcs: dict[Literal['initial_proba_init', 'initial_proba_init_kwargs', 'initial_proba_init_custom', 'transition_proba_init', 'transition_proba_init_kwargs', 'transition_proba_init_custom'], InitFunctionHMM | dict[str, Any] | bool] | None#
Dictionary of initialization functions for HMM parameters.
- initialize_optimizer_and_state(init_params, X, y)#
Initialize the optimization routine and its state for running fit and update.
This method must be called before using
update()for iterative optimization. It sets up the solver with the provided initial parameters and data.- Parameters:
X (DESIGN_INPUT_TYPE) – Input data, array of shape
(n_time_bins, n_features)or pytree of same.y (jnp.ndarray) – Target data, array of shape
(n_time_bins,)for single neuron models or(n_time_bins, n_neurons)for population models.init_params (UserProvidedParamsT) – Initial parameter tuple of (coefficients, intercept).
- Returns:
Initial solver state.
- Return type:
SolverState
- Raises:
ValueError – If inputs or parameters have incompatible shapes or invalid values.
- initialize_params(X, y)#
Initialize model parameters.
Initialize coefficients with zeros and intercept by matching the mean firing rate.
- Parameters:
X (DESIGN_INPUT_TYPE) – Input data, array of shape
(n_time_bins, n_features)or pytree of same.y (jnp.ndarray) – Target data, array of shape
(n_time_bins,)for single neuron models or(n_time_bins, n_neurons)for population models.
- Returns:
Initial parameter tuple of (coefficients, intercept).
- Return type:
UserProvidedParamsT
- property inverse_link_function#
Inverse link function mapping the linear predictor to the emission space.
Always a callable. If
Nonewas passed at construction time, this is resolved to the observation model’s default (e.g.jnp.expfor Poisson,1 / xfor Gamma,jax.nn.sigmoidfor Bernoulli). Shared across all hidden states.
- property maxiter#
EM maximum number of iterations.
- property model_initialization_funcs: MODEL_INITIALIZATION_FN_DICT_T | None#
Dictionary of initialization functions for model parameters.
- property observation_model: Observations#
The observation model governing the emission distribution at each state.
Always an instance of an
Observationssubclass. The same distribution is used across all hidden states (per-state differences come from the state-specific coefficients/intercept/scale, not from the family). If a string alias was passed at construction time it is resolved to the corresponding instance here.
- property optimizer_init_state: None | Callable[[Any, Array, Array], SolverState]#
Provides the initialization function for the optimizer state.
This function is responsible for initializing the optimizer state, necessary for the start of the optimizer process. It sets up initial values for parameters like gradients and step sizes based on the model configuration and input data.
- Returns:
The function to initialize the optimizer state, if available; otherwise, None if the optimizer has not yet been instantiated.
- property optimizer_run: None | Callable[[Any, Array, Array], Tuple[Any, SolverState, Aux]]#
Provides the function to execute the optimization process.
This function runs the optimizer using the initialized parameters and state, performing the optimization to fit the model to the data. It iteratively updates the model parameters until a stopping criterion is met, such as convergence or exceeding a maximum number of iterations.
- Returns:
The function to run the optimization process, if available; otherwise, None if the optimizer has not yet been instantiated.
- property optimizer_update: None | Callable[[Any, NamedTuple, Array, Array], Tuple[Any, SolverState, Aux]]#
Provides the function for updating the state during the optimization process.
This function is used to perform a single update step in the optimization process. It updates the model’s parameters based on the current state, data, and gradients. It is typically used in scenarios where fine-grained control over each optimizer step is necessary, such as in online learning or complex optimization scenarios.
- Returns:
The function to perform a single optimization update step, if available; otherwise, None if the optimizer has not yet been instantiated.
- property regularizer: None | Regularizer#
Getter for the regularizer attribute.
- save_params(filename)[source]#
Save GLM-HMM model parameters and fit state to a .npz file.
Persists hyperparameters returned by
get_params()together with the fitted attributes (coef_,intercept_,scale_,initial_prob_,transition_prob_,dof_resid_). Thesolver_state_is intentionally excluded as it is solver-specific and not needed to reuse the fitted model. The file can be reloaded withnemos.load_model().If the model was configured with custom initialization functions, pass them back to
nemos.load_model()viamapping_dictto restore them (see example below). Built-in initializers are resolved automatically.Examples
Default round-trip — built-in initializers are resolved automatically on load:
>>> import os, tempfile >>> import numpy as np >>> import nemos as nmo >>> np.random.seed(0) >>> X = np.random.normal(size=(80, 3)) >>> y = np.random.binomial(n=1, p=0.5, size=80) >>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X, y) >>> with tempfile.TemporaryDirectory() as d: ... path = os.path.join(d, "glmhmm.npz") ... model.save_params(path) ... loaded = nmo.load_model(path) >>> bool(np.allclose(model.coef_, loaded.coef_)) True
Round-trip with a custom GLM-params initializer. Pass it back as a partial dict under
model_initialization_funcs; remaining slots fall back to the saved (built-in) names:>>> import jax.numpy as jnp >>> def my_glm_init( ... n_states, X, y, inverse_link_function, observation_model, ... session_starts, random_key, ... ): ... return jnp.zeros((X.shape[1], n_states)), jnp.zeros((n_states,)) >>> model = nmo.glm_hmm.GLMHMM(n_states=2) >>> model.setup(glm_params_init=my_glm_init) >>> _ = model.fit(X, y) >>> with tempfile.TemporaryDirectory() as d: ... path = os.path.join(d, "glmhmm.npz") ... model.save_params(path) ... loaded = nmo.load_model( ... path, ... mapping_dict={ ... "model_initialization_funcs": {"glm_params_init": my_glm_init}, ... }, ... ) >>> loaded.model_initialization_funcs["glm_params_init"] is my_glm_init True
- score(X, y, session_starts=None)#
Marginal log-likelihood of the data under the fitted HMM.
HMM-family models score only by log-likelihood. Variance-based or deviance-based pseudo-R² metrics are not implemented because they depend on a null/saturated-model construction that has no clean analogue for latent-state sequence models. Compute AIC/BIC or held-out log-likelihood externally if needed.
- Parameters:
X (Union[DESIGN_INPUT_TYPE, ArrayLike]) – Input data/design matrix, shape
(n_samples, n_features).y (ArrayLike) – Output data/observations, shape
(n_samples, n_observations).session_starts (Optional[ArrayLike]) – Optional array indicating user-provided session boundaries. Can be: - a boolean array indicating session starts, shape
(n_samples,)- an integer array of indices marking session starts, shape(n_sessions,)- a pynapple.IntervalSet marking session epochs (requires either X or y to be a pynapple Tsd or TsdFrame to get timestamps) If None, creates a default array treating all data as one session.
- Return type:
jnp.ndarray
- Returns:
The marginal log-likelihood (summed over time).
- property seed#
Random seed as a jax PRNG key.
- set_fit_request(*, init_params='$UNCHANGED$', session_starts='$UNCHANGED$')#
Configure whether metadata should be requested to be passed to the
fitmethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
init_params (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
init_paramsparameter infit.session_starts (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
session_startsparameter infit.self (GLMHMM)
- Returns:
self – The updated object.
- Return type:
- set_params(**params)#
Manage warnings in case of multiple parameter settings.
- Parameters:
params (Any)
- set_score_request(*, session_starts='$UNCHANGED$')#
Configure whether metadata should be requested to be passed to the
scoremethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed toscoreif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it toscore.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- setup(initial_proba_init=None, initial_proba_init_kwargs=None, transition_proba_init=None, transition_proba_init_kwargs=None, glm_params_init=None, glm_params_init_kwargs=None, scale_init=None, scale_init_kwargs=None)[source]#
Configure how
fit()initializes each model parameter.Calling
setup()is optional: if it is never called, fitting starts from the default initializers listed below. Use it to change the initialization strategy by providing either the name of a built-in initialization function or a custom callable. Each argument left asNonekeeps the previously configured value; only the parameters you supply are updated.Available built-in initialization functions:
initial_proba_init:"uniform"(default),"random","dirichlet","kmeans".transition_proba_init:"sticky"(default),"uniform","random","dirichlet","kmeans".glm_params_init:"random"(default),"kmeans".scale_init:"constant"(default),"kmeans".
- Parameters:
initial_proba_init (
Union[Literal['uniform','random','dirichlet','kmeans'],InitFunctionHMM,None]) – Built-in name or custom callable used to initialize the initial-state probabilities (shape(n_states,)).initial_proba_init_kwargs (
Optional[dict]) – Extra keyword arguments forwarded toinitial_proba_init.transition_proba_init (
Union[Literal['sticky','uniform','random','dirichlet','kmeans'],InitFunctionHMM,None]) – Built-in name or custom callable used to initialize the transition matrix (shape(n_states, n_states)).transition_proba_init_kwargs (
Optional[dict]) – Extra keyword arguments forwarded totransition_proba_init.glm_params_init (
Union[Literal['random','kmeans'],InitFunctionGLM,None]) – Built-in name or custom callable used to initialize the per-state GLM coefficients and intercepts.glm_params_init_kwargs (
Optional[dict]) – Extra keyword arguments forwarded toglm_params_init.scale_init (
Union[Literal['constant','kmeans'],InitFunctionGLM,None]) – Built-in name or custom callable used to initialize the scale parameter of the observation model (e.g. variance for Gaussian, dispersion for NegativeBinomial). Ignored by observation models without a scale.scale_init_kwargs (
Optional[dict]) – Extra keyword arguments forwarded toscale_init.
- Raises:
ValueError – If a custom callable’s signature is incompatible with the protocol above, or if a
*_kwargsentry contains keys that don’t match the corresponding initializer’s signature.
Notes
Custom callables must satisfy one of two
typing.Protocolclasses:initial_proba_initandtransition_proba_initmust satisfyInitFunctionHMMand return ajnp.ndarrayof shape(n_states,)or(n_states, n_states)respectively.glm_params_initandscale_initmust satisfyInitFunctionGLM.glm_params_initreturns(coef, intercept)matched to the design andn_states;scale_initreturns the scale array for the observation model.
To inspect a protocol’s signature, import and
help()it:from nemos.hmm.initialize_parameters import InitFunctionHMM from nemos.glm_hmm.initialize_parameters import InitFunctionGLM help(InitFunctionHMM) # or help(InitFunctionGLM)
All arguments must appear in the function signature even when unused, so the framework can supply them uniformly.
Examples
Switch a parameter to a different built-in scheme by passing its label:
>>> from nemos.glm_hmm import GLMHMM >>> model = GLMHMM(n_states=3) >>> model.setup(initial_proba_init="random", glm_params_init="kmeans")
Plug in a custom callable matching the GLM-side protocol:
>>> import jax.numpy as jnp >>> def my_glm_init( ... n_states, X, y, inverse_link_function, observation_model, ... session_starts, random_key, ... ): ... coef = jnp.zeros((X.shape[1], n_states)) ... intercept = jnp.zeros((n_states,)) ... return coef, intercept >>> model.setup(glm_params_init=my_glm_init)
- simulate(random_key, feedforward_input, state_format='index', session_starts=None)[source]#
Simulate neural activity and hidden states from the model.
Simulates a trajectory through the hidden state space according to the HMM dynamics, then generates observations from the GLM emission model conditioned on each state.
- Parameters:
random_key (jax.Array) – JAX random key for reproducible simulation.
feedforward_input (DESIGN_INPUT_TYPE) – Design matrix of shape
(n_time_bins, n_features). If a pynapple Tsd/TsdFrame is provided, session boundaries are detected fromtime_supportand the hidden state chain is reset at each session start.state_format (Literal[‘one-hot’, ‘index’]) –
Format for the returned states:
"index": Integer array of shape(n_time_bins,)with state indices."one-hot": Binary array of shape(n_time_bins, n_states).
session_starts (Optional[jax.Array]) –
Optional session boundaries. Accepts:
a boolean array of shape
(n_time_bins,)withTrueat each session start,an integer array of session-start indices,
a pynapple
IntervalSet(requiresfeedforward_inputto be a pynapple object to supply timestamps).
If
feedforward_inputis a pynapple object andsession_startsisNone, thetime_supportdetermines the session starts. With no pynapple input andsession_starts=None, the whole input is treated as a single session.
- Return type:
Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
- Returns:
simulated_activity – Simulated observations from the emission model. Shape
(n_time_bins,)for single neuron or(n_time_bins, n_neurons)for population models.firing_rates – Predicted firing rates conditioned on the simulated states. Shape
(n_time_bins,)or(n_time_bins, n_neurons).simulated_states – Simulated hidden state trajectory. Shape depends on
state_format.
- Raises:
ValueError – If the model has not been fit.
Examples
>>> import jax >>> import numpy as np >>> import nemos as nmo >>> np.random.seed(123) >>> X = np.random.randn(100, 3) >>> y = np.random.binomial(1, 0.5, 100) >>> model = nmo.glm_hmm.GLMHMM(n_states=2, observation_model="Bernoulli") >>> model = model.fit(X, y) >>> key = jax.random.key(0) >>> X_new = np.random.randn(50, 3) >>> activity, rates, states = model.simulate(key, X_new) >>> activity.shape (50,) >>> states.shape (50,)
See also
decode_stateInfer most likely state sequence from observations.
smooth_probaCompute posterior state probabilities.
- smooth_proba(X, y, session_starts=None)[source]#
Compute smoothing posterior probabilities over hidden states.
Computes the probability of being in each hidden state at each time bin, conditioned on the entire observed sequence. Uses the forward-backward algorithm to incorporate information from both past and future observations, providing optimal state estimates given all available data.
The smoothing posteriors answer: “Given all observations, what is the probability that the system was in state
kat timet?”- Parameters:
X (Union[DESIGN_INPUT_TYPE, ArrayLike]) – Predictors, shape
(n_time_bins, n_features). A pytree of 2-D arrays sharing the leading time axis is also accepted.y (Union[NDArray, jnp.ndarray, nap.Tsd]) – Observations, shape
(n_time_bins,)for a single neuron or(n_time_bins, n_neurons)for a population model. A pynappleTsd/TsdFrameis accepted; session boundaries are then inferred fromtime_support.session_starts (Optional[ArrayLike]) –
Optional session boundaries. Accepts:
a boolean array of shape
(n_time_bins,)withTrueat each session start,an integer array of session-start indices,
a pynapple
IntervalSet(requiresXoryto be a pynapple object to supply timestamps).
If
None, the entire input is treated as a single session.
- Returns:
Smoothing posterior probabilities, shape
(n_time_bins, n_states). Each row sums to 1. Returns a pynappleTsdFrame(with columns named"state_0","state_1", …) when the inputs are pynapple objects; otherwise returns a JAX array.- Return type:
jnp.ndarray | nap.TsdFrame
- Raises:
ValueError – If the model has not been fitted (call
fit()first).ValueError – If
Xorycontain NaN values in the interior of an epoch (boundary NaNs are allowed and removed before inference).ValueError – If
Xandyhave inconsistent shapes or feature counts.
See also
filter_probaCompute filtering posteriors (conditioned on past observations only).
decode_stateCompute the most likely state sequence via Viterbi decoding.
Notes
Smoothing uses all data (non-causal) and gives better state estimates than filtering. For online or real-time applications use
filter_proba()instead. Session boundaries reset the HMM chain so that no information crosses session borders.Examples
Fit a GLM-HMM and compute smoothing posteriors:
>>> import numpy as np >>> import nemos as nmo >>> np.random.seed(123) >>> X = np.random.randn(100, 5) >>> y = np.random.poisson(2, size=100) >>> model = nmo.glm_hmm.GLMHMM(n_states=3, observation_model="Poisson").fit(X, y) >>> posteriors = model.smooth_proba(X, y) >>> posteriors.shape (100, 3) >>> bool(np.allclose(posteriors.sum(axis=1), 1.0)) True
With pynapple inputs the result is returned as a
TsdFrame:>>> import pynapple as nap >>> t = np.arange(100) * 0.01 >>> X_tsd = nap.TsdFrame(t=t, d=X) >>> y_tsd = nap.Tsd(t=t, d=y.astype(float)) >>> type(model.smooth_proba(X_tsd, y_tsd)).__name__ 'TsdFrame'
- property solver_kwargs#
Getter for the solver_kwargs attribute.
- property solver_spec: SolverSpec#
Getter for the solver specification.
- property tol#
Tolerance for the EM algorithm convergence criterion.
The algorithm stops when the absolute change in log-likelihood between consecutive iterations falls below this threshold:
|log_likelihood_current - log_likelihood_previous| < tol- Returns:
float
- Return type:
Convergence tolerance value.
- update(params, opt_state, X, y, *args, session_starts=None, n_samples=None, **kwargs)[source]#
Run a single EM iteration on the GLM-HMM.
Performs one E-step / M-step pair starting from the supplied parameters and EM state, updates the model’s fitted attributes (
coef_,intercept_,scale_,initial_prob_,transition_prob_,solver_state_,dof_resid_) in place, and returns the updated parameter tuple and EM state. Intended for callers that need fine-grained control over EM iteration (e.g. checkpointing, custom convergence criteria) instead of the bundledfit()loop.initialize_optimizer_and_state()must be called first so that the EM step function and initialopt_stateare available.- Parameters:
params (GLMHMMUserParams) – Current model parameters as a 5-tuple
(coef, intercept, scale, initial_prob, transition_prob)matching the structure produced byinitialize_params().opt_state (NamedTuple) – EM state returned by
initialize_optimizer_and_state()or by the previous call toupdate().X (DESIGN_INPUT_TYPE) – Predictors, shape
(n_time_bins, n_features)(or a pytree of arrays of the same shape).y (jnp.ndarray) – Observations, shape
(n_time_bins,)or(n_time_bins, n_neurons).session_starts (Optional[jnp.ndarray]) –
Optional session boundaries. Accepts:
a boolean array of shape
(n_time_bins,)withTrueat each session start,an integer array of session-start indices,
a pynapple
IntervalSet(requiresXoryto be a pynapple object to supply timestamps).
If
None, the entire input is treated as a single session.n_samples (Optional[int]) – Total sample count to use when estimating the residual degrees of freedom. Defaults to
X.shape[0].
- Return type:
StepResult
- Returns:
params – Updated user-facing parameter tuple.
state – Updated EM state.
- Raises:
ValueError – If inputs fail shape/consistency validation.
Examples
>>> import numpy as np >>> import nemos as nmo >>> np.random.seed(0) >>> X = np.random.normal(size=(80, 3)) >>> y = np.random.binomial(n=1, p=0.5, size=80) >>> model = nmo.glm_hmm.GLMHMM(n_states=2) >>> init_params = model.initialize_params(X, y) >>> opt_state = model.initialize_optimizer_and_state(init_params, X, y) >>> new_params, new_state = model.update(init_params, opt_state, X, y)