nemos.glm_hmm.GLMHMM#

class nemos.glm_hmm.GLMHMM(n_states, observation_model='Bernoulli', inverse_link_function=None, regularizer='Ridge', regularizer_strength=1.0, dirichlet_initial_proba=None, dirichlet_transition_proba=None, solver_name=None, solver_kwargs=None, maxiter=1000, tol=1e-08, seed=Array([0, 123], dtype=uint32), hmm_initialization_funcs=None, model_initialization_funcs=None)[source]#

Bases: BaseHMM

Generalized Linear Model with Hidden Markov Model (GLM-HMM).

This model combines a Generalized Linear Model (GLM) with a Hidden Markov Model (HMM) to capture state-dependent relationships between predictors and neural or behavioral responses. The GLM-HMM is suitable for modeling time series data where the relationship between inputs and outputs varies according to an underlying latent state that evolves over time following Markovian dynamics.

The model assumes that at each time step, the system is in one of n_states discrete hidden states. Each state has its own GLM parameters (coefficients and intercept), and transitions between states are governed by a transition probability matrix. The model is fitted using the Expectation-Maximization (EM) algorithm.

Below is a table of the default inverse link function for the available observation models.

Observation Model

Default Inverse Link Function

Poisson

\(e^x\)

Gamma

\(1/x\)

Bernoulli

\(1 / (1 + e^{-x})\)

NegativeBinomial

\(e^x\)

Gaussian

\(x\)

Below is a table listing the default and available solvers for each regularizer.

Regularizer

Default Solver

Available Solvers

UnRegularized

LBFGS

GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient

Ridge

LBFGS

GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient

Lasso

ProximalGradient

ProximalGradient

ElasticNet

ProximalGradient

ProximalGradient

GroupLasso

ProximalGradient

ProximalGradient

Parameters:
  • n_states (int) – The number of hidden states in the HMM. Must be a positive integer.

  • observation_model (Union[Observations, Literal['Poisson', 'Gamma', 'Bernoulli', 'NegativeBinomial', 'Gaussian']]) – Observation model to use. The model describes the distribution of the response variable. Default is the Bernoulli model. Alternatives are “Poisson”, “Gamma”, “NegativeBinomial”, and “Gaussian”.

  • inverse_link_function (Optional[Callable[[Array], Array]]) – A function that maps the linear combination of predictors into a rate or probability. The default depends on the observation model, see the table above.

  • regularizer (Union[str, Regularizer]) – Regularization scheme used in the M-step for the per-state GLM coefficients. Default is "Ridge". Pass "UnRegularized" to disable regularization.

  • regularizer_strength (Any) – Strength of the regularization applied to the GLM coefficients. Default is 1.0. Ignored when regularizer="UnRegularized".

  • dirichlet_initial_proba (Optional[Array]) – Alpha parameters for the Dirichlet prior over the initial state probabilities. Shape (n_states,). If None, a flat (uninformative) prior is assumed.

  • dirichlet_transition_proba (Optional[Array]) – Alpha parameters for the Dirichlet prior over the transition probabilities. Shape (n_states, n_states). If None, a flat (uninformative) prior is assumed.

  • solver_name (str) – Solver used for the GLM M-step. The solver must be valid for the chosen regularizer (see table above). Default is None, in which case the regularizer’s default solver is selected ("LBFGS" for Ridge / UnRegularized, "ProximalGradient" for Lasso / ElasticNet / GroupLasso).

  • solver_kwargs (Optional[dict]) – Optional dictionary for keyword arguments that are passed to the solver when instantiated. E.g., stepsize, tol, acceleration, etc.

  • maxiter (int) – Maximum number of EM iterations. Default is 1000.

  • tol (float) – Convergence tolerance for the EM algorithm. The algorithm stops when the absolute change in log-likelihood between consecutive iterations falls below this threshold. Default is 1e-8.

  • seed – JAX PRNG key for random number generation during initialization. Default is jax.random.PRNGKey(123).

  • hmm_initialization_funcs (Optional[dict[Literal['initial_proba_init', 'initial_proba_init_kwargs', 'initial_proba_init_custom', 'transition_proba_init', 'transition_proba_init_kwargs', 'transition_proba_init_custom'], InitFunctionHMM | dict[str, Any] | bool]]) – Dictionary of initialization functions for HMM probabilities (initial and transition). Included for scikit-learn compatibility; prefer configuring via the setup() method after construction. If None, defaults from DEFAULT_INIT_FUNCTIONS are used.

  • model_initialization_funcs (Optional[dict[Literal['glm_params_init', 'glm_params_init_kwargs', 'glm_params_init_custom', 'scale_init', 'scale_init_kwargs', 'scale_init_custom'], InitFunctionGLM | InitFunctionHMM | dict[str, Any] | bool]]) – Dictionary of initialization functions for the GLM-specific parameters (coefficients, intercept, and scale). Included for scikit-learn compatibility; prefer configuring via the setup() method after construction. If None, defaults from DEFAULT_INIT_FUNCTIONS_GLMHMM are used.

transition_prob_#

Transition probability matrix of shape (n_states, n_states). Entry [i, j] represents the probability of transitioning from state i to state j.

initial_prob_#

Initial state probability vector of shape (n_states,). Entry [i] represents the probability of starting in state i.

coef_#

GLM coefficients for each state, shape (n_features, n_states).

intercept_#

GLM intercepts (bias terms) for each state, shape (n_states,).

solver_state_#

State of the solver after fitting. May include details like optimization error.

scale_#

Scale parameter for the observation model, shape (n_states,).

dof_resid_#

Degrees of freedom for the residuals.

Notes

To bypass the initialization functions entirely and provide parameter arrays directly, pass them to the fit() method:

