input_region

class mrnntorch.region.input_region.InputRegion(num_units, sign='pos', device='cuda')[source]

Bases: Region

recurrent_region

class mrnntorch.region.recurrent_region.RecurrentRegion(num_units, base_firing=0, init=0, sign='pos', parent_region=None, learnable_bias=False, device='cuda')[source]

Bases: Region

region_base

Region definitions for mRNN.

Provides base Region and concrete RecurrentRegion and InputRegion containers that own connection parameters and masks.

class mrnntorch.region.region_base.Connection(parameter: Tensor | None = None, weight_mask: Tensor | None = None, sign_matrix: Tensor | None = None, zero_connection: Tensor | None = None)[source]

Bases: object

Container for a single inter-region connection and its masks.

parameter: Tensor | None = None
sign_matrix: Tensor | None = None
weight_mask: Tensor | None = None
zero_connection: Tensor | None = None
class mrnntorch.region.region_base.Region(num_units: int, sign: str = 'pos', device: str = 'cuda')[source]

Bases: Module

Base class for regions used by mRNN.

Models outgoing connections to other regions along with simple region properties. Each region maintains its own connection parameters and masks (including sign masks enforcing Dale’s Law when used by mRNN).

Attributes:

num_units (int): Number of units in the region. sign (str): “pos” for excitatory or “neg” for inhibitory outputs. device (str): Torch device string for tensors. connections (dict): Mapping of destination region name -> dict with

keys {“parameter”, “weight_mask”, “sign_matrix”, “zero_connection”}.

masks (dict): Convenience masks including “ones” and “zeros” of length num_units.

add_connection(dst_region_name: str, dst_region_units: int, sparsity: float | None = None, zero_connection: bool = False)[source]

Add a connection from this region to dst_region.

Creates a trainable weight parameter and associated non-trainable masks. If sparsity is provided, a binary mask is sampled to achieve the requested sparsity.

Args:

dst_region_name (str): Name of the destination region. dst_region_units (int): Number of units in destination region sparsity (float | None): Fraction of nonzero connections (0-1). zero_connection (bool): If True, registers a fixed zero connection

(no trainable parameters are created for this edge).

has_connection_to(region: str) bool[source]

Return whether this region already defines a connection to region.