abpt / src /model /equilibrium.py
Search
auto: sync run_testformer_wikitext_combo_remote.py
f37be5a
"""Equilibrium Signal β€” unified routing/confidence/plasticity trigger.
Computes deviation of activations from running mean (accumulated during training).
Near-zero overhead: reuses LayerNorm statistics.
ED(x) = || (x - mu_running) / sigma_running ||
Small ED β†’ forward pass (confident)
Medium ED β†’ branching (uncertain, explore alternatives)
Large ED β†’ backward pass (re-process through earlier layers)
Critical ED β†’ plastic activation (adapt to new context)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class EquilibriumSignal(nn.Module):
def __init__(self, d_model: int, momentum: float = 0.1, warmup_steps: int = 50):
super().__init__()
self.d_model = d_model
self.momentum = momentum
self.warmup_steps = warmup_steps
# Running statistics (like BatchNorm)
self.register_buffer("running_mean", torch.zeros(d_model))
self.register_buffer("running_var", torch.ones(d_model))
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
@property
def is_warming_up(self) -> bool:
return self.num_batches_tracked.item() < self.warmup_steps
def forward(self, x: torch.Tensor) -> dict:
"""Compute equilibrium deviation for each token.
Args:
x: [B, T, D] β€” layer output activations
Returns:
dict with:
ed: [B, T] β€” equilibrium deviation per token
x: [B, T, D] β€” unchanged input (pass-through)
warming_up: bool β€” True if still in warmup phase
"""
# Update running stats during training
if self.training:
with torch.no_grad():
batch_mean = x.detach().mean(dim=(0, 1)) # [D]
batch_var = x.detach().var(dim=(0, 1)) # [D]
self.running_mean.mul_(1 - self.momentum).add_(batch_mean, alpha=self.momentum)
self.running_var.mul_(1 - self.momentum).add_(batch_var, alpha=self.momentum)
self.num_batches_tracked += 1
# Compute ED: normalized distance from running mean
# [B, T, D]
normalized = (x - self.running_mean) / (self.running_var.sqrt() + 1e-8)
# [B, T] β€” L2 norm over feature dim, normalized by sqrt(d_model)
ed = normalized.norm(dim=-1) / (self.d_model ** 0.5)
return {"ed": ed, "x": x, "warming_up": self.is_warming_up}
class RoutingDecision(nn.Module):
"""Converts equilibrium deviation into routing decisions.
Buckets are calibrated from running ED quantiles so Stage B does not collapse
to a single route just because the absolute ED scale shifted.
Small learnable offsets allow training to nudge boundaries around the
quantile-derived defaults.
"""
def __init__(
self,
init_thresholds: tuple[float, float, float] = (0.75, 1.0, 1.35),
target_fractions: tuple[float, float, float, float] = (0.55, 0.25, 0.15, 0.05),
threshold_momentum: float = 0.2,
temperature: float = 8.0,
offset_scale: float = 0.2,
):
super().__init__()
fractions = torch.tensor(target_fractions, dtype=torch.float32)
fractions = fractions / fractions.sum().clamp_min(1e-8)
self.register_buffer("target_cdf", fractions.cumsum(dim=0)[:-1])
self.register_buffer("running_thresholds", torch.tensor(init_thresholds, dtype=torch.float32))
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
self.threshold_momentum = threshold_momentum
self.temperature = temperature
self.offset_scale = offset_scale
self.threshold_offsets = nn.Parameter(torch.zeros(3))
def _batch_thresholds(self, ed: torch.Tensor) -> torch.Tensor:
flat = ed.detach().reshape(-1)
if flat.numel() == 0:
return self.running_thresholds
return torch.quantile(flat, self.target_cdf.to(device=ed.device, dtype=flat.dtype))
def _ordered_thresholds(self, base: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
offsets = self.offset_scale * torch.tanh(self.threshold_offsets).to(base.device, base.dtype)
raw = base + offsets
min_gap = torch.tensor(1e-3, device=base.device, dtype=base.dtype)
t1 = raw[0]
t2 = torch.maximum(raw[1], t1 + min_gap)
t3 = torch.maximum(raw[2], t2 + min_gap)
return t1, t2, t3
@property
def theta1(self) -> torch.Tensor:
thresholds = self._ordered_thresholds(self.running_thresholds)
return thresholds[0]
@property
def theta2(self) -> torch.Tensor:
thresholds = self._ordered_thresholds(self.running_thresholds)
return thresholds[1]
@property
def theta3(self) -> torch.Tensor:
thresholds = self._ordered_thresholds(self.running_thresholds)
return thresholds[2]
def forward(self, ed: torch.Tensor) -> dict:
"""Classify each token into routing buckets.
Args:
ed: [B, T] β€” equilibrium deviation
Returns:
dict with:
route: [B, T] β€” 0=forward, 1=branch, 2=backward, 3=plastic
route_probs: [B, T, 4] β€” soft routing probabilities
"""
if self.training:
batch_thresholds = self._batch_thresholds(ed)
with torch.no_grad():
self.running_thresholds.mul_(1 - self.threshold_momentum).add_(
batch_thresholds.to(self.running_thresholds.device, self.running_thresholds.dtype),
alpha=self.threshold_momentum,
)
self.num_batches_tracked += 1
base_thresholds = self.running_thresholds.to(device=ed.device, dtype=ed.dtype)
elif self.num_batches_tracked.item() > 0:
base_thresholds = self.running_thresholds.to(device=ed.device, dtype=ed.dtype)
else:
base_thresholds = self._batch_thresholds(ed).to(device=ed.device, dtype=ed.dtype)
t1, t2, t3 = self._ordered_thresholds(base_thresholds)
left_width = (t2 - t1).clamp_min(1e-3)
right_width = (t3 - t2).clamp_min(1e-3)
centers = torch.stack(
[
t1 - left_width,
(t1 + t2) * 0.5,
(t2 + t3) * 0.5,
t3 + right_width,
]
)
logits = -self.temperature * (ed.unsqueeze(-1) - centers).abs()
probs = torch.softmax(logits, dim=-1)
thresholds = torch.stack([t1, t2, t3])
route = torch.bucketize(ed, thresholds)
return {"route": route, "route_probs": probs, "thresholds": thresholds}
class TokenEnergyBudget(nn.Module):
"""Limits compute per token based on ED.
Low ED β†’ 1 pass (minimum compute)
Medium ED β†’ 2 passes (branching)
High ED β†’ 3+ passes (backward + re-process)
Total budget across all tokens is capped.
"""
def __init__(self, max_budget_per_token: int = 4, total_budget_ratio: float = 2.0):
super().__init__()
self.max_per_token = max_budget_per_token
self.total_budget_ratio = total_budget_ratio
def forward(self, ed: torch.Tensor, route_probs: torch.Tensor) -> torch.Tensor:
"""Compute energy budget per token.
Args:
ed: [B, T] β€” equilibrium deviation
route_probs: [B, T, 4] β€” routing probabilities
Returns:
budget: [B, T] β€” integer compute budget per token (1 to max_per_token)
"""
B, T = ed.shape
total_budget = int(T * self.total_budget_ratio)
# Base budget from route: forward=1, branch=2, backward=3, plastic=4
base_costs = torch.tensor([1.0, 2.0, 3.0, 4.0], device=ed.device)
expected_cost = (route_probs * base_costs).sum(dim=-1) # [B, T]
# Scale to fit total budget
cost_sum = expected_cost.sum(dim=-1, keepdim=True) # [B, 1]
scale = total_budget / (cost_sum + 1e-8)
scale = scale.clamp(max=1.0) # don't inflate, only deflate
budget = (expected_cost * scale).clamp(min=1, max=self.max_per_token)
return budget.round().long()