model.fit(X, y, init_params=my_params)
Raises:
  • TypeError – If n_states is not a positive integer.

  • TypeError – If provided regularizer or observation_model are not valid.

  • TypeError – If seed is not a valid JAX PRNG key.

  • KeyError – If hmm_initialization_funcs or model_initialization_funcs contains keys that are not valid for their respective default dictionary.

  • ValueError – If any *_kwargs entry in either initialization-funcs dictionary contains keyword arguments that don’t match the signature of the corresponding initialization function.

  • ValueError – If maxiter is not a positive integer.

  • ValueError – If tol is not a positive float.

Parameters:
  • n_states (int)

  • observation_model (Observations | Literal['Poisson', 'Gamma', 'Bernoulli', 'NegativeBinomial', 'Gaussian'])

  • inverse_link_function (Optional[Callable[[jnp.ndarray], jnp.ndarray]])

  • regularizer (Union[str, Regularizer])

  • regularizer_strength (Any)

  • dirichlet_initial_proba (Union[jnp.ndarray, None])

  • dirichlet_transition_proba (Union[jnp.ndarray | None])

  • solver_name (str)

  • solver_kwargs (Optional[dict])

  • maxiter (int)

  • tol (float)

  • hmm_initialization_funcs (Optional[HMM_INITIALIZATION_FN_DICT])

  • model_initialization_funcs (Optional[GLMHMM_INITIALIZATION_FN_DICT])

Examples

Fit a GLM-HMM

Basic model fitting with the default Bernoulli observation model. The number of hidden states is the only required argument; coef_ carries one column per state, and the HMM transition matrix and initial distribution are exposed as fitted attributes.

>>> import jax
>>> import numpy as np
>>> import nemos as nmo
>>> np.random.seed(123)
>>> X = np.random.normal(size=(200, 4))
>>> y = np.random.binomial(n=1, p=0.5, size=200)
>>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X, y)
>>> model.coef_.shape
(4, 2)
>>> model.transition_prob_.shape
(2, 2)
>>> model.initial_prob_.shape
(2,)

Customize the Observation Model

Specify the observation model as a string:

>>> model = nmo.glm_hmm.GLMHMM(n_states=2, observation_model="Poisson")
>>> model.observation_model
PoissonObservations()

Or pass the observation model object directly:

>>> model = nmo.glm_hmm.GLMHMM(
...     n_states=2, observation_model=nmo.observation_models.PoissonObservations()
... )
>>> model.observation_model
PoissonObservations()

Customize the Inverse Link Function

Use a soft-plus inverse link function instead of the observation-model default:

>>> model = nmo.glm_hmm.GLMHMM(n_states=2, inverse_link_function=jax.nn.softplus)
>>> model.inverse_link_function.__name__
'softplus'

Change the Regularization

Regularization applies to the per-state GLM coefficients. The default is Ridge with strength 1.0. Tune the strength:

>>> model = nmo.glm_hmm.GLMHMM(n_states=2, regularizer_strength=0.1).fit(X, y)
>>> model.regularizer, float(model.regularizer_strength)
(Ridge(), 0.1)

Or switch to Lasso for sparse per-state coefficients (Lasso requires a proximal solver):

>>> model = nmo.glm_hmm.GLMHMM(
...     n_states=2,
...     regularizer="Lasso",
...     regularizer_strength=0.01,
...     solver_name="ProximalGradient",
... ).fit(X, y)
>>> model.regularizer
Lasso()

Select a Solver

The solver is used for the M-step inside EM. Pick LBFGS for potentially faster convergence on smooth losses:

>>> model = nmo.glm_hmm.GLMHMM(n_states=2, solver_name="LBFGS").fit(X, y)
>>> model.solver_name
'LBFGS'

Fit Across Multiple Sessions

Mark session boundaries with session_starts so the HMM resets at each new session start instead of treating the data as a single chain. Pass either a boolean mask of shape (n_time_bins,) with True at each session start, or an integer array of session-start indices — the two are equivalent:

>>> is_new_mask = np.zeros(200, dtype=bool)
>>> is_new_mask[0] = True
>>> is_new_mask[100] = True
>>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X, y, session_starts=is_new_mask)
>>> # Equivalent: pass the starts as integer indices.
>>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X, y, session_starts=np.array([0, 100]))

Decode Hidden States

Recover the most-likely state sequence (Viterbi-style) or the smoothed posterior probabilities from the forward-backward pass:

>>> states = model.decode_state(X, y, session_starts=is_new_mask)
>>> states.shape
(200, 2)
>>> post = model.smooth_proba(X, y, session_starts=is_new_mask)
>>> post.shape
(200, 2)

Simulate from the Fitted Model

Sample a hidden-state trajectory and observations conditioned on inputs:

>>> activity, rates, sim_states = model.simulate(
...     jax.random.key(0), X, state_format="index"
... )
>>> activity.shape, sim_states.shape
((200,), (200,))

Use a Dict of Arrays as Input

Features can be passed as any JAX pytree of 2-D arrays; the fitted coef_ will share the same pytree structure, with the trailing axis indexing states:

>>> X_dict = {"input_1": X[:, :2], "input_2": X[:, 2:]}
>>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X_dict, y)
>>> type(model.coef_)
<class 'dict'>

Attributes

dirichlet_initial_proba

Alpha parameters of the Dirichlet prior over the initial probabilities of HMM states.

dirichlet_transition_proba

Alpha parameters of the Dirichlet prior over the initial probabilities of HMM states.

hmm_initialization_funcs

Dictionary of initialization functions for HMM parameters.

inverse_link_function

Inverse link function mapping the linear predictor to the emission space.

maxiter

EM maximum number of iterations.

model_initialization_funcs

Dictionary of initialization functions for model parameters.

n_states

Number of hidden states of the HMM.

observation_model

The observation model governing the emission distribution at each state.

optimizer_init_state

Provides the initialization function for the optimizer state.

optimizer_run

Provides the function to execute the optimization process.

