gary-boon Claude Opus 4.6 (1M context) commited on
Commit
10baadf
·
1 Parent(s): 82349c1

Refactor head classification from cascade to score-all-then-rank

Browse files

Replace priority-ordered first-match cascade with two-dimension system:

Behaviour type (attention geometry) — all scored simultaneously:
- attention_sink, previous_token, local, induction, focused, diffuse
- Primary = highest qualifying score (with min thresholds), secondary = runner-up
- "focused" and "diffuse" replace weak "positional"/"semantic" catch-alls

Code cue (separate dimension — code token relevance):
- delimiter_sensitive: attention to (){}[]:;,
- keyword_sensitive: attention to language keywords (def, return, if, etc.)
- pattern_reuse: derived from induction signal (repeated spans)
- Each scored by proportion of attention mass on target token class
- Evidence text included (e.g. "34% attention on delimiter sensitive tokens")

Per-head output now includes:
- pattern (primary behaviour type + confidence)
- secondary_behaviour (type + score)
- code_cue (type + score + evidence text)
- secondary_cue (type + score)

Token texts decoded and cached per generation step for code-aware detection.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Files changed (1) hide show
  1. backend/model_service.py +118 -39
backend/model_service.py CHANGED
@@ -2193,26 +2193,28 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
2193
  entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy
2194
  avg_entropy = 0.0 if math.isnan(avg_entropy) or math.isinf(avg_entropy) else avg_entropy
2195
 
2196
- # Data-driven head pattern classification (priority order)
 
2197
  seq_len_hw = head_weights.shape[0]
2198
- pattern_type = None
2199
- confidence = 0.0
2200
 
2201
- # 1. Attention sink: >50% weight on positions 0-2
 
 
 
2202
  sink_w = head_weights[:min(3, seq_len_hw)].sum().item()
2203
- if sink_w > 0.5:
2204
- pattern_type = "attention_sink"
2205
- confidence = sink_w
2206
- # 2. Previous token: sharp focus on immediate predecessor
2207
- elif max_weight > 0.9 and head_weights[-2].item() > 0.85:
2208
- pattern_type = "previous_token"
2209
- confidence = head_weights[-2].item()
2210
- # 3. Local: >80% weight within 5 positions of query
2211
- elif seq_len_hw > 5 and head_weights[max(0, seq_len_hw - 5):].sum().item() > 0.8:
2212
- pattern_type = "local"
2213
- confidence = head_weights[max(0, seq_len_hw - 5):].sum().item()
2214
- # 4. Induction: attends to positions following previous occurrences of current token
2215
- elif step > 0:
2216
  current_tok = current_ids[0, -1]
2217
  prev_occ = (current_ids[0, :-1] == current_tok).nonzero(as_tuple=True)[0]
2218
  if len(prev_occ) > 0:
@@ -2220,24 +2222,81 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
2220
  foll = foll[foll < seq_len_hw]
2221
  if len(foll) > 0:
2222
  ind_w = head_weights[foll].sum().item()
2223
- if ind_w > 0.3:
2224
- pattern_type = "induction"
2225
- confidence = min(1.0, ind_w)
2226
- if pattern_type is None:
2227
- if entropy < 1.0:
2228
- pattern_type = "positional"
2229
- confidence = 1.0 - entropy
2230
- elif entropy >= 1.0:
2231
- pattern_type = "semantic"
2232
- confidence = min(1.0, 0.5)
2233
- # 5. Positional: low entropy, focused attention
2234
- elif entropy < 1.0:
2235
- pattern_type = "positional"
2236
- confidence = 1.0 - entropy
2237
- # 6. Semantic: broad attention (fallback)
2238
- elif entropy >= 1.0:
2239
- pattern_type = "semantic"
2240
- confidence = min(1.0, 0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2241
 
2242
  # Sanitize confidence
2243
  confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
@@ -2267,7 +2326,7 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
2267
  })
2268
 
2269
  # Return only metadata (matrices fetched on-demand via /matrix endpoint)
