agiformer / src /models /memory.py
tefoteknik's picture
Phase 7: Curriculum Learning (20K steps, BPC 1.78)
344be51 verified
## Developer: inkbytefo
## Modified: 2025-11-23
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class HebbianMemory(nn.Module):
"""
Hebbian Memory Module (Fast Weights).
Implements the update rule:
M_t = lambda * M_{t-1} + K_t * V_t^T
O_t = Q_t * M_t
CRITICAL CHANGE:
To prevent numerical overflow in parallel computation (cumsum),
the decay rate (lambda) is constrained to the range [0.99, 1.0].
This ensures lambda^(-L) does not explode for L=1024.
"""
def __init__(self, d_model, num_heads=8, dropout=0.1):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
# Feature map: ELU + 1 ensures positivity for valid probability kernel
self.feature_map = nn.ELU()
# Learnable Decay Parameter
# Initialized to generate sigmoid output ~0.5, mapped to range later
self.decay_logits = nn.Parameter(torch.zeros(num_heads))
self.norm = nn.LayerNorm(d_model)
# Plasticity Factor (Alpha) - Controlled externally
self.plasticity = 1.0
def set_plasticity(self, alpha):
"""
Updates the plasticity coefficient (alpha).
alpha: float in [0, 1].
0.1 -> Childhood (Fast forgetting)
0.99 -> Adulthood (Stable memory)
"""
self.plasticity = alpha
@torch.amp.autocast('cuda', enabled=False)
def forward(self, x):
# CRITICAL: Bypass AMP for this entire module to prevent NaN
# With plasticity=0.1, decay factors become exp(±50) and the cumsum
# operations accumulate massive intermediate values that overflow in float16
# We must use float32 for all computations including linear layers
x = x.float() # Ensure input is float32
input_dtype = x.dtype
B, L, D = x.shape
H = self.num_heads
E = self.head_dim
# 1. Projections
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
# Reshape (B, L, H, E)
q = q.view(B, L, H, E)
k = k.view(B, L, H, E)
v = v.view(B, L, H, E)
# 2. Feature Map (Kernel Trick)
q = self.feature_map(q) + 1.0
k = self.feature_map(k) + 1.0
# Scale Q to prevent magnitude explosion
q = q / math.sqrt(E)
# 3. Decay Factor (Lambda) - STABILIZED
# Map sigmoid (0,1) to (0.990, 1.0)
# This prevents overflow. 0.99^-1024 ~= 29468 (Safe for FP32)
raw_sigmoid = torch.sigmoid(self.decay_logits).view(1, 1, H, 1)
lambdas = 0.99 + (0.01 * raw_sigmoid)
# Apply Plasticity Schedule
# Effective Lambda = Lambda * Alpha
# If Alpha is low (childhood), decay is very fast.
lambdas = lambdas * self.plasticity
# 4. Parallel Hebbian Update
# Formula: O_i = (Q_i * sum_{j=1}^i lambda^{i-j} K_j^T V_j)
# Implementation: Q_i * lambda^i * cumsum(lambda^-j * K_j * V_j)
indices = torch.arange(L, device=x.device, dtype=torch.float32).view(1, L, 1, 1)
# Use log-space arithmetic to prevent overflow/underflow
log_lambdas = torch.log(lambdas.clamp(min=1e-10))
# Clamp the exponent BEFORE exp() to prevent overflow
# We use ±50 as a safe range that works for float32
exp_k = (-indices * log_lambdas).clamp(min=-50, max=50)
exp_q = (indices * log_lambdas).clamp(min=-50, max=50)
# Compute decay factors
decay_k = torch.exp(exp_k) # lambda^-indices
decay_q = torch.exp(exp_q) # lambda^indices
k_decayed = k * decay_k
# Memory State Accumulation (KV)
# (B, L, H, E) * (B, L, H, E) -> (B, L, H, E, E)
# Einsum: b l h e, b l h f -> b l h e f
kv = torch.einsum('blhe,blhf->blhef', k_decayed, v)
# Cumsum (The "Write" Operation)
memory_state = torch.cumsum(kv, dim=1) # (B, L, H, E, E)
# Denominator Accumulation (Z) for normalization
k_sum_decayed = torch.cumsum(k_decayed, dim=1) # (B, L, H, E)
# Read Operation (Query * Memory)
q_decayed = q * decay_q
# Num: (B, L, H, E) * (B, L, H, E, E) -> (B, L, H, E)
num = torch.einsum('blhe,blhef->blhf', q_decayed, memory_state)
# Den: (B, L, H, E) * (B, L, H, E) -> (B, L, H)
den = torch.einsum('blhe,blhe->blh', q_decayed, k_sum_decayed)
den = den.unsqueeze(-1) + 1e-6 # Stability epsilon
out = num / den
# Final Projection
out = out.reshape(B, L, D)
out = self.out_proj(out)
# Convert back to input dtype before applying norm and dropout
out = self.dropout(self.norm(out))
return out.to(input_dtype) # Convert back to original dtype