optimizer_update

Provides the function for updating the state during the optimization process.

regularizer

Getter for the regularizer attribute.

regularizer_strength

Regularizer strength getter.

seed

Random seed as a jax PRNG key.

solver_kwargs

Getter for the solver_kwargs attribute.

solver_name

Getter for the solver_name attribute.

solver_spec

Getter for the solver specification.

tol

Tolerance for the EM algorithm convergence criterion.

__init__(n_states, observation_model='Bernoulli', inverse_link_function=None, regularizer='Ridge', regularizer_strength=1.0, dirichlet_initial_proba=None, dirichlet_transition_proba=None, solver_name=None, solver_kwargs=None, maxiter=1000, tol=1e-08, seed=Array([0, 123], dtype=uint32), hmm_initialization_funcs=None, model_initialization_funcs=None)[source]#
Parameters:
  • n_states (int)

  • observation_model (Observations | Literal['Poisson', 'Gamma', 'Bernoulli', 'NegativeBinomial', 'Gaussian'])

  • inverse_link_function (Callable[[Array], Array] | None)

  • regularizer (str | Regularizer)

  • regularizer_strength (Any)

  • dirichlet_initial_proba (Array | None)

  • dirichlet_transition_proba (Array | None)

  • solver_name (str)

  • solver_kwargs (dict | None)

  • maxiter (int)

  • tol (float)

  • hmm_initialization_funcs (dict[Literal['initial_proba_init', 'initial_proba_init_kwargs', 'initial_proba_init_custom', 'transition_proba_init', 'transition_proba_init_kwargs', 'transition_proba_init_custom'], ~nemos.hmm.initialize_parameters.InitFunctionHMM | dict[str, ~typing.Any] | bool] | None)

  • model_initialization_funcs (dict[Literal['glm_params_init', 'glm_params_init_kwargs', 'glm_params_init_custom', 'scale_init', 'scale_init_kwargs', 'scale_init_custom'], ~nemos.glm_hmm.initialize_parameters.InitFunctionGLM | ~nemos.hmm.initialize_parameters.InitFunctionHMM | dict[str, ~typing.Any] | bool] | None)

Methods

__init__(n_states[, observation_model, ...])

compute_loss(params, X, y, *args, **kwargs)

Compute the loss function for the model.

decode_state(X, y[, session_starts, ...])

Compute the most likely hidden state sequence (Viterbi decoding).

filter_proba(X, y[, session_starts])

Compute filtering posterior probabilities over hidden states.

fit(X, y[, init_params, session_starts])

Fit the GLM-HMM via Expectation-Maximization.

get_params([deep])

From scikit-learn, get parameters by inspecting init.

initialize_optimizer_and_state(init_params, X, y)

Initialize the optimization routine and its state for running fit and update.

initialize_params(X, y)

Initialize model parameters.

save_params(filename)

Save GLM-HMM model parameters and fit state to a .npz file.

score(X, y[, session_starts])

Marginal log-likelihood of the data under the fitted HMM.

set_params(**params)

Manage warnings in case of multiple parameter settings.

setup([initial_proba_init, ...])

Configure how fit() initializes each model parameter.

simulate(random_key, feedforward_input[, ...])

Simulate neural activity and hidden states from the model.

smooth_proba(X, y[, session_starts])

Compute smoothing posterior probabilities over hidden states.

update(params, opt_state, X, y, *args[, ...])

Run a single EM iteration on the GLM-HMM.

classmethod __init_subclass__(**kwargs)#

Set the set_{method}_request methods.

This uses PEP-487 [1] to set the set_{method}_request methods. It looks for the information available in the set default values which are set using __metadata_request__* class attributes, or inferred from method signatures.

The __metadata_request__* class attributes are used when a method does not explicitly accept a metadata through its arguments or if the developer would like to specify a request value for those metadata which are different from the default None.

References

__repr__()[source]#

Hierarchical repr for the GLMHMM class.

Return type:

str

__sklearn_tags__()#

Return regression model specific estimator tags.

compute_loss(params, X, y, *args, **kwargs)#

Compute the loss function for the model.

This method validates inputs and converts user-provided parameters to the internal representation before computing the loss.

Parameters:
  • params (UserProvidedParamsT) – Parameter tuple of (coefficients, intercept).

  • X (DESIGN_INPUT_TYPE) – Input data, array of shape (n_time_bins, n_features) or pytree of same.

  • y (jnp.ndarray) – Target data, array of shape (n_time_bins,) for single neuron models or (n_time_bins, n_neurons) for population models.

  • *args – Additional positional arguments passed to the model-specific loss function.

  • **kwargs – Additional keyword arguments passed to the model-specific loss function.

Returns:

The loss value (negative log-likelihood).

Return type:

jnp.ndarray

Raises:

ValueError – If inputs or parameters have incompatible shapes or invalid values.

decode_state(X, y, session_starts=None, state_format='one-hot')[source]#

Compute the most likely hidden state sequence (Viterbi decoding).

Finds the single most likely sequence of hidden states that best explains the observed data. Uses the Viterbi (max-sum) algorithm to compute the state sequence that maximizes the joint probability of states and observations.

Unlike smooth_proba() and filter_proba(), which return a probability distribution over states at each time bin, this method makes a hard assignment to the single globally optimal state path.

The decoded states answer: “What is the most likely sequence of states that generated the observed data?”

