elman_mrnn

mRNN core module.

Implements the multi-regional recurrent neural network (mRNN) building blocks and step-wise dynamics, along with helpers for connectivity, constraints, and initialization.

class mrnntorch.mrnn.elman_mrnn.ElmanmRNN(config: str = None, activation: str = 'relu', noise_level_act: float = 0.01, noise_level_inp: float = 0.01, rec_constrained: bool = True, inp_constrained: bool = True, batch_first: bool = True, spectral_radius: float = None, config_finalize: bool = True, device: str = 'cuda')[source]

Bases: mRNNBase

Elman-style multi-regional RNN that evolves only hidden activations.

batched_initial_condition(batch_size: int) Tuple[Tensor, Tensor][source]

Return the batched initial hidden state.

forward(inp: Tensor, h0: Tensor, stim_input: Tensor | None = None, noise: bool = False, W_rec: Tensor | None = None) Tensor[source]

Run the recurrent dynamics over a sequence.

Discretized update: h_{t+1} =  W_rec h_t + W_inp u_t + b + noise and h_{t+1} = activation(x_{t+1}).

Args:
inp (torch.Tensor): Input sequence. Shape [B, T, I] if batch_first

else [T, B, I].

x0 (torch.Tensor): Initial pre-activation hidden state, shape [B, H]. h0 (torch.Tensor): Initial activation, shape [B, H]. *args (torch.Tensor): Optional additive inputs with same temporal layout

as inp and feature size H.

noise (bool): If True, add Gaussian noise to hidden state and inputs.

Returns:

tuple[torch.Tensor, torch.Tensor]: (x_seq, h_seq) sequences matching the temporal layout of inp.

mrnntorch.mrnn.elman_mrnn.linear(x)[source]

Return x unchanged.

leaky_mrnn

mRNN core module.

Implements the multi-regional recurrent neural network (mRNN) building blocks and step-wise dynamics, along with helpers for connectivity, constraints, and initialization.

mrnntorch.mrnn.leaky_mrnn.linear(x)[source]

Return x unchanged.

class mrnntorch.mrnn.leaky_mrnn.mRNN(config: str = None, activation: str = 'relu', noise_level_act: float = 0.01, noise_level_inp: float = 0.01, rec_constrained: bool = True, inp_constrained: bool = True, batch_first: bool = True, spectral_radius: float = None, config_finalize: bool = True, device: str = 'cuda', dt: float = 10, tau: float = 100)[source]

Bases: mRNNBase

Leaky multi-regional RNN with separate pre-activation and activation states.

batched_initial_condition(batch_size: int) Tuple[Tensor, Tensor][source]

Return batched initial pre-activation and activation states.

forward(inp: Tensor, x0: Tensor, h0: Tensor | None = None, stim_input: Tensor | None = None, noise: bool = False, W_rec: Tensor | None = None) Tuple[Tensor, Tensor][source]

Run the recurrent dynamics over a sequence.

Discretized update: x_{t+1} = x_t + alpha * (-x_t + W_rec h_t + W_inp u_t + b + noise) and h_{t+1} = activation(x_{t+1}).

Args:
inp (torch.Tensor): Input sequence. Shape [B, T, I] if batch_first

else [T, B, I].

x0 (torch.Tensor): Initial pre-activation hidden state, shape [B, H]. h0 (torch.Tensor): Initial activation, shape [B, H]. *args (torch.Tensor): Optional additive inputs with same temporal layout

as inp and feature size H.

noise (bool): If True, add Gaussian noise to hidden state and inputs.

Returns:

tuple[torch.Tensor, torch.Tensor]: (x_seq, h_seq) sequences matching the temporal layout of inp.

mrnn_base

mRNN core module.

Implements the multi-regional recurrent neural network (mRNN) building blocks and step-wise dynamics, along with helpers for connectivity, constraints, and initialization.

mrnntorch.mrnn.mrnn_base.linear(x)[source]
class mrnntorch.mrnn.mrnn_base.mRNNBase(config: str = None, activation: str = 'relu', noise_level_act: float = 0.01, noise_level_inp: float = 0.01, rec_constrained: bool = True, inp_constrained: bool = True, batch_first: bool = True, spectral_radius: float = None, config_finalize: bool = True, device: str = 'cuda')[source]

Bases: Module

add_input_connection(src_region: str, dst_region: str, sparsity: float | None = None)[source]

Create an input connection from an input region to a recurrent region. Currently not allowed to make a zero connection.

Args:

src_region (str): Source input region name. dst_region (str): Destination recurrent region name. sparsity (float | None): Fraction of connections to keep (0-1). If

None, dense mask is used.

add_input_region(name: str, num_units: int, sign: str = 'pos')[source]

Add an input region to the network.

Args:

name (str): Input region name (unique key). num_units (int): Number of input channels in this region. sign (str): “pos” or “neg”; used to set sign mask for inputs.

add_recurrent_connection(src_region: str, dst_region: str, sparsity: float = None)[source]

Create a recurrent connection from one region to another.

Registers the weight parameter and associated masks. If sparsity is provided, a binary connectivity mask is sampled accordingly.

Currently not allowed to make a zero connection.

Args:

src_region (str): Source recurrent region name. dst_region (str): Destination recurrent region name. sparsity (float | None): Fraction of connections to keep (0-1). If

