import torch
import numpy as np
import time
from copy import deepcopy
from rnntoolkit.fixed_points.fp import FixedPointCollection
from rnntoolkit.fixed_points.fp_finder import FixedPointFinderBase
from mrnntorch.mrnn.elman_mrnn import ElmanmRNN
[docs]
class emFixedPointFinder(FixedPointFinderBase[ElmanmRNN]):
"""Fixed-point finder specialized for :class:`ElmanmRNN` dynamics."""
_default_hps = {
"lr_init": 1e-4,
"lr_patience": 5,
"lr_factor": 0.95,
"lr_cooldown": 0,
"tol_q": 1e-12,
"tol_dq": 1e-20,
"max_iters": 5000,
"do_rerun_q_outliers": False,
"outlier_q_scale": 10.0,
"do_exclude_distance_outliers": True,
"outlier_distance_scale": 10.0,
"tol_unique": 1e-3,
"max_n_unique": np.inf,
"dtype": "float32",
"random_seed": 0,
"verbose": True,
"super_verbose": False,
"n_iters_per_print_update": 100,
"batch_first": True,
}
[docs]
@classmethod
def default_hps(cls):
"""Returns a deep copy of the default hyperparameters dict.
The deep copy protects against external updates to the defaults, which
in turn protects against unintended interactions with the hashing done
by the Hyperparameters class.
Args:
None.
Returns:
dict of hyperparameters.
"""
return deepcopy(cls._default_hps)
def __init__(
self,
rnn: ElmanmRNN,
lr_init: float = _default_hps["lr_init"],
lr_patience: float = _default_hps["lr_patience"],
lr_factor: float = _default_hps["lr_factor"],
lr_cooldown: float = _default_hps["lr_cooldown"],
tol_q: float = _default_hps["tol_q"],
tol_dq: float = _default_hps["tol_dq"],
max_iters: int = _default_hps["max_iters"],
do_rerun_q_outliers: bool = _default_hps["do_rerun_q_outliers"],
outlier_q_scale: float = _default_hps["outlier_q_scale"],
do_exclude_distance_outliers: bool = _default_hps[
"do_exclude_distance_outliers"
],
outlier_distance_scale: float = _default_hps["outlier_distance_scale"],
tol_unique: float = _default_hps["tol_unique"],
max_n_unique: int = _default_hps["max_n_unique"],
dtype: str = _default_hps["dtype"],
random_seed: int = _default_hps["random_seed"],
verbose: bool = _default_hps["verbose"],
super_verbose: bool = _default_hps["super_verbose"],
n_iters_per_print_update: int = _default_hps["n_iters_per_print_update"],
):
"""Initialize fixed-point search hyperparameters for an Elman mRNN.
Args:
rnn (ElmanmRNN): Network whose fixed points will be optimized.
lr_init (float): Initial optimizer learning rate.
lr_patience (float): Plateau scheduler patience.
lr_factor (float): Plateau scheduler decay factor.
lr_cooldown (float): Plateau scheduler cooldown.
tol_q (float): Absolute fixed-point objective tolerance.
tol_dq (float): Per-step objective improvement tolerance.
max_iters (int): Maximum optimization iterations.
do_rerun_q_outliers (bool): Whether to rerun optimization on high-q outliers.
outlier_q_scale (float): Multiplier used to classify q outliers.
do_exclude_distance_outliers (bool): Whether to discard distant fixed points.
outlier_distance_scale (float): Distance threshold scale for outlier removal.
tol_unique (float): Tolerance used to collapse duplicate fixed points.
max_n_unique (int): Maximum number of unique fixed points to retain.
dtype (str): Torch dtype name used during optimization.
random_seed (int): Random seed for reproducible sampling.
verbose (bool): Whether to print high-level progress.
super_verbose (bool): Whether to print per-iteration progress.
n_iters_per_print_update (int): Iteration interval between progress prints.
"""
super().__init__(
rnn,
)
self.dtype = dtype
self.device = next(rnn.parameters()).device
self.torch_dtype = getattr(torch, self.dtype)
# Make random sequences reproducible
self.random_seed = random_seed
self.rng = np.random.RandomState(random_seed)
# *********************************************************************
# Optimization hyperparameters ****************************************
# *********************************************************************
self.lr_init = lr_init
self.lr_patience = lr_patience
self.lr_factor = lr_factor
self.lr_cooldown = lr_cooldown
self.tol_q = tol_q
self.tol_dq = tol_dq
self.max_iters = max_iters
self.do_rerun_q_outliers = do_rerun_q_outliers
self.outlier_q_scale = outlier_q_scale
self.do_exclude_distance_outliers = do_exclude_distance_outliers
self.outlier_distance_scale = outlier_distance_scale
self.tol_unique = tol_unique
self.max_n_unique = max_n_unique
self.verbose = verbose
self.super_verbose = super_verbose
self.n_iters_per_print_update = n_iters_per_print_update
# *************************************************************************
# Primary exposed functions ***********************************************
# *************************************************************************
[docs]
def find_fixed_points(
self,
initial_states: torch.Tensor,
ext_inputs: torch.Tensor,
*args,
stim_inp: torch.Tensor | None = None,
W_rec: torch.Tensor | None = None,
n_rounds_q_opt: int = 1,
) -> tuple[FixedPointCollection, FixedPointCollection]:
"""Finds RNN fixed points and the Jacobians at the fixed points.
Args:
initial_states: Tensor specifying the initial
states of the RNN, from which the optimization will search for
fixed points.
ext_inputs: external inputs to the RNN
stim_inp: Additional stimulus input to the network
W_rec: Fixed weight matrix to replace self.mrnn.W_rec in forward
pass
W_rec: Fixed weight matrix to replace self.mrnn.W_inp in forward
pass
n_rounds_q_opt: Number of rounds to run extra iterations on q
outliers
Returns:
unique_fps: A FixedPoints object containing the set of unique
fixed points after optimizing from all initial_states. Two fixed
points are considered unique if all absolute element-wise
differences are less than tol_unique AND the corresponding inputs
are unique following the same criteria. See FixedPoints.py for
additional detail.
all_fps: A FixedPoints object containing the likely redundant set
of fixed points (and associated metadata) resulting from ALL
initializations in initial_states (i.e., the full set of fixed
points before filtering out putative duplicates to yield
unique_fps).
"""
all_fps = self._fp_optimization(
initial_states,
ext_inputs,
*args,
stim_inp=stim_inp,
W_rec=W_rec,
)
# Filter out duplicates after from the first optimization round
unique_fps = all_fps.get_unique()
self._print_if_verbose("\tIdentified %d unique fixed points." % unique_fps.n)
if self.do_exclude_distance_outliers:
unique_fps = self._exclude_distance_outliers(unique_fps, initial_states)
# Optionally run additional optimization iterations on identified
# fixed points with q values on the large side of the q-distribution.
if self.do_rerun_q_outliers:
unique_fps = self._run_additional_iterations_on_outliers(
unique_fps,
*args,
stim_inp=stim_inp,
W_rec=W_rec,
n_rounds=n_rounds_q_opt,
)
# Filter out duplicates after from the second optimization round
unique_fps = unique_fps.get_unique()
# Optionally subselect from the unique fixed points (e.g., for
# computational savings when not all are needed.)
if unique_fps.n > self.max_n_unique:
self._print_if_verbose(
"\tRandomly selecting %d unique "
"fixed points to keep." % self.max_n_unique
)
max_n_unique = int(self.max_n_unique)
idx_keep = list(self.rng.choice(unique_fps.n, max_n_unique, replace=False))
unique_fps = unique_fps[idx_keep]
self._print_if_verbose("\tFixed point finding complete.\n")
return unique_fps, all_fps
# *************************************************************************
# Helper functions ********************************************************
# *************************************************************************
def _run_additional_iterations_on_outliers(
self,
fps: FixedPointCollection,
*args,
n_rounds: int = 1,
stim_inp: torch.Tensor | None = None,
W_rec: torch.Tensor | None = None,
) -> FixedPointCollection:
"""Rerun optimization on candidate fixed points with unusually large q.
Args:
fps (FixedPointCollection): Candidate fixed points from a prior run.
*args (str): Optional recurrent regions to optimize.
n_rounds (int): Number of rerun rounds to perform.
stim_inp (torch.Tensor | None): Optional stimulus input during reruns.
W_rec (torch.Tensor | None): Optional recurrent weight matrix override.
Returns:
FixedPointCollection: Updated fixed-point collection.
"""
"""
Known issue:
Additional iterations do not always reduce q! This may have to do
with learning rate schedules restarting from values that are too
large.
"""
assert fps.qstar is not None
outlier_min_q = float(np.median(fps.qstar) * self.outlier_q_scale)
def perform_outlier_optimization(
fps: FixedPointCollection,
) -> FixedPointCollection:
"""Optimize only the currently identified q outliers."""
idx_outliers = self.identify_q_outliers(fps, outlier_min_q)
outlier_fps = fps[idx_outliers.tolist()]
n_prev_iters = outlier_fps.n_iters
inputs = outlier_fps.inputs
initial_states = outlier_fps.xstar
self._print_if_verbose(
"\tPerforming another round of "
"joint optimization, "
"over outlier states only."
)
assert inputs is not None
assert n_prev_iters is not None
updated_outlier_fps = self._fp_optimization(
initial_states,
inputs,
*args,
stim_inp=stim_inp,
W_rec=W_rec,
)
assert updated_outlier_fps.n_iters is not None
updated_outlier_fps.n_iters += n_prev_iters
fps[idx_outliers.tolist()] = updated_outlier_fps
return fps
def outlier_update(fps: FixedPointCollection) -> torch.Tensor:
"""Identify the current q outliers and report their count."""
idx_outliers = self.identify_q_outliers(fps, outlier_min_q)
n_outliers = len(idx_outliers)
self._print_if_verbose(
"\n\tDetected %d putative outliers "
"(q>%.2e)." % (n_outliers, outlier_min_q)
)
return idx_outliers
idx_outliers = outlier_update(fps)
if len(idx_outliers) == 0:
return fps
for _ in range(n_rounds):
fps = perform_outlier_optimization(fps)
idx_outliers = outlier_update(fps)
if len(idx_outliers) == 0:
return fps
return fps
def _fp_optimization(
self,
initial_states: torch.Tensor,
ext_inp: torch.Tensor,
*args,
stim_inp: torch.Tensor | None = None,
W_rec: torch.Tensor | None = None,
) -> FixedPointCollection:
"""Optimize a batch of candidate states toward fixed points jointly.
Args:
initial_states (torch.Tensor): Initial recurrent states to optimize.
ext_inp (torch.Tensor): Constant external inputs paired with each state.
*args (str): Optional recurrent regions to optimize.
stim_inp (torch.Tensor | None): Optional stimulus input during optimization.
W_rec (torch.Tensor | None): Optional recurrent weight matrix override.
Returns:
FixedPointCollection: Optimized fixed points and optimization metadata.
"""
# Get batch and time dims
if self.batch_first:
TIME_DIM = 1
else:
TIME_DIM = 0
initial_states = self._broadcast_nxd(initial_states, tile_n=1)
# Get batch size of states
n = initial_states.shape[0]
# Broadcast external input to [n, 1, d]
ext_inp = self._broadcast_nxd(ext_inp, tile_n=n)
ext_inp = ext_inp.unsqueeze(TIME_DIM)
# Broadcast stimulus input to [n, 1, d]
if stim_inp is not None:
stim_inp = self._broadcast_nxd(stim_inp, tile_n=n)
stim_inp = stim_inp.unsqueeze(TIME_DIM)
else:
stim_inp = torch.zeros(size=(n, 1, 1))
stim_inp = stim_inp.to(self.torch_dtype)
stim_inp = stim_inp.to(self.device)
# assert the correct batch shapes
assert ext_inp.shape[0] == initial_states.shape[0]
assert stim_inp.shape[0] == initial_states.shape[0]
self._print_if_verbose(
"\nSearching for fixed points from %d initial states.\n" % n
)
# Ensure that fixed point optimization does not alter RNN parameters.
print(
"\tFreezing model parameters so model is not affected by fixed point optimization."
)
for p in self.rnn.parameters():
p.requires_grad = False
self._print_if_verbose("\tFinding fixed points via joint optimization.")
# initialize args to include all regions if empty
if not args:
args = list(self.rnn.region_dict.keys())
ext_inp.requires_grad = False
# Gather all of the regions to concatenate during training
# Get them region by region for proper optimization
region_tensor_list = []
region_to_opt_idx = []
for i, region in enumerate(self.rnn.region_dict):
act = self.rnn.get_region_activity(initial_states, region)
region_tensor_list.append(act)
if region in args:
act.requires_grad = True
region_to_opt_idx.append(i)
init_lr = self.lr_init
optimizer = torch.optim.Adam(
[region_tensor_list[idx] for idx in region_to_opt_idx], lr=self.lr_init
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=self.lr_factor, patience=self.lr_patience, cooldown=self.lr_cooldown, threshold=1e-10
)
iter_count = 1
iter_learning_rate = init_lr
t_start = time.time()
q_prev_b = torch.full((n,), float("nan"), device=self.device)
if W_rec is not None:
W_rec = W_rec.detach().clone()
while True:
h = torch.cat(region_tensor_list, dim=-1)
F_x_1xbxd = self.rnn(
ext_inp,
h,
stim_inp,
noise=False,
W_rec=W_rec,
)
F_x_1xbxd = F_x_1xbxd.squeeze(TIME_DIM)
h_prev = []
h_next = []
for region in args:
h_prev.append(self.rnn.get_region_activity(h, region))
h_next.append(self.rnn.get_region_activity(F_x_1xbxd, region))
h_prev, h_next = torch.cat(h_prev, dim=-1), torch.cat(h_next, dim=-1)
dx_bxd = h_prev - h_next
q_b = 0.5 * torch.sum(torch.square(dx_bxd), dim=-1)
q_scalar = torch.mean(q_b)
dq_b = torch.abs(q_b - q_prev_b)
optimizer.zero_grad()
q_scalar.backward()
optimizer.step()
scheduler.step(metrics=q_scalar.detach())
iter_learning_rate = scheduler.state_dict()["_last_lr"][0]
ev_q_b = q_b.detach().cpu()
ev_dq_b = dq_b.detach().cpu()
if (
self.super_verbose
and np.mod(iter_count, self.n_iters_per_print_update) == 0
):
self._print_iter_update(
iter_count, t_start, ev_q_b, ev_dq_b, iter_learning_rate
)
if iter_count > 1 and torch.all(
torch.logical_or(
ev_dq_b < self.tol_dq * iter_learning_rate, ev_q_b < self.tol_q
)
):
"""Here dq is scaled by the learning rate. Otherwise very
small steps due to very small learning rates would spuriously
indicate convergence. This scaling is roughly equivalent to
measuring the gradient norm."""
self._print_if_verbose("\tOptimization complete to desired tolerance.")
break
if iter_count + 1 > self.max_iters:
self._print_if_verbose(
"\tMaximum iteration count reached. Terminating."
)
break
q_prev_b = q_b
iter_count += 1
if self.verbose:
self._print_iter_update(
iter_count, t_start, ev_q_b, ev_dq_b, iter_learning_rate, is_final=True
)
# remove extra dims
# For now make the fixed point include all regions
full_fp = torch.cat(region_tensor_list, dim=-1)
xstar = full_fp.detach().cpu()
F_xstar = F_x_1xbxd.detach().cpu()
# Indicate same n_iters for each initialization (i.e., joint optimization)
n_iters = torch.tile(torch.tensor([iter_count]), dims=(F_xstar.shape[0],))
inputs_bxd = ext_inp.squeeze(TIME_DIM)
fps = FixedPointCollection(
xstar=xstar,
x_init=initial_states,
inputs=inputs_bxd,
F_xstar=F_xstar,
qstar=ev_q_b,
dq=ev_dq_b,
n_iters=n_iters,
tol_unique=self.tol_unique,
dtype=self.torch_dtype,
)
return fps
def _exclude_distance_outliers(
self, fps: FixedPointCollection, initial_states: torch.Tensor
) -> FixedPointCollection:
"""Drop fixed points that are too far from the supplied initial states."""
idx_keep = self.get_fp_non_distance_outliers(
fps, initial_states, self.outlier_distance_scale
)
return fps[idx_keep.tolist()]
def _print_if_verbose(self, *args, **kwargs):
"""Print only when verbose logging is enabled."""
if self.verbose:
print(*args, **kwargs)
@classmethod
def _print_iter_update(
cls,
iter_count: int,
t_start: float,
q: torch.Tensor,
dq: torch.Tensor,
lr: float,
is_final: bool = False,
):
"""Print a standardized optimization progress line."""
t = time.time()
t_elapsed = t - t_start
avg_iter_time = t_elapsed / iter_count
if is_final:
delimiter = "\n\t\t"
print("\t\t%d iters%s" % (iter_count, delimiter), end="")
else:
delimiter = ", "
print("\tIter: %d%s" % (iter_count, delimiter), end="")
if q.size == 1:
print("q = %.2e%sdq = %.2e%s" % (q, delimiter, dq, delimiter), end="")
else:
mean_q = torch.mean(q)
std_q = torch.std(q)
mean_dq = torch.mean(dq)
std_dq = torch.std(dq)
print(
"q = %.2e +/- %.2e%s"
"dq = %.2e +/- %.2e%s"
% (mean_q, std_q, delimiter, mean_dq, std_dq, delimiter),
end="",
)
print("learning rate = %.2e%s" % (lr, delimiter), end="")
print("avg iter time = %.2e sec" % avg_iter_time, end="")