Parameters:
  • X (Union[DESIGN_INPUT_TYPE, ArrayLike]) – Predictors, shape (n_time_bins, n_features). A pytree of 2-D arrays sharing the leading time axis is also accepted.

  • y (ArrayLike) – Observations, shape (n_time_bins,) for a single neuron or (n_time_bins, n_neurons) for a population model. A pynapple Tsd/TsdFrame is accepted; session boundaries are then inferred from time_support.

  • session_starts (Optional[ArrayLike]) –

    Optional session boundaries. Accepts:

    • a boolean array of shape (n_time_bins,) with True at each session start,

    • an integer array of session-start indices,

    • a pynapple IntervalSet (requires X or y to be a pynapple object to supply timestamps).

    If None, the entire input is treated as a single session.

  • state_format (Literal[‘one-hot’, ‘index’]) –

    Format of the returned state sequence:

    • "one-hot" (default): binary array of shape (n_time_bins, n_states) with a single 1 per row.

    • "index": integer array of shape (n_time_bins,) with values in [0, n_states - 1].

Returns:

Most likely state sequence. Shape and dtype depend on state_format (see above). Returns a pynapple TsdFrame (columns "state_0", "state_1", …) for "one-hot" format or a pynapple Tsd for "index" format when the inputs are pynapple objects; otherwise returns a JAX array.

Return type:

jnp.ndarray | nap.TsdFrame

Raises:
  • ValueError – If the model has not been fitted (call fit() first).

  • ValueError – If state_format is not "one-hot" or "index".

  • ValueError – If X or y contain NaN values in the interior of an epoch (boundary NaNs are allowed and removed before inference).

  • ValueError – If X and y have inconsistent shapes or feature counts.

See also

smooth_proba

Compute smoothing posteriors (soft, probabilistic state assignments).

filter_proba

Compute filtering posteriors (causal, conditioned on past observations).

Notes

Viterbi decoding finds the globally optimal state sequence, which can differ from the sequence of states that are individually most probable at each time bin (as returned by smooth_proba()). For uncertainty estimates use smooth_proba() instead. Session boundaries reset the Viterbi recursion so that no path crosses session borders.

Examples

Decode the most likely state sequence as integer indices:

>>> import numpy as np
>>> import nemos as nmo
>>> np.random.seed(123)
>>> X = np.random.randn(100, 5)
>>> y = np.random.poisson(2, size=100)
>>> model = nmo.glm_hmm.GLMHMM(n_states=3, observation_model="Poisson").fit(X, y)
>>> states = model.decode_state(X, y, state_format="index")
>>> states.shape
(100,)

One-hot output (default):

>>> states_onehot = model.decode_state(X, y)
>>> states_onehot.shape
(100, 3)
>>> bool(np.all(states_onehot.sum(axis=1) == 1))
True
property dirichlet_initial_proba: Array | None#

Alpha parameters of the Dirichlet prior over the initial probabilities of HMM states.

If None, a flat prior is assumed.

property dirichlet_transition_proba: Array | None#

Alpha parameters of the Dirichlet prior over the initial probabilities of HMM states.

If None, a flat prior is assumed.

filter_proba(X, y, session_starts=None)[source]#

Compute filtering posterior probabilities over hidden states.

Computes the probability of being in each hidden state at each time bin, conditioned only on observations up to that time bin. Uses the forward pass of the forward-backward algorithm, providing causal (online) state estimates that rely solely on past and current observations.

The filtering posteriors answer: “Given observations up to time t, what is the probability that the system is in state k at time t?”

Parameters:
  • X (Union[DESIGN_INPUT_TYPE, ArrayLike]) – Predictors, shape (n_time_bins, n_features). A pytree of 2-D arrays sharing the leading time axis is also accepted.

  • y (Union[NDArray, jnp.ndarray, nap.Tsd]) – Observations, shape (n_time_bins,) for a single neuron or (n_time_bins, n_neurons) for a population model. A pynapple Tsd/TsdFrame is accepted; session boundaries are then inferred from time_support.

  • session_starts (Optional[ArrayLike]) –

    Optional session boundaries. Accepts:

    • a boolean array of shape (n_time_bins,) with True at each session start,

    • an integer array of session-start indices,

    • a pynapple IntervalSet (requires X or y to be a pynapple object to supply timestamps).

    If None, the entire input is treated as a single session.

Returns:

Filtering posterior probabilities, shape (n_time_bins, n_states). Each row sums to 1. Returns a pynapple TsdFrame (with columns named "state_0", "state_1", …) when the inputs are pynapple objects; otherwise returns a JAX array.

Return type:

jnp.ndarray | nap.TsdFrame

Raises:
  • ValueError – If the model has not been fitted (call fit() first).

  • ValueError – If X or y contain NaN values in the interior of an epoch (boundary NaNs are allowed and removed before inference).

  • ValueError – If X and y have inconsistent shapes or feature counts.

See also

smooth_proba

Compute smoothing posteriors (conditioned on all observations).

decode_state

Compute the most likely state sequence via Viterbi decoding.

Notes

Filtering is causal: each posterior at time t uses only observations up to t, making it suitable for online or real-time applications. For retrospective analysis where all data are available, smooth_proba() gives better state estimates. Session boundaries reset the HMM chain so that no information crosses session borders.

Examples

Fit a GLM-HMM and compute filtering posteriors (causal/online):

>>> import numpy as np
>>> import nemos as nmo
>>> np.random.seed(123)
>>> X = np.random.randn(100, 5)
>>> y = np.random.poisson(2, size=100)
>>> model = nmo.glm_hmm.GLMHMM(n_states=3, observation_model="Poisson").fit(X, y)
>>> filt = model.filter_proba(X, y)
>>> filt.shape
(100, 3)
>>> bool(np.allclose(filt.sum(axis=1), 1.0))
True

With pynapple inputs the result is returned as a TsdFrame:

>>> import pynapple as nap
>>> t = np.arange(100) * 0.01
>>> X_tsd = nap.TsdFrame(t=t, d=X)
>>> y_tsd = nap.Tsd(t=t, d=y.astype(float))
>>> type(model.filter_proba(X_tsd, y_tsd)).__name__
'TsdFrame'
fit(X, y, init_params=None, session_starts=None)[source]#

Fit the GLM-HMM via Expectation-Maximization.

