import torch
import torch.nn as nn
from mrnntorch.region.region_base import Region, DEFAULT_REC_REGIONS
[docs]
class RecurrentRegion(Region):
def __init__(
self,
num_units,
base_firing=DEFAULT_REC_REGIONS["base_firing"],
init=DEFAULT_REC_REGIONS["init"],
sign=DEFAULT_REC_REGIONS["sign"],
parent_region=DEFAULT_REC_REGIONS["parent_region"],
learnable_bias=DEFAULT_REC_REGIONS["learnable_bias"],
device=DEFAULT_REC_REGIONS["device"],
):
"""Initialize a recurrent region.
Args:
num_units (int): Number of units in the region.
base_firing (float): Baseline firing for each unit.
init (float): Initial pre-activation value for units.
sign (str): "pos" or "neg" indicating excitatory/inhibitory outputs.
device (str): Torch device string.
parent_region (str | None): Optional parent identifier.
learnable_bias (bool): If True, make ``base_firing`` a trainable parameter.
"""
super(RecurrentRegion, self).__init__(num_units, sign=sign, device=device)
self.init = torch.as_tensor(init, device=self.device) * torch.ones(
size=(self.num_units,), device=self.device
)
self.learnable_bias = learnable_bias
self.parent_region = parent_region
base_firing = torch.as_tensor(base_firing, device=self.device)
if learnable_bias is True:
self.base_firing = nn.Parameter(
base_firing * torch.ones(size=(num_units,), device=self.device)
)
else:
self.base_firing = base_firing * torch.ones(
size=(num_units,), device=self.device
)