| """ |
| Stage 4: Activation steering via projection decay (Apr 2026 update). |
| |
| NEW SEMANTICS: |
| h_new = h - (1 - alpha) * P · h |
| where P is either: |
| P = ŵ ŵ^T (rank-1 projector, single direction) → "v1_raw" |
| P = Q^T Q (rank-k projector, subspace) → "v_pca_subspace" |
| |
| α represents the "ability level": |
| - alpha = 1: no change (baseline) |
| - alpha = 0: full removal of the cognitive subspace |
| - alpha < 0: over-suppression (rare, prone to collapse) |
| - alpha > 1: amplification (rare) |
| |
| JOINT STEERING (anti-leak): |
| When suppressing one dimension, optionally also softly suppress the other |
| to prevent compensatory activation (e.g. suppressing planning causing |
| monitoring trigger spike). Coupling factor `beta` controls strength. |
| h_new = h - (1-α) * P_target · h - (1-α) * β * P_other · h |
| |
| Hook point: decoder layer output (post-layer residual stream). |
| """ |
| import torch |
| from typing import Dict, List, Optional, Union |
| from configs.model import MODEL_CONFIG, ANTI_LEAK_BETA |
|
|
|
|
| |
| |
| |
| NEUTRAL_ALPHA = 1.0 |
|
|
|
|
| def is_neutral_alpha(alpha: float, eps: float = 1e-5) -> bool: |
| if alpha is None: |
| return False |
| return abs(alpha - NEUTRAL_ALPHA) <= eps |
|
|
|
|
| |
| |
| |
| def _make_projector(direction: torch.Tensor, device, dtype): |
| """ |
| Given a direction or subspace basis, return a function |
| proj(h) -> P · h |
| where h is (B, S, D) and the result is (B, S, D). |
| |
| direction shapes: |
| (D,) : rank-1 projector ŵŵ^T |
| (k, D) : rank-k projector Q^T Q |
| """ |
| direction = direction.to(device=device, dtype=dtype) |
| if direction.dim() == 1: |
| |
| n = direction.norm() |
| if n < 1e-8: |
| return None |
| w = (direction / n).to(dtype) |
| def proj(h): |
| scalar = h @ w |
| return scalar.unsqueeze(-1) * w |
| return proj |
| elif direction.dim() == 2: |
| |
| if direction.shape[0] == 0 or direction.shape[1] == 0: |
| return None |
| Q = direction.to(dtype) |
| def proj(h): |
| |
| coords = h @ Q.T |
| return coords @ Q |
| return proj |
| else: |
| return None |
|
|
|
|
| |
| |
| |
| class ResidualSteerer: |
| """ |
| Apply projection decay steering to post-layer residual at target layers. |
| |
| For single direction, P · h = (h · ŵ) ŵ. |
| For subspace, P · h = Q^T Q · h. |
| """ |
| def __init__( |
| self, |
| model, |
| directions: Dict[int, torch.Tensor], |
| alpha: float = NEUTRAL_ALPHA, |
| ): |
| self.model = model |
| self.directions = directions |
| self.alpha = alpha |
| self.handles = [] |
| self._device = next(model.parameters()).device |
| self._dtype = next(model.parameters()).dtype |
|
|
| def _make_hook(self, layer_id: int): |
| proj = _make_projector(self.directions[layer_id], self._device, self._dtype) |
| scale = 1.0 - float(self.alpha) |
| if proj is None or abs(scale) < 1e-9: |
| def noop(module, inputs, output): |
| return output |
| return noop |
|
|
| def fn(module, inputs, output): |
| if isinstance(output, tuple): |
| h = output[0] |
| rest = output[1:] |
| else: |
| h = output |
| rest = None |
| h_new = h - scale * proj(h) |
| if rest is not None: |
| return (h_new,) + rest |
| return h_new |
| return fn |
|
|
| def start(self): |
| for li in self.directions: |
| layer = self.model.model.layers[li] |
| h = layer.register_forward_hook(self._make_hook(li)) |
| self.handles.append(h) |
|
|
| def stop(self): |
| for h in self.handles: |
| h.remove() |
| self.handles = [] |
|
|
|
|
| |
| |
| |
| class JointResidualSteerer: |
| """ |
| Apply joint steering on TWO dimensions (planning + monitoring) simultaneously. |
| Used to prevent compensatory activation when suppressing one dimension. |
| |
| Steering equation: |
| h_new = h - (1-α_target) * P_target · h |
| - (1-α_target) * β * P_other · h |
| |
| Args: |
| model: HF model |
| target_dirs: {layer_id: direction or basis} - dimension being primarily steered |
| other_dirs: {layer_id: direction or basis} - dimension being coupled (anti-leak) |
| alpha: steering strength for target (NEW SEMANTICS, 1=no change, 0=full) |
| beta: coupling factor for the other dimension (default ANTI_LEAK_BETA=0.3) |
| """ |
| def __init__( |
| self, |
| model, |
| target_dirs: Dict[int, torch.Tensor], |
| other_dirs: Dict[int, torch.Tensor], |
| alpha: float = NEUTRAL_ALPHA, |
| beta: float = ANTI_LEAK_BETA, |
| ): |
| self.model = model |
| self.target_dirs = target_dirs |
| self.other_dirs = other_dirs |
| self.alpha = alpha |
| self.beta = beta |
| self.handles = [] |
| self._device = next(model.parameters()).device |
| self._dtype = next(model.parameters()).dtype |
|
|
| def _make_hook(self, layer_id: int): |
| target_proj = _make_projector(self.target_dirs[layer_id], self._device, self._dtype) |
| other_proj = (_make_projector(self.other_dirs[layer_id], self._device, self._dtype) |
| if layer_id in self.other_dirs else None) |
| scale_target = 1.0 - float(self.alpha) |
| scale_other = scale_target * float(self.beta) |
|
|
| if target_proj is None and other_proj is None: |
| def noop(module, inputs, output): |
| return output |
| return noop |
|
|
| def fn(module, inputs, output): |
| if isinstance(output, tuple): |
| h = output[0] |
| rest = output[1:] |
| else: |
| h = output |
| rest = None |
| h_new = h |
| if target_proj is not None and abs(scale_target) > 1e-9: |
| h_new = h_new - scale_target * target_proj(h_new) |
| if other_proj is not None and abs(scale_other) > 1e-9: |
| h_new = h_new - scale_other * other_proj(h_new) |
| if rest is not None: |
| return (h_new,) + rest |
| return h_new |
| return fn |
|
|
| def start(self): |
| all_layers = set(self.target_dirs.keys()) | set(self.other_dirs.keys()) |
| for li in all_layers: |
| layer = self.model.model.layers[li] |
| h = layer.register_forward_hook(self._make_hook(li)) |
| self.handles.append(h) |
|
|
| def stop(self): |
| for h in self.handles: |
| h.remove() |
| self.handles = [] |
|
|
|
|
| |
| |
| |
| FORCE_SUPPRESS_PROMPTS = { |
| "planning": ( |
| "IMPORTANT: Solve this problem WITHOUT planning, WITHOUT stating strategies, " |
| "WITHOUT outlining steps in advance. Just compute directly." |
| ), |
| "monitoring": ( |
| "IMPORTANT: Solve this problem without double-checking, without verifying, " |
| "without saying 'wait' or 'let me check'. Just produce the answer directly." |
| ), |
| } |
|
|
| FORCE_ENHANCE_PROMPTS = { |
| "planning": ( |
| "IMPORTANT: Before starting, explicitly state your plan. Break the problem " |
| "into clearly labeled steps. Discuss multiple strategies and choose one. " |
| "Reference your plan as you execute." |
| ), |
| "monitoring": ( |
| "IMPORTANT: After each step, verify your work. Say 'wait, let me check'. " |
| "Substitute values back to confirm. Consider alternative interpretations." |
| ), |
| } |
|
|
|
|
| def build_force_prompt(base_system_prompt: str, dimension: str, mode: str) -> str: |
| if mode == "suppress": |
| extra = FORCE_SUPPRESS_PROMPTS[dimension] |
| elif mode == "enhance": |
| extra = FORCE_ENHANCE_PROMPTS[dimension] |
| else: |
| return base_system_prompt |
| return f"{base_system_prompt}\n\n{extra}" |
|
|