Runs the EM algorithm until the absolute change in log-likelihood between consecutive iterations falls below tol or maxiter is reached. Fitted parameters are exposed on the instance as coef_, intercept_, scale_, initial_prob_, transition_prob_, plus solver_state_ (EM trace) and dof_resid_.

How parameters are initialized:

  • If init_params is None (typical), the per-state GLM parameters and HMM probabilities are produced by the initializers configured via setup() (or the package defaults when setup() was never called).

  • If init_params is provided, it bypasses the initializers entirely. It must be a 5-tuple (coef, intercept, scale, initial_prob, transition_prob) whose shapes are consistent with X, y, and n_states.

Parameters:
  • X (DESIGN_INPUT_TYPE) – Predictors, shape (n_time_bins, n_features). A pytree of arrays sharing leading dimension is also accepted; the fitted coef_ mirrors the pytree structure (with a trailing state axis). A pynapple TsdFrame is accepted.

  • y (Union[NDArray, jnp.ndarray, nap.Tsd]) – Observations, shape (n_time_bins,) for single neuron or (n_time_bins, n_neurons) for population models. A pynapple Tsd/TsdFrame is accepted.

  • init_params (Optional[GLMHMMUserParams]) – Optional explicit initial parameters as a 5-tuple (coef, intercept, scale, initial_prob, transition_prob). When None (default), the initializers configured by setup() (or the defaults) are used.

  • session_starts (Optional[jnp.ndarray]) –

    Optional session boundaries for the HMM. Accepts:

    • a boolean array of shape (n_time_bins,) with True at each session start,

    • an integer array of session-start indices,

    • a pynapple IntervalSet (requires X or y to be a pynapple object to supply timestamps).

    If X or y is a pynapple object and session_starts is None, the (unique, enforced) time_support of the pynapple input determines the session starts. With no pynapple input and session_starts=None, the whole input is treated as a single session.

Returns:

The fitted estimator.

Return type:

GLMHMM

Raises:
  • ValueError – If inputs fail dimensionality, shape, or consistency checks (e.g. coef features do not match X.shape[1], or NaNs appear mid-epoch).

  • TypeError – If init_params is not a 5-tuple or has incompatible leaf types.

Warns:

RuntimeWarning – Emitted when EM runs out of iterations without satisfying the tol criterion (solver_state_.iterations == maxiter). Consider enabling float64, raising maxiter, or loosening tol.

Examples

Basic fit with default Bernoulli observations:

>>> import numpy as np
>>> import nemos as nmo
>>> np.random.seed(0)
>>> X = np.random.normal(size=(200, 4))
>>> y = np.random.binomial(n=1, p=0.5, size=200)
>>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X, y)
>>> model.coef_.shape, model.transition_prob_.shape
((4, 2), (2, 2))

Multiple sessions via explicit session_starts:

>>> session_starts = np.array([0, 100])
>>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X, y, session_starts=session_starts)

See also

setup

Configure the initializers used when init_params is None.

update

Run a single EM iteration (advanced, manual loop).

get_metadata_routing()#

Get metadata routing of this object.

Please check User Guide on how the routing mechanism works.

Returns:

routing – A MetadataRequest encapsulating routing information.

Return type:

MetadataRequest

get_params(deep=True)#

From scikit-learn, get parameters by inspecting init.

Parameters:

deep – If True, will return the parameters for this estimator and contained subobjects that are estimators.

Return type:

dict

Returns:

A dictionary containing the parameters. Key is the parameter name, value is the parameter value.

property hmm_initialization_funcs: dict[Literal['initial_proba_init', 'initial_proba_init_kwargs', 'initial_proba_init_custom', 'transition_proba_init', 'transition_proba_init_kwargs', 'transition_proba_init_custom'], InitFunctionHMM | dict[str, Any] | bool] | None#

Dictionary of initialization functions for HMM parameters.

initialize_optimizer_and_state(init_params, X, y)#

Initialize the optimization routine and its state for running fit and update.

This method must be called before using update() for iterative optimization. It sets up the solver with the provided initial parameters and data.

Parameters:
  • X (DESIGN_INPUT_TYPE) – Input data, array of shape (n_time_bins, n_features) or pytree of same.

  • y (jnp.ndarray) – Target data, array of shape (n_time_bins,) for single neuron models or (n_time_bins, n_neurons) for population models.

  • init_params (UserProvidedParamsT) – Initial parameter tuple of (coefficients, intercept).

Returns:

Initial solver state.

Return type:

SolverState

Raises:

ValueError – If inputs or parameters have incompatible shapes or invalid values.

initialize_params(X, y)#

Initialize model parameters.

Initialize coefficients with zeros and intercept by matching the mean firing rate.

Parameters:
  • X (DESIGN_INPUT_TYPE) – Input data, array of shape (n_time_bins, n_features) or pytree of same.

  • y (jnp.ndarray) – Target data, array of shape (n_time_bins,) for single neuron models or (n_time_bins, n_neurons) for population models.

Returns:

Initial parameter tuple of (coefficients, intercept).

Return type:

UserProvidedParamsT

Inverse link function mapping the linear predictor to the emission space.

Always a callable. If None was passed at construction time, this is resolved to the observation model’s default (e.g. jnp.exp for Poisson, 1 / x for Gamma, jax.nn.sigmoid for Bernoulli). Shared across all hidden states.

property maxiter#

EM maximum number of iterations.

property model_initialization_funcs: MODEL_INITIALIZATION_FN_DICT_T | None#

Dictionary of initialization functions for model parameters.

property n_states: int#

Number of hidden states of the HMM.

property observation_model: Observations#

The observation model governing the emission distribution at each state.

Always an instance of an Observations subclass. The same distribution is used across all hidden states (per-state differences come from the state-specific coefficients/intercept/scale, not from the family). If a string alias was passed at construction time it is resolved to the corresponding instance here.

