Source code for mrnntorch.analysis.linear.leaky_linear

import torch
from mrnntorch.mrnn.leaky_mrnn import mRNN
from typing import Tuple
import warnings


[docs] class mLinearization: """Local linear analysis utilities for leaky :class:`mRNN` models.""" def __init__( self, rnn: mRNN, *args, ): """Initialize the linearization helper for a model and region subset. Args: rnn (mRNN): Network to analyze. *args (str): Optional recurrent region names to include in the linearized subspace. If omitted, all recurrent regions are used. """ self.rnn = rnn # Regions which are treated as grid elements self.zero_states = torch.zeros( size=( 1, rnn.total_num_units, ) ) self.region_list = ( self.rnn.hid_regions if not args else [region for region in rnn._ensure_order(*args)] ) # Regions treated as static inputs for grid elements self.static_region_list = ( [] if self.region_list == self.rnn.hid_regions else self.rnn.get_excluded_hid_regions(*self.region_list) ) def __call__( self, input: torch.Tensor, x: torch.Tensor, delta_input: torch.Tensor, delta_state: torch.Tensor, delta_state_static: torch.Tensor | None = None, h: torch.Tensor | None = None, dh: bool = False, ) -> torch.Tensor: """Alias for :meth:`forward`.""" return self.forward( input, x, delta_input, delta_state, delta_state_static=delta_state_static, h=h, dh=dh, )
[docs] def forward( self, input: torch.Tensor, x: torch.Tensor, delta_input: torch.Tensor, delta_state: torch.Tensor, delta_state_static: torch.Tensor | None = None, h: torch.Tensor | None = None, dh: bool = False, ) -> torch.Tensor: """Evaluate the first-order Taylor approximation of the leaky dynamics. Args: input (torch.Tensor): External input at the operating point. x (torch.Tensor): Pre-activation state about which to linearize. delta_input (torch.Tensor): Input perturbation. delta_state (torch.Tensor): Perturbation for the state of the included \ region subset. Should be x perturbations or h perturbations if dh=True h (torch.Tensor | None): Activation corresponding to ``x`` when \ linearizing hidden activations directly. delta_h_static (torch.Tensor | None): Perturbation applied to excluded \ regions when only a subset of regions is linearized. dh (bool): If ``True``, linearize the hidden activation update instead \ of the pre-activation update. Returns: torch.Tensor: Linearized next state in the requested coordinates. """ # Assert correct shapes assert input.dim() == 1 assert x.dim() == 1 if h is not None: assert h.dim() == 1 # Flatten delta since it can be batched if delta_state.dim() > 1: delta_state = delta_state.flatten(start_dim=0, end_dim=-2) # Get jacobians for included regions _jacobian, _jacobian_inp = self.jacobian(input, x, h=h, dh=dh) if len(self.static_region_list) >= 1: # Get jacobians for excluded regions if available _jacobian_exc, _ = self.jacobian( input, x, excluded_regions=True, h=h, dh=dh ) else: _jacobian_exc = None # reshape to pass into RNN inp = input.unsqueeze(0).unsqueeze(0) x = x.unsqueeze(0) if h is not None: h = h.unsqueeze(0) # Get h_next for affine function x_next, h_next = self.rnn(inp, x, h0=h) out = h_next if dh else x_next out = self.rnn.get_region_activity(out, *self.region_list) if _jacobian_exc is None or delta_state_static is None: pert = ( out.squeeze(0) + (_jacobian @ delta_state.T).T + (_jacobian_inp @ delta_input) ) else: pert = ( out.squeeze(0) + (_jacobian @ delta_state.T).T + (_jacobian_exc @ delta_state_static) + (_jacobian_inp @ delta_input) ) return pert
[docs] def jacobian( self, input: torch.Tensor, x: torch.Tensor, excluded_regions: bool = False, h: torch.Tensor | None = None, dh: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Return Jacobians of the leaky update with respect to state and input. Args: input (torch.Tensor): Input vector at which to linearize. x (torch.Tensor): Pre-activation state at which to linearize. h (torch.Tensor | None): Hidden activation used when ``dh`` is ``True``. excluded_regions (bool): If ``True``, return the projection from excluded recurrent regions into the included region subset. dh (bool): If ``True``, differentiate the hidden activation output rather than the pre-activation state output. Returns: Tuple[torch.Tensor, torch.Tensor]: Jacobian with respect to state followed by Jacobian with respect to input. """ assert isinstance(excluded_regions, bool) assert x.dim() == 1 assert input.dim() == 1 if h is not None: assert h.dim() == 1 """ Taking jacobian of x with respect to F In this case, the form should be: J_(ij)(x) = -I_(ij) + W_(ij)h'(x_j) """ input = input.unsqueeze(0).unsqueeze(0) x = x.unsqueeze(0) if h is not None and not dh: warnings.warn( "Provided h will be ignored since dh is False. If you want to include h, set dh to True." ) # Only pay attention to h if dh is true # if dh is False, h will be ignored if dh: assert h is not None h = h.unsqueeze(0) # For leaky mrnn, there are three inputs and two outputs if dh: _, h_jacobians = torch.autograd.functional.jacobian(self.rnn, (input, x, h)) # unpack the tuples for x and h h_jacobian_input, _, h_jacobian_h = h_jacobians _jacobian = h_jacobian_h _jacobian_input = h_jacobian_input else: x_jacobians, _ = torch.autograd.functional.jacobian(self.rnn, (input, x)) # unpack the tuples for x and h x_jacobian_input, x_jacobian_x = x_jacobians _jacobian = x_jacobian_x _jacobian_input = x_jacobian_input # Squeeze values now to get proper weight subsets _jacobian = self._jac_nxd(_jacobian) _jacobian_input = self._jac_nxd(_jacobian_input) if excluded_regions and len(self.static_region_list) >= 1: excluded_to_included = [] for r_i in self.region_list: excluded_to_region = [] for r_e in self.static_region_list: to_start, to_end = self.rnn.get_region_indices(r_i) from_start, from_end = self.rnn.get_region_indices(r_e) projection = _jacobian[to_start:to_end, from_start:from_end] excluded_to_region.append(projection) excluded_to_region = torch.cat(excluded_to_region, dim=-1) excluded_to_included.append(excluded_to_region) _jacobian = torch.cat(excluded_to_included, dim=0) else: _jacobian = self.rnn.get_weight_subset(*self.region_list, W=_jacobian) # Get subsets for input jacobians input_to_rec = [] for r_i in self.region_list: input_to_region = [] for r_e in self.rnn.inp_dict: to_start, to_end = self.rnn.get_region_indices(r_i) from_start, from_end = self.rnn.get_region_indices(r_e) projection = _jacobian_input[to_start:to_end, from_start:from_end] input_to_region.append(projection) input_to_region = torch.cat(input_to_region, dim=-1) input_to_rec.append(input_to_region) _jacobian_input = torch.cat(input_to_rec, dim=0) return _jacobian, _jacobian_input
[docs] def eigendecomposition( self, input: torch.Tensor, x: torch.Tensor, h: torch.Tensor | None = None, dh: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute the eigendecomposition of the local Jacobian. Args: input (torch.Tensor): Input vector at which to linearize. x (torch.Tensor): Pre-activation state at which to linearize. h (torch.Tensor | None): Hidden activation used when ``dh`` is ``True``. dh (bool): If ``True``, eigendecompose the hidden-state Jacobian. Returns: torch.Tensor: Real parts of eigenvalues. torch.Tensor: Imag parts of eigenvalues. torch.Tensor: Eigenvectors stacked column-wise. """ _jacobian, _ = self.jacobian(input, x, h=h, dh=dh) eigenvalues, eigenvectors = torch.linalg.eig(_jacobian) # Split real and imaginary parts reals = [] for eigenvalue in eigenvalues: reals.append(eigenvalue.real.item()) reals = torch.tensor(reals) ims = [] for eigenvalue in eigenvalues: ims.append(eigenvalue.imag.item()) ims = torch.tensor(ims) return reals, ims, eigenvectors
def _jac_nxd(self, jac): """ broadcast jacobian to nxd jacobian will be nxd with a bunch of extra 1 dimensional squeeze all one dims, and account for single inputs/units jac should never be more than 3 dims, if so there are likely other issues """ # Squeeze values now to get proper weight subsets jac = jac.squeeze() if jac.dim() == 0: jac = jac.unsqueeze(0) if jac.dim() == 1: jac = jac.unsqueeze(1) return jac