gary-boon Claude Opus 4.6 commited on
Commit
121a2d9
·
1 Parent(s): 54d9b6e

Add margin-based decision analysis, interventional counterfactuals, and run comparison (v3.0)

Browse files

Phase 1: Per-token margin computation with stability classification (stable/moderate/boundary/fragile),
layer-wise margin tracking via logit lens, commitment layer detection, flip event detection,
causal margin contribution decomposition per layer ((W_U[winner] - W_U[runner-up]) · residual_ℓ),
and hidden state caching for intervention reuse.

Phase 2: POST /analyze/intervention endpoint supporting mask_system, mask_user_span, mask_generated
(real forward-pass re-evaluation with attention_mask), temperature_sweep (cached logits),
layer_ablation, head_ablation, and expert_mask. Returns margin shift, stability change,
and winner change with full diagnostics.

Phase 3: POST /analyze/compare endpoint for per-token margin diffs between cached runs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. backend/model_service.py +639 -3
backend/model_service.py CHANGED
@@ -204,6 +204,128 @@ class MatrixCache:
204
  matrix_cache = MatrixCache(ttl_seconds=3600) # 60 min TTL
205
 
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  app = FastAPI(title="Visualisable.ai Model Service", version="0.1.0")
208
 
209
  # CORS configuration for local development and production
@@ -2747,15 +2869,56 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2747
  "was_greedy": next_token_id == greedy_token_id
2748
  }
2749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2750
  token_alternatives_by_step.append({
2751
  "step": step,
2752
  "selected_token": next_token_text,
2753
  "selected_token_id": next_token_id,
2754
  "alternatives": alternatives,
2755
  "logits": logits_entries,
2756
- "sampling": sampling_metadata
 
2757
  })
2758
 
 
 
 
 
 
 
 
 
2759
  # Emit generated token immediately so clients can show code progressively
2760
  yield sse_event('generated_token', stage=2, totalStages=5,
2761
  progress=10 + ((step + 1) / max_tokens) * 20,
@@ -2776,6 +2939,18 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2776
  layer_data_this_token = []
2777
  n_total_layers = len(outputs.attentions)
2778
 
 
 
 
 
 
 
 
 
 
 
 
 
2779
  for layer_idx in range(n_total_layers):
2780
  # Emit extraction progress (within generating stage for combined progress)
2781
  if step == max_tokens - 1: # Only emit detailed layer progress on last token
@@ -2824,6 +2999,25 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2824
  if delta_norm is not None:
2825
  delta_norm = 0.0 if math.isnan(delta_norm) or math.isinf(delta_norm) else delta_norm
2826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2827
  # --- Batched head processing: all heads at once on GPU ---
2828
  num_heads_layer = layer_attn.shape[0]
2829
 
@@ -2989,7 +3183,8 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2989
  "activation_magnitude": activation_magnitude,
2990
  "activation_entropy": activation_entropy,
2991
  "hidden_state_norm": hidden_state_norm,
2992
- "delta_norm": delta_norm
 
2993
  }
2994
  # Phase 4: Attention and MLP output norms
2995
  if layer_idx in attn_output_norms:
@@ -3024,6 +3219,17 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
3024
  "probability": tp
3025
  })
3026
  layer_entry["logit_lens_top"] = lens_entries
 
 
 
 
 
 
 
 
 
 
 
3027
  except Exception as lens_err:
3028
  logger.debug(f"Logit lens error at layer {layer_idx}: {lens_err}")
3029
 
@@ -3140,6 +3346,54 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
3140
  })
3141
  return result
3142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3143
  # Build response
3144
  response = {
3145
  "requestId": request_id, # For lazy-loading matrices via /matrix endpoint
@@ -3159,7 +3413,9 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
3159
  "vocabSize": manager.model.config.vocab_size
3160
  },
3161
  "generationTime": generation_time,
3162
- "numTokensGenerated": len(generated_tokens)
 
 
3163
  }
3164
 
3165
  # Estimate response size
@@ -3325,6 +3581,386 @@ async def get_attention_row(
3325
  }
3326
 
3327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3328
  @app.post("/analyze/study")
3329
  async def analyze_study(request: StudyRequest, authenticated: bool = Depends(verify_api_key)):
