gary-boon Claude Opus 4.6 (1M context) commited on
Commit
e375e45
·
1 Parent(s): bfdde66

Fix score-all classification in vectorised path

Browse files

The vectorised GPU path (used for most layers) still had the old
cascade with "semantic"/"positional" catch-alls. Updated to match
the score-all-then-rank system with behaviour types + code cues.

Also fixed self reference error — code runs in a standalone async
function, not a class method. Use dict-based cache instead.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Files changed (1) hide show
  1. backend/model_service.py +91 -35
backend/model_service.py CHANGED
@@ -2254,15 +2254,15 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
2254
 
2255
  # --- Code cue scores (what code tokens are attended to) ---
2256
  # Decode token texts for code-aware detection (cached per step)
2257
- if not hasattr(self, '_step_token_texts') or self._step_token_texts_step != step:
2258
  try:
2259
- self._step_token_texts = [
2260
  manager.tokenizer.decode([tid]) for tid in current_ids[0, :seq_len_hw].tolist()
2261
  ]
2262
  except Exception:
2263
- self._step_token_texts = []
2264
- self._step_token_texts_step = step
2265
- token_texts = self._step_token_texts
2266
 
2267
  code_cues = {}
2268
  if len(token_texts) == seq_len_hw:
@@ -2790,6 +2790,9 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2790
  pass
2791
  return hook
2792
 
 
 
 
2793
  # Detect FFN type from first layer
2794
  ffn_type = "gelu" # default
2795
 
@@ -3223,35 +3226,77 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
3223
  ent = 0.0 if math.isnan(ent) or math.isinf(ent) else ent
3224
  avg_ent = 0.0 if math.isnan(avg_ent) or math.isinf(avg_ent) else avg_ent
3225
 
3226
- # Data-driven head pattern classification (priority order)
3227
- pattern_type = None
3228
- confidence = 0.0
3229
- # 1. Attention sink: >50% weight on positions 0-2
3230
- if skw > 0.5:
3231
- pattern_type = "attention_sink"
3232
- confidence = skw
3233
- # 2. Previous token: sharp focus on immediate predecessor
3234
- elif mw > 0.9 and ptw > 0.85:
3235
- pattern_type = "previous_token"
3236
- confidence = ptw
3237
- # 3. Local: >80% weight within 5 positions of query
3238
- elif seq_len_attn > 5 and lcw > 0.8:
3239
- pattern_type = "local"
3240
- confidence = lcw
3241
- # 4. Induction: attends to positions following previous occurrences of current token
3242
- elif step > 0 and idw > 0.3:
3243
- pattern_type = "induction"
3244
- confidence = min(1.0, idw)
3245
- # 5. Positional: low entropy, focused attention
3246
- elif ent < 1.0:
3247
- pattern_type = "positional"
3248
- confidence = 1.0 - ent
3249
- # 6. Semantic: broad attention (fallback)
3250
- elif ent >= 1.0:
3251
- pattern_type = "semantic"
3252
- confidence = min(1.0, 0.5)
 
3253
  confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
3254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3255
  attention_matrix = layer_attn_cpu[head_idx]
3256
 
3257
  q_matrix = None
@@ -3269,14 +3314,25 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
3269
  "v_matrix": v_matrix
3270
  })
3271
 
3272
- critical_heads.append({
3273
  "head_idx": head_idx,
3274
  "entropy": ent,
3275
  "avg_entropy": avg_ent,
3276
  "max_weight": mw,
3277
  "has_matrices": attention_matrix is not None,
3278
- "pattern": {"type": pattern_type, "confidence": confidence} if pattern_type else None
3279
- })
 
 
 
 
 
 
 
 
 
 
 
3280
 
3281
  critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
3282
 
 
2254
 
2255
  # --- Code cue scores (what code tokens are attended to) ---
2256
  # Decode token texts for code-aware detection (cached per step)
2257
+ if step_token_texts_cache.get('step') != step:
2258
  try:
2259
+ step_token_texts_cache['texts'] = [
2260
  manager.tokenizer.decode([tid]) for tid in current_ids[0, :seq_len_hw].tolist()
2261
  ]
2262
  except Exception:
2263
+ step_token_texts_cache['texts'] = []
2264
+ step_token_texts_cache['step'] = step
2265
+ token_texts = step_token_texts_cache.get('texts', [])
2266
 
2267
  code_cues = {}
2268
  if len(token_texts) == seq_len_hw:
 
2790
  pass
2791
  return hook
2792
 
2793
+ # Cache for decoded token texts (reused across heads within a step)
2794
+ step_token_texts_cache: Dict[str, Any] = {}
2795
+
2796
  # Detect FFN type from first layer
2797
  ffn_type = "gelu" # default
2798
 
 
3226
  ent = 0.0 if math.isnan(ent) or math.isinf(ent) else ent
3227
  avg_ent = 0.0 if math.isnan(avg_ent) or math.isinf(avg_ent) else avg_ent
3228
 
