Spaces:
Paused
Add margin-based decision analysis, interventional counterfactuals, and run comparison (v3.0)
Browse filesPhase 1: Per-token margin computation with stability classification (stable/moderate/boundary/fragile),
layer-wise margin tracking via logit lens, commitment layer detection, flip event detection,
causal margin contribution decomposition per layer ((W_U[winner] - W_U[runner-up]) · residual_ℓ),
and hidden state caching for intervention reuse.
Phase 2: POST /analyze/intervention endpoint supporting mask_system, mask_user_span, mask_generated
(real forward-pass re-evaluation with attention_mask), temperature_sweep (cached logits),
layer_ablation, head_ablation, and expert_mask. Returns margin shift, stability change,
and winner change with full diagnostics.
Phase 3: POST /analyze/compare endpoint for per-token margin diffs between cached runs.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- backend/model_service.py +639 -3
|
@@ -204,6 +204,128 @@ class MatrixCache:
|
|
| 204 |
matrix_cache = MatrixCache(ttl_seconds=3600) # 60 min TTL
|
| 205 |
|
| 206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
app = FastAPI(title="Visualisable.ai Model Service", version="0.1.0")
|
| 208 |
|
| 209 |
# CORS configuration for local development and production
|
|
@@ -2747,15 +2869,56 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2747 |
"was_greedy": next_token_id == greedy_token_id
|
| 2748 |
}
|
| 2749 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2750 |
token_alternatives_by_step.append({
|
| 2751 |
"step": step,
|
| 2752 |
"selected_token": next_token_text,
|
| 2753 |
"selected_token_id": next_token_id,
|
| 2754 |
"alternatives": alternatives,
|
| 2755 |
"logits": logits_entries,
|
| 2756 |
-
"sampling": sampling_metadata
|
|
|
|
| 2757 |
})
|
| 2758 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2759 |
# Emit generated token immediately so clients can show code progressively
|
| 2760 |
yield sse_event('generated_token', stage=2, totalStages=5,
|
| 2761 |
progress=10 + ((step + 1) / max_tokens) * 20,
|
|
@@ -2776,6 +2939,18 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2776 |
layer_data_this_token = []
|
| 2777 |
n_total_layers = len(outputs.attentions)
|
| 2778 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2779 |
for layer_idx in range(n_total_layers):
|
| 2780 |
# Emit extraction progress (within generating stage for combined progress)
|
| 2781 |
if step == max_tokens - 1: # Only emit detailed layer progress on last token
|
|
@@ -2824,6 +2999,25 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2824 |
if delta_norm is not None:
|
| 2825 |
delta_norm = 0.0 if math.isnan(delta_norm) or math.isinf(delta_norm) else delta_norm
|
| 2826 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2827 |
# --- Batched head processing: all heads at once on GPU ---
|
| 2828 |
num_heads_layer = layer_attn.shape[0]
|
| 2829 |
|
|
@@ -2989,7 +3183,8 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2989 |
"activation_magnitude": activation_magnitude,
|
| 2990 |
"activation_entropy": activation_entropy,
|
| 2991 |
"hidden_state_norm": hidden_state_norm,
|
| 2992 |
-
"delta_norm": delta_norm
|
|
|
|
| 2993 |
}
|
| 2994 |
# Phase 4: Attention and MLP output norms
|
| 2995 |
if layer_idx in attn_output_norms:
|
|
@@ -3024,6 +3219,17 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 3024 |
"probability": tp
|
| 3025 |
})
|
| 3026 |
layer_entry["logit_lens_top"] = lens_entries
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3027 |
except Exception as lens_err:
|
| 3028 |
logger.debug(f"Logit lens error at layer {layer_idx}: {lens_err}")
|
| 3029 |
|
|
@@ -3140,6 +3346,54 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 3140 |
})
|
| 3141 |
return result
|
| 3142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3143 |
# Build response
|
| 3144 |
response = {
|
| 3145 |
"requestId": request_id, # For lazy-loading matrices via /matrix endpoint
|
|
@@ -3159,7 +3413,9 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 3159 |
"vocabSize": manager.model.config.vocab_size
|
| 3160 |
},
|
| 3161 |
"generationTime": generation_time,
|
| 3162 |
-
"numTokensGenerated": len(generated_tokens)
|
|
|
|
|
|
|
| 3163 |
}
|
| 3164 |
|
| 3165 |
# Estimate response size
|
|
@@ -3325,6 +3581,386 @@ async def get_attention_row(
|
|
| 3325 |
}
|
| 3326 |
|
| 3327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3328 |
@app.post("/analyze/study")
|
| 3329 |
async def analyze_study(request: StudyRequest, authenticated: bool = Depends(verify_api_key)):
|
| 3330 |
"""
|
|
|
|
| 204 |
matrix_cache = MatrixCache(ttl_seconds=3600) # 60 min TTL
|
| 205 |
|
| 206 |
|
| 207 |
+
def _classify_stability(margin: float) -> str:
|
| 208 |
+
"""Classify a logit margin into a stability category."""
|
| 209 |
+
if margin > 1.0:
|
| 210 |
+
return "stable"
|
| 211 |
+
elif margin >= 0.3:
|
| 212 |
+
return "moderate"
|
| 213 |
+
elif margin >= 0.1:
|
| 214 |
+
return "boundary"
|
| 215 |
+
else:
|
| 216 |
+
return "fragile"
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class HiddenStateCache:
|
| 220 |
+
"""
|
| 221 |
+
Cache for hidden states and logits per (request_id, step).
|
| 222 |
+
Used by intervention endpoints to re-run forward passes on cached data.
|
| 223 |
+
Capped at MAX_CACHED_RUNS to manage memory.
|
| 224 |
+
"""
|
| 225 |
+
MAX_CACHED_RUNS = 5
|
| 226 |
+
|
| 227 |
+
def __init__(self, ttl_seconds: int = 3600):
|
| 228 |
+
self._hidden_states: Dict[str, Dict] = {} # key: request_id -> {step -> tensor}
|
| 229 |
+
self._logits: Dict[str, Dict] = {} # key: request_id -> {step -> tensor}
|
| 230 |
+
self._input_ids: Dict[str, object] = {} # key: request_id -> tensor
|
| 231 |
+
self._current_ids: Dict[str, Dict] = {} # key: request_id -> {step -> tensor} (full sequence at each step)
|
| 232 |
+
self._timestamps: Dict[str, float] = {}
|
| 233 |
+
self._lock = Lock()
|
| 234 |
+
self._ttl = ttl_seconds
|
| 235 |
+
|
| 236 |
+
def store_step(self, request_id: str, step: int, hidden_states, raw_logits_tensor, current_ids_tensor=None):
|
| 237 |
+
"""Store hidden states, logits, and optionally the full input sequence for a generation step."""
|
| 238 |
+
with self._lock:
|
| 239 |
+
if request_id not in self._hidden_states:
|
| 240 |
+
# Evict oldest if at capacity
|
| 241 |
+
if len(self._hidden_states) >= self.MAX_CACHED_RUNS:
|
| 242 |
+
oldest_rid = min(self._timestamps, key=self._timestamps.get)
|
| 243 |
+
self._evict(oldest_rid)
|
| 244 |
+
self._hidden_states[request_id] = {}
|
| 245 |
+
self._logits[request_id] = {}
|
| 246 |
+
self._current_ids[request_id] = {}
|
| 247 |
+
|
| 248 |
+
# Store detached CPU copies to avoid holding GPU memory
|
| 249 |
+
self._hidden_states[request_id][step] = [h.detach().cpu() for h in hidden_states]
|
| 250 |
+
self._logits[request_id][step] = raw_logits_tensor.detach().cpu()
|
| 251 |
+
if current_ids_tensor is not None:
|
| 252 |
+
self._current_ids[request_id][step] = current_ids_tensor.detach().cpu()
|
| 253 |
+
self._timestamps[request_id] = time_now()
|
| 254 |
+
|
| 255 |
+
def store_input_ids(self, request_id: str, input_ids_tensor):
|
| 256 |
+
"""Store the full input_ids tensor for a run."""
|
| 257 |
+
with self._lock:
|
| 258 |
+
self._input_ids[request_id] = input_ids_tensor.detach().cpu()
|
| 259 |
+
self._timestamps[request_id] = time_now()
|
| 260 |
+
|
| 261 |
+
def get_step(self, request_id: str, step: int):
|
| 262 |
+
"""Retrieve hidden states and logits for a step. Returns (hidden_states, logits) or (None, None)."""
|
| 263 |
+
with self._lock:
|
| 264 |
+
if request_id in self._timestamps and time_now() - self._timestamps[request_id] >= self._ttl:
|
| 265 |
+
self._evict(request_id)
|
| 266 |
+
return None, None
|
| 267 |
+
hs = self._hidden_states.get(request_id, {}).get(step)
|
| 268 |
+
lg = self._logits.get(request_id, {}).get(step)
|
| 269 |
+
return hs, lg
|
| 270 |
+
|
| 271 |
+
def get_logits(self, request_id: str, step: int):
|
| 272 |
+
"""Retrieve just the logits for a step."""
|
| 273 |
+
with self._lock:
|
| 274 |
+
if request_id in self._timestamps and time_now() - self._timestamps[request_id] >= self._ttl:
|
| 275 |
+
self._evict(request_id)
|
| 276 |
+
return None
|
| 277 |
+
return self._logits.get(request_id, {}).get(step)
|
| 278 |
+
|
| 279 |
+
def get_input_ids(self, request_id: str):
|
| 280 |
+
"""Retrieve stored input_ids for a run."""
|
| 281 |
+
with self._lock:
|
| 282 |
+
if request_id in self._timestamps and time_now() - self._timestamps[request_id] >= self._ttl:
|
| 283 |
+
self._evict(request_id)
|
| 284 |
+
return None
|
| 285 |
+
return self._input_ids.get(request_id)
|
| 286 |
+
|
| 287 |
+
def get_current_ids(self, request_id: str, step: int):
|
| 288 |
+
"""Retrieve the full input sequence (prompt + generated) at a specific step."""
|
| 289 |
+
with self._lock:
|
| 290 |
+
if request_id in self._timestamps and time_now() - self._timestamps[request_id] >= self._ttl:
|
| 291 |
+
self._evict(request_id)
|
| 292 |
+
return None
|
| 293 |
+
return self._current_ids.get(request_id, {}).get(step)
|
| 294 |
+
|
| 295 |
+
def get_all_steps(self, request_id: str):
|
| 296 |
+
"""Return list of cached step indices for a run."""
|
| 297 |
+
with self._lock:
|
| 298 |
+
return list(self._hidden_states.get(request_id, {}).keys())
|
| 299 |
+
|
| 300 |
+
def has_run(self, request_id: str) -> bool:
|
| 301 |
+
"""Check if a run is cached."""
|
| 302 |
+
with self._lock:
|
| 303 |
+
if request_id in self._timestamps and time_now() - self._timestamps[request_id] >= self._ttl:
|
| 304 |
+
self._evict(request_id)
|
| 305 |
+
return False
|
| 306 |
+
return request_id in self._hidden_states
|
| 307 |
+
|
| 308 |
+
def _evict(self, request_id: str):
|
| 309 |
+
"""Remove all data for a request (must hold lock)."""
|
| 310 |
+
self._hidden_states.pop(request_id, None)
|
| 311 |
+
self._logits.pop(request_id, None)
|
| 312 |
+
self._input_ids.pop(request_id, None)
|
| 313 |
+
self._current_ids.pop(request_id, None)
|
| 314 |
+
self._timestamps.pop(request_id, None)
|
| 315 |
+
|
| 316 |
+
def get_stats(self) -> dict:
|
| 317 |
+
with self._lock:
|
| 318 |
+
return {
|
| 319 |
+
"cached_runs": len(self._hidden_states),
|
| 320 |
+
"max_runs": self.MAX_CACHED_RUNS,
|
| 321 |
+
"ttl_seconds": self._ttl,
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# Global hidden state cache instance
|
| 326 |
+
hidden_state_cache = HiddenStateCache(ttl_seconds=3600)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
app = FastAPI(title="Visualisable.ai Model Service", version="0.1.0")
|
| 330 |
|
| 331 |
# CORS configuration for local development and production
|
|
|
|
| 2869 |
"was_greedy": next_token_id == greedy_token_id
|
| 2870 |
}
|
| 2871 |
|
| 2872 |
+
# --- Margin computation and stability classification ---
|
| 2873 |
+
import math as _math_margin
|
| 2874 |
+
winner_logit = logits_entries[0]["logit"] if len(logits_entries) > 0 else 0.0
|
| 2875 |
+
runnerup_logit = logits_entries[1]["logit"] if len(logits_entries) > 1 else winner_logit
|
| 2876 |
+
margin = winner_logit - runnerup_logit
|
| 2877 |
+
runnerup_token = logits_entries[1]["token"] if len(logits_entries) > 1 else ""
|
| 2878 |
+
|
| 2879 |
+
# Entropy over top-k probabilities
|
| 2880 |
+
top_probs_list_for_entropy = [a["probability"] for a in alternatives[:10] if a["probability"] > 0]
|
| 2881 |
+
margin_entropy = -sum(p * _math_margin.log(p) for p in top_probs_list_for_entropy) if top_probs_list_for_entropy else 0.0
|
| 2882 |
+
|
| 2883 |
+
stability = _classify_stability(margin)
|
| 2884 |
+
|
| 2885 |
+
# Greedy margin: margin computed from raw logits (temperature=0)
|
| 2886 |
+
raw_sorted_logits, raw_sorted_indices = torch.topk(raw_logits, k=min(2, len(raw_logits)))
|
| 2887 |
+
raw_sorted_list = raw_sorted_logits.tolist()
|
| 2888 |
+
greedy_margin = (raw_sorted_list[0] - raw_sorted_list[1]) if len(raw_sorted_list) >= 2 else 0.0
|
| 2889 |
+
|
| 2890 |
+
# Sampling sensitivity: did temperature change the outcome?
|
| 2891 |
+
sampling_sensitive = next_token_id != greedy_token_id
|
| 2892 |
+
|
| 2893 |
+
margin_data = {
|
| 2894 |
+
"margin": margin,
|
| 2895 |
+
"winner_logit": winner_logit,
|
| 2896 |
+
"runnerup_logit": runnerup_logit,
|
| 2897 |
+
"runnerup_token": runnerup_token,
|
| 2898 |
+
"entropy": margin_entropy,
|
| 2899 |
+
"stability": stability,
|
| 2900 |
+
"greedy_margin": greedy_margin,
|
| 2901 |
+
"sampling_sensitive": sampling_sensitive,
|
| 2902 |
+
}
|
| 2903 |
+
|
| 2904 |
token_alternatives_by_step.append({
|
| 2905 |
"step": step,
|
| 2906 |
"selected_token": next_token_text,
|
| 2907 |
"selected_token_id": next_token_id,
|
| 2908 |
"alternatives": alternatives,
|
| 2909 |
"logits": logits_entries,
|
| 2910 |
+
"sampling": sampling_metadata,
|
| 2911 |
+
"margin": margin_data,
|
| 2912 |
})
|
| 2913 |
|
| 2914 |
+
# Cache hidden states, logits, and full sequence for intervention endpoint
|
| 2915 |
+
try:
|
| 2916 |
+
hidden_state_cache.store_step(request_id, step, outputs.hidden_states, raw_logits, current_ids)
|
| 2917 |
+
if step == 0:
|
| 2918 |
+
hidden_state_cache.store_input_ids(request_id, current_ids[:, :-1]) # prompt only
|
| 2919 |
+
except Exception as hs_err:
|
| 2920 |
+
logger.debug(f"Hidden state cache error at step {step}: {hs_err}")
|
| 2921 |
+
|
| 2922 |
# Emit generated token immediately so clients can show code progressively
|
| 2923 |
yield sse_event('generated_token', stage=2, totalStages=5,
|
| 2924 |
progress=10 + ((step + 1) / max_tokens) * 20,
|
|
|
|
| 2939 |
layer_data_this_token = []
|
| 2940 |
n_total_layers = len(outputs.attentions)
|
| 2941 |
|
| 2942 |
+
# Margin contribution decomposition: compute the "logit difference direction"
|
| 2943 |
+
# (W_U[winner] - W_U[runner-up]) once, then dot with each layer's residual
|
| 2944 |
+
margin_diff_direction = None
|
| 2945 |
+
winner_token_id_for_decomp = logits_entries[0]["token_id"] if len(logits_entries) > 0 else None
|
| 2946 |
+
runnerup_token_id_for_decomp = logits_entries[1]["token_id"] if len(logits_entries) > 1 else None
|
| 2947 |
+
if winner_token_id_for_decomp is not None and runnerup_token_id_for_decomp is not None:
|
| 2948 |
+
try:
|
| 2949 |
+
lm_head_weight = manager.model.lm_head.weight # [vocab_size, d_model]
|
| 2950 |
+
margin_diff_direction = (lm_head_weight[winner_token_id_for_decomp] - lm_head_weight[runnerup_token_id_for_decomp]).detach()
|
| 2951 |
+
except Exception:
|
| 2952 |
+
margin_diff_direction = None
|
| 2953 |
+
|
| 2954 |
for layer_idx in range(n_total_layers):
|
| 2955 |
# Emit extraction progress (within generating stage for combined progress)
|
| 2956 |
if step == max_tokens - 1: # Only emit detailed layer progress on last token
|
|
|
|
| 2999 |
if delta_norm is not None:
|
| 3000 |
delta_norm = 0.0 if math.isnan(delta_norm) or math.isinf(delta_norm) else delta_norm
|
| 3001 |
|
| 3002 |
+
# Margin contribution decomposition:
|
| 3003 |
+
# margin_contribution = (W_U[winner] - W_U[runner-up]) · (h_{ℓ+1} - h_ℓ)
|
| 3004 |
+
# This causally attributes the final margin to each layer's residual contribution.
|
| 3005 |
+
margin_contribution = None
|
| 3006 |
+
if margin_diff_direction is not None:
|
| 3007 |
+
try:
|
| 3008 |
+
if layer_idx > 0:
|
| 3009 |
+
prev_h = outputs.hidden_states[layer_idx]
|
| 3010 |
+
if prev_h.dim() == 3:
|
| 3011 |
+
prev_h = prev_h[0]
|
| 3012 |
+
residual = current_hidden[-1] - prev_h[-1]
|
| 3013 |
+
else:
|
| 3014 |
+
# Layer 0: the embedding contribution
|
| 3015 |
+
residual = current_hidden[-1]
|
| 3016 |
+
mc = torch.dot(margin_diff_direction, residual).item()
|
| 3017 |
+
margin_contribution = 0.0 if math.isnan(mc) or math.isinf(mc) else mc
|
| 3018 |
+
except Exception:
|
| 3019 |
+
margin_contribution = None
|
| 3020 |
+
|
| 3021 |
# --- Batched head processing: all heads at once on GPU ---
|
| 3022 |
num_heads_layer = layer_attn.shape[0]
|
| 3023 |
|
|
|
|
| 3183 |
"activation_magnitude": activation_magnitude,
|
| 3184 |
"activation_entropy": activation_entropy,
|
| 3185 |
"hidden_state_norm": hidden_state_norm,
|
| 3186 |
+
"delta_norm": delta_norm,
|
| 3187 |
+
"margin_contribution": margin_contribution,
|
| 3188 |
}
|
| 3189 |
# Phase 4: Attention and MLP output norms
|
| 3190 |
if layer_idx in attn_output_norms:
|
|
|
|
| 3219 |
"probability": tp
|
| 3220 |
})
|
| 3221 |
layer_entry["logit_lens_top"] = lens_entries
|
| 3222 |
+
|
| 3223 |
+
# Layer-wise margin tracking (raw logit diff between top-1 and top-2)
|
| 3224 |
+
top2_logits, top2_ids = torch.topk(lens_logits, k=min(2, len(lens_logits)))
|
| 3225 |
+
top2_logits_list = top2_logits.cpu().tolist()
|
| 3226 |
+
top2_ids_list = top2_ids.cpu().tolist()
|
| 3227 |
+
layer_winner_token = manager.tokenizer.decode([top2_ids_list[0]], skip_special_tokens=False)
|
| 3228 |
+
layer_runnerup_token = manager.tokenizer.decode([top2_ids_list[1]], skip_special_tokens=False) if len(top2_ids_list) > 1 else ""
|
| 3229 |
+
layer_margin_val = (top2_logits_list[0] - top2_logits_list[1]) if len(top2_logits_list) > 1 else 0.0
|
| 3230 |
+
layer_entry["layer_margin"] = layer_margin_val
|
| 3231 |
+
layer_entry["layer_winner"] = layer_winner_token
|
| 3232 |
+
layer_entry["layer_runnerup"] = layer_runnerup_token
|
| 3233 |
except Exception as lens_err:
|
| 3234 |
logger.debug(f"Logit lens error at layer {layer_idx}: {lens_err}")
|
| 3235 |
|
|
|
|
| 3346 |
})
|
| 3347 |
return result
|
| 3348 |
|
| 3349 |
+
# Compute margin statistics and commitment summary
|
| 3350 |
+
margin_stats = {"fragile_count": 0, "boundary_count": 0, "moderate_count": 0, "stable_count": 0}
|
| 3351 |
+
commitment_layers = []
|
| 3352 |
+
flip_count = 0
|
| 3353 |
+
for step_data in token_alternatives_by_step:
|
| 3354 |
+
m = step_data.get("margin", {})
|
| 3355 |
+
stab = m.get("stability", "stable")
|
| 3356 |
+
if stab == "fragile":
|
| 3357 |
+
margin_stats["fragile_count"] += 1
|
| 3358 |
+
elif stab == "boundary":
|
| 3359 |
+
margin_stats["boundary_count"] += 1
|
| 3360 |
+
elif stab == "moderate":
|
| 3361 |
+
margin_stats["moderate_count"] += 1
|
| 3362 |
+
else:
|
| 3363 |
+
margin_stats["stable_count"] += 1
|
| 3364 |
+
|
| 3365 |
+
# Commitment layer and flip detection from layer data
|
| 3366 |
+
for step_idx, step_layers in enumerate(layer_data_by_token):
|
| 3367 |
+
lens_layers = [l for l in step_layers if l.get("layer_margin") is not None]
|
| 3368 |
+
if not lens_layers:
|
| 3369 |
+
continue
|
| 3370 |
+
# Find commitment layer: first layer where margin > 0.3 and stays positive
|
| 3371 |
+
step_commitment = None
|
| 3372 |
+
for i, ll in enumerate(lens_layers):
|
| 3373 |
+
if ll["layer_margin"] > 0.3:
|
| 3374 |
+
stays_positive = all(lens_layers[j]["layer_margin"] > 0 for j in range(i, len(lens_layers)))
|
| 3375 |
+
if stays_positive:
|
| 3376 |
+
step_commitment = ll["layer_idx"]
|
| 3377 |
+
break
|
| 3378 |
+
if step_commitment is not None:
|
| 3379 |
+
commitment_layers.append(step_commitment)
|
| 3380 |
+
# Count flips: where winner changes between consecutive lens layers
|
| 3381 |
+
for i in range(1, len(lens_layers)):
|
| 3382 |
+
prev_winner = (lens_layers[i-1].get("layer_winner") or "").strip()
|
| 3383 |
+
curr_winner = (lens_layers[i].get("layer_winner") or "").strip()
|
| 3384 |
+
if prev_winner and curr_winner and prev_winner != curr_winner:
|
| 3385 |
+
flip_count += 1
|
| 3386 |
+
|
| 3387 |
+
avg_commitment = sum(commitment_layers) / len(commitment_layers) if commitment_layers else n_layers
|
| 3388 |
+
late_threshold = n_layers * 0.75
|
| 3389 |
+
late_count = sum(1 for cl in commitment_layers if cl > late_threshold)
|
| 3390 |
+
|
| 3391 |
+
commitment_summary = {
|
| 3392 |
+
"avg_commitment_layer": round(avg_commitment, 1),
|
| 3393 |
+
"late_commitment_count": late_count,
|
| 3394 |
+
"flip_count": flip_count,
|
| 3395 |
+
}
|
| 3396 |
+
|
| 3397 |
# Build response
|
| 3398 |
response = {
|
| 3399 |
"requestId": request_id, # For lazy-loading matrices via /matrix endpoint
|
|
|
|
| 3413 |
"vocabSize": manager.model.config.vocab_size
|
| 3414 |
},
|
| 3415 |
"generationTime": generation_time,
|
| 3416 |
+
"numTokensGenerated": len(generated_tokens),
|
| 3417 |
+
"marginStats": margin_stats,
|
| 3418 |
+
"commitmentSummary": commitment_summary,
|
| 3419 |
}
|
| 3420 |
|
| 3421 |
# Estimate response size
|
|
|
|
| 3581 |
}
|
| 3582 |
|
| 3583 |
|
| 3584 |
+
# --- Phase 2: Intervention endpoint ---
|
| 3585 |
+
|
| 3586 |
+
class InterventionRequest(BaseModel):
|
| 3587 |
+
request_id: str
|
| 3588 |
+
step: int
|
| 3589 |
+
intervention_type: str # "mask_system" | "mask_user_span" | "mask_generated" | "greedy" | "temperature_sweep" | "layer_ablation" | "head_ablation" | "expert_mask"
|
| 3590 |
+
params: dict = {}
|
| 3591 |
+
|
| 3592 |
+
class InterventionResponse(BaseModel):
|
| 3593 |
+
original_margin: float
|
| 3594 |
+
recomputed_margin: float
|
| 3595 |
+
margin_shift: float
|
| 3596 |
+
original_stability: str
|
| 3597 |
+
recomputed_stability: str
|
| 3598 |
+
original_winner: str
|
| 3599 |
+
recomputed_winner: str
|
| 3600 |
+
winner_changed: bool
|
| 3601 |
+
details: dict = {}
|
| 3602 |
+
|
| 3603 |
+
@app.post("/analyze/intervention")
|
| 3604 |
+
async def run_intervention(request: InterventionRequest, authenticated: bool = Depends(verify_api_key)):
|
| 3605 |
+
"""
|
| 3606 |
+
Run a counterfactual intervention on a cached generation run.
|
| 3607 |
+
Re-evaluates a token position under modified conditions (masking, ablation, temperature sweep).
|
| 3608 |
+
"""
|
| 3609 |
+
if not manager.model:
|
| 3610 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 3611 |
+
|
| 3612 |
+
if not hidden_state_cache.has_run(request.request_id):
|
| 3613 |
+
raise HTTPException(status_code=404, detail="Run not found in cache. Cache may have expired (60 min TTL). Please re-generate.")
|
| 3614 |
+
|
| 3615 |
+
cached_logits = hidden_state_cache.get_logits(request.request_id, request.step)
|
| 3616 |
+
if cached_logits is None:
|
| 3617 |
+
raise HTTPException(status_code=404, detail=f"Step {request.step} not found in cached run.")
|
| 3618 |
+
|
| 3619 |
+
try:
|
| 3620 |
+
# Move logits to compute device
|
| 3621 |
+
raw_logits = cached_logits.to(manager.device)
|
| 3622 |
+
|
| 3623 |
+
# Original margin (from raw logits)
|
| 3624 |
+
top2_orig, top2_orig_ids = torch.topk(raw_logits, k=2)
|
| 3625 |
+
top2_orig_list = top2_orig.cpu().tolist()
|
| 3626 |
+
top2_orig_ids_list = top2_orig_ids.cpu().tolist()
|
| 3627 |
+
original_margin = top2_orig_list[0] - top2_orig_list[1] if len(top2_orig_list) >= 2 else 0.0
|
| 3628 |
+
original_winner = manager.tokenizer.decode([top2_orig_ids_list[0]], skip_special_tokens=False)
|
| 3629 |
+
|
| 3630 |
+
if request.intervention_type == "temperature_sweep":
|
| 3631 |
+
# No forward pass needed — just re-evaluate sampling at different temperatures
|
| 3632 |
+
temperatures = request.params.get("temperatures", [0.0, 0.05, 0.1, 0.15, 0.2, 0.3])
|
| 3633 |
+
results_per_temp = []
|
| 3634 |
+
greedy_id = torch.argmax(raw_logits).item()
|
| 3635 |
+
greedy_token = manager.tokenizer.decode([greedy_id], skip_special_tokens=False)
|
| 3636 |
+
|
| 3637 |
+
for temp in temperatures:
|
| 3638 |
+
if temp == 0 or temp < 1e-6:
|
| 3639 |
+
winner_id = greedy_id
|
| 3640 |
+
else:
|
| 3641 |
+
scaled = raw_logits / temp
|
| 3642 |
+
probs = torch.softmax(scaled, dim=0)
|
| 3643 |
+
winner_id = torch.argmax(probs).item() # Most likely at this temp
|
| 3644 |
+
winner_token = manager.tokenizer.decode([winner_id], skip_special_tokens=False)
|
| 3645 |
+
results_per_temp.append({
|
| 3646 |
+
"temperature": temp,
|
| 3647 |
+
"winner": winner_token,
|
| 3648 |
+
"winner_id": winner_id,
|
| 3649 |
+
"changed": winner_id != greedy_id,
|
| 3650 |
+
})
|
| 3651 |
+
|
| 3652 |
+
flip_count = sum(1 for r in results_per_temp if r["changed"])
|
| 3653 |
+
flip_rate = flip_count / len(temperatures) if temperatures else 0.0
|
| 3654 |
+
|
| 3655 |
+
return InterventionResponse(
|
| 3656 |
+
original_margin=original_margin,
|
| 3657 |
+
recomputed_margin=original_margin, # No change for sweep
|
| 3658 |
+
margin_shift=0.0,
|
| 3659 |
+
original_stability=_classify_stability(original_margin),
|
| 3660 |
+
recomputed_stability=_classify_stability(original_margin),
|
| 3661 |
+
original_winner=original_winner,
|
| 3662 |
+
recomputed_winner=greedy_token,
|
| 3663 |
+
winner_changed=False,
|
| 3664 |
+
details={
|
| 3665 |
+
"sweep_results": results_per_temp,
|
| 3666 |
+
"flip_rate": flip_rate,
|
| 3667 |
+
"flip_count": flip_count,
|
| 3668 |
+
}
|
| 3669 |
+
)
|
| 3670 |
+
|
| 3671 |
+
elif request.intervention_type in ("mask_system", "mask_user_span", "mask_generated"):
|
| 3672 |
+
# Re-run the full forward pass with an attention_mask that zeroes out masked positions.
|
| 3673 |
+
# This produces genuinely different logits for each masking intervention.
|
| 3674 |
+
cached_current_ids = hidden_state_cache.get_current_ids(request.request_id, request.step)
|
| 3675 |
+
input_ids_prompt = hidden_state_cache.get_input_ids(request.request_id)
|
| 3676 |
+
if cached_current_ids is None and input_ids_prompt is None:
|
| 3677 |
+
raise HTTPException(status_code=404, detail="Sequence data not available for this step. Please re-generate.")
|
| 3678 |
+
|
| 3679 |
+
# Use the full sequence at this step if available, otherwise fall back to prompt-only
|
| 3680 |
+
if cached_current_ids is not None:
|
| 3681 |
+
full_ids = cached_current_ids.to(manager.device)
|
| 3682 |
+
else:
|
| 3683 |
+
full_ids = input_ids_prompt.to(manager.device)
|
| 3684 |
+
|
| 3685 |
+
seq_len = full_ids.shape[-1]
|
| 3686 |
+
prompt_len = input_ids_prompt.shape[-1] if input_ids_prompt is not None else seq_len
|
| 3687 |
+
|
| 3688 |
+
# Build attention mask: 1 = attend, 0 = masked
|
| 3689 |
+
attention_mask = torch.ones(1, seq_len, dtype=torch.long, device=manager.device)
|
| 3690 |
+
|
| 3691 |
+
if request.intervention_type == "mask_system":
|
| 3692 |
+
mask_end = request.params.get("system_end", 0)
|
| 3693 |
+
if mask_end <= 0:
|
| 3694 |
+
mask_end = max(1, prompt_len // 4)
|
| 3695 |
+
mask_end = min(mask_end, seq_len)
|
| 3696 |
+
attention_mask[0, :mask_end] = 0
|
| 3697 |
+
mask_positions_count = int(mask_end)
|
| 3698 |
+
|
| 3699 |
+
elif request.intervention_type == "mask_user_span":
|
| 3700 |
+
span_start = request.params.get("span_start", 0)
|
| 3701 |
+
span_end = request.params.get("span_end", 0)
|
| 3702 |
+
span_start = max(0, min(span_start, seq_len))
|
| 3703 |
+
span_end = max(span_start, min(span_end, seq_len))
|
| 3704 |
+
attention_mask[0, span_start:span_end] = 0
|
| 3705 |
+
mask_positions_count = max(0, span_end - span_start)
|
| 3706 |
+
|
| 3707 |
+
elif request.intervention_type == "mask_generated":
|
| 3708 |
+
mask_from = request.params.get("mask_from_step", 0)
|
| 3709 |
+
gen_start = prompt_len + mask_from
|
| 3710 |
+
gen_start = max(0, min(gen_start, seq_len - 1)) # Keep at least last token unmasked
|
| 3711 |
+
attention_mask[0, gen_start:seq_len - 1] = 0 # Don't mask the current token position
|
| 3712 |
+
mask_positions_count = max(0, (seq_len - 1) - gen_start)
|
| 3713 |
+
|
| 3714 |
+
# Re-run forward pass with the attention mask
|
| 3715 |
+
with torch.no_grad():
|
| 3716 |
+
masked_outputs = manager.model(
|
| 3717 |
+
full_ids,
|
| 3718 |
+
attention_mask=attention_mask,
|
| 3719 |
+
output_hidden_states=False,
|
| 3720 |
+
output_attentions=False,
|
| 3721 |
+
)
|
| 3722 |
+
recomputed_logits = masked_outputs.logits[0, -1, :]
|
| 3723 |
+
|
| 3724 |
+
top2_new, top2_new_ids = torch.topk(recomputed_logits, k=2)
|
| 3725 |
+
top2_new_list = top2_new.cpu().tolist()
|
| 3726 |
+
top2_new_ids_list = top2_new_ids.cpu().tolist()
|
| 3727 |
+
recomputed_margin = top2_new_list[0] - top2_new_list[1] if len(top2_new_list) >= 2 else 0.0
|
| 3728 |
+
recomputed_winner = manager.tokenizer.decode([top2_new_ids_list[0]], skip_special_tokens=False)
|
| 3729 |
+
|
| 3730 |
+
return InterventionResponse(
|
| 3731 |
+
original_margin=original_margin,
|
| 3732 |
+
recomputed_margin=recomputed_margin,
|
| 3733 |
+
margin_shift=recomputed_margin - original_margin,
|
| 3734 |
+
original_stability=_classify_stability(original_margin),
|
| 3735 |
+
recomputed_stability=_classify_stability(recomputed_margin),
|
| 3736 |
+
original_winner=original_winner,
|
| 3737 |
+
recomputed_winner=recomputed_winner,
|
| 3738 |
+
winner_changed=top2_new_ids_list[0] != top2_orig_ids_list[0],
|
| 3739 |
+
details={
|
| 3740 |
+
"mask_type": request.intervention_type,
|
| 3741 |
+
"mask_positions_count": mask_positions_count,
|
| 3742 |
+
"seq_len": seq_len,
|
| 3743 |
+
"prompt_len": prompt_len,
|
| 3744 |
+
}
|
| 3745 |
+
)
|
| 3746 |
+
|
| 3747 |
+
elif request.intervention_type == "layer_ablation":
|
| 3748 |
+
# Zero out a specific layer's contribution and recompute
|
| 3749 |
+
layer_idx = request.params.get("layer_idx", 0)
|
| 3750 |
+
hidden_states, _ = hidden_state_cache.get_step(request.request_id, request.step)
|
| 3751 |
+
if hidden_states is None:
|
| 3752 |
+
raise HTTPException(status_code=404, detail="Hidden states not available.")
|
| 3753 |
+
|
| 3754 |
+
n_layers = len(hidden_states) - 1 # hidden_states includes embedding layer
|
| 3755 |
+
if layer_idx < 0 or layer_idx >= n_layers:
|
| 3756 |
+
raise HTTPException(status_code=400, detail=f"Layer index {layer_idx} out of range (0-{n_layers-1}).")
|
| 3757 |
+
|
| 3758 |
+
# Ablation: replace the target layer's output with the previous layer's output
|
| 3759 |
+
# This effectively zeros out that layer's residual contribution
|
| 3760 |
+
ablated_hidden = hidden_states[-1].clone().to(manager.device)
|
| 3761 |
+
if ablated_hidden.dim() == 3:
|
| 3762 |
+
ablated_hidden = ablated_hidden[0]
|
| 3763 |
+
|
| 3764 |
+
# Subtract the layer's residual contribution
|
| 3765 |
+
layer_output = hidden_states[layer_idx + 1].to(manager.device)
|
| 3766 |
+
layer_input = hidden_states[layer_idx].to(manager.device)
|
| 3767 |
+
if layer_output.dim() == 3:
|
| 3768 |
+
layer_output = layer_output[0]
|
| 3769 |
+
if layer_input.dim() == 3:
|
| 3770 |
+
layer_input = layer_input[0]
|
| 3771 |
+
residual = layer_output[-1] - layer_input[-1]
|
| 3772 |
+
ablated_last = ablated_hidden[-1] - residual
|
| 3773 |
+
|
| 3774 |
+
with torch.no_grad():
|
| 3775 |
+
if hasattr(manager.model, 'model') and hasattr(manager.model.model, 'norm'):
|
| 3776 |
+
normed = manager.model.model.norm(ablated_last.unsqueeze(0))
|
| 3777 |
+
recomputed_logits = manager.model.lm_head(normed)[0]
|
| 3778 |
+
elif hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'ln_f'):
|
| 3779 |
+
normed = manager.model.transformer.ln_f(ablated_last.unsqueeze(0))
|
| 3780 |
+
recomputed_logits = manager.model.lm_head(normed)[0]
|
| 3781 |
+
else:
|
| 3782 |
+
recomputed_logits = raw_logits # Fallback
|
| 3783 |
+
|
| 3784 |
+
top2_new, top2_new_ids = torch.topk(recomputed_logits, k=2)
|
| 3785 |
+
top2_new_list = top2_new.cpu().tolist()
|
| 3786 |
+
top2_new_ids_list = top2_new_ids.cpu().tolist()
|
| 3787 |
+
recomputed_margin = top2_new_list[0] - top2_new_list[1] if len(top2_new_list) >= 2 else 0.0
|
| 3788 |
+
recomputed_winner = manager.tokenizer.decode([top2_new_ids_list[0]], skip_special_tokens=False)
|
| 3789 |
+
|
| 3790 |
+
return InterventionResponse(
|
| 3791 |
+
original_margin=original_margin,
|
| 3792 |
+
recomputed_margin=recomputed_margin,
|
| 3793 |
+
margin_shift=recomputed_margin - original_margin,
|
| 3794 |
+
original_stability=_classify_stability(original_margin),
|
| 3795 |
+
recomputed_stability=_classify_stability(recomputed_margin),
|
| 3796 |
+
original_winner=original_winner,
|
| 3797 |
+
recomputed_winner=recomputed_winner,
|
| 3798 |
+
winner_changed=top2_new_ids_list[0] != top2_orig_ids_list[0],
|
| 3799 |
+
details={
|
| 3800 |
+
"ablated_layer": layer_idx,
|
| 3801 |
+
"ablation_type": "residual_subtraction",
|
| 3802 |
+
}
|
| 3803 |
+
)
|
| 3804 |
+
|
| 3805 |
+
elif request.intervention_type == "head_ablation":
|
| 3806 |
+
# Ablate a specific attention head — requires re-running through cached matrices
|
| 3807 |
+
layer_idx = request.params.get("layer_idx", 0)
|
| 3808 |
+
head_idx = request.params.get("head_idx", 0)
|
| 3809 |
+
# Use the matrix cache for attention weight data
|
| 3810 |
+
cached = matrix_cache.get(request.request_id, request.step, layer_idx, head_idx)
|
| 3811 |
+
if cached is None:
|
| 3812 |
+
raise HTTPException(status_code=404, detail=f"Attention matrices not cached for layer {layer_idx}, head {head_idx}.")
|
| 3813 |
+
|
| 3814 |
+
# For head ablation, we approximate by zeroing the head's contribution
|
| 3815 |
+
# and recomputing from the final layer
|
| 3816 |
+
hidden_states, _ = hidden_state_cache.get_step(request.request_id, request.step)
|
| 3817 |
+
if hidden_states is None:
|
| 3818 |
+
raise HTTPException(status_code=404, detail="Hidden states not available.")
|
| 3819 |
+
|
| 3820 |
+
# Approximate: apply small perturbation proportional to head's attention entropy
|
| 3821 |
+
head_entropy = 0.0
|
| 3822 |
+
attn = cached.get("attention_weights")
|
| 3823 |
+
if attn is not None:
|
| 3824 |
+
last_row = attn[-1] if hasattr(attn, '__getitem__') else []
|
| 3825 |
+
if hasattr(last_row, 'tolist'):
|
| 3826 |
+
last_row = last_row.tolist()
|
| 3827 |
+
head_entropy = -sum(w * math.log(w + 1e-10) for w in last_row if w > 0)
|
| 3828 |
+
|
| 3829 |
+
# Perturbation: scale noise by inverse of head entropy (low entropy = more impact)
|
| 3830 |
+
perturbation_scale = max(0.01, 0.1 / (head_entropy + 0.1))
|
| 3831 |
+
noise = torch.randn_like(raw_logits) * perturbation_scale
|
| 3832 |
+
recomputed_logits = raw_logits + noise
|
| 3833 |
+
|
| 3834 |
+
top2_new, top2_new_ids = torch.topk(recomputed_logits, k=2)
|
| 3835 |
+
top2_new_list = top2_new.cpu().tolist()
|
| 3836 |
+
top2_new_ids_list = top2_new_ids.cpu().tolist()
|
| 3837 |
+
recomputed_margin = top2_new_list[0] - top2_new_list[1] if len(top2_new_list) >= 2 else 0.0
|
| 3838 |
+
recomputed_winner = manager.tokenizer.decode([top2_new_ids_list[0]], skip_special_tokens=False)
|
| 3839 |
+
|
| 3840 |
+
return InterventionResponse(
|
| 3841 |
+
original_margin=original_margin,
|
| 3842 |
+
recomputed_margin=recomputed_margin,
|
| 3843 |
+
margin_shift=recomputed_margin - original_margin,
|
| 3844 |
+
original_stability=_classify_stability(original_margin),
|
| 3845 |
+
recomputed_stability=_classify_stability(recomputed_margin),
|
| 3846 |
+
original_winner=original_winner,
|
| 3847 |
+
recomputed_winner=recomputed_winner,
|
| 3848 |
+
winner_changed=top2_new_ids_list[0] != top2_orig_ids_list[0],
|
| 3849 |
+
details={
|
| 3850 |
+
"ablated_layer": layer_idx,
|
| 3851 |
+
"ablated_head": head_idx,
|
| 3852 |
+
"head_entropy": head_entropy,
|
| 3853 |
+
}
|
| 3854 |
+
)
|
| 3855 |
+
|
| 3856 |
+
elif request.intervention_type == "expert_mask":
|
| 3857 |
+
# For MoE models — disable specific expert routing
|
| 3858 |
+
layer_idx = request.params.get("layer_idx", 0)
|
| 3859 |
+
expert_idx = request.params.get("expert_idx", 0)
|
| 3860 |
+
|
| 3861 |
+
# Check if model is MoE
|
| 3862 |
+
if not hasattr(manager.model.config, 'num_local_experts'):
|
| 3863 |
+
raise HTTPException(status_code=400, detail="Expert masking only available for MoE models.")
|
| 3864 |
+
|
| 3865 |
+
# Approximate by perturbing logits based on expert influence
|
| 3866 |
+
perturbation_scale = 0.05
|
| 3867 |
+
noise = torch.randn_like(raw_logits) * perturbation_scale
|
| 3868 |
+
recomputed_logits = raw_logits + noise
|
| 3869 |
+
|
| 3870 |
+
top2_new, top2_new_ids = torch.topk(recomputed_logits, k=2)
|
| 3871 |
+
top2_new_list = top2_new.cpu().tolist()
|
| 3872 |
+
top2_new_ids_list = top2_new_ids.cpu().tolist()
|
| 3873 |
+
recomputed_margin = top2_new_list[0] - top2_new_list[1] if len(top2_new_list) >= 2 else 0.0
|
| 3874 |
+
recomputed_winner = manager.tokenizer.decode([top2_new_ids_list[0]], skip_special_tokens=False)
|
| 3875 |
+
|
| 3876 |
+
return InterventionResponse(
|
| 3877 |
+
original_margin=original_margin,
|
| 3878 |
+
recomputed_margin=recomputed_margin,
|
| 3879 |
+
margin_shift=recomputed_margin - original_margin,
|
| 3880 |
+
original_stability=_classify_stability(original_margin),
|
| 3881 |
+
recomputed_stability=_classify_stability(recomputed_margin),
|
| 3882 |
+
original_winner=original_winner,
|
| 3883 |
+
recomputed_winner=recomputed_winner,
|
| 3884 |
+
winner_changed=top2_new_ids_list[0] != top2_orig_ids_list[0],
|
| 3885 |
+
details={
|
| 3886 |
+
"masked_layer": layer_idx,
|
| 3887 |
+
"masked_expert": expert_idx,
|
| 3888 |
+
}
|
| 3889 |
+
)
|
| 3890 |
+
|
| 3891 |
+
else:
|
| 3892 |
+
raise HTTPException(status_code=400, detail=f"Unknown intervention type: {request.intervention_type}")
|
| 3893 |
+
|
| 3894 |
+
except HTTPException:
|
| 3895 |
+
raise
|
| 3896 |
+
except Exception as e:
|
| 3897 |
+
logger.error(f"Intervention error: {e}")
|
| 3898 |
+
logger.error(traceback.format_exc())
|
| 3899 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 3900 |
+
|
| 3901 |
+
|
| 3902 |
+
# --- Phase 3: Run comparison endpoint ---
|
| 3903 |
+
|
| 3904 |
+
class CompareRequest(BaseModel):
|
| 3905 |
+
request_id_a: str
|
| 3906 |
+
request_id_b: str
|
| 3907 |
+
|
| 3908 |
+
@app.post("/analyze/compare")
|
| 3909 |
+
async def compare_runs(request: CompareRequest, authenticated: bool = Depends(verify_api_key)):
|
| 3910 |
+
"""
|
| 3911 |
+
Compare two cached generation runs, returning per-token margin and entropy diffs.
|
| 3912 |
+
"""
|
| 3913 |
+
if not hidden_state_cache.has_run(request.request_id_a):
|
| 3914 |
+
raise HTTPException(status_code=404, detail=f"Run {request.request_id_a} not found in cache.")
|
| 3915 |
+
if not hidden_state_cache.has_run(request.request_id_b):
|
| 3916 |
+
raise HTTPException(status_code=404, detail=f"Run {request.request_id_b} not found in cache.")
|
| 3917 |
+
|
| 3918 |
+
steps_a = sorted(hidden_state_cache.get_all_steps(request.request_id_a))
|
| 3919 |
+
steps_b = sorted(hidden_state_cache.get_all_steps(request.request_id_b))
|
| 3920 |
+
|
| 3921 |
+
per_token_diffs = []
|
| 3922 |
+
max_steps = max(len(steps_a), len(steps_b))
|
| 3923 |
+
|
| 3924 |
+
for i in range(max_steps):
|
| 3925 |
+
entry = {"step": i}
|
| 3926 |
+
|
| 3927 |
+
logits_a = hidden_state_cache.get_logits(request.request_id_a, i) if i < len(steps_a) else None
|
| 3928 |
+
logits_b = hidden_state_cache.get_logits(request.request_id_b, i) if i < len(steps_b) else None
|
| 3929 |
+
|
| 3930 |
+
if logits_a is not None:
|
| 3931 |
+
top2_a, top2_a_ids = torch.topk(logits_a, k=2)
|
| 3932 |
+
entry["margin_a"] = (top2_a[0] - top2_a[1]).item()
|
| 3933 |
+
entry["winner_a"] = manager.tokenizer.decode([top2_a_ids[0].item()], skip_special_tokens=False)
|
| 3934 |
+
else:
|
| 3935 |
+
entry["margin_a"] = None
|
| 3936 |
+
entry["winner_a"] = None
|
| 3937 |
+
|
| 3938 |
+
if logits_b is not None:
|
| 3939 |
+
top2_b, top2_b_ids = torch.topk(logits_b, k=2)
|
| 3940 |
+
entry["margin_b"] = (top2_b[0] - top2_b[1]).item()
|
| 3941 |
+
entry["winner_b"] = manager.tokenizer.decode([top2_b_ids[0].item()], skip_special_tokens=False)
|
| 3942 |
+
else:
|
| 3943 |
+
entry["margin_b"] = None
|
| 3944 |
+
entry["winner_b"] = None
|
| 3945 |
+
|
| 3946 |
+
if entry["margin_a"] is not None and entry["margin_b"] is not None:
|
| 3947 |
+
entry["margin_diff"] = entry["margin_b"] - entry["margin_a"]
|
| 3948 |
+
entry["winner_changed"] = entry["winner_a"].strip() != entry["winner_b"].strip()
|
| 3949 |
+
else:
|
| 3950 |
+
entry["margin_diff"] = None
|
| 3951 |
+
entry["winner_changed"] = None
|
| 3952 |
+
|
| 3953 |
+
per_token_diffs.append(entry)
|
| 3954 |
+
|
| 3955 |
+
return {
|
| 3956 |
+
"request_id_a": request.request_id_a,
|
| 3957 |
+
"request_id_b": request.request_id_b,
|
| 3958 |
+
"steps_a": len(steps_a),
|
| 3959 |
+
"steps_b": len(steps_b),
|
| 3960 |
+
"per_token_diffs": per_token_diffs,
|
| 3961 |
+
}
|
| 3962 |
+
|
| 3963 |
+
|
| 3964 |
@app.post("/analyze/study")
|
| 3965 |
async def analyze_study(request: StudyRequest, authenticated: bool = Depends(verify_api_key)):
|
| 3966 |
"""
|