Source code for mrnntorch.mrnn.mrnn_base

"""mRNN core module.

Implements the multi-regional recurrent neural network (mRNN) building blocks
and step-wise dynamics, along with helpers for connectivity, constraints, and
initialization."""

import torch
import torch.nn as nn
import numpy as np
import json
from typing import Tuple
from collections import OrderedDict
from mrnntorch.region.region_base import Region
from mrnntorch.region.recurrent_region import RecurrentRegion
from mrnntorch.region.input_region import InputRegion
from mrnntorch.region.region_base import (
    DEFAULT_REC_REGIONS,
    DEFAULT_REGION_BASE,
    DEFAULT_CONNECTIONS,
)


DEFAULTS_MRNN = {
    "config": None,
    "activation": "relu",
    "noise_level_act": 0.01,
    "noise_level_inp": 0.01,
    "rec_constrained": True,
    "inp_constrained": True,
    "batch_first": True,
    "spectral_radius": None,
    "config_finalize": True,
    "device": "cuda",
}


[docs] def linear(x): return x
[docs] class mRNNBase(nn.Module): def __init__( self, config: str = DEFAULTS_MRNN["config"], activation: str = DEFAULTS_MRNN["activation"], noise_level_act: float = DEFAULTS_MRNN["noise_level_act"], noise_level_inp: float = DEFAULTS_MRNN["noise_level_inp"], rec_constrained: bool = DEFAULTS_MRNN["rec_constrained"], inp_constrained: bool = DEFAULTS_MRNN["inp_constrained"], batch_first: bool = DEFAULTS_MRNN["batch_first"], spectral_radius: float = DEFAULTS_MRNN["spectral_radius"], config_finalize: bool = DEFAULTS_MRNN["config_finalize"], device: str = DEFAULTS_MRNN["device"], ): super(mRNNBase, self).__init__() """ Multi-Regional Recurrent Neural Network (mRNN). Simulates interactions between multiple recurrent "regions" with optional input regions. Supports Dale's Law constraints, tonic inputs, noise in hidden state and inputs, and basic step-wise dynamics with configurable discretization parameters. Key features: - Multiple regions with independent sizes and signs (excitatory/inhibitory) - Dale's Law constraints via sign masks on weights - Optional noise on hidden state and input - Tonic (baseline) input per region - JSON-based configuration or fully manual construction Args: config (str | None): Path to a JSON configuration file describing recurrent regions, input regions, and their connections. If None, build the network manually by calling the add_* methods. activation (str): One of {"relu", "tanh", "sigmoid", "softplus", "linear"}. noise_level_act (float): Std of hidden-state noise term. Default: 0.01. noise_level_inp (float): Std of input noise term. Default: 0.01. rec_constrained (bool): If True, apply Dale's Law to rec regions. Default: True. inp_constrained (bool): If True, apply Dale's Law to inp regions. Default: True. dt (float): Discrete step in ms used for the Euler update. Default: 10. tau (float): Time constant in ms; alpha = dt / tau. Default: 100. batch_first (bool): If True, sequences are [B, T, ...]; else [T, B, ...]. spectral_radius (float | None): If set, scales recurrent weights so the spectral radius equals this value after finalization. config_finalize (bool): If True and a config is supplied, finalize connectivity after reading config. Default: True. device (str): Torch device string (e.g., "cpu" or "cuda"). Default: "cuda". """ # Initialize network parameters self.region_dict = {} self.inp_dict = {} self.region_mask_dict = {} self.rec_constrained = rec_constrained self.inp_constrained = inp_constrained self.device = device self.batch_first = batch_first self.sigma_recur = noise_level_act self.sigma_input = noise_level_inp self.activation_name = activation self.spectral_radius = spectral_radius self.config_finalize = config_finalize self.rec_finalized = False self.inp_finalized = False # Specify activation function # Only common activations are implemented if activation == "relu": self.activation = nn.ReLU() elif activation == "tanh": self.activation = nn.Tanh() elif activation == "sigmoid": self.activation = nn.Sigmoid() elif activation == "softplus": self.activation = nn.Softplus() elif activation == "linear": self.activation = linear else: raise Exception( "Only relu, tanh, sigmoid, or linear activations are implemented" ) # Allow configuration file to be optional # If configuration is given, network will be automatically generated using it # Otherwise, user will manually build a network in their own class if config is not None: # Load and process configuration with open(config, "r") as f: config_file = json.load(f) # Default everything to empty dict # Nothing inherently needs to be specified and can instead be created in custom network config_file.setdefault("recurrent_regions", {}) config_file.setdefault("input_regions", {}) config_file.setdefault("recurrent_connections", {}) config_file.setdefault("input_connections", {}) """ Configuration file protocol: Here we are allowing flexibility between using configs vs. custom network definitions. None of the regions or connections in the network need to be fully specified in the config. Users can specify some regions and connections in the config, then define the rest manually if they choose. config_finalize is defaulted to True, this assumes that the connections in the configuration define all of the connections in the network, and that the network will be finalized automatically after passing the config. Users can set this to False to continue to build regions and connections in their custom model after passing the config. Lastly, an empty config that is passed in(or empty parts of the config) default to an empty dictionary {}. In this case, any key in the json file that is empty will not affect the network or be defined at all. These missing pieces must then be defined in the custom model. Connections that are defined in the config without the corresponding regions being defined will give an error. Additionally, the config file itself defaults to None, which would then imply the user needs to manually enter all regions and connections in the custom model. """ # Generate network structure self._create_def_values(config_file) if len(config_file["recurrent_regions"]) >= 1: # Generate recurrent regions for region in config_file["recurrent_regions"]: self.add_recurrent_region( name=region["name"], num_units=region["num_units"], sign=region["sign"], base_firing=region["base_firing"], init=region["init"], parent_region=region["parent_region"], learnable_bias=region["learnable_bias"], ) if len(config_file["input_regions"]) >= 1: # Generate input regions for region in config_file["input_regions"]: self.add_input_region( name=region["name"], num_units=region["num_units"], sign=region["sign"], ) # Now checking whether or not connections are specified in config if len(config_file["recurrent_connections"]) >= 1: # Generate recurrent connections for connection in config_file["recurrent_connections"]: self.add_recurrent_connection( src_region=connection["src_region"], dst_region=connection["dst_region"], sparsity=connection["sparsity"], ) """ Finalization for Configuration: This completes the connections matrix between regions by padding with zeros where explicit connections are not specified. """ if self.config_finalize: self.finalize_rec_connectivity() if len(config_file["input_connections"]) >= 1: # Generate input connections for connection in config_file["input_connections"]: self.add_input_connection( src_region=connection["src_region"], dst_region=connection["dst_region"], sparsity=connection["sparsity"], ) # Finalization input regions if self.config_finalize: self.finalize_inp_connectivity() def __setitem__(self, idx: str, region: RecurrentRegion | InputRegion): """Assign a recurrent region or input region to a valid index in mRNN Args: idx (str): an input or recurrent region in the mRNN region (RecurrentRegion | InputRegion): the new region used for assignment """ assert isinstance(idx, str), "Only string indexing to regions is allowed" if idx in self.region_dict: if isinstance(region, RecurrentRegion): self.region_dict[idx] = region else: raise ValueError( "Not a RecurrentRegion object, \ cannot assign to recurrent region" ) elif idx in self.inp_dict: if isinstance(region, InputRegion): self.inp_dict[idx] = region else: raise ValueError( "Not an InputRegion object, \ cannot assign to input region" ) else: raise ValueError("Index not a valid recurrent or input region") def __getitem__(self, idx: str) -> RecurrentRegion | InputRegion: """ Index a recurrent region or input region in mRNN Args: idx (str): an input or recurrent region in the mRNN """ assert isinstance(idx, str), "Only string indexing to regions is allowed" if idx in self.region_dict: return self.region_dict[idx] elif idx in self.inp_dict: return self.inp_dict[idx] else: raise ValueError("Index not a valid recurrent or input region")
[docs] def add_recurrent_region( self, name: str, num_units: int, sign: str = DEFAULT_REC_REGIONS["sign"], base_firing: float = DEFAULT_REC_REGIONS["base_firing"], init: float = DEFAULT_REC_REGIONS["init"], parent_region: str = DEFAULT_REC_REGIONS["parent_region"], learnable_bias: bool = DEFAULT_REC_REGIONS["learnable_bias"], ): """Add a recurrent region to the network. Args: name (str): Region name (unique key). num_units (int): Number of units in this region. sign (str): "pos" for excitatory or "neg" for inhibitory outputs. base_firing (float | torch.Tensor): Baseline firing per unit. init (float): Initial pre-activation value for units in this region. parent_region (str | None): Optional parent region identifier. learnable_bias (bool): If True, baseline firing is trainable. """ if self.rec_finalized: raise Exception( "Recurrent connectivity already finalized, please \ include all regions and connections beforehand" ) # Create region self.region_dict[name] = RecurrentRegion( num_units=num_units, base_firing=base_firing, init=init, sign=sign, device=self.device, parent_region=parent_region, learnable_bias=learnable_bias, ) # General network parameters self.total_num_units = self._get_total_num_units(self.region_dict) # Get indices for specific regions for region in self.region_dict: # Get the mask for the whole region, regardless of cell type self.region_mask_dict[region] = {} self.region_mask_dict[region] = self._gen_region_mask(region)
[docs] def add_input_region( self, name: str, num_units: int, sign: str = DEFAULT_REGION_BASE["sign"], ): """Add an input region to the network. Args: name (str): Input region name (unique key). num_units (int): Number of input channels in this region. sign (str): "pos" or "neg"; used to set sign mask for inputs. """ if self.inp_finalized: raise Exception( "Input connectivity already finalized, \ please include all regions and connections beforehand" ) # Create region self.inp_dict[name] = InputRegion( num_units=num_units, sign=sign, device=self.device ) # Update number of input units self.total_num_inputs = self._get_total_num_units(self.inp_dict)
[docs] def add_recurrent_connection( self, src_region: str, dst_region: str, sparsity: float = DEFAULT_CONNECTIONS["sparsity"], ): """Create a recurrent connection from one region to another. Registers the weight parameter and associated masks. If ``sparsity`` is provided, a binary connectivity mask is sampled accordingly. Currently not allowed to make a zero connection. Args: src_region (str): Source recurrent region name. dst_region (str): Destination recurrent region name. sparsity (float | None): Fraction of connections to keep (0-1). If None, dense mask is used. """ # Ensure that no more connections can be added if network is finalized if self.rec_finalized: raise Exception( "Recurrent connectivity already finalized, \ please include all regions and connections beforehand" ) # Add connection to specified region object self.region_dict[src_region].add_connection( dst_region_name=dst_region, dst_region_units=self.region_dict[dst_region].num_units, sparsity=sparsity, ) # Get the empty weights weight = self.region_dict[src_region][dst_region].parameter # initialize the empty weights based on constraints if self.rec_constrained: self._constrained_default_init_rec(weight) else: nn.init.xavier_normal_(weight)
[docs] def add_input_connection( self, src_region: str, dst_region: str, sparsity: float | None = DEFAULT_CONNECTIONS["sparsity"], ): """Create an input connection from an input region to a recurrent region. Currently not allowed to make a zero connection. Args: src_region (str): Source input region name. dst_region (str): Destination recurrent region name. sparsity (float | None): Fraction of connections to keep (0-1). If None, dense mask is used. """ if self.inp_finalized: raise Exception( "Input connectivity already finalized, \ please include all regions and connections beforehand" ) # Add connection to specified input region object self.inp_dict[src_region].add_connection( dst_region_name=dst_region, dst_region_units=self.region_dict[dst_region].num_units, sparsity=sparsity, ) # Access empty input weight weight = self.inp_dict[src_region][dst_region].parameter # initialize the weight depending on constraints if self.inp_constrained: self._constrained_default_init_inp(weight) else: nn.init.xavier_normal_(weight)
[docs] def set_spectral_radius( self, W: torch.Tensor, W_tmp: torch.Tensor | None = None ) -> torch.Tensor: """Scale recurrent weights so their spectral radius matches ``self.spectral_radius``. Usage: 1. Define regions and connections (via config or manual methods). 2. If building manually, call :meth:`finalize_connectivity` first. 3. Set ``self.spectral_radius`` and call this method. 4. W_tmp will compute spectral radius of another network (i.e dales law network) """ # Compute spectral radius if W_tmp is not None: cur_spectral_radius = self.compute_spectral_radius(W_tmp) else: cur_spectral_radius = self.compute_spectral_radius(W) W_scaled = (W / cur_spectral_radius) * self.spectral_radius return W_scaled
[docs] def finalize_connectivity(self): """Finalize both input and recurrent connectivity This function is primarily implemented so users don't have to separately call rec and inp connectivity functions """ if not self.rec_finalized: self.finalize_rec_connectivity() if not self.inp_finalized: self.finalize_inp_connectivity()
[docs] def finalize_rec_connectivity(self): """Fill rest of recurrent connections with zeros Ensure finalized flag is set to true """ for region in self.region_dict: self._get_full_connectivity(self.region_dict[region]) # Apply Dale's Law if constrained W_rec, W_rec_mask, W_rec_sign_matrix = self.gen_w(self.region_dict) # Set spectral radius if self.spectral_radius is not None: if self.rec_constrained: W_rec_tmp = self.apply_dales_law(W_rec, W_rec_mask, W_rec_sign_matrix) else: W_rec_tmp = W_rec * W_rec_mask W_rec = self.set_spectral_radius(W_rec, W_tmp=W_rec_tmp) # Create parameters self.W_rec = nn.Parameter(W_rec) self.W_rec_mask = nn.Parameter(W_rec_mask, requires_grad=False) self.W_rec_sign_matrix = nn.Parameter(W_rec_sign_matrix, requires_grad=False) # Set finalized flag to true, no more connections can be added self.rec_finalized = True
[docs] def finalize_inp_connectivity(self): """Fill rest of input connections with zeros Ensure finalized flag is set to true """ for inp in self.inp_dict: self._get_full_connectivity(self.inp_dict[inp]) # Apply to input weights as well W_inp, W_inp_mask, W_inp_sign_matrix = self.gen_w(self.inp_dict) # Create parameters self.W_inp = nn.Parameter(W_inp) self.W_inp_mask = nn.Parameter(W_inp_mask, requires_grad=False) self.W_inp_sign_matrix = nn.Parameter(W_inp_sign_matrix, requires_grad=False) # Set finalized flag to true, no more connections can be added self.inp_finalized = True
[docs] def compute_spectral_radius(self, weight: torch.Tensor) -> float: """Compute the spectral radius (max |eigenvalue|) of a square matrix. Args: weight (torch.Tensor): Square weight matrix. Returns: torch.Tensor: Spectral radius as a scalar tensor. """ # Largest absolute eigenvalue of W_rec eig_vals = torch.linalg.eigvals(weight) abs_eig_vals = eig_vals.abs() spectral_radius = abs_eig_vals.max() return spectral_radius
[docs] def gen_w(self, dict_: dict) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Generates the full recurrent connectivity matrix and associated masks. Returns: tuple: (W_rec, W_rec_mask, W_rec_sign_matrix) - W_rec: Learnable weight matrix - W_rec_mask: Binary mask for allowed connections - W_rec_sign_matrix: Sign constraints for Dale's Law """ # Initialize empty lists to hold the concatenated tensors region_connection_columns = [] region_weight_mask_columns = [] region_sign_matrix_columns = [] # Iterate over the regions in dict_ for cur_region in dict_: # List comprehensions to collect connections, masks, and sign matrices # calling Region[connection] should invoke __getitem__ connections_from_region = [ dict_[cur_region][connection].parameter for connection in self.region_dict ] weight_mask_from_region = [ dict_[cur_region][connection].weight_mask for connection in self.region_dict ] sign_matrix_from_region = [ dict_[cur_region][connection].sign_matrix for connection in self.region_dict ] # Concatenate the region-specific matrices and append to the lists region_connection_columns.append(torch.cat(connections_from_region, dim=0)) region_weight_mask_columns.append(torch.cat(weight_mask_from_region, dim=0)) region_sign_matrix_columns.append(torch.cat(sign_matrix_from_region, dim=0)) # Concatenate all region-specific matrices along the column dimension W_rec = torch.cat(region_connection_columns, dim=1) W_rec_mask = torch.cat(region_weight_mask_columns, dim=1) W_rec_sign = torch.cat(region_sign_matrix_columns, dim=1) return W_rec, W_rec_mask, W_rec_sign
[docs] def apply_dales_law( self, W_rec: torch.Tensor, W_rec_mask: torch.Tensor, W_rec_sign_matrix: torch.Tensor, ) -> torch.Tensor: """ Applies Dale's Law constraints to the recurrent weight matrix. Dale's Law states that a neuron can be either excitatory or inhibitory, but not both. Returns: torch.Tensor: Constrained weight matrix """ return W_rec_mask * torch.abs(W_rec) * W_rec_sign_matrix
@property def tonic_inp(self): """ Collects baseline firing rates for all regions. Returns: torch.Tensor: Vector of baseline firing rates """ return torch.cat( [region.base_firing for region in self.region_dict.values()] ).to(self.device)
[docs] def named_rec_regions(self, prefix: str = ""): """Loop through rec region names and objects Args: prefix (str, optional): Defaults to ''. """ for name, region in self.region_dict.items(): yield prefix + name, region
[docs] def named_inp_regions(self, prefix: str = ""): """Loop through inp region names and objects Args: prefix (str, optional): Defaults to ''. """ for name, region in self.inp_dict.items(): yield prefix + name, region
[docs] def get_region_size(self, region: str) -> int: """Get the number of units in a region Args: region (str): region to get size of """ return ( self.region_dict[region].num_units if region in self.region_dict else self.inp_dict[region].num_units )
[docs] def get_region_activity(self, act: torch.Tensor, *args) -> torch.Tensor: """ Takes in hn and the specified region and returns the activity hn for the corresponding region Args: act (Torch.Tensor): tensor containing model hidden activity. Activations must be in last dimension (-1) args (str): name of regions to collect activity from Returns: region_hn: tensor containing hidden activity only for specified region """ # Default to returning whole activity unique_regions = list(OrderedDict.fromkeys(args)) if not args: return act # Check to ensure region is recurrent for region in args: if region in self.inp_dict: raise Exception("Can only get activity for recurrent regions") args = self._ensure_order(*args) # Go and check if any parent regions are entered for region in unique_regions.copy(): if self._check_if_parent_region(region): unique_regions.remove(region) unique_regions.extend(self._get_child_regions(region)) # collect all necessary indices now region_indices = { region: self.get_region_indices(region) for region in unique_regions } region_acts = torch.cat( [ act[..., start_idx:end_idx] for region in unique_regions for (start_idx, end_idx) in [region_indices[region]] ], dim=-1, ) return region_acts
[docs] def get_weight_subset(self, *args, W: torch.Tensor | None = None) -> torch.Tensor: """Gather a subset of the weights from all regions in args to and from each other and themselves. This should return a square matrix of all connections between regions in args Args: args (str): all regions specified W (torch.Tensor): use this specified weight matrix instead Returns: torch.Tensor: subset of the total weight matrix """ if W is None: # Gather original weight matrix and apply Dale's Law if constrained # Can only be recurrent if not using to and from if self.rec_constrained: mrnn_weight = self.apply_dales_law( self.W_rec, self.W_rec_mask, self.W_rec_sign_matrix ) else: mrnn_weight = self.W_rec * self.W_rec_mask # Default to standard weight matrix if no regions are provided if not args: return mrnn_weight else: mrnn_weight = W # Check if user specifies input region through args instead of to, from for region in args: if region in self.inp_dict: raise Exception("Can only gather input subsets using get_projection") args = self._ensure_order(*args) # This is used to store the final collected weight matrix global_weight_collection = [] region_indices = {region: self.get_region_indices(region) for region in args} # List comprehension that gathers all information gathering weight subset global_weight_collection = [ torch.cat( [ mrnn_weight[src_start_idx:src_end_idx, dst_start_idx:dst_end_idx] for dst_region in args for dst_start_idx, dst_end_idx in [region_indices[dst_region]] ], dim=1, ) for _, (src_start_idx, src_end_idx) in region_indices.items() ] # Similar to before but now concatenating along rows global_weight_collection = torch.cat(global_weight_collection, dim=0) return global_weight_collection
[docs] def combine_states( self, states_a: torch.Tensor, states_b: torch.Tensor, region_list_a: list, region_list_b: list, keep_dims=True, ) -> torch.Tensor: """ Take two states and concatenate them in order according to their given region lists Args: states_a (Tensor): a tensor of states from regions in region_list_a (assumed to be in exact order) states_b (Tensor): a tensor of states from regions in region_list_b (assumed to be in exact order) region_list_a (list): list containing hidden regions of states_a in order region_list_b (list): list containing hidden regions of states_b in order Returns: full_state (Tensor): concatenated tensor of states_a and states_b along unit dimension """ assert states_a.dim() == states_b.dim() shape = tuple(states_a.shape)[:-1] # Ensure states are bxd states_a = torch.flatten(states_a, end_dim=-2) states_b = torch.flatten(states_b, end_dim=-2) # Gather batches of grids with trial activity at each timestep region_a_idx = 0 region_b_idx = 0 full_state = [] for region in self.region_dict: if region in region_list_a: full_state.append( states_a[ :, region_a_idx : region_a_idx + self.get_region_size(region), ] ) region_a_idx += self.get_region_size(region) elif region in region_list_b: full_state.append( states_b[ :, region_b_idx : region_b_idx + self.get_region_size(region), ] ) region_b_idx += self.get_region_size(region) else: raise Exception(f"region {region} not in either list") full_state = torch.cat(full_state, dim=-1) if keep_dims: full_state = torch.reshape(full_state, (*shape, full_state.shape[-1])) return full_state
[docs] def get_projection(self, to: str, from_: str) -> torch.Tensor: """Gather a subset of the weights Args: to (str): Name of region that is recieving projection (row) from_ (str): Name of region projecting (column) Returns: torch.Tensor: weight matrix of from_->to projection """ # Store regions if parent regions are given to_regions = [] from_regions = [] # If to region is a parent region, then get children regions if self._check_if_parent_region(to): to_regions.extend(self._get_child_regions(to)) else: to_regions.append(to) # If from region is a parent region, then get children regions if self._check_if_parent_region(from_): from_regions.extend(self._get_child_regions(from_)) else: from_regions.append(from_) # Check which weight matrix to use based on from region if from_ in self.region_dict: if self.rec_constrained: weight = self.apply_dales_law( self.W_rec, self.W_rec_mask, self.W_rec_sign_matrix ) else: weight = self.W_rec elif from_ in self.inp_dict: if self.inp_constrained: weight = self.apply_dales_law( self.W_inp, self.W_inp_mask, self.W_inp_sign_matrix ) else: weight = self.W_inp else: raise Exception("from_ region not in region or input dictionary") # Store all of the weights to region from another to_from_weight = [] # Now go through each of the collected regions and get the weights for to_region in to_regions: from_weight = [] # Get the indices for to region to_start_idx, to_end_idx = self.get_region_indices(to_region) for from_region in from_regions: # Gather indices from_start_idx, from_end_idx = self.get_region_indices(from_region) from_weight.append( weight[to_start_idx:to_end_idx, from_start_idx:from_end_idx] ) # Collect all of the weights for from region collected_from_weight = torch.cat(from_weight, dim=1) to_from_weight.append(collected_from_weight) # Collect final weight matrix collected_to_from_weight = torch.cat(to_from_weight, dim=0) return collected_to_from_weight
[docs] def get_region_indices(self, region: str) -> tuple[int, int]: """ Gets the start and end indices for a specific region in the hidden state vector. Args: region (str): Name of the region Returns: tuple: (start_idx, end_idx) """ if self._check_if_parent_region(region): raise ValueError( "Can only get indices of a single region, not parent region" ) # Get the region indices start_idx = 0 end_idx = 0 # Check whether or not specified region is input or rec # This is to handle indices for both rec and inp regions if region in self.region_dict: dict_ = self.region_dict elif region in self.inp_dict: dict_ = self.inp_dict else: raise Exception("Not an input or recurrent region") for cur_reg in dict_: region_units = dict_[cur_reg].num_units if cur_reg == region: end_idx = start_idx + region_units break start_idx += region_units return start_idx, end_idx
@property def initial_condition(self) -> torch.Tensor: """Create an initial xn for the network Returns: Tensor: tensor like xn filled with region specified initial conds """ return torch.cat([region.init for region in self.region_dict.values()]).to( self.device )
[docs] def batched_initial_condition( self, *args, **kwargs ) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError
@property def hid_regions(self): """Returns names of all hidden regions in network as a list""" return [r for r in self.region_dict] @property def inp_regions(self): """Returns names of all input regions in network as a list""" return [r for r in self.inp_dict]
[docs] def get_excluded_hid_regions(self, *args) -> list: """ Return all of the hidden regions in the network not given to this function """ excluded_regions = [] for region in self.region_dict: if region not in args: excluded_regions.append(region) return excluded_regions
[docs] def get_excluded_inp_regions(self, *args) -> list: """ Return all of the input regions in the network not given to this function """ excluded_regions = [] for region in self.inp_dict: if region not in args: excluded_regions.append(region) return excluded_regions
[docs] def forward(self, *args, **kwargs): raise NotImplementedError
@property def hid_noise_const(self): """noise constant used for hidden activity""" const_hid = (1 / self.alpha) * np.sqrt(2 * self.alpha * self.sigma_recur**2) return const_hid @property def inp_noise_const(self): """noise constant used for inputs""" const_inp = (1 / self.alpha) * np.sqrt(2 * self.alpha * self.sigma_input**2) return const_inp def _hid_noise(self, batch_shape: int): """ Gather a random noise sample at a given timepoint Args: const (float): hidden noise constant batch_shape (int): batch_shape Returns: Tensor: total_num_units sized tensor containing Gaussian noise """ perturb_hid = self.hid_noise_const * torch.randn( size=(batch_shape, self.total_num_units), device=self.device ) return perturb_hid def _inp_noise(self, batch_shape: int): """ Gather a random noise sample at a given timepoint Args: const (float): hidden noise constant batch_shape (int): batch_shape Returns: Tensor: total_num_inputs sized tensor containing Gaussian noise """ perturb_inp = self.inp_noise_const * torch.randn( size=(batch_shape, self.total_num_inputs), device=self.device ) return perturb_inp def _create_def_values(self, config: dict): """Generate default values for configuration Args: config (json): Network configuration file """ # Set default values for recurrent region connections for i, region in enumerate(config["recurrent_regions"]): # Go through all possible default options in default dict for param in DEFAULT_REC_REGIONS: # If the parameter is not specified by the user in the configuration... if param not in region: # If parameter is name, add the index to ensure unique naming if param == "name": region[param] = DEFAULT_REC_REGIONS[param] + str(i) # Otherwise, default the parameter else: region[param] = DEFAULT_REC_REGIONS[param] # Set default values for recurrent region connections for connection in config["recurrent_connections"]: for param in DEFAULT_CONNECTIONS: if param not in connection: connection[param] = DEFAULT_CONNECTIONS[param] # Set default values for input regions for i, region in enumerate(config["input_regions"]): for param in DEFAULT_REGION_BASE: if param not in region: if param == "name": region[param] = DEFAULT_REGION_BASE[param] + str(i) else: region[param] = DEFAULT_REGION_BASE[param] # Set default values for input region connections for connection in config["input_connections"]: for param in DEFAULT_CONNECTIONS: if param not in connection: connection[param] = DEFAULT_CONNECTIONS[param] def _gen_region_mask(self, region: str) -> torch.Tensor: """ Generates a mask for a specific region and optionally a cell type. Args: region (str): Region name Returns: torch.Tensor: Binary mask """ mask = [] for next_region in self.region_dict: if region == next_region: mask.append(self.region_dict[region].masks["ones"]) else: mask.append(self.region_dict[next_region].masks["zeros"]) return torch.cat(mask).to(self.device) def _get_full_connectivity(self, region: Region): """ Ensures all possible connections are defined for a region, adding zero connections where none are specified. Args: region (Region): Region object to complete connections for """ for other_region in self.region_dict: if not region.has_connection_to(other_region): region.add_connection( dst_region_name=other_region, dst_region_units=self.region_dict[other_region].num_units, sparsity=None, zero_connection=True, ) def _get_total_num_units(self, dict_: dict) -> int: """ Calculates total number of units across all regions. Args: dict_ (dict): either region_dict or inp_dict Returns: int: Total number of units """ return sum(region.num_units for region in dict_.values()) def _check_if_parent_region(self, parent_region: str) -> bool: """ Return True if any region has ``parent_region`` set to the given name. Args: parent_region (str): name of parent region Returns: bool: whether or not parent_region is a parent region """ for region in self.region_dict.values(): if region.parent_region == parent_region: return True return False def _get_child_regions(self, parent_region: str) -> tuple[str]: """ Return a tuple of region names that list ``parent_region`` as their parent. Args: parent_region (str): name of parent region Returns: tuple: names of all child regions under parent region """ child_region_list = [] for region in self.region_dict: if self.region_dict[region].parent_region == parent_region: child_region_list.append(region) return tuple(child_region_list) def _constrained_default_init_rec(self, weight: torch.Tensor): """Default init for recurrent weights under Dale's Law constraints. Draws weights from a zero-mean normal with variance 1/(2H), then applies a sign mask to respect excitation/inhibition of source regions. """ nn.init.uniform_(weight, a=0, b=np.sqrt(1 / (2 * self.total_num_units))) def _constrained_default_init_inp(self, weight: torch.Tensor): """Default init for input weights under Dale's Law constraints. Draws weights from a zero-mean normal with variance 1/(H + I), then applies a sign mask to respect region sign. """ nn.init.uniform_( weight, a=0, b=np.sqrt(1 / (self.total_num_units + self.total_num_inputs)), ) def _ensure_order(self, *args) -> tuple[str]: """Reorder args if given regions are out of order""" return tuple(r for r in self.region_dict if r in args)