Source code for mrnntorch.analysis.flow_fields.elman_flow_field_finder

import torch
from mrnntorch.analysis.linear.elman_linear import emLinearization
from rnntoolkit.flow_fields.flow_field import FlowField
from rnntoolkit.flow_fields.flow_field_finder_base import FlowFieldFinderBase
from mrnntorch.mrnn.elman_mrnn import ElmanmRNN


[docs] class emFlowFieldFinder(FlowFieldFinderBase[ElmanmRNN]): """Flow-field estimator for Elman mRNN trajectories and local linearizations.""" _default_hps = { "num_components": 2, "num_points": 50, "x_offset": 1, "y_offset": 1, "center": 0, "cancel_other_regions": False, "follow_traj": False, "name": "run", "dtype": torch.float32, } def __init__( self, rnn: ElmanmRNN, fit_states: torch.Tensor, num_points: int, x_offset: int, y_offset: int, x_center: int = 0, y_center: int = 0, follow_traj: bool = False, region_list: list = [], cancel_other_regions: bool = False, ): """Initialize a 2D flow-field finder around a trajectory. Args: rnn (ElmanmRNN): Network to analyze. fit_states (torch.Tensor): States used to fit the dimensionality reduction used for the flow-field plane. num_points (int): Number of grid points along each axis. x_offset (int): Horizontal half-width of the sampled grid. y_offset (int): Vertical half-width of the sampled grid. x_center (int): Fixed x-axis center when not following the trajectory. y_center (int): Fixed y-axis center when not following the trajectory. follow_traj (bool): If ``True``, center the grid on each sampled state. region_list (list): Recurrent regions to include in the reduced plane. cancel_other_regions (bool): If ``True``, zero activity in excluded regions. """ super().__init__( rnn, fit_states, num_points, x_offset, y_offset, x_center, y_center, ) # Unload mrnn specific kwargs self.cancel_other_regions = cancel_other_regions self.follow_traj = follow_traj self.zero_states = torch.zeros( size=( 1, rnn.total_num_units, ) ) # Regions which are treated as grid elements self.region_list = self.rnn.hid_regions if not region_list else region_list # Regions treated as static inputs for grid elements self.static_region_list = ( [] if self.region_list == self.rnn.hid_regions else self.rnn.get_excluded_hid_regions(*self.region_list) ) self.linearization = emLinearization(rnn, *self.region_list)
[docs] def find_nonlinear_flow( self, states: torch.Tensor, input: torch.Tensor, stim_input: torch.Tensor | None = None, W: torch.Tensor | None = None, ) -> list: """Compute nonlinear 2D flow fields in a region subspace along a trajectory. Projects selected region activity onto a 2D PCA subspace, constructs a grid around the current point, and advances the system by one step to estimate the local flow (velocity vectors). Can zero out non-selected regions or keep their control values. Args: states (torch.Tensor): Hidden activations over time [batch_size, T, N]. inp (torch.Tensor): External input sequence. stim_input (torch.Tensor | None): Optional additive stimulus input. W (torch.Tensor | None): Optional weight matrix to use. Kwargs: stim_input (torch.Tensor): tensor input to network without weights, acts as manipulation W (torch.Tensor): replace the weight matrix of mRNN with W traj_to_reduce (torch.Tensor): tensor similar to states that will be used for PCA instead of states Returns: list: FlowField object per sampled time. """ flow_field_list = [] if stim_input is None: stim_input = torch.zeros_like(states, dtype=self.dtype) # Reshape to nxd states, input, stim_input = ( self._nxd(states), self._nxd(input), self._nxd(stim_input), ) assert states.shape[0] == input.shape[0] n_states = states.shape[0] """ states is now meant to be network activation, or h. To get valid xs, we will invert h in the compute function """ # get region activity for fitting and reduction tmp_states = self.rnn.get_region_activity(states, *self.region_list) reduced_traj = self._reduce_traj(tmp_states) if not self.static_region_list: # Default to dummy tensor with shape static_states = None else: static_states = self.rnn.get_region_activity( states, *self.static_region_list ) if self.cancel_other_regions: static_states = static_states * torch.zeros_like(static_states) # Now going through trajectory for n in range(n_states): # default for static states reduced_traj_n = reduced_traj[n] input_n = input[n] static_states_n = static_states[n] if static_states is not None else None stim_input_n = stim_input[n] # If follow trajectory is true get grid centered around current t # This will make a different grid for each state (n grids) if self.follow_traj: lower_bound_x, upper_bound_x, lower_bound_y, upper_bound_y = ( self._set_tv_bounds(reduced_traj_n) ) else: lower_bound_x, upper_bound_x, lower_bound_y, upper_bound_y = ( self._set_bounds() ) low_dim_grid, inverse_grid = self._inverse_grid( lower_bound_x, upper_bound_x, lower_bound_y, upper_bound_y, ) # Repeat along the batch dimension to match the grid if static_states_n is not None: static_act_batch = static_states_n.repeat(low_dim_grid.shape[0], 1) else: static_act_batch = None full_input_batch = input_n.repeat(low_dim_grid.shape[0], 1) full_stim_batch = stim_input_n.repeat(low_dim_grid.shape[0], 1) # Combine the grid and static states to treat excluded regions as input if static_act_batch is not None: h_0_flow = self.rnn.combine_states( inverse_grid, static_act_batch, self.region_list, self.static_region_list, ) else: h_0_flow = inverse_grid with torch.no_grad(): # Get activity for current timestep h_next = self.rnn( full_input_batch.unsqueeze(self.time_dim), h_0_flow, stim_input=full_stim_batch.unsqueeze(self.time_dim), noise=False, W_rec=W, ) next_state = self.rnn.get_region_activity(h_next, *self.region_list) next_state_reduced = self._reduce_traj(next_state) x_vel, y_vel = self._compute_velocity(next_state_reduced, low_dim_grid) speed = self._compute_speed(x_vel, y_vel) # Reshape to match FlowField object requirements x_vel, y_vel, low_dim_grid, speed = self._reshape_vals( x_vel, y_vel, low_dim_grid, speed ) flow_field = FlowField(x_vel, y_vel, low_dim_grid, speed) flow_field_list.append(flow_field) return flow_field_list
[docs] def find_linear_flow( self, states: torch.Tensor, input: torch.Tensor, delta_input: torch.Tensor, delta_h_static: torch.Tensor | None = None, ) -> list: """Compute linearized 2D flow fields around sampled trajectory states. Similar to :func:`flow_field`, but uses a local linear approximation (Jacobian) of the dynamics around points on the trajectory instead of a full forward step. Assumes no external input to the selected regions. Args: states (torch.Tensor): Hidden activations over time. inp (torch.Tensor): External input sequence aligned with ``states``. delta_inp (torch.Tensor): Input perturbations for the local linear model. delta_h_static (torch.Tensor | None): Perturbations for recurrent regions excluded from the reduced plane. Returns: list: FlowField objects per sampled time. """ # reshape to nxd states, input, delta_input = ( self._nxd(states), self._nxd(input), self._nxd(delta_input), ) assert input.shape[0] == delta_input.shape[0] assert states.shape[0] == input.shape[0] n_states = states.shape[0] # Lists for x and y velocities flow_field_list = [] # Activity specific to regions in region list for later computations region_tmp = self.rnn.get_region_activity(states, *self.region_list) reduced_traj = self._reduce_traj(region_tmp) # zero out static perturbations if regions are cancelled if self.cancel_other_regions and delta_h_static is not None: delta_h_static = delta_h_static * torch.zeros_like(delta_h_static) for n in range(n_states): states_n = states[n] reduced_traj_n = reduced_traj[n] input_n = input[n] delta_input_n = delta_input[n] delta_h_static_n = delta_h_static[n] if delta_h_static is not None else None # If follow trajectory is true get grid centered around current t # This will make a different grid for each state (n grids) if self.follow_traj: lower_bound_x, upper_bound_x, lower_bound_y, upper_bound_y = ( self._set_tv_bounds(reduced_traj_n) ) else: lower_bound_x, upper_bound_x, lower_bound_y, upper_bound_y = ( self._set_bounds() ) # Inverse the grid to pass through RNN low_dim_grid, inverse_grid = self._inverse_grid( lower_bound_x, upper_bound_x, lower_bound_y, upper_bound_y, ) # Get a perturbation of the activity region_states_n = self.rnn.get_region_activity(states_n, *self.region_list) delta_h = inverse_grid - region_states_n with torch.no_grad(): h_next = self.linearization( input_n, states_n, delta_input_n, delta_h, delta_h_static=delta_h_static_n, ) # Put next h into a grid format h_next = self._reduce_traj(h_next) # Compute velocities between gathered trajectory of grid and original grid values x_vel, y_vel = self._compute_velocity(h_next, low_dim_grid) speed = self._compute_speed(x_vel, y_vel) x_vel, y_vel, low_dim_grid, speed = self._reshape_vals( x_vel, y_vel, low_dim_grid, speed ) flow_field = FlowField(x_vel, y_vel, low_dim_grid, speed) # Reshape data back to grid flow_field_list.append(flow_field) return flow_field_list