None, dense mask is used.

add_recurrent_region(name: str, num_units: int, sign: str = 'pos', base_firing: float = 0, init: float = 0, parent_region: str = None, learnable_bias: bool = False)[source]

Add a recurrent region to the network.

Args:

name (str): Region name (unique key). num_units (int): Number of units in this region. sign (str): “pos” for excitatory or “neg” for inhibitory outputs. base_firing (float | torch.Tensor): Baseline firing per unit. init (float): Initial pre-activation value for units in this region. parent_region (str | None): Optional parent region identifier. learnable_bias (bool): If True, baseline firing is trainable.

apply_dales_law(W_rec: Tensor, W_rec_mask: Tensor, W_rec_sign_matrix: Tensor) Tensor[source]

Applies Dale’s Law constraints to the recurrent weight matrix. Dale’s Law states that a neuron can be either excitatory or inhibitory, but not both.

Returns:

torch.Tensor: Constrained weight matrix

batched_initial_condition(*args, **kwargs) Tuple[Tensor, Tensor][source]
combine_states(states_a: Tensor, states_b: Tensor, region_list_a: list, region_list_b: list, keep_dims=True) Tensor[source]

Take two states and concatenate them in order according to their given region lists

Args:

states_a (Tensor): a tensor of states from regions in region_list_a (assumed to be in exact order) states_b (Tensor): a tensor of states from regions in region_list_b (assumed to be in exact order) region_list_a (list): list containing hidden regions of states_a in order region_list_b (list): list containing hidden regions of states_b in order

Returns:

full_state (Tensor): concatenated tensor of states_a and states_b along unit dimension

compute_spectral_radius(weight: Tensor) float[source]

Compute the spectral radius (max |eigenvalue|) of a square matrix.

Args:

weight (torch.Tensor): Square weight matrix.

Returns:

torch.Tensor: Spectral radius as a scalar tensor.

finalize_connectivity()[source]

Finalize both input and recurrent connectivity This function is primarily implemented so users don’t have to separately call rec and inp connectivity functions

finalize_inp_connectivity()[source]

Fill rest of input connections with zeros Ensure finalized flag is set to true

finalize_rec_connectivity()[source]

Fill rest of recurrent connections with zeros Ensure finalized flag is set to true

forward(*args, **kwargs)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

gen_w(dict_: dict) tuple[Tensor, Tensor, Tensor][source]

Generates the full recurrent connectivity matrix and associated masks.

Returns:
tuple: (W_rec, W_rec_mask, W_rec_sign_matrix)
  • W_rec: Learnable weight matrix

  • W_rec_mask: Binary mask for allowed connections

  • W_rec_sign_matrix: Sign constraints for Dale’s Law

get_excluded_hid_regions(*args) list[source]

Return all of the hidden regions in the network not given to this function

get_excluded_inp_regions(*args) list[source]

Return all of the input regions in the network not given to this function

get_projection(to: str, from_: str) Tensor[source]

Gather a subset of the weights

Args:

to (str): Name of region that is recieving projection (row) from_ (str): Name of region projecting (column)

Returns:

torch.Tensor: weight matrix of from_->to projection

get_region_activity(act: Tensor, *args) Tensor[source]

Takes in hn and the specified region and returns the activity hn for the corresponding region

Args:

act (Torch.Tensor): tensor containing model hidden activity. Activations must be in last dimension (-1) args (str): name of regions to collect activity from

Returns:

region_hn: tensor containing hidden activity only for specified region

get_region_indices(region: str) tuple[int, int][source]

Gets the start and end indices for a specific region in the hidden state vector.

Args:

region (str): Name of the region

Returns:

tuple: (start_idx, end_idx)

get_region_size(region: str) int[source]

Get the number of units in a region

Args:

region (str): region to get size of

get_weight_subset(*args, W: Tensor | None = None) Tensor[source]

Gather a subset of the weights from all regions in args to and from each other and themselves.

This should return a square matrix of all connections between regions in args

Args:

args (str): all regions specified W (torch.Tensor): use this specified weight matrix instead

Returns:

torch.Tensor: subset of the total weight matrix

property hid_noise_const

noise constant used for hidden activity

property hid_regions

Returns names of all hidden regions in network as a list

property initial_condition: Tensor

Create an initial xn for the network

Returns:

Tensor: tensor like xn filled with region specified initial conds

property inp_noise_const

noise constant used for inputs

property inp_regions

Returns names of all input regions in network as a list

named_inp_regions(prefix: str = '')[source]

Loop through inp region names and objects

Args:

prefix (str, optional): Defaults to ‘’.

named_rec_regions(prefix: str = '')[source]

Loop through rec region names and objects

Args:

prefix (str, optional): Defaults to ‘’.

set_spectral_radius(W: Tensor, W_tmp: Tensor | None = None) Tensor[source]

Scale recurrent weights so their spectral radius matches self.spectral_radius.

Usage:
  1. Define regions and connections (via config or manual methods).

  2. If building manually, call finalize_connectivity() first.

  3. Set self.spectral_radius and call this method.

  4. W_tmp will compute spectral radius of another network (i.e dales law network)

property tonic_inp

Collects baseline firing rates for all regions.

Returns:

torch.Tensor: Vector of baseline firing rates