input_region¶
recurrent_region¶
region_base¶
Region definitions for mRNN.
Provides base Region and concrete RecurrentRegion and
InputRegion containers that own connection parameters and masks.
- class mrnntorch.region.region_base.Connection(parameter: Tensor | None = None, weight_mask: Tensor | None = None, sign_matrix: Tensor | None = None, zero_connection: Tensor | None = None)[source]¶
Bases:
objectContainer for a single inter-region connection and its masks.
- parameter: Tensor | None = None¶
- sign_matrix: Tensor | None = None¶
- weight_mask: Tensor | None = None¶
- zero_connection: Tensor | None = None¶
- class mrnntorch.region.region_base.Region(num_units: int, sign: str = 'pos', device: str = 'cuda')[source]¶
Bases:
ModuleBase class for regions used by mRNN.
Models outgoing connections to other regions along with simple region properties. Each region maintains its own connection parameters and masks (including sign masks enforcing Dale’s Law when used by mRNN).
- Attributes:
num_units (int): Number of units in the region. sign (str): “pos” for excitatory or “neg” for inhibitory outputs. device (str): Torch device string for tensors. connections (dict): Mapping of destination region name -> dict with
keys {“parameter”, “weight_mask”, “sign_matrix”, “zero_connection”}.
masks (dict): Convenience masks including “ones” and “zeros” of length
num_units.
- add_connection(dst_region_name: str, dst_region_units: int, sparsity: float | None = None, zero_connection: bool = False)[source]¶
Add a connection from this region to
dst_region.Creates a trainable weight parameter and associated non-trainable masks. If
sparsityis provided, a binary mask is sampled to achieve the requested sparsity.- Args:
dst_region_name (str): Name of the destination region. dst_region_units (int): Number of units in destination region sparsity (float | None): Fraction of nonzero connections (0-1). zero_connection (bool): If True, registers a fixed zero connection
(no trainable parameters are created for this edge).