2270
- critical_heads.append({
2271
  "head_idx": head_idx,
2272
  "entropy": entropy,
2273
  "avg_entropy": avg_entropy, # Averaged over all query positions
@@ -2275,9 +2334,29 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
2275
  "has_matrices": attention_matrix is not None, # Flag for frontend
2276
  "pattern": {
2277
  "type": pattern_type,
2278
- "confidence": confidence
2279
- } if pattern_type else None
2280
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2281
 
2282
  # Sort by max_weight (return all heads, frontend will decide how many to display)
2283
  critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
 
2193
  entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy
2194
  avg_entropy = 0.0 if math.isnan(avg_entropy) or math.isinf(avg_entropy) else avg_entropy
2195
 
2196
+ # Score-all-then-rank head classification
2197
+ # Two dimensions: behaviour type (attention geometry) + code cue (token relevance)
2198
  seq_len_hw = head_weights.shape[0]
 
 
2199
 
2200
+ # --- Behaviour type scores (attention geometry) ---
2201
+ behaviour_scores = {}
2202
+
2203
+ # Attention sink: weight on positions 0-2
2204
  sink_w = head_weights[:min(3, seq_len_hw)].sum().item()
2205
+ behaviour_scores["attention_sink"] = sink_w
2206
+
2207
+ # Previous token: weight on immediate predecessor
2208
+ prev_tok_w = head_weights[-2].item() if seq_len_hw >= 2 else 0.0
2209
+ behaviour_scores["previous_token"] = prev_tok_w
2210
+
2211
+ # Local: weight within last 5 positions
2212
+ local_w = head_weights[max(0, seq_len_hw - 5):].sum().item() if seq_len_hw > 5 else 0.0
2213
+ behaviour_scores["local"] = local_w
2214
+
2215
+ # Induction: weight on positions following previous occurrences of current token
2216
+ ind_w = 0.0
2217
+ if step > 0 and seq_len_hw > 1:
2218
  current_tok = current_ids[0, -1]
2219
  prev_occ = (current_ids[0, :-1] == current_tok).nonzero(as_tuple=True)[0]
2220
  if len(prev_occ) > 0:
 
2222
  foll = foll[foll < seq_len_hw]
2223
  if len(foll) > 0:
2224
  ind_w = head_weights[foll].sum().item()
2225
+ behaviour_scores["induction"] = min(1.0, ind_w)
2226
+
2227
+ # Focused: low entropy, concentrated attention (not captured by above)
2228
+ focused_score = max(0.0, 1.0 - entropy) if entropy < 1.5 else 0.0
2229
+ behaviour_scores["focused"] = focused_score
2230
+
2231
+ # Diffuse: high entropy, broad attention
2232
+ diffuse_score = min(1.0, max(0.0, (entropy - 1.0) / 2.0))
2233
+ behaviour_scores["diffuse"] = diffuse_score
2234
+
2235
+ # Pick primary behaviour (highest score, with minimum thresholds)
2236
+ behaviour_thresholds = {
2237
+ "attention_sink": 0.4,
2238
+ "previous_token": 0.7,
2239
+ "local": 0.5,
2240
+ "induction": 0.2,
2241
+ "focused": 0.3,
2242
+ "diffuse": 0.3,
2243
+ }
2244
+ qualified_behaviours = {
2245
+ k: v for k, v in behaviour_scores.items()
2246
+ if v >= behaviour_thresholds.get(k, 0.3)
2247
+ }
2248
+ sorted_behaviours = sorted(qualified_behaviours.items(), key=lambda x: x[1], reverse=True)
2249
+ primary_behaviour = sorted_behaviours[0] if sorted_behaviours else ("diffuse", diffuse_score)
2250
+ secondary_behaviour = sorted_behaviours[1] if len(sorted_behaviours) > 1 else None
2251
+
2252
+ pattern_type = primary_behaviour[0]
2253
+ confidence = primary_behaviour[1]
2254
+
2255
+ # --- Code cue scores (what code tokens are attended to) ---
2256
+ # Decode token texts for code-aware detection (cached per step)
2257
+ if not hasattr(self, '_step_token_texts') or self._step_token_texts_step != step:
2258
+ try:
2259
+ self._step_token_texts = [
2260
+ manager.tokenizer.decode([tid]) for tid in current_ids[0, :seq_len_hw].tolist()
2261
+ ]
2262
+ except Exception:
2263
+ self._step_token_texts = []
2264
+ self._step_token_texts_step = step
2265
+ token_texts = self._step_token_texts
2266
+
2267
+ code_cues = {}
2268
+ if len(token_texts) == seq_len_hw:
2269
+ # Delimiter-sensitive: attention to brackets, braces, parens
2270
+ delimiters = {'(', ')', '{', '}', '[', ']', ':', ';', ','}
2271
+ delim_indices = [i for i, t in enumerate(token_texts) if t.strip() in delimiters]
2272
+ if delim_indices:
2273
+ delim_w = head_weights[delim_indices].sum().item()
2274
+ code_cues["delimiter_sensitive"] = delim_w
2275
+
2276
+ # Keyword-sensitive: attention to language keywords
2277
+ keywords = {'def', 'return', 'if', 'else', 'elif', 'for', 'while', 'class',
2278
+ 'import', 'from', 'try', 'except', 'with', 'as', 'in', 'not',
2279
+ 'and', 'or', 'True', 'False', 'None', 'self', 'yield', 'async',
2280
+ 'await', 'lambda', 'raise', 'pass', 'break', 'continue',
2281
+ 'function', 'const', 'let', 'var', 'new', 'this'}
2282
+ kw_indices = [i for i, t in enumerate(token_texts) if t.strip() in keywords]
2283
+ if kw_indices:
2284
+ kw_w = head_weights[kw_indices].sum().item()
2285
+ code_cues["keyword_sensitive"] = kw_w
2286
+
2287
+ # Pattern reuse: attention to a contiguous span that appeared earlier
2288
+ # (broader than induction — checks for repeated multi-token sequences)
2289
+ if ind_w > 0.15:
2290
+ code_cues["pattern_reuse"] = min(1.0, ind_w * 1.5)
2291
+
2292
+ # Filter code cues by minimum threshold
2293
+ code_cue_threshold = 0.15
2294
+ qualified_cues = {
2295
+ k: round(v, 4) for k, v in code_cues.items()
2296
+ if v >= code_cue_threshold
2297
+ }
2298
+ sorted_cues = sorted(qualified_cues.items(), key=lambda x: x[1], reverse=True)
2299
+ primary_cue = sorted_cues[0] if sorted_cues else None
2300
 
2301
  # Sanitize confidence
2302
  confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
 
2326
  })