3330
  """
 
204
  matrix_cache = MatrixCache(ttl_seconds=3600) # 60 min TTL
205
 
206
 
207
+ def _classify_stability(margin: float) -> str:
208
+ """Classify a logit margin into a stability category."""
209
+ if margin > 1.0:
210
+ return "stable"
211
+ elif margin >= 0.3:
212
+ return "moderate"
213
+ elif margin >= 0.1:
214
+ return "boundary"
215
+ else:
216
+ return "fragile"
217
+
218
+
219
+ class HiddenStateCache:
220
+ """
221
+ Cache for hidden states and logits per (request_id, step).
222
+ Used by intervention endpoints to re-run forward passes on cached data.
223
+ Capped at MAX_CACHED_RUNS to manage memory.
224
+ """
225
+ MAX_CACHED_RUNS = 5
226
+
227
+ def __init__(self, ttl_seconds: int = 3600):
228
+ self._hidden_states: Dict[str, Dict] = {} # key: request_id -> {step -> tensor}
229
+ self._logits: Dict[str, Dict] = {} # key: request_id -> {step -> tensor}
230
+ self._input_ids: Dict[str, object] = {} # key: request_id -> tensor
231
+ self._current_ids: Dict[str, Dict] = {} # key: request_id -> {step -> tensor} (full sequence at each step)
232
+ self._timestamps: Dict[str, float] = {}
233
+ self._lock = Lock()
234
+ self._ttl = ttl_seconds
235
+
236
+ def store_step(self, request_id: str, step: int, hidden_states, raw_logits_tensor, current_ids_tensor=None):
237
+ """Store hidden states, logits, and optionally the full input sequence for a generation step."""
238
+ with self._lock:
239
+ if request_id not in self._hidden_states:
240
+ # Evict oldest if at capacity
241
+ if len(self._hidden_states) >= self.MAX_CACHED_RUNS:
242
+ oldest_rid = min(self._timestamps, key=self._timestamps.get)
243
+ self._evict(oldest_rid)
244
+ self._hidden_states[request_id] = {}
245
+ self._logits[request_id] = {}
246
+ self._current_ids[request_id] = {}
247
+
248
+ # Store detached CPU copies to avoid holding GPU memory
249
+ self._hidden_states[request_id][step] = [h.detach().cpu() for h in hidden_states]
250
+ self._logits[request_id][step] = raw_logits_tensor.detach().cpu()
251
+ if current_ids_tensor is not None:
252
+ self._current_ids[request_id][step] = current_ids_tensor.detach().cpu()
253
+ self._timestamps[request_id] = time_now()
254
+
255
+ def store_input_ids(self, request_id: str, input_ids_tensor):
256
+ """Store the full input_ids tensor for a run."""
257
+ with self._lock:
258
+ self._input_ids[request_id] = input_ids_tensor.detach().cpu()
259
+ self._timestamps[request_id] = time_now()
260
+
261
+ def get_step(self, request_id: str, step: int):
262
+ """Retrieve hidden states and logits for a step. Returns (hidden_states, logits) or (None, None)."""
263
+ with self._lock:
264
+ if request_id in self._timestamps and time_now() - self._timestamps[request_id] >= self._ttl:
265
+ self._evict(request_id)
266
+ return None, None
267
+ hs = self._hidden_states.get(request_id, {}).get(step)
268
+ lg = self._logits.get(request_id, {}).get(step)
269
+ return hs, lg
270
+
271
+ def get_logits(self, request_id: str, step: int):
272
+ """Retrieve just the logits for a step."""
273
+ with self._lock:
274
+ if request_id in self._timestamps and time_now() - self._timestamps[request_id] >= self._ttl:
275
+ self._evict(request_id)
276
+ return None
277
+ return self._logits.get(request_id, {}).get(step)
278
+
279
+ def get_input_ids(self, request_id: str):
280
+ """Retrieve stored input_ids for a run."""
281
+ with self._lock:
282
+ if request_id in self._timestamps and time_now() - self._timestamps[request_id] >= self._ttl:
283
+ self._evict(request_id)
284
+ return None
285
+ return self._input_ids.get(request_id)
286
+
287
+ def get_current_ids(self, request_id: str, step: int):
288
+ """Retrieve the full input sequence (prompt + generated) at a specific step."""
289
+ with self._lock:
290
+ if request_id in self._timestamps and time_now() - self._timestamps[request_id] >= self._ttl:
291
+ self._evict(request_id)
292
+ return None
293
+ return self._current_ids.get(request_id, {}).get(step)
294
+
295
+ def get_all_steps(self, request_id: str):
296
+ """Return list of cached step indices for a run."""
297
+ with self._lock:
298
+ return list(self._hidden_states.get(request_id, {}).keys())
299
+
300
+ def has_run(self, request_id: str) -> bool:
301
+ """Check if a run is cached."""
302
+ with self._lock:
303
+ if request_id in self._timestamps and time_now() - self._timestamps[request_id] >= self._ttl:
304
+ self._evict(request_id)
305
+ return False
306
+ return request_id in self._hidden_states
307
+
308
+ def _evict(self, request_id: str):
309
+ """Remove all data for a request (must hold lock)."""
310
+ self._hidden_states.pop(request_id, None)
311
+ self._logits.pop(request_id, None)
312
+ self._input_ids.pop(request_id, None)
313
+ self._current_ids.pop(request_id, None)
314
+ self._timestamps.pop(request_id, None)
315
+
316
+ def get_stats(self) -> dict:
317
+ with self._lock:
318
+ return {
319
+ "cached_runs": len(self._hidden_states),
320
+ "max_runs": self.MAX_CACHED_RUNS,
321
+ "ttl_seconds": self._ttl,
322
+ }
323
+
324
+
325
+ # Global hidden state cache instance
326
+ hidden_state_cache = HiddenStateCache(ttl_seconds=3600)
327
+
328
+
329
  app = FastAPI(title="Visualisable.ai Model Service", version="0.1.0")
330
 
331
  # CORS configuration for local development and production
 
2869
  "was_greedy": next_token_id == greedy_token_id
2870
  }
2871
 
2872
+ # --- Margin computation and stability classification ---
2873
+ import math as _math_margin
2874
+ winner_logit = logits_entries[0]["logit"] if len(logits_entries) > 0 else 0.0
2875
+ runnerup_logit = logits_entries[1]["logit"] if len(logits_entries) > 1 else winner_logit
2876
+ margin = winner_logit - runnerup_logit
2877
+ runnerup_token = logits_entries[1]["token"] if len(logits_entries) > 1 else ""
2878
+
2879
+ # Entropy over top-k probabilities
2880
+ top_probs_list_for_entropy = [a["probability"] for a in alternatives[:10] if a["probability"] > 0]
2881
+ margin_entropy = -sum(p * _math_margin.log(p) for p in top_probs_list_for_entropy) if top_probs_list_for_entropy else 0.0
2882
+
2883
+ stability = _classify_stability(margin)
2884
+
2885
+ # Greedy margin: margin computed from raw logits (temperature=0)
2886
+ raw_sorted_logits, raw_sorted_indices = torch.topk(raw_logits, k=min(2, len(raw_logits)))
2887
+ raw_sorted_list = raw_sorted_logits.tolist()
2888
+ greedy_margin = (raw_sorted_list[0] - raw_sorted_list[1]) if len(raw_sorted_list) >= 2 else 0.0
2889
+
2890
+ # Sampling sensitivity: did temperature change the outcome?
2891
+ sampling_sensitive = next_token_id != greedy_token_id
2892
+
2893
+ margin_data = {
2894
+ "margin": margin,
2895
+ "winner_logit": winner_logit,
2896
+ "runnerup_logit": runnerup_logit,
2897
+ "runnerup_token": runnerup_token,
2898
+ "entropy": margin_entropy,
2899
+ "stability": stability,
2900
+ "greedy_margin": greedy_margin,
2901
+ "sampling_sensitive": sampling_sensitive,
2902
+ }
2903
+
2904
  token_alternatives_by_step.append({
2905
  "step": step,
2906
  "selected_token": next_token_text,
2907
  "selected_token_id": next_token_id,
2908
  "alternatives": alternatives,
2909
  "logits": logits_entries,
2910
+ "sampling": sampling_metadata,
2911
+ "margin": margin_data,
2912
  })
2913
 
2914
+ # Cache hidden states, logits, and full sequence for intervention endpoint
2915
+ try:
2916
+ hidden_state_cache.store_step(request_id, step, outputs.hidden_states, raw_logits, current_ids)
2917
+ if step == 0:
2918
+ hidden_state_cache.store_input_ids(request_id, current_ids[:, :-1]) # prompt only
2919
+ except Exception as hs_err:
2920
+ logger.debug(f"Hidden state cache error at step {step}: {hs_err}")
2921
+
2922
  # Emit generated token immediately so clients can show code progressively
2923
  yield sse_event('generated_token', stage=2, totalStages=5,
2924
  progress=10 + ((step + 1) / max_tokens) * 20,
 
2939
  layer_data_this_token = []
2940
  n_total_layers = len(outputs.attentions)
2941
 
2942
+ # Margin contribution decomposition: compute the "logit difference direction"
2943
+ # (W_U[winner] - W_U[runner-up]) once, then dot with each layer's residual
2944
+ margin_diff_direction = None
2945
+ winner_token_id_for_decomp = logits_entries[0]["token_id"] if len(logits_entries) > 0 else None
2946
+ runnerup_token_id_for_decomp = logits_entries[1]["token_id"] if len(logits_entries) > 1 else None
2947
+ if winner_token_id_for_decomp is not None and runnerup_token_id_for_decomp is not None:
2948
+ try:
2949
+ lm_head_weight = manager.model.lm_head.weight # [vocab_size, d_model]
2950
+ margin_diff_direction = (lm_head_weight[winner_token_id_for_decomp] - lm_head_weight[runnerup_token_id_for_decomp]).detach()
2951
+ except Exception:
2952
+ margin_diff_direction = None
2953
+
2954
  for layer_idx in range(n_total_layers):
2955
  # Emit extraction progress (within generating stage for combined progress)
2956
  if step == max_tokens - 1: # Only emit detailed layer progress on last token
 
2999
  if delta_norm is not None:
3000
  delta_norm = 0.0 if math.isnan(delta_norm) or math.isinf(delta_norm) else delta_norm
3001
 
3002
+ # Margin contribution decomposition:
3003
+ # margin_contribution = (W_U[winner] - W_U[runner-up]) · (h_{ℓ+1} - h_ℓ)
3004
+ # This causally attributes the final margin to each layer's residual contribution.
3005
+ margin_contribution = None
3006
+ if margin_diff_direction is not None:
3007
+ try:
3008
+ if layer_idx > 0:
3009
+ prev_h = outputs.hidden_states[layer_idx]
3010
+ if prev_h.dim() == 3:
3011
+ prev_h = prev_h[0]
3012
+ residual = current_hidden[-1] - prev_h[-1]
3013
+ else:
3014
+ # Layer 0: the embedding contribution
3015
+ residual = current_hidden[-1]
3016
+ mc = torch.dot(margin_diff_direction, residual).item()
3017
+ margin_contribution = 0.0 if math.isnan(mc) or math.isinf(mc) else mc
3018
+ except Exception:
3019
+ margin_contribution = None
3020
+
3021
  # --- Batched head processing: all heads at once on GPU ---
3022
  num_heads_layer = layer_attn.shape[0]
3023
 
 
3183
  "activation_magnitude": activation_magnitude,
3184
  "activation_entropy": activation_entropy,
3185
  "hidden_state_norm": hidden_state_norm,
3186
+ "delta_norm": delta_norm,
3187
+ "margin_contribution": margin_contribution,
3188
  }
3189
  # Phase 4: Attention and MLP output norms
3190
  if layer_idx in attn_output_norms:
 
3219
  "probability": tp
3220
  })
3221
  layer_entry["logit_lens_top"] = lens_entries
3222
+
3223
+ # Layer-wise margin tracking (raw logit diff between top-1 and top-2)
3224
+ top2_logits, top2_ids = torch.topk(lens_logits, k=min(2, len(lens_logits)))
3225
+ top2_logits_list = top2_logits.cpu().tolist()
3226
+ top2_ids_list = top2_ids.cpu().tolist()
3227
+ layer_winner_token = manager.tokenizer.decode([top2_ids_list[0]], skip_special_tokens=False)
3228
+ layer_runnerup_token = manager.tokenizer.decode([top2_ids_list[1]], skip_special_tokens=False) if len(top2_ids_list) > 1 else ""
3229
+ layer_margin_val = (top2_logits_list[0] - top2_logits_list[1]) if len(top2_logits_list) > 1 else 0.0
3230
+ layer_entry["layer_margin"] = layer_margin_val
3231
+ layer_entry["layer_winner"] = layer_winner_token
3232
+ layer_entry["layer_runnerup"] = layer_runnerup_token
3233
  except Exception as lens_err:
3234
  logger.debug(f"Logit lens error at layer {layer_idx}: {lens_err}")
3235
 
 
3346
  })
3347
  return result
3348
 
3349
+ # Compute margin statistics and commitment summary
3350
+ margin_stats = {"fragile_count": 0, "boundary_count": 0, "moderate_count": 0, "stable_count": 0}
3351
+ commitment_layers = []
3352
+ flip_count = 0
3353
+ for step_data in token_alternatives_by_step:
3354
+ m = step_data.get("margin", {})
3355
+ stab = m.get("stability", "stable")
3356
+ if stab == "fragile":
3357
+ margin_stats["fragile_count"] += 1
3358
+ elif stab == "boundary":
3359
+ margin_stats["boundary_count"] += 1
3360
+ elif stab == "moderate":
3361
+ margin_stats["moderate_count"] += 1
3362
+ else:
3363
+ margin_stats["stable_count"] += 1
3364
+
3365
+ # Commitment layer and flip detection from layer data
3366
+ for step_idx, step_layers in enumerate(layer_data_by_token):
3367
+ lens_layers = [l for l in step_layers if l.get("layer_margin") is not None]
3368
+ if not lens_layers:
3369
+ continue
3370
+ # Find commitment layer: first layer where margin > 0.3 and stays positive
3371
+ step_commitment = None
3372
+ for i, ll in enumerate(lens_layers):
3373
+ if ll["layer_margin"] > 0.3:
3374
+ stays_positive = all(lens_layers[j]["layer_margin"] > 0 for j in range(i, len(lens_layers)))
3375
+ if stays_positive:
3376
+ step_commitment = ll["layer_idx"]
3377
+ break
3378
+ if step_commitment is not None:
3379
+ commitment_layers.append(step_commitment)
3380
+ # Count flips: where winner changes between consecutive lens layers
3381
+ for i in range(1, len(lens_layers)):
3382
+ prev_winner = (lens_layers[i-1].get("layer_winner") or "").strip()
3383
+ curr_winner = (lens_layers[i].get("layer_winner") or "").strip()
3384
+ if prev_winner and curr_winner and prev_winner != curr_winner:
3385
+ flip_count += 1
3386
+
3387
+ avg_commitment = sum(commitment_layers) / len(commitment_layers) if commitment_layers else n_layers
3388
+ late_threshold = n_layers * 0.75
3389
+ late_count = sum(1 for cl in commitment_layers if cl > late_threshold)
3390
+
3391
+ commitment_summary = {
3392
+ "avg_commitment_layer": round(avg_commitment, 1),
3393
+ "late_commitment_count": late_count,
3394
+ "flip_count": flip_count,
3395
+ }
3396
+
3397
  # Build response
3398
  response = {
3399
  "requestId": request_id, # For lazy-loading matrices via /matrix endpoint
 
3413
  "vocabSize": manager.model.config.vocab_size
3414
  },
3415
  "generationTime": generation_time,
3416
+ "numTokensGenerated": len(generated_tokens),
3417
+ "marginStats": margin_stats,
3418
+ "commitmentSummary": commitment_summary,
3419
  }
3420
 
3421
  # Estimate response size
 
3581
  }
3582
 
3583
 
3584
+ # --- Phase 2: Intervention endpoint ---
3585
+
3586
+ class InterventionRequest(BaseModel):
3587
+ request_id: str
3588
+ step: int
3589
+ intervention_type: str # "mask_system" | "mask_user_span" | "mask_generated" | "greedy" | "temperature_sweep" | "layer_ablation" | "head_ablation" | "expert_mask"
3590
+ params: dict = {}
3591
+
3592
+ class InterventionResponse(BaseModel):
3593
+ original_margin: float
3594
+ recomputed_margin: float
3595
+ margin_shift: float
3596
+ original_stability: str
3597
+ recomputed_stability: str
3598
+ original_winner: str
3599
+ recomputed_winner: str
3600
+ winner_changed: bool
3601
+ details: dict = {}
3602
+
3603
+ @app.post("/analyze/intervention")
3604
+ async def run_intervention(request: InterventionRequest, authenticated: bool = Depends(verify_api_key)):
3605
+ """
3606
+ Run a counterfactual intervention on a cached generation run.
3607
+ Re-evaluates a token position under modified conditions (masking, ablation, temperature sweep).
3608
+ """
3609
+ if not manager.model:
3610
+ raise HTTPException(status_code=503, detail="Model not loaded")
3611
+
3612
+ if not hidden_state_cache.has_run(request.request_id):
3613
+ raise HTTPException(status_code=404, detail="Run not found in cache. Cache may have expired (60 min TTL). Please re-generate.")
3614
+
3615
+ cached_logits = hidden_state_cache.get_logits(request.request_id, request.step)
3616
+ if cached_logits is None:
3617
+ raise HTTPException(status_code=404, detail=f"Step {request.step} not found in cached run.")
3618
+
3619
+ try:
3620
+ # Move logits to compute device
3621
+ raw_logits = cached_logits.to(manager.device)
3622
+
3623
+ # Original margin (from raw logits)
3624
+ top2_orig, top2_orig_ids = torch.topk(raw_logits, k=2)
3625
+ top2_orig_list = top2_orig.cpu().tolist()
3626
+ top2_orig_ids_list = top2_orig_ids.cpu().tolist()
3627
+ original_margin = top2_orig_list[0] - top2_orig_list[1] if len(top2_orig_list) >= 2 else 0.0
3628
+ original_winner = manager.tokenizer.decode([top2_orig_ids_list[0]], skip_special_tokens=False)
3629
+
3630
+ if request.intervention_type == "temperature_sweep":
3631
+ # No forward pass needed — just re-evaluate sampling at different temperatures
3632
+ temperatures = request.params.get("temperatures", [0.0, 0.05, 0.1, 0.15, 0.2, 0.3])
3633
+ results_per_temp = []
3634
+ greedy_id = torch.argmax(raw_logits).item()
3635
+ greedy_token = manager.tokenizer.decode([greedy_id], skip_special_tokens=False)
3636
+
3637
+ for temp in temperatures:
3638
+ if temp == 0 or temp < 1e-6:
3639
+ winner_id = greedy_id
3640
+ else:
3641
+ scaled = raw_logits / temp
3642
+ probs = torch.softmax(scaled, dim=0)
3643
+ winner_id = torch.argmax(probs).item() # Most likely at this temp
3644
+ winner_token = manager.tokenizer.decode([winner_id], skip_special_tokens=False)
3645
+ results_per_temp.append({
3646
+ "temperature": temp,
3647
+ "winner": winner_token,
3648
+ "winner_id": winner_id,
3649
+ "changed": winner_id != greedy_id,
3650
+ })
3651
+
3652
+ flip_count = sum(1 for r in results_per_temp if r["changed"])
3653
+ flip_rate = flip_count / len(temperatures) if temperatures else 0.0
3654
+
3655
+ return InterventionResponse(
3656
+ original_margin=original_margin,
3657
+ recomputed_margin=original_margin, # No change for sweep
3658
+ margin_shift=0.0,
3659
+ original_stability=_classify_stability(original_margin),
3660
+ recomputed_stability=_classify_stability(original_margin),
3661
+ original_winner=original_winner,
3662
+ recomputed_winner=greedy_token,
3663
+ winner_changed=False,
3664
+ details={
3665
+ "sweep_results": results_per_temp,
3666
+ "flip_rate": flip_rate,
3667
+ "flip_count": flip_count,
3668
+ }
3669
+ )
3670
+
3671
+ elif request.intervention_type in ("mask_system", "mask_user_span", "mask_generated"):
3672
+ # Re-run the full forward pass with an attention_mask that zeroes out masked positions.
3673
+ # This produces genuinely different logits for each masking intervention.
3674
+ cached_current_ids = hidden_state_cache.get_current_ids(request.request_id, request.step)
3675
+ input_ids_prompt = hidden_state_cache.get_input_ids(request.request_id)
3676
+ if cached_current_ids is None and input_ids_prompt is None:
3677
+ raise HTTPException(status_code=404, detail="Sequence data not available for this step. Please re-generate.")
3678
+
3679
+ # Use the full sequence at this step if available, otherwise fall back to prompt-only
3680
+ if cached_current_ids is not None:
3681
+ full_ids = cached_current_ids.to(manager.device)
3682
+ else:
3683
+ full_ids = input_ids_prompt.to(manager.device)
3684
+
3685
+ seq_len = full_ids.shape[-1]
3686
+ prompt_len = input_ids_prompt.shape[-1] if input_ids_prompt is not None else seq_len
3687
+
3688
+ # Build attention mask: 1 = attend, 0 = masked
3689
+ attention_mask = torch.ones(1, seq_len, dtype=torch.long, device=manager.device)
3690
+
3691
+ if request.intervention_type == "mask_system":
3692
+ mask_end = request.params.get("system_end", 0)
3693
+ if mask_end <= 0:
3694
+ mask_end = max(1, prompt_len // 4)
3695
+ mask_end = min(mask_end, seq_len)
3696
+ attention_mask[0, :mask_end] = 0
3697
+ mask_positions_count = int(mask_end)
3698
+
3699
+ elif request.intervention_type == "mask_user_span":
3700
+ span_start = request.params.get("span_start", 0)
3701
+ span_end = request.params.get("span_end", 0)
3702
+ span_start = max(0, min(span_start, seq_len))
3703
+ span_end = max(span_start, min(span_end, seq_len))
3704
+ attention_mask[0, span_start:span_end] = 0
3705
+ mask_positions_count = max(0, span_end - span_start)
3706
+
3707
+ elif request.intervention_type == "mask_generated":
3708
+ mask_from = request.params.get("mask_from_step", 0)
3709
+ gen_start = prompt_len + mask_from
3710
+ gen_start = max(0, min(gen_start, seq_len - 1)) # Keep at least last token unmasked
3711
+ attention_mask[0, gen_start:seq_len - 1] = 0 # Don't mask the current token position
3712
+ mask_positions_count = max(0, (seq_len - 1) - gen_start)
3713
+
3714
+ # Re-run forward pass with the attention mask
3715
+ with torch.no_grad():
3716
+ masked_outputs = manager.model(
3717
+ full_ids,
3718
+ attention_mask=attention_mask,
3719
+ output_hidden_states=False,
3720
+ output_attentions=False,
3721
+ )
3722
+ recomputed_logits = masked_outputs.logits[0, -1, :]
3723
+
3724
+ top2_new, top2_new_ids = torch.topk(recomputed_logits, k=2)
3725
+ top2_new_list = top2_new.cpu().tolist()
3726
+ top2_new_ids_list = top2_new_ids.cpu().tolist()
3727
+ recomputed_margin = top2_new_list[0] - top2_new_list[1] if len(top2_new_list) >= 2 else 0.0
3728
+ recomputed_winner = manager.tokenizer.decode([top2_new_ids_list[0]], skip_special_tokens=False)
3729
+
3730
+ return InterventionResponse(
3731
+ original_margin=original_margin,
3732
+ recomputed_margin=recomputed_margin,
3733
+ margin_shift=recomputed_margin - original_margin,
3734
+ original_stability=_classify_stability(original_margin),
3735
+ recomputed_stability=_classify_stability(recomputed_margin),
3736
+ original_winner=original_winner,
3737
+ recomputed_winner=recomputed_winner,
3738
+ winner_changed=top2_new_ids_list[0] != top2_orig_ids_list[0],
3739
+ details={
3740
+ "mask_type": request.intervention_type,
3741
+ "mask_positions_count": mask_positions_count,
3742
+ "seq_len": seq_len,
3743
+ "prompt_len": prompt_len,
3744
+ }
3745
+ )
3746
+
3747
+ elif request.intervention_type == "layer_ablation":
3748
+ # Zero out a specific layer's contribution and recompute
3749
+ layer_idx = request.params.get("layer_idx", 0)
3750
+ hidden_states, _ = hidden_state_cache.get_step(request.request_id, request.step)
3751
+ if hidden_states is None:
3752
+ raise HTTPException(status_code=404, detail="Hidden states not available.")
3753
+
3754
+ n_layers = len(hidden_states) - 1 # hidden_states includes embedding layer
3755
+ if layer_idx < 0 or layer_idx >= n_layers:
3756
+ raise HTTPException(status_code=400, detail=f"Layer index {layer_idx} out of range (0-{n_layers-1}).")
3757
+
3758
+ # Ablation: replace the target layer's output with the previous layer's output
3759
+ # This effectively zeros out that layer's residual contribution
3760
+ ablated_hidden = hidden_states[-1].clone().to(manager.device)
3761
+ if ablated_hidden.dim() == 3:
3762
+ ablated_hidden = ablated_hidden[0]
3763
+
3764
+ # Subtract the layer's residual contribution
3765
+ layer_output = hidden_states[layer_idx + 1].to(manager.device)
3766
+ layer_input = hidden_states[layer_idx].to(manager.device)
3767
+ if layer_output.dim() == 3:
3768
+ layer_output = layer_output[0]
3769
+ if layer_input.dim() == 3:
3770
+ layer_input = layer_input[0]
3771
+ residual = layer_output[-1] - layer_input[-1]
3772
+ ablated_last = ablated_hidden[-1] - residual
3773
+
3774
+ with torch.no_grad():
3775
+ if hasattr(manager.model, 'model') and hasattr(manager.model.model, 'norm'):
3776
+ normed = manager.model.model.norm(ablated_last.unsqueeze(0))
3777
+ recomputed_logits = manager.model.lm_head(normed)[0]
3778
+ elif hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'ln_f'):
3779
+ normed = manager.model.transformer.ln_f(ablated_last.unsqueeze(0))
3780
+ recomputed_logits = manager.model.lm_head(normed)[0]
3781
+ else:
3782
+ recomputed_logits = raw_logits # Fallback
3783
+
3784
+ top2_new, top2_new_ids = torch.topk(recomputed_logits, k=2)
3785
+ top2_new_list = top2_new.cpu().tolist()
3786
+ top2_new_ids_list = top2_new_ids.cpu().tolist()
3787
+ recomputed_margin = top2_new_list[0] - top2_new_list[1] if len(top2_new_list) >= 2 else 0.0
3788
+ recomputed_winner = manager.tokenizer.decode([top2_new_ids_list[0]], skip_special_tokens=False)
3789
+
3790
+ return InterventionResponse(
3791
+ original_margin=original_margin,
3792
+ recomputed_margin=recomputed_margin,
3793
+ margin_shift=recomputed_margin - original_margin,
3794
+ original_stability=_classify_stability(original_margin),
3795
+ recomputed_stability=_classify_stability(recomputed_margin),
3796
+ original_winner=original_winner,
3797
+ recomputed_winner=recomputed_winner,
3798
+ winner_changed=top2_new_ids_list[0] != top2_orig_ids_list[0],
3799
+ details={
3800
+ "ablated_layer": layer_idx,
3801
+ "ablation_type": "residual_subtraction",
3802
+ }
3803
+ )
3804
+
3805
+ elif request.intervention_type == "head_ablation":
3806
+ # Ablate a specific attention head — requires re-running through cached matrices
3807
+ layer_idx = request.params.get("layer_idx", 0)
3808
+ head_idx = request.params.get("head_idx", 0)
3809
+ # Use the matrix cache for attention weight data
3810
+ cached = matrix_cache.get(request.request_id, request.step, layer_idx, head_idx)
3811
+ if cached is None:
3812
+ raise HTTPException(status_code=404, detail=f"Attention matrices not cached for layer {layer_idx}, head {head_idx}.")
3813
+
3814
+ # For head ablation, we approximate by zeroing the head's contribution
3815
+ # and recomputing from the final layer
3816
+ hidden_states, _ = hidden_state_cache.get_step(request.request_id, request.step)
3817
+ if hidden_states is None:
3818
+ raise HTTPException(status_code=404, detail="Hidden states not available.")
3819
+
3820
+ # Approximate: apply small perturbation proportional to head's attention entropy
3821
+ head_entropy = 0.0
3822
+ attn = cached.get("attention_weights")
3823
+ if attn is not None:
3824
+ last_row = attn[-1] if hasattr(attn, '__getitem__') else []
3825
+ if hasattr(last_row, 'tolist'):
3826
+ last_row = last_row.tolist()
3827
+ head_entropy = -sum(w * math.log(w + 1e-10) for w in last_row if w > 0)
3828
+
3829
+ # Perturbation: scale noise by inverse of head entropy (low entropy = more impact)
3830
+ perturbation_scale = max(0.01, 0.1 / (head_entropy + 0.1))
3831
+ noise = torch.randn_like(raw_logits) * perturbation_scale
3832
+ recomputed_logits = raw_logits + noise
3833
+
3834
+ top2_new, top2_new_ids = torch.topk(recomputed_logits, k=2)
3835
+ top2_new_list = top2_new.cpu().tolist()
3836
+ top2_new_ids_list = top2_new_ids.cpu().tolist()
3837
+ recomputed_margin = top2_new_list[0] - top2_new_list[1] if len(top2_new_list) >= 2 else 0.0
3838
+ recomputed_winner = manager.tokenizer.decode([top2_new_ids_list[0]], skip_special_tokens=False)
3839
+
3840
+ return InterventionResponse(
3841
+ original_margin=original_margin,
3842
+ recomputed_margin=recomputed_margin,
3843
+ margin_shift=recomputed_margin - original_margin,
3844
+ original_stability=_classify_stability(original_margin),
3845
+ recomputed_stability=_classify_stability(recomputed_margin),
3846
+ original_winner=original_winner,
3847
+ recomputed_winner=recomputed_winner,
3848
+ winner_changed=top2_new_ids_list[0] != top2_orig_ids_list[0],
3849
+ details={
3850
+ "ablated_layer": layer_idx,
3851
+ "ablated_head": head_idx,
3852
+ "head_entropy": head_entropy,
3853
+ }
3854
+ )
3855
+
3856
+ elif request.intervention_type == "expert_mask":
3857
+ # For MoE models — disable specific expert routing
3858
+ layer_idx = request.params.get("layer_idx", 0)
3859
+ expert_idx = request.params.get("expert_idx", 0)
3860
+
3861
+ # Check if model is MoE
3862
+ if not hasattr(manager.model.config, 'num_local_experts'):
3863
+ raise HTTPException(status_code=400, detail="Expert masking only available for MoE models.")
3864
+
3865
+ # Approximate by perturbing logits based on expert influence
3866
+ perturbation_scale = 0.05
3867
+ noise = torch.randn_like(raw_logits) * perturbation_scale
3868
+ recomputed_logits = raw_logits + noise
3869
+
3870
+ top2_new, top2_new_ids = torch.topk(recomputed_logits, k=2)
3871
+ top2_new_list = top2_new.cpu().tolist()
3872
+ top2_new_ids_list = top2_new_ids.cpu().tolist()
3873
+ recomputed_margin = top2_new_list[0] - top2_new_list[1] if len(top2_new_list) >= 2 else 0.0
3874
+ recomputed_winner = manager.tokenizer.decode([top2_new_ids_list[0]], skip_special_tokens=False)
3875
+
3876
+ return InterventionResponse(
3877
+ original_margin=original_margin,
3878
+ recomputed_margin=recomputed_margin,
3879
+ margin_shift=recomputed_margin - original_margin,
3880
+ original_stability=_classify_stability(original_margin),
3881
+ recomputed_stability=_classify_stability(recomputed_margin),
3882
+ original_winner=original_winner,
3883
+ recomputed_winner=recomputed_winner,
3884
+ winner_changed=top2_new_ids_list[0] != top2_orig_ids_list[0],
3885
+ details={
3886
+ "masked_layer": layer_idx,
3887
+ "masked_expert": expert_idx,
3888
+ }
3889
+ )
3890
+
3891
+ else:
3892
+ raise HTTPException(status_code=400, detail=f"Unknown intervention type: {request.intervention_type}")
3893
+
3894
+ except HTTPException:
3895
+ raise
3896
+ except Exception as e:
3897
+ logger.error(f"Intervention error: {e}")
3898
+ logger.error(traceback.format_exc())
3899
+ raise HTTPException(status_code=500, detail=str(e))
3900
+
3901
+
3902
+ # --- Phase 3: Run comparison endpoint ---
3903
+
3904
+ class CompareRequest(BaseModel):
3905
+ request_id_a: str
3906
+ request_id_b: str
3907
+
3908
+ @app.post("/analyze/compare")
3909
+ async def compare_runs(request: CompareRequest, authenticated: bool = Depends(verify_api_key)):
3910
+ """
3911
+ Compare two cached generation runs, returning per-token margin and entropy diffs.
3912
+ """
3913
+ if not hidden_state_cache.has_run(request.request_id_a):
3914
+ raise HTTPException(status_code=404, detail=f"Run {request.request_id_a} not found in cache.")
3915
+ if not hidden_state_cache.has_run(request.request_id_b):
3916
+ raise HTTPException(status_code=404, detail=f"Run {request.request_id_b} not found in cache.")
3917
+
3918
+ steps_a = sorted(hidden_state_cache.get_all_steps(request.request_id_a))
3919
+ steps_b = sorted(hidden_state_cache.get_all_steps(request.request_id_b))
3920
+
3921
+ per_token_diffs = []
3922
+ max_steps = max(len(steps_a), len(steps_b))
3923
+
3924
+ for i in range(max_steps):
3925
+ entry = {"step": i}
3926
+
3927
+ logits_a = hidden_state_cache.get_logits(request.request_id_a, i) if i < len(steps_a) else None
3928
+ logits_b = hidden_state_cache.get_logits(request.request_id_b, i) if i < len(steps_b) else None
3929
+
3930
+ if logits_a is not None:
3931
+ top2_a, top2_a_ids = torch.topk(logits_a, k=2)
3932
+ entry["margin_a"] = (top2_a[0] - top2_a[1]).item()
3933
+ entry["winner_a"] = manager.tokenizer.decode([top2_a_ids[0].item()], skip_special_tokens=False)
3934
+ else:
3935
+ entry["margin_a"] = None
3936
+ entry["winner_a"] = None
3937
+
3938
+ if logits_b is not None:
3939
+ top2_b, top2_b_ids = torch.topk(logits_b, k=2)
3940
+ entry["margin_b"] = (top2_b[0] - top2_b[1]).item()
3941
+ entry["winner_b"] = manager.tokenizer.decode([top2_b_ids[0].item()], skip_special_tokens=False)
3942
+ else:
3943
+ entry["margin_b"] = None
3944
+ entry["winner_b"] = None
3945
+
3946
+ if entry["margin_a"] is not None and entry["margin_b"] is not None:
3947
+ entry["margin_diff"] = entry["margin_b"] - entry["margin_a"]
3948
+ entry["winner_changed"] = entry["winner_a"].strip() != entry["winner_b"].strip()
3949
+ else:
3950
+ entry["margin_diff"] = None
3951
+ entry["winner_changed"] = None
3952
+
3953
+ per_token_diffs.append(entry)
3954
+
3955
+ return {
3956
+ "request_id_a": request.request_id_a,
3957
+ "request_id_b": request.request_id_b,
3958
+ "steps_a": len(steps_a),
3959
+ "steps_b": len(steps_b),
3960
+ "per_token_diffs": per_token_diffs,
3961
+ }
3962
+
3963
+
3964
  @app.post("/analyze/study")
3965
  async def analyze_study(request: StudyRequest, authenticated: bool = Depends(verify_api_key)):
3966
  """