3229
+ # Score-all-then-rank head classification
3230
+ # Behaviour type scores (attention geometry)
3231
+ behaviour_scores = {
3232
+ "attention_sink": skw,
3233
+ "previous_token": ptw,
3234
+ "local": lcw,
3235
+ "induction": min(1.0, idw),
3236
+ "focused": max(0.0, 1.0 - ent) if ent < 1.5 else 0.0,
3237
+ "diffuse": min(1.0, max(0.0, (ent - 1.0) / 2.0)),
3238
+ }
3239
+ behaviour_thresholds = {
3240
+ "attention_sink": 0.4,
3241
+ "previous_token": 0.7,
3242
+ "local": 0.5,
3243
+ "induction": 0.2,
3244
+ "focused": 0.3,
3245
+ "diffuse": 0.3,
3246
+ }
3247
+ qualified = {
3248
+ k: v for k, v in behaviour_scores.items()
3249
+ if v >= behaviour_thresholds.get(k, 0.3)
3250
+ }
3251
+ sorted_behaviours = sorted(qualified.items(), key=lambda x: x[1], reverse=True)
3252
+ primary = sorted_behaviours[0] if sorted_behaviours else ("diffuse", behaviour_scores["diffuse"])
3253
+ secondary = sorted_behaviours[1] if len(sorted_behaviours) > 1 else None
3254
+
3255
+ pattern_type = primary[0]
3256
+ confidence = primary[1]
3257
  confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
3258
 
3259
+ # Code cue scores (what code tokens are attended to)
3260
+ # Decode token texts once per step (cached via nonlocal)
3261
+ if step_token_texts_cache.get('step') != step:
3262
+ try:
3263
+ step_token_texts_cache['texts'] = [
3264
+ manager.tokenizer.decode([tid]) for tid in current_ids[0, :seq_len_attn].tolist()
3265
+ ]
3266
+ except Exception:
3267
+ step_token_texts_cache['texts'] = []
3268
+ step_token_texts_cache['step'] = step
3269
+ token_texts = step_token_texts_cache.get('texts', [])
3270
+
3271
+ code_cues = {}
3272
+ if len(token_texts) == seq_len_attn:
3273
+ head_weights = all_last_row[head_idx].cpu()
3274
+ delimiters = {'(', ')', '{', '}', '[', ']', ':', ';', ','}
3275
+ delim_indices = [i for i, t in enumerate(token_texts) if t.strip() in delimiters]
3276
+ if delim_indices:
3277
+ code_cues["delimiter_sensitive"] = head_weights[delim_indices].sum().item()
3278
+
3279
+ keywords = {'def', 'return', 'if', 'else', 'elif', 'for', 'while', 'class',
3280
+ 'import', 'from', 'try', 'except', 'with', 'as', 'in', 'not',
3281
+ 'and', 'or', 'True', 'False', 'None', 'self', 'yield', 'async',
3282
+ 'await', 'lambda', 'raise', 'pass', 'break', 'continue',
3283
+ 'function', 'const', 'let', 'var', 'new', 'this',
3284
+ 'public', 'private', 'static', 'void', 'int', 'string', 'bool',
3285
+ 'namespace', 'using', 'class', 'interface', 'override', 'virtual'}
3286
+ kw_indices = [i for i, t in enumerate(token_texts) if t.strip() in keywords]
3287
+ if kw_indices:
3288
+ code_cues["keyword_sensitive"] = head_weights[kw_indices].sum().item()
3289
+
3290
+ if idw > 0.15:
3291
+ code_cues["pattern_reuse"] = min(1.0, idw * 1.5)
3292
+
3293
+ cue_threshold = 0.15
3294
+ sorted_cues = sorted(
3295
+ [(k, round(v, 4)) for k, v in code_cues.items() if v >= cue_threshold],
3296
+ key=lambda x: x[1], reverse=True
3297
+ )
3298
+ primary_cue = sorted_cues[0] if sorted_cues else None
3299
+
3300
  attention_matrix = layer_attn_cpu[head_idx]
3301
 
3302
  q_matrix = None
 
3314
  "v_matrix": v_matrix
3315
  })
3316
 
3317
+ head_entry = {
3318
  "head_idx": head_idx,
3319
  "entropy": ent,
3320
  "avg_entropy": avg_ent,
3321
  "max_weight": mw,
3322
  "has_matrices": attention_matrix is not None,
3323
+ "pattern": {"type": pattern_type, "confidence": round(confidence, 4)} if pattern_type else None,
3324
+ }
3325
+ if secondary:
3326
+ head_entry["secondary_behaviour"] = {"type": secondary[0], "score": round(secondary[1], 4)}
3327
+ if primary_cue:
3328
+ head_entry["code_cue"] = {
3329
+ "type": primary_cue[0],
3330
+ "score": primary_cue[1],
3331
+ "evidence": f"{round(primary_cue[1] * 100)}% attention on {primary_cue[0].replace('_', ' ')} tokens",
3332
+ }
3333
+ if len(sorted_cues) > 1:
3334
+ head_entry["secondary_cue"] = {"type": sorted_cues[1][0], "score": sorted_cues[1][1]}
3335
+ critical_heads.append(head_entry)
3336
 
3337
  critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
3338