2327
 
2328
  # Return only metadata (matrices fetched on-demand via /matrix endpoint)
2329
+ head_entry = {
2330
  "head_idx": head_idx,
2331
  "entropy": entropy,
2332
  "avg_entropy": avg_entropy, # Averaged over all query positions
 
2334
  "has_matrices": attention_matrix is not None, # Flag for frontend
2335
  "pattern": {
2336
  "type": pattern_type,
2337
+ "confidence": round(confidence, 4),
2338
+ } if pattern_type else None,
2339
+ }
2340
+ # Secondary behaviour (if present and distinct from primary)
2341
+ if secondary_behaviour:
2342
+ head_entry["secondary_behaviour"] = {
2343
+ "type": secondary_behaviour[0],
2344
+ "score": round(secondary_behaviour[1], 4),
2345
+ }
2346
+ # Code cue (separate dimension from behaviour type)
2347
+ if primary_cue:
2348
+ head_entry["code_cue"] = {
2349
+ "type": primary_cue[0],
2350
+ "score": round(primary_cue[1], 4),
2351
+ "evidence": f"{round(primary_cue[1] * 100)}% attention on {primary_cue[0].replace('_', ' ')} tokens",
2352
+ }
2353
+ # Secondary code cue
2354
+ if len(sorted_cues) > 1:
2355
+ head_entry["secondary_cue"] = {
2356
+ "type": sorted_cues[1][0],
2357
+ "score": round(sorted_cues[1][1], 4),
2358
+ }
2359
+ critical_heads.append(head_entry)
2360
 
2361
  # Sort by max_weight (return all heads, frontend will decide how many to display)
2362
  critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)