property optimizer_init_state: None | Callable[[Any, Array, Array], SolverState]#

Provides the initialization function for the optimizer state.

This function is responsible for initializing the optimizer state, necessary for the start of the optimizer process. It sets up initial values for parameters like gradients and step sizes based on the model configuration and input data.

Returns:

The function to initialize the optimizer state, if available; otherwise, None if the optimizer has not yet been instantiated.

property optimizer_run: None | Callable[[Any, Array, Array], Tuple[Any, SolverState, Aux]]#

Provides the function to execute the optimization process.

This function runs the optimizer using the initialized parameters and state, performing the optimization to fit the model to the data. It iteratively updates the model parameters until a stopping criterion is met, such as convergence or exceeding a maximum number of iterations.

Returns:

The function to run the optimization process, if available; otherwise, None if the optimizer has not yet been instantiated.

property optimizer_update: None | Callable[[Any, NamedTuple, Array, Array], Tuple[Any, SolverState, Aux]]#

Provides the function for updating the state during the optimization process.

This function is used to perform a single update step in the optimization process. It updates the model’s parameters based on the current state, data, and gradients. It is typically used in scenarios where fine-grained control over each optimizer step is necessary, such as in online learning or complex optimization scenarios.

Returns:

The function to perform a single optimization update step, if available; otherwise, None if the optimizer has not yet been instantiated.

property regularizer: None | Regularizer#

Getter for the regularizer attribute.

property regularizer_strength: Any#

Regularizer strength getter.

save_params(filename)[source]#

Save GLM-HMM model parameters and fit state to a .npz file.

Persists hyperparameters returned by get_params() together with the fitted attributes (coef_, intercept_, scale_, initial_prob_, transition_prob_, dof_resid_). The solver_state_ is intentionally excluded as it is solver-specific and not needed to reuse the fitted model. The file can be reloaded with nemos.load_model().

If the model was configured with custom initialization functions, pass them back to nemos.load_model() via mapping_dict to restore them (see example below). Built-in initializers are resolved automatically.

Parameters:

filename (Union[str, Path]) – Path of the output file (.npz format).

Examples

Default round-trip — built-in initializers are resolved automatically on load:

>>> import os, tempfile
>>> import numpy as np
>>> import nemos as nmo
>>> np.random.seed(0)
>>> X = np.random.normal(size=(80, 3))
>>> y = np.random.binomial(n=1, p=0.5, size=80)
>>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X, y)
>>> with tempfile.TemporaryDirectory() as d:
...     path = os.path.join(d, "glmhmm.npz")
...     model.save_params(path)
...     loaded = nmo.load_model(path)
>>> bool(np.allclose(model.coef_, loaded.coef_))
True

Round-trip with a custom GLM-params initializer. Pass it back as a partial dict under model_initialization_funcs; remaining slots fall back to the saved (built-in) names:

>>> import jax.numpy as jnp
>>> def my_glm_init(
...     n_states, X, y, inverse_link_function, observation_model,
...     session_starts, random_key,
... ):
...     return jnp.zeros((X.shape[1], n_states)), jnp.zeros((n_states,))
>>> model = nmo.glm_hmm.GLMHMM(n_states=2)
>>> model.setup(glm_params_init=my_glm_init)
>>> _ = model.fit(X, y)
>>> with tempfile.TemporaryDirectory() as d:
...     path = os.path.join(d, "glmhmm.npz")
...     model.save_params(path)
...     loaded = nmo.load_model(
...         path,
...         mapping_dict={
...             "model_initialization_funcs": {"glm_params_init": my_glm_init},
...         },
...     )
>>> loaded.model_initialization_funcs["glm_params_init"] is my_glm_init
True
score(X, y, session_starts=None)#

Marginal log-likelihood of the data under the fitted HMM.

HMM-family models score only by log-likelihood. Variance-based or deviance-based pseudo-R² metrics are not implemented because they depend on a null/saturated-model construction that has no clean analogue for latent-state sequence models. Compute AIC/BIC or held-out log-likelihood externally if needed.

Parameters:
  • X (Union[DESIGN_INPUT_TYPE, ArrayLike]) – Input data/design matrix, shape (n_samples, n_features).

  • y (ArrayLike) – Output data/observations, shape (n_samples, n_observations).

  • session_starts (Optional[ArrayLike]) – Optional array indicating user-provided session boundaries. Can be: - a boolean array indicating session starts, shape (n_samples,) - an integer array of indices marking session starts, shape (n_sessions,) - a pynapple.IntervalSet marking session epochs (requires either X or y to be a pynapple Tsd or TsdFrame to get timestamps) If None, creates a default array treating all data as one session.

Return type:

jnp.ndarray

Returns:

The marginal log-likelihood (summed over time).

property seed#

Random seed as a jax PRNG key.

set_fit_request(*, init_params='$UNCHANGED$', session_starts='$UNCHANGED$')#

