| | """ |
| | Half-Life Regularizer for FDRA Oscillators |
| | |
| | This implements the exact mathematical regularizer from the Cursor instructions: |
| | |
| | ## Regularizer 1: Log-Uniform Half-Life Prior (primary) |
| | |
| | Target distribution: p(τ) ∝ 1/τ for τ ∈ [τ_min, τ_max] |
| | This gives equal mass per temporal decade (log scale). |
| | |
| | Loss: |
| | z_i = log(τ_i) |
| | μ = mean(z_i) |
| | σ² = mean((z_i - μ)²) |
| | |
| | μ* = (log(τ_min) + log(τ_max)) / 2 |
| | σ²* = (log(τ_max) - log(τ_min))² / 12 |
| | |
| | L_HL = α*(μ - μ*)² + β*(σ² - σ²*)² |
| | |
| | ## Regularizer 2: Long-Tail Survival Constraint (supporting) |
| | |
| | Ensure existence of long-range oscillators: |
| | s_i = σ(k * (τ_i - γ*L)) |
| | tail_mass = mean(s_i) |
| | L_tail = max(0, ρ - tail_mass)² |
| | |
| | Where: |
| | γ = 0.5 (fraction of full context) |
| | ρ = 0.05 (minimum fraction of oscillators) |
| | k = 10.0 (sigmoid sharpness) |
| | |
| | ## Regularizer 3: Tau Bounds Constraint (CRITICAL FIX) |
| | |
| | The moment-matching loss (L_HL) can be satisfied by a pathological bimodal |
| | distribution with taus outside [tau_min, tau_max]. This creates oscillators |
| | that are either useless (tau << 1) or extreme (tau >> L). |
| | |
| | L_bounds = mean(relu(tau_min - tau_i)^2) + mean(relu(tau_i - tau_max)^2) |
| | |
| | ## Combined Loss |
| | |
| | L_total = L_task + λ1 * L_HL + λ2 * L_tail + λ3 * L_bounds |
| | |
| | Authors: Half-Life Regularization Implementation |
| | Date: 2026-01-22 |
| | """ |
| |
|
| | import numpy as np |
| | from typing import Dict, Tuple, Optional, Any |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | import json |
| | from datetime import datetime |
| |
|
| |
|
| | @dataclass |
| | class HalfLifeRegularizerConfig: |
| | """Configuration for half-life regularization.""" |
| | |
| | |
| | sequence_length: int = 4096 |
| | tau_min: float = 1.0 |
| | tau_max: float = 4096.0 |
| | |
| | |
| | alpha: float = 1.0 |
| | beta: float = 1.0 |
| | |
| | |
| | gamma: float = 0.5 |
| | rho: float = 0.05 |
| | k: float = 10.0 |
| | |
| | |
| | lambda1: float = 0.01 |
| | lambda2: float = 0.01 |
| | |
| | |
| | lambda3: float = 0.1 |
| | bound_sharpness: float = 5.0 |
| |
|
| |
|
| | class HalfLifeRegularizer: |
| | """ |
| | Half-Life Regularizer for FDRA Oscillator Banks. |
| | |
| | Prevents decay parameter collapse by regularizing the half-life |
| | distribution toward a log-uniform target. |
| | |
| | Usage: |
| | config = HalfLifeRegularizerConfig() |
| | regularizer = HalfLifeRegularizer(config) |
| | |
| | # During training: |
| | lambdas = oscillator_bank.lambdas |
| | loss, metrics = regularizer.compute(lambdas) |
| | |
| | # Add to total loss: |
| | total_loss = task_loss + loss |
| | |
| | # Log metrics: |
| | log(metrics) |
| | """ |
| | |
| | def __init__(self, config: HalfLifeRegularizerConfig): |
| | self.config = config |
| | |
| | |
| | z_min = np.log(config.tau_min) |
| | z_max = np.log(config.tau_max) |
| | |
| | |
| | self.mu_star = (z_min + z_max) / 2.0 |
| | |
| | |
| | self.sigma2_star = (z_max - z_min) ** 2 / 12.0 |
| | |
| | |
| | self.tau_threshold = config.gamma * config.sequence_length |
| | |
| | def lambdas_to_half_lives(self, lambdas: np.ndarray) -> np.ndarray: |
| | """ |
| | Convert decay parameters to half-lives. |
| | |
| | τ_i = ln(0.5) / ln(λ_i) |
| | |
| | Args: |
| | lambdas: Decay parameters, shape (N,) |
| | |
| | Returns: |
| | taus: Half-lives, shape (N,) |
| | """ |
| | |
| | safe_lambdas = np.clip(lambdas, 1e-10, 1.0 - 1e-10) |
| | taus = np.log(0.5) / np.log(safe_lambdas) |
| | return taus |
| | |
| | def compute_log_uniform_loss( |
| | self, |
| | lambdas: np.ndarray |
| | ) -> Tuple[float, Dict[str, float]]: |
| | """ |
| | Compute Log-Uniform Half-Life Prior loss. |
| | |
| | L_HL = α*(μ - μ*)² + β*(σ² - σ²*)² |
| | |
| | Args: |
| | lambdas: Decay parameters, shape (N,) |
| | |
| | Returns: |
| | loss: Scalar loss value |
| | metrics: Dictionary of component metrics |
| | """ |
| | |
| | taus = self.lambdas_to_half_lives(lambdas) |
| | z = np.log(taus) |
| | |
| | |
| | mu = np.mean(z) |
| | sigma2 = np.var(z) |
| | |
| | |
| | mean_loss = self.config.alpha * (mu - self.mu_star) ** 2 |
| | var_loss = self.config.beta * (sigma2 - self.sigma2_star) ** 2 |
| | |
| | loss = mean_loss + var_loss |
| | |
| | metrics = { |
| | "log_tau_mean": float(mu), |
| | "log_tau_var": float(sigma2), |
| | "log_tau_target_mean": float(self.mu_star), |
| | "log_tau_target_var": float(self.sigma2_star), |
| | "mean_deviation": float(abs(mu - self.mu_star)), |
| | "var_deviation": float(abs(sigma2 - self.sigma2_star)), |
| | "log_uniform_loss": float(loss), |
| | } |
| | |
| | return float(loss), metrics |
| | |
| | def compute_long_tail_loss( |
| | self, |
| | lambdas: np.ndarray |
| | ) -> Tuple[float, Dict[str, float]]: |
| | """ |
| | Compute Long-Tail Survival Constraint loss. |
| | |
| | s_i = σ(k * (τ_i - γ*L)) |
| | tail_mass = mean(s_i) |
| | L_tail = max(0, ρ - tail_mass)² |
| | |
| | Args: |
| | lambdas: Decay parameters, shape (N,) |
| | |
| | Returns: |
| | loss: Scalar loss value |
| | metrics: Dictionary of component metrics |
| | """ |
| | |
| | taus = self.lambdas_to_half_lives(lambdas) |
| | |
| | |
| | |
| | x = self.config.k * (taus - self.tau_threshold) |
| | x = np.clip(x, -500, 500) |
| | s = 1.0 / (1.0 + np.exp(-x)) |
| | |
| | |
| | tail_mass = np.mean(s) |
| | |
| | |
| | deficit = max(0, self.config.rho - tail_mass) |
| | loss = deficit ** 2 |
| | |
| | |
| | n_long_range = np.sum(taus > self.tau_threshold) |
| | frac_long_range = n_long_range / len(taus) |
| | |
| | metrics = { |
| | "tail_mass": float(tail_mass), |
| | "tail_target": float(self.config.rho), |
| | "tail_deficit": float(deficit), |
| | "n_long_range": int(n_long_range), |
| | "frac_long_range": float(frac_long_range), |
| | "tau_threshold": float(self.tau_threshold), |
| | "long_tail_loss": float(loss), |
| | } |
| | |
| | return float(loss), metrics |
| | |
| | def compute_bounds_loss( |
| | self, |
| | lambdas: np.ndarray |
| | ) -> Tuple[float, Dict[str, float]]: |
| | """ |
| | Compute tau bounds constraint loss. |
| | |
| | CRITICAL FIX: The moment-matching loss alone can be satisfied by |
| | a pathological bimodal distribution with taus outside [tau_min, tau_max]. |
| | |
| | This loss penalizes taus below tau_min or above tau_max: |
| | |
| | L_bounds = mean(relu(tau_min - tau_i)^2) + mean(relu(tau_i - tau_max)^2) |
| | |
| | Uses soft penalty with configurable sharpness. |
| | """ |
| | taus = self.lambdas_to_half_lives(lambdas) |
| | k = self.config.bound_sharpness |
| | |
| | |
| | below_min = np.maximum(0, self.config.tau_min - taus) |
| | lower_penalty = np.mean((k * below_min) ** 2) |
| | |
| | |
| | above_max = np.maximum(0, taus - self.config.tau_max) |
| | upper_penalty = np.mean((k * above_max) ** 2) |
| | |
| | loss = lower_penalty + upper_penalty |
| | |
| | n_below_min = np.sum(taus < self.config.tau_min) |
| | n_above_max = np.sum(taus > self.config.tau_max) |
| | |
| | metrics = { |
| | "bounds_loss": float(loss), |
| | "lower_bound_penalty": float(lower_penalty), |
| | "upper_bound_penalty": float(upper_penalty), |
| | "n_below_tau_min": int(n_below_min), |
| | "n_above_tau_max": int(n_above_max), |
| | "frac_in_bounds": float(1 - (n_below_min + n_above_max) / len(taus)), |
| | } |
| | |
| | return float(loss), metrics |
| | |
| | def compute(self, lambdas: np.ndarray) -> Tuple[float, Dict[str, Any]]: |
| | """ |
| | Compute total half-life regularization loss. |
| | |
| | L_total = λ1 * L_HL + λ2 * L_tail + λ3 * L_bounds |
| | |
| | CRITICAL: L_bounds prevents the pathological case where moment-matching |
| | is satisfied by a bimodal distribution with taus outside [tau_min, tau_max]. |
| | |
| | Args: |
| | lambdas: Decay parameters, shape (N,) |
| | |
| | Returns: |
| | loss: Total regularization loss |
| | metrics: All component metrics |
| | """ |
| | |
| | log_uniform_loss, log_uniform_metrics = self.compute_log_uniform_loss(lambdas) |
| | long_tail_loss, long_tail_metrics = self.compute_long_tail_loss(lambdas) |
| | bounds_loss, bounds_metrics = self.compute_bounds_loss(lambdas) |
| | |
| | |
| | total_loss = ( |
| | self.config.lambda1 * log_uniform_loss + |
| | self.config.lambda2 * long_tail_loss + |
| | self.config.lambda3 * bounds_loss |
| | ) |
| | |
| | |
| | taus = self.lambdas_to_half_lives(lambdas) |
| | |
| | metrics = { |
| | "total_regularization_loss": float(total_loss), |
| | "log_uniform_component": float(self.config.lambda1 * log_uniform_loss), |
| | "long_tail_component": float(self.config.lambda2 * long_tail_loss), |
| | "bounds_component": float(self.config.lambda3 * bounds_loss), |
| | "tau_min": float(np.min(taus)), |
| | "tau_max": float(np.max(taus)), |
| | "tau_mean": float(np.mean(taus)), |
| | "tau_median": float(np.median(taus)), |
| | **log_uniform_metrics, |
| | **long_tail_metrics, |
| | **bounds_metrics, |
| | } |
| | |
| | return float(total_loss), metrics |
| | |
| | def compute_gradient( |
| | self, |
| | lambdas: np.ndarray, |
| | epsilon: float = 1e-5 |
| | ) -> np.ndarray: |
| | """ |
| | Compute gradient of regularization loss w.r.t. lambdas. |
| | |
| | Uses finite differences for simplicity. |
| | In a real implementation, this would use autodiff. |
| | |
| | Args: |
| | lambdas: Decay parameters, shape (N,) |
| | epsilon: Perturbation size |
| | |
| | Returns: |
| | grad: Gradient, shape (N,) |
| | """ |
| | grad = np.zeros_like(lambdas) |
| | |
| | for i in range(len(lambdas)): |
| | |
| | lambdas_plus = lambdas.copy() |
| | lambdas_plus[i] += epsilon |
| | loss_plus, _ = self.compute(lambdas_plus) |
| | |
| | |
| | lambdas_minus = lambdas.copy() |
| | lambdas_minus[i] -= epsilon |
| | loss_minus, _ = self.compute(lambdas_minus) |
| | |
| | |
| | grad[i] = (loss_plus - loss_minus) / (2 * epsilon) |
| | |
| | return grad |
| | |
| | def diagnose(self, lambdas: np.ndarray) -> str: |
| | """ |
| | Generate diagnostic string for current half-life distribution. |
| | |
| | Args: |
| | lambdas: Decay parameters |
| | |
| | Returns: |
| | Diagnostic string |
| | """ |
| | loss, metrics = self.compute(lambdas) |
| | taus = self.lambdas_to_half_lives(lambdas) |
| | |
| | lines = [ |
| | "=" * 60, |
| | "HALF-LIFE REGULARIZER DIAGNOSTICS", |
| | "=" * 60, |
| | "", |
| | "Current Distribution:", |
| | f" τ range: [{metrics['tau_min']:.1f}, {metrics['tau_max']:.1f}]", |
| | f" τ mean: {metrics['tau_mean']:.1f}", |
| | f" τ median: {metrics['tau_median']:.1f}", |
| | "", |
| | "Target Distribution:", |
| | f" τ range: [{self.config.tau_min}, {self.config.tau_max}]", |
| | f" log(τ) target mean: {self.mu_star:.3f}", |
| | f" log(τ) target var: {self.sigma2_star:.3f}", |
| | "", |
| | "Log-Uniform Prior:", |
| | f" log(τ) mean: {metrics['log_tau_mean']:.3f} (target: {metrics['log_tau_target_mean']:.3f})", |
| | f" log(τ) var: {metrics['log_tau_var']:.3f} (target: {metrics['log_tau_target_var']:.3f})", |
| | f" Mean deviation: {metrics['mean_deviation']:.3f}", |
| | f" Var deviation: {metrics['var_deviation']:.3f}", |
| | f" Loss: {metrics['log_uniform_loss']:.6f}", |
| | "", |
| | "Long-Tail Survival:", |
| | f" Threshold: τ > {metrics['tau_threshold']:.1f}", |
| | f" Long-range count: {metrics['n_long_range']}/{len(lambdas)} ({metrics['frac_long_range']:.1%})", |
| | f" Tail mass (soft): {metrics['tail_mass']:.3f} (target: {metrics['tail_target']:.3f})", |
| | f" Loss: {metrics['long_tail_loss']:.6f}", |
| | "", |
| | "Total Regularization Loss:", |
| | f" Log-uniform component: {metrics['log_uniform_component']:.6f}", |
| | f" Long-tail component: {metrics['long_tail_component']:.6f}", |
| | f" Total: {metrics['total_regularization_loss']:.6f}", |
| | "", |
| | ] |
| | |
| | |
| | lines.append("Half-Life Histogram (log scale):") |
| | bins = np.logspace(0, np.log10(self.config.tau_max), 11) |
| | hist, _ = np.histogram(taus, bins=bins) |
| | for i, count in enumerate(hist): |
| | bar = "█" * count |
| | lines.append(f" [{bins[i]:7.1f}, {bins[i+1]:7.1f}): {count:2d} {bar}") |
| | |
| | lines.append("") |
| | lines.append("=" * 60) |
| | |
| | return "\n".join(lines) |
| |
|
| |
|
| | def simulate_collapse_and_recovery(): |
| | """ |
| | Simulate the half-life collapse problem and demonstrate regularization. |
| | |
| | This shows: |
| | 1. Initial log-uniform distribution (good) |
| | 2. Simulated collapse to short half-lives (bad, mimics training at scale) |
| | 3. Regularization gradient direction (recovery) |
| | """ |
| | print("=" * 70) |
| | print("HALF-LIFE COLLAPSE AND REGULARIZATION DEMONSTRATION") |
| | print("=" * 70) |
| | |
| | config = HalfLifeRegularizerConfig( |
| | sequence_length=4096, |
| | tau_min=1.0, |
| | tau_max=4096.0, |
| | lambda1=0.01, |
| | lambda2=0.01 |
| | ) |
| | |
| | regularizer = HalfLifeRegularizer(config) |
| | |
| | |
| | print("\n1. INITIAL DISTRIBUTION (Log-Uniform)") |
| | print("-" * 60) |
| | |
| | n_oscillators = 32 |
| | log_taus_init = np.linspace(np.log(1.0), np.log(4096.0), n_oscillators) |
| | taus_init = np.exp(log_taus_init) |
| | lambdas_init = np.power(0.5, 1.0 / taus_init) |
| | |
| | loss_init, metrics_init = regularizer.compute(lambdas_init) |
| | print(f" Half-lives: [{metrics_init['tau_min']:.1f}, {metrics_init['tau_max']:.1f}]") |
| | print(f" Regularization loss: {loss_init:.6f}") |
| | print(f" Long-range oscillators: {metrics_init['n_long_range']}/{n_oscillators}") |
| | |
| | |
| | print("\n2. COLLAPSED DISTRIBUTION (Training at Scale)") |
| | print("-" * 60) |
| | print(" Simulating what happens during GPT-2 scale training...") |
| | |
| | |
| | taus_collapsed = np.random.uniform(2, 10, n_oscillators) |
| | lambdas_collapsed = np.power(0.5, 1.0 / taus_collapsed) |
| | |
| | loss_collapsed, metrics_collapsed = regularizer.compute(lambdas_collapsed) |
| | print(f" Half-lives: [{metrics_collapsed['tau_min']:.1f}, {metrics_collapsed['tau_max']:.1f}]") |
| | print(f" Regularization loss: {loss_collapsed:.6f} ({loss_collapsed/loss_init:.0f}x initial)") |
| | print(f" Long-range oscillators: {metrics_collapsed['n_long_range']}/{n_oscillators}") |
| | |
| | |
| | print("\n3. REGULARIZATION GRADIENT ANALYSIS") |
| | print("-" * 60) |
| | |
| | grad = regularizer.compute_gradient(lambdas_collapsed) |
| | |
| | print(" Gradient direction indicates how to adjust λ_i to reduce loss:") |
| | print(" (Negative gradient → increase λ → longer half-life)") |
| | print() |
| | |
| | |
| | for i in range(min(5, n_oscillators)): |
| | tau_i = taus_collapsed[i] |
| | grad_i = grad[i] |
| | direction = "→ increase τ" if grad_i < 0 else "→ decrease τ" |
| | print(f" Osc {i}: τ={tau_i:.1f}, grad={grad_i:+.4f} {direction}") |
| | |
| | print(f" ... ({n_oscillators - 5} more)") |
| | print(f"\n Mean gradient magnitude: {np.mean(np.abs(grad)):.4f}") |
| | |
| | |
| | print("\n4. AFTER REGULARIZATION STEP") |
| | print("-" * 60) |
| | |
| | lr = 1.0 |
| | lambdas_reg = lambdas_collapsed - lr * grad |
| | lambdas_reg = np.clip(lambdas_reg, 0.01, 0.9999) |
| | |
| | loss_reg, metrics_reg = regularizer.compute(lambdas_reg) |
| | |
| | print(f" Half-lives: [{metrics_reg['tau_min']:.1f}, {metrics_reg['tau_max']:.1f}]") |
| | print(f" Regularization loss: {loss_reg:.6f} ({loss_reg/loss_collapsed:.1%} of collapsed)") |
| | print(f" Long-range oscillators: {metrics_reg['n_long_range']}/{n_oscillators}") |
| | |
| | |
| | print("\n5. SUMMARY") |
| | print("-" * 60) |
| | print(f""" |
| | State | Loss | τ range | Long-range |
| | -------------------|-----------|-----------------|------------ |
| | Initial (good) | {loss_init:.6f} | [{metrics_init['tau_min']:.1f}, {metrics_init['tau_max']:.1f}] | {metrics_init['n_long_range']}/{n_oscillators} |
| | Collapsed (bad) | {loss_collapsed:.6f} | [{metrics_collapsed['tau_min']:.1f}, {metrics_collapsed['tau_max']:.1f}] | {metrics_collapsed['n_long_range']}/{n_oscillators} |
| | After 1 reg step | {loss_reg:.6f} | [{metrics_reg['tau_min']:.1f}, {metrics_reg['tau_max']:.1f}] | {metrics_reg['n_long_range']}/{n_oscillators} |
| | """) |
| | |
| | print("=" * 70) |
| | print("CONCLUSION:") |
| | print(" The regularizer provides gradients that push collapsed half-lives") |
| | print(" back toward a log-uniform distribution spanning the full context.") |
| | print("=" * 70) |
| | |
| | return { |
| | "initial": {"loss": loss_init, "metrics": metrics_init}, |
| | "collapsed": {"loss": loss_collapsed, "metrics": metrics_collapsed}, |
| | "regularized": {"loss": loss_reg, "metrics": metrics_reg}, |
| | } |
| |
|
| |
|
| | if __name__ == "__main__": |
| | simulate_collapse_and_recovery() |
| |
|