elman_mrnn¶
mRNN core module.
Implements the multi-regional recurrent neural network (mRNN) building blocks and step-wise dynamics, along with helpers for connectivity, constraints, and initialization.
- class mrnntorch.mrnn.elman_mrnn.ElmanmRNN(config: str = None, activation: str = 'relu', noise_level_act: float = 0.01, noise_level_inp: float = 0.01, rec_constrained: bool = True, inp_constrained: bool = True, batch_first: bool = True, spectral_radius: float = None, config_finalize: bool = True, device: str = 'cuda')[source]¶
Bases:
mRNNBaseElman-style multi-regional RNN that evolves only hidden activations.
- batched_initial_condition(batch_size: int) Tuple[Tensor, Tensor][source]¶
Return the batched initial hidden state.
- forward(inp: Tensor, h0: Tensor, stim_input: Tensor | None = None, noise: bool = False, W_rec: Tensor | None = None) Tensor[source]¶
Run the recurrent dynamics over a sequence.
Discretized update:
h_{t+1} = W_rec h_t + W_inp u_t + b + noiseandh_{t+1} = activation(x_{t+1}).- Args:
- inp (torch.Tensor): Input sequence. Shape
[B, T, I]if batch_first else
[T, B, I].
x0 (torch.Tensor): Initial pre-activation hidden state, shape
[B, H]. h0 (torch.Tensor): Initial activation, shape[B, H]. *args (torch.Tensor): Optional additive inputs with same temporal layoutas
inpand feature sizeH.noise (bool): If True, add Gaussian noise to hidden state and inputs.
- inp (torch.Tensor): Input sequence. Shape
- Returns:
tuple[torch.Tensor, torch.Tensor]:
(x_seq, h_seq)sequences matching the temporal layout ofinp.
leaky_mrnn¶
mRNN core module.
Implements the multi-regional recurrent neural network (mRNN) building blocks and step-wise dynamics, along with helpers for connectivity, constraints, and initialization.
- class mrnntorch.mrnn.leaky_mrnn.mRNN(config: str = None, activation: str = 'relu', noise_level_act: float = 0.01, noise_level_inp: float = 0.01, rec_constrained: bool = True, inp_constrained: bool = True, batch_first: bool = True, spectral_radius: float = None, config_finalize: bool = True, device: str = 'cuda', dt: float = 10, tau: float = 100)[source]¶
Bases:
mRNNBaseLeaky multi-regional RNN with separate pre-activation and activation states.
- batched_initial_condition(batch_size: int) Tuple[Tensor, Tensor][source]¶
Return batched initial pre-activation and activation states.
- forward(inp: Tensor, x0: Tensor, h0: Tensor | None = None, stim_input: Tensor | None = None, noise: bool = False, W_rec: Tensor | None = None) Tuple[Tensor, Tensor][source]¶
Run the recurrent dynamics over a sequence.
Discretized update:
x_{t+1} = x_t + alpha * (-x_t + W_rec h_t + W_inp u_t + b + noise)andh_{t+1} = activation(x_{t+1}).- Args:
- inp (torch.Tensor): Input sequence. Shape
[B, T, I]if batch_first else
[T, B, I].
x0 (torch.Tensor): Initial pre-activation hidden state, shape
[B, H]. h0 (torch.Tensor): Initial activation, shape[B, H]. *args (torch.Tensor): Optional additive inputs with same temporal layoutas
inpand feature sizeH.noise (bool): If True, add Gaussian noise to hidden state and inputs.
- inp (torch.Tensor): Input sequence. Shape
- Returns:
tuple[torch.Tensor, torch.Tensor]:
(x_seq, h_seq)sequences matching the temporal layout ofinp.
mrnn_base¶
mRNN core module.
Implements the multi-regional recurrent neural network (mRNN) building blocks and step-wise dynamics, along with helpers for connectivity, constraints, and initialization.
- class mrnntorch.mrnn.mrnn_base.mRNNBase(config: str = None, activation: str = 'relu', noise_level_act: float = 0.01, noise_level_inp: float = 0.01, rec_constrained: bool = True, inp_constrained: bool = True, batch_first: bool = True, spectral_radius: float = None, config_finalize: bool = True, device: str = 'cuda')[source]¶
Bases:
Module- add_input_connection(src_region: str, dst_region: str, sparsity: float | None = None)[source]¶
Create an input connection from an input region to a recurrent region. Currently not allowed to make a zero connection.
- Args:
src_region (str): Source input region name. dst_region (str): Destination recurrent region name. sparsity (float | None): Fraction of connections to keep (0-1). If
None, dense mask is used.
- add_input_region(name: str, num_units: int, sign: str = 'pos')[source]¶
Add an input region to the network.
- Args:
name (str): Input region name (unique key). num_units (int): Number of input channels in this region. sign (str): “pos” or “neg”; used to set sign mask for inputs.
- add_recurrent_connection(src_region: str, dst_region: str, sparsity: float = None)[source]¶
Create a recurrent connection from one region to another.
Registers the weight parameter and associated masks. If
sparsityis provided, a binary connectivity mask is sampled accordingly.Currently not allowed to make a zero connection.
- Args:
src_region (str): Source recurrent region name. dst_region (str): Destination recurrent region name. sparsity (float | None): Fraction of connections to keep (0-1). If
None, dense mask is used.
- add_recurrent_region(name: str, num_units: int, sign: str = 'pos', base_firing: float = 0, init: float = 0, parent_region: str = None, learnable_bias: bool = False)[source]¶
Add a recurrent region to the network.
- Args:
name (str): Region name (unique key). num_units (int): Number of units in this region. sign (str): “pos” for excitatory or “neg” for inhibitory outputs. base_firing (float | torch.Tensor): Baseline firing per unit. init (float): Initial pre-activation value for units in this region. parent_region (str | None): Optional parent region identifier. learnable_bias (bool): If True, baseline firing is trainable.
- apply_dales_law(W_rec: Tensor, W_rec_mask: Tensor, W_rec_sign_matrix: Tensor) Tensor[source]¶
Applies Dale’s Law constraints to the recurrent weight matrix. Dale’s Law states that a neuron can be either excitatory or inhibitory, but not both.
- Returns:
torch.Tensor: Constrained weight matrix
- combine_states(states_a: Tensor, states_b: Tensor, region_list_a: list, region_list_b: list, keep_dims=True) Tensor[source]¶
Take two states and concatenate them in order according to their given region lists
- Args:
states_a (Tensor): a tensor of states from regions in region_list_a (assumed to be in exact order) states_b (Tensor): a tensor of states from regions in region_list_b (assumed to be in exact order) region_list_a (list): list containing hidden regions of states_a in order region_list_b (list): list containing hidden regions of states_b in order
- Returns:
full_state (Tensor): concatenated tensor of states_a and states_b along unit dimension
- compute_spectral_radius(weight: Tensor) float[source]¶
Compute the spectral radius (max |eigenvalue|) of a square matrix.
- Args:
weight (torch.Tensor): Square weight matrix.
- Returns:
torch.Tensor: Spectral radius as a scalar tensor.
- finalize_connectivity()[source]¶
Finalize both input and recurrent connectivity This function is primarily implemented so users don’t have to separately call rec and inp connectivity functions
- finalize_inp_connectivity()[source]¶
Fill rest of input connections with zeros Ensure finalized flag is set to true
- finalize_rec_connectivity()[source]¶
Fill rest of recurrent connections with zeros Ensure finalized flag is set to true
- forward(*args, **kwargs)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- gen_w(dict_: dict) tuple[Tensor, Tensor, Tensor][source]¶
Generates the full recurrent connectivity matrix and associated masks.
- Returns:
- tuple: (W_rec, W_rec_mask, W_rec_sign_matrix)
W_rec: Learnable weight matrix
W_rec_mask: Binary mask for allowed connections
W_rec_sign_matrix: Sign constraints for Dale’s Law
- get_excluded_hid_regions(*args) list[source]¶
Return all of the hidden regions in the network not given to this function
- get_excluded_inp_regions(*args) list[source]¶
Return all of the input regions in the network not given to this function
- get_region_activity(act: Tensor, *args) Tensor[source]¶
Takes in hn and the specified region and returns the activity hn for the corresponding region
- Args:
act (Torch.Tensor): tensor containing model hidden activity. Activations must be in last dimension (-1) args (str): name of regions to collect activity from
- Returns:
region_hn: tensor containing hidden activity only for specified region
- get_region_indices(region: str) tuple[int, int][source]¶
Gets the start and end indices for a specific region in the hidden state vector.
- Args:
region (str): Name of the region
- Returns:
tuple: (start_idx, end_idx)
- get_region_size(region: str) int[source]¶
Get the number of units in a region
- Args:
region (str): region to get size of
- get_weight_subset(*args, W: Tensor | None = None) Tensor[source]¶
Gather a subset of the weights from all regions in args to and from each other and themselves.
This should return a square matrix of all connections between regions in args
- Args:
args (str): all regions specified W (torch.Tensor): use this specified weight matrix instead
- Returns:
torch.Tensor: subset of the total weight matrix
- property hid_noise_const¶
noise constant used for hidden activity
- property hid_regions¶
Returns names of all hidden regions in network as a list
- property initial_condition: Tensor¶
Create an initial xn for the network
- Returns:
Tensor: tensor like xn filled with region specified initial conds
- property inp_noise_const¶
noise constant used for inputs
- property inp_regions¶
Returns names of all input regions in network as a list
- named_inp_regions(prefix: str = '')[source]¶
Loop through inp region names and objects
- Args:
prefix (str, optional): Defaults to ‘’.
- named_rec_regions(prefix: str = '')[source]¶
Loop through rec region names and objects
- Args:
prefix (str, optional): Defaults to ‘’.
- set_spectral_radius(W: Tensor, W_tmp: Tensor | None = None) Tensor[source]¶
Scale recurrent weights so their spectral radius matches
self.spectral_radius.- Usage:
Define regions and connections (via config or manual methods).
If building manually, call
finalize_connectivity()first.Set
self.spectral_radiusand call this method.W_tmp will compute spectral radius of another network (i.e dales law network)
- property tonic_inp¶
Collects baseline firing rates for all regions.
- Returns:
torch.Tensor: Vector of baseline firing rates