import torch
from mrnntorch.mrnn.elman_mrnn import ElmanmRNN
from typing import Tuple
[docs]
class emLinearization:
"""Local linear analysis utilities for :class:`ElmanmRNN` models."""
def __init__(
self,
rnn: ElmanmRNN,
*args,
):
"""Initialize the linearization helper for a model and region subset.
Args:
rnn (ElmanmRNN): Network to analyze.
*args (str): Optional recurrent region names to include in the
linearized subspace. If omitted, all recurrent regions are used.
"""
self.rnn = rnn
# Regions which are treated as grid elements
self.zero_states = torch.zeros(
size=(
1,
rnn.total_num_units,
)
)
self.region_list = (
self.rnn.hid_regions
if not args
else [region for region in rnn._ensure_order(*args)]
)
# Regions treated as static inputs for grid elements
self.static_region_list = (
[]
if self.region_list == self.rnn.hid_regions
else self.rnn.get_excluded_hid_regions(*self.region_list)
)
def __call__(
self,
input: torch.Tensor,
h: torch.Tensor,
delta_input: torch.Tensor,
delta_h: torch.Tensor,
delta_h_static: torch.Tensor | None = None,
) -> torch.Tensor:
"""Alias for :meth:`forward`."""
return self.forward(
input, h, delta_input, delta_h, delta_h_static=delta_h_static
)
[docs]
def forward(
self,
input: torch.Tensor,
h: torch.Tensor,
delta_input: torch.Tensor,
delta_h: torch.Tensor,
delta_h_static: torch.Tensor | None = None,
) -> torch.Tensor:
"""Evaluate the first-order Taylor approximation of the Elman dynamics.
Args:
input (torch.Tensor): External input at the operating point.
h (torch.Tensor): Hidden state about which to linearize.
delta_input (torch.Tensor): Input perturbation.
delta_h (torch.Tensor): Hidden-state perturbation for included regions.
delta_h_static (torch.Tensor | None): Perturbation applied to excluded
regions when only a subset of regions is linearized.
Returns:
torch.Tensor: Linearized next hidden state.
"""
# Assert correct shapes
assert input.dim() == 1
assert delta_input.dim() == 1
assert h.dim() == 1
if delta_h.dim() > 1:
delta_h = delta_h.flatten(start_dim=0, end_dim=-2)
# Get jacobians for included regions
_jacobian, _jacobian_inp = self.jacobian(input, h)
if len(self.static_region_list) >= 1:
# Get jacobians for excluded regions if available
_jacobian_exc, _ = self.jacobian(input, h, excluded_regions=True)
else:
_jacobian_exc = None
# reshape to pass into RNN
inp = input.unsqueeze(0).unsqueeze(0)
h = h.unsqueeze(0)
# Get h_next for affine function
h_next = self.rnn(inp, h)
h_next = self.rnn.get_region_activity(h_next, *self.region_list)
if _jacobian_exc is None or delta_h_static is None:
pert = (
h_next.squeeze(0)
+ (_jacobian @ delta_h.T).T
+ (_jacobian_inp @ delta_input)
)
else:
pert = (
h_next.squeeze(0)
+ (_jacobian @ delta_h.T).T
+ (_jacobian_exc @ delta_h_static)
+ (_jacobian_inp @ delta_input)
)
return pert
[docs]
def jacobian(
self,
input: torch.Tensor,
h: torch.Tensor,
excluded_regions: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Return Jacobians of the Elman update with respect to state and input.
Args:
input (torch.Tensor): Input vector at which to linearize.
h (torch.Tensor): Hidden state at which to linearize.
excluded_regions (bool): If ``True``, return the projection from
excluded recurrent regions into the included region subset.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Jacobian with respect to hidden
state followed by Jacobian with respect to input.
"""
assert isinstance(excluded_regions, bool)
assert h.dim() == 1
assert input.dim() == 1
"""
Taking jacobian of x with respect to F
In this case, the form should be:
J_(ij)(x) = -I_(ij) + W_(ij)h'(x_j)
"""
input = input.unsqueeze(0).unsqueeze(0)
h = h.unsqueeze(0)
# For elman mrnn, there is a single output
h_jacobians = torch.autograd.functional.jacobian(self.rnn, (input, h))
# unpack the tuples for x and h
h_jacobian_input, h_jacobian_h = h_jacobians
_jacobian = h_jacobian_h
_jacobian_input = h_jacobian_input
# Squeeze values now to get proper weight subsets
_jacobian = self._jac_nxd(_jacobian)
_jacobian_input = self._jac_nxd(_jacobian_input)
if excluded_regions and len(self.static_region_list) >= 1:
excluded_to_included = []
for r_i in self.region_list:
excluded_to_region = []
for r_e in self.static_region_list:
to_start, to_end = self.rnn.get_region_indices(r_i)
from_start, from_end = self.rnn.get_region_indices(r_e)
projection = _jacobian[to_start:to_end, from_start:from_end]
excluded_to_region.append(projection)
excluded_to_region = torch.cat(excluded_to_region, dim=-1)
excluded_to_included.append(excluded_to_region)
_jacobian = torch.cat(excluded_to_included, dim=0)
else:
_jacobian = self.rnn.get_weight_subset(*self.region_list, W=_jacobian)
# Get subsets for input jacobians
input_to_rec = []
for r_i in self.region_list:
input_to_region = []
for r_e in self.rnn.inp_dict:
to_start, to_end = self.rnn.get_region_indices(r_i)
from_start, from_end = self.rnn.get_region_indices(r_e)
projection = _jacobian_input[to_start:to_end, from_start:from_end]
input_to_region.append(projection)
input_to_region = torch.cat(input_to_region, dim=-1)
input_to_rec.append(input_to_region)
_jacobian_input = torch.cat(input_to_rec, dim=0)
return _jacobian, _jacobian_input
[docs]
def eigendecomposition(
self,
input: torch.Tensor,
h: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute the eigendecomposition of the local hidden-state Jacobian.
Args:
input (torch.Tensor): Input vector at which to linearize.
h (torch.Tensor): Hidden state at which to linearize.
Returns:
torch.Tensor: Real parts of eigenvalues.
torch.Tensor: Imag parts of eigenvalues.
torch.Tensor: Eigenvectors stacked column-wise.
"""
_jacobian, _ = self.jacobian(input, h)
eigenvalues, eigenvectors = torch.linalg.eig(_jacobian)
# Split real and imaginary parts
reals = []
for eigenvalue in eigenvalues:
reals.append(eigenvalue.real.item())
reals = torch.tensor(reals)
ims = []
for eigenvalue in eigenvalues:
ims.append(eigenvalue.imag.item())
ims = torch.tensor(ims)
return reals, ims, eigenvectors
def _jac_nxd(self, jac):
"""
broadcast jacobian to nxd
jacobian will be nxd with a bunch of extra 1 dimensional
squeeze all one dims, and account for single inputs/units
jac should never be more than 3 dims, if so there are likely other issues
"""
# Squeeze values now to get proper weight subsets
jac = jac.squeeze()
if jac.dim() == 0:
jac = jac.unsqueeze(0)
if jac.dim() == 1:
jac = jac.unsqueeze(1)
return jac