Configure whether metadata should be requested to be passed to the fit method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:
  • init_params (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for init_params parameter in fit.

  • session_starts (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for session_starts parameter in fit.

  • self (GLMHMM)

Returns:

self – The updated object.

Return type:

object

set_params(**params)#

Manage warnings in case of multiple parameter settings.

Parameters:

params (Any)

set_score_request(*, session_starts='$UNCHANGED$')#

Configure whether metadata should be requested to be passed to the score method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to score if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to score.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:
  • session_starts (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for session_starts parameter in score.

  • self (GLMHMM)

Returns:

self – The updated object.

Return type:

object

setup(initial_proba_init=None, initial_proba_init_kwargs=None, transition_proba_init=None, transition_proba_init_kwargs=None, glm_params_init=None, glm_params_init_kwargs=None, scale_init=None, scale_init_kwargs=None)[source]#

Configure how fit() initializes each model parameter.

Calling setup() is optional: if it is never called, fitting starts from the default initializers listed below. Use it to change the initialization strategy by providing either the name of a built-in initialization function or a custom callable. Each argument left as None keeps the previously configured value; only the parameters you supply are updated.

Available built-in initialization functions:

  • initial_proba_init: "uniform" (default), "random", "dirichlet", "kmeans".

  • transition_proba_init: "sticky" (default), "uniform", "random", "dirichlet", "kmeans".

  • glm_params_init: "random" (default), "kmeans".

  • scale_init: "constant" (default), "kmeans".

Parameters:
  • initial_proba_init (Union[Literal['uniform', 'random', 'dirichlet', 'kmeans'], InitFunctionHMM, None]) – Built-in name or custom callable used to initialize the initial-state probabilities (shape (n_states,)).

  • initial_proba_init_kwargs (Optional[dict]) – Extra keyword arguments forwarded to initial_proba_init.

  • transition_proba_init (Union[Literal['sticky', 'uniform', 'random', 'dirichlet', 'kmeans'], InitFunctionHMM, None]) – Built-in name or custom callable used to initialize the transition matrix (shape (n_states, n_states)).

  • transition_proba_init_kwargs (Optional[dict]) – Extra keyword arguments forwarded to transition_proba_init.

  • glm_params_init (Union[Literal['random', 'kmeans'], InitFunctionGLM, None]) – Built-in name or custom callable used to initialize the per-state GLM coefficients and intercepts.

  • glm_params_init_kwargs (Optional[dict]) – Extra keyword arguments forwarded to glm_params_init.

  • scale_init (Union[Literal['constant', 'kmeans'], InitFunctionGLM, None]) – Built-in name or custom callable used to initialize the scale parameter of the observation model (e.g. variance for Gaussian, dispersion for NegativeBinomial). Ignored by observation models without a scale.

  • scale_init_kwargs (Optional[dict]) – Extra keyword arguments forwarded to scale_init.

Raises:

ValueError – If a custom callable’s signature is incompatible with the protocol above, or if a *_kwargs entry contains keys that don’t match the corresponding initializer’s signature.

Notes

Custom callables must satisfy one of two typing.Protocol classes:

  • initial_proba_init and transition_proba_init must satisfy InitFunctionHMM and return a jnp.ndarray of shape (n_states,) or (n_states, n_states) respectively.

  • glm_params_init and scale_init must satisfy InitFunctionGLM. glm_params_init returns (coef, intercept) matched to the design and n_states; scale_init returns the scale array for the observation model.

To inspect a protocol’s signature, import and help() it:

from nemos.hmm.initialize_parameters import InitFunctionHMM
from nemos.glm_hmm.initialize_parameters import InitFunctionGLM
help(InitFunctionHMM)  # or help(InitFunctionGLM)

All arguments must appear in the function signature even when unused, so the framework can supply them uniformly.

Examples

Switch a parameter to a different built-in scheme by passing its label:

>>> from nemos.glm_hmm import GLMHMM
>>> model = GLMHMM(n_states=3)
>>> model.setup(initial_proba_init="random", glm_params_init="kmeans")

Plug in a custom callable matching the GLM-side protocol:

>>> import jax.numpy as jnp
>>> def my_glm_init(
...     n_states, X, y, inverse_link_function, observation_model,
...     session_starts, random_key,
... ):
...     coef = jnp.zeros((X.shape[1], n_states))
...     intercept = jnp.zeros((n_states,))
...     return coef, intercept
>>> model.setup(glm_params_init=my_glm_init)
simulate(random_key, feedforward_input, state_format='index', session_starts=None)[source]#

Simulate neural activity and hidden states from the model.

Simulates a trajectory through the hidden state space according to the HMM dynamics, then generates observations from the GLM emission model conditioned on each state.

Parameters:
  • random_key (jax.Array) – JAX random key for reproducible simulation.

  • feedforward_input (DESIGN_INPUT_TYPE) – Design matrix of shape (n_time_bins, n_features). If a pynapple Tsd/TsdFrame is provided, session boundaries are detected from time_support and the hidden state chain is reset at each session start.

  • state_format (Literal[‘one-hot’, ‘index’]) –

    Format for the returned states:

    • "index": Integer array of shape (n_time_bins,) with state indices.

    • "one-hot": Binary array of shape (n_time_bins, n_states).

  • session_starts (Optional[jax.Array]) –

    Optional session boundaries. Accepts:

    • a boolean array of shape (n_time_bins,) with True at each session start,

    • an integer array of session-start indices,

    • a pynapple IntervalSet (requires feedforward_input to be a pynapple object to supply timestamps).

    If feedforward_input is a pynapple object and session_starts is None, the time_support determines the session starts. With no pynapple input and session_starts=None, the whole input is treated as a single session.

Return type:

Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]

Returns:

  • simulated_activity – Simulated observations from the emission model. Shape (n_time_bins,) for single neuron or (n_time_bins, n_neurons) for population models.

  • firing_rates – Predicted firing rates conditioned on the simulated states. Shape (n_time_bins,) or (n_time_bins, n_neurons).

  • simulated_states – Simulated hidden state trajectory. Shape depends on state_format.

Raises:

ValueError – If the model has not been fit.

Examples

>>> import jax
>>> import numpy as np
>>> import nemos as nmo
>>> np.random.seed(123)
>>> X = np.random.randn(100, 3)
>>> y = np.random.binomial(1, 0.5, 100)
>>> model = nmo.glm_hmm.GLMHMM(n_states=2, observation_model="Bernoulli")
>>> model = model.fit(X, y)
>>> key = jax.random.key(0)
>>> X_new = np.random.randn(50, 3)
>>> activity, rates, states = model.simulate(key, X_new)
>>> activity.shape
(50,)
>>> states.shape
(50,)

See also

decode_state

Infer most likely state sequence from observations.

smooth_proba

Compute posterior state probabilities.

smooth_proba(X, y, session_starts=None)[source]#

Compute smoothing posterior probabilities over hidden states.

Computes the probability of being in each hidden state at each time bin, conditioned on the entire observed sequence. Uses the forward-backward algorithm to incorporate information from both past and future observations, providing optimal state estimates given all available data.

The smoothing posteriors answer: “Given all observations, what is the probability that the system was in state k at time t?”

Parameters:
  • X (Union[DESIGN_INPUT_TYPE, ArrayLike]) – Predictors, shape (n_time_bins, n_features). A pytree of 2-D arrays sharing the leading time axis is also accepted.

  • y (Union[NDArray, jnp.ndarray, nap.Tsd]) – Observations, shape (n_time_bins,) for a single neuron or (n_time_bins, n_neurons) for a population model. A pynapple Tsd/TsdFrame is accepted; session boundaries are then inferred from time_support.

  • session_starts (Optional[ArrayLike]) –

    Optional session boundaries. Accepts:

    • a boolean array of shape (n_time_bins,) with True at each session start,

    • an integer array of session-start indices,

    • a pynapple IntervalSet (requires X or y to be a pynapple object to supply timestamps).

    If None, the entire input is treated as a single session.

Returns:

Smoothing posterior probabilities, shape (n_time_bins, n_states). Each row sums to 1. Returns a pynapple TsdFrame (with columns named "state_0", "state_1", …) when the inputs are pynapple objects; otherwise returns a JAX array.

Return type:

jnp.ndarray | nap.TsdFrame

Raises:
  • ValueError – If the model has not been fitted (call fit() first).

  • ValueError – If X or y contain NaN values in the interior of an epoch (boundary NaNs are allowed and removed before inference).

  • ValueError – If X and y have inconsistent shapes or feature counts.

See also

filter_proba

Compute filtering posteriors (conditioned on past observations only).

decode_state

Compute the most likely state sequence via Viterbi decoding.

Notes

Smoothing uses all data (non-causal) and gives better state estimates than filtering. For online or real-time applications use filter_proba() instead. Session boundaries reset the HMM chain so that no information crosses session borders.

Examples

Fit a GLM-HMM and compute smoothing posteriors:

>>> import numpy as np
>>> import nemos as nmo
>>> np.random.seed(123)
>>> X = np.random.randn(100, 5)
>>> y = np.random.poisson(2, size=100)
>>> model = nmo.glm_hmm.GLMHMM(n_states=3, observation_model="Poisson").fit(X, y)
>>> posteriors = model.smooth_proba(X, y)
>>> posteriors.shape
(100, 3)
>>> bool(np.allclose(posteriors.sum(axis=1), 1.0))
True

With pynapple inputs the result is returned as a TsdFrame:

>>> import pynapple as nap
>>> t = np.arange(100) * 0.01
>>> X_tsd = nap.TsdFrame(t=t, d=X)
>>> y_tsd = nap.Tsd(t=t, d=y.astype(float))
>>> type(model.smooth_proba(X_tsd, y_tsd)).__name__
'TsdFrame'
property solver_kwargs#

Getter for the solver_kwargs attribute.

property solver_name: str#

Getter for the solver_name attribute.

property solver_spec: SolverSpec#

Getter for the solver specification.

property tol#

Tolerance for the EM algorithm convergence criterion.

The algorithm stops when the absolute change in log-likelihood between consecutive iterations falls below this threshold:

|log_likelihood_current - log_likelihood_previous| < tol

Returns:

float

Return type:

Convergence tolerance value.

update(params, opt_state, X, y, *args, session_starts=None, n_samples=None, **kwargs)[source]#

Run a single EM iteration on the GLM-HMM.

Performs one E-step / M-step pair starting from the supplied parameters and EM state, updates the model’s fitted attributes (coef_, intercept_, scale_, initial_prob_, transition_prob_, solver_state_, dof_resid_) in place, and returns the updated parameter tuple and EM state. Intended for callers that need fine-grained control over EM iteration (e.g. checkpointing, custom convergence criteria) instead of the bundled fit() loop.

initialize_optimizer_and_state() must be called first so that the EM step function and initial opt_state are available.

Parameters:
  • params (GLMHMMUserParams) – Current model parameters as a 5-tuple (coef, intercept, scale, initial_prob, transition_prob) matching the structure produced by initialize_params().

  • opt_state (NamedTuple) – EM state returned by initialize_optimizer_and_state() or by the previous call to update().

  • X (DESIGN_INPUT_TYPE) – Predictors, shape (n_time_bins, n_features) (or a pytree of arrays of the same shape).

  • y (jnp.ndarray) – Observations, shape (n_time_bins,) or (n_time_bins, n_neurons).

  • session_starts (Optional[jnp.ndarray]) –

    Optional session boundaries. Accepts:

    • a boolean array of shape (n_time_bins,) with True at each session start,

    • an integer array of session-start indices,

    • a pynapple IntervalSet (requires X or y to be a pynapple object to supply timestamps).

    If None, the entire input is treated as a single session.

  • n_samples (Optional[int]) – Total sample count to use when estimating the residual degrees of freedom. Defaults to X.shape[0].

Return type:

StepResult

Returns:

  • params – Updated user-facing parameter tuple.

  • state – Updated EM state.

Raises:

ValueError – If inputs fail shape/consistency validation.

Examples

>>> import numpy as np
>>> import nemos as nmo
>>> np.random.seed(0)
>>> X = np.random.normal(size=(80, 3))
>>> y = np.random.binomial(n=1, p=0.5, size=80)
>>> model = nmo.glm_hmm.GLMHMM(n_states=2)
>>> init_params = model.initialize_params(X, y)
>>> opt_state = model.initialize_optimizer_and_state(init_params, X, y)
>>> new_params, new_state = model.update(init_params, opt_state, X, y)