Spaces:
Paused
Refactor head classification from cascade to score-all-then-rank
Browse filesReplace priority-ordered first-match cascade with two-dimension system:
Behaviour type (attention geometry) — all scored simultaneously:
- attention_sink, previous_token, local, induction, focused, diffuse
- Primary = highest qualifying score (with min thresholds), secondary = runner-up
- "focused" and "diffuse" replace weak "positional"/"semantic" catch-alls
Code cue (separate dimension — code token relevance):
- delimiter_sensitive: attention to (){}[]:;,
- keyword_sensitive: attention to language keywords (def, return, if, etc.)
- pattern_reuse: derived from induction signal (repeated spans)
- Each scored by proportion of attention mass on target token class
- Evidence text included (e.g. "34% attention on delimiter sensitive tokens")
Per-head output now includes:
- pattern (primary behaviour type + confidence)
- secondary_behaviour (type + score)
- code_cue (type + score + evidence text)
- secondary_cue (type + score)
Token texts decoded and cached per generation step for code-aware detection.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- backend/model_service.py +118 -39
|
@@ -2193,26 +2193,28 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 2193 |
entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy
|
| 2194 |
avg_entropy = 0.0 if math.isnan(avg_entropy) or math.isinf(avg_entropy) else avg_entropy
|
| 2195 |
|
| 2196 |
-
#
|
|
|
|
| 2197 |
seq_len_hw = head_weights.shape[0]
|
| 2198 |
-
pattern_type = None
|
| 2199 |
-
confidence = 0.0
|
| 2200 |
|
| 2201 |
-
#
|
|
|
|
|
|
|
|
|
|
| 2202 |
sink_w = head_weights[:min(3, seq_len_hw)].sum().item()
|
| 2203 |
-
|
| 2204 |
-
|
| 2205 |
-
|
| 2206 |
-
|
| 2207 |
-
|
| 2208 |
-
|
| 2209 |
-
|
| 2210 |
-
|
| 2211 |
-
|
| 2212 |
-
|
| 2213 |
-
|
| 2214 |
-
|
| 2215 |
-
|
| 2216 |
current_tok = current_ids[0, -1]
|
| 2217 |
prev_occ = (current_ids[0, :-1] == current_tok).nonzero(as_tuple=True)[0]
|
| 2218 |
if len(prev_occ) > 0:
|
|
@@ -2220,24 +2222,81 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 2220 |
foll = foll[foll < seq_len_hw]
|
| 2221 |
if len(foll) > 0:
|
| 2222 |
ind_w = head_weights[foll].sum().item()
|
| 2223 |
-
|
| 2224 |
-
|
| 2225 |
-
|
| 2226 |
-
|
| 2227 |
-
|
| 2228 |
-
|
| 2229 |
-
|
| 2230 |
-
|
| 2231 |
-
|
| 2232 |
-
|
| 2233 |
-
#
|
| 2234 |
-
|
| 2235 |
-
|
| 2236 |
-
|
| 2237 |
-
|
| 2238 |
-
|
| 2239 |
-
|
| 2240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2241 |
|
| 2242 |
# Sanitize confidence
|
| 2243 |
confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
|
|
@@ -2267,7 +2326,7 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 2267 |
})
|
| 2268 |
|
| 2269 |
# Return only metadata (matrices fetched on-demand via /matrix endpoint)
|
| 2270 |
-
|
| 2271 |
"head_idx": head_idx,
|
| 2272 |
"entropy": entropy,
|
| 2273 |
"avg_entropy": avg_entropy, # Averaged over all query positions
|
|
@@ -2275,9 +2334,29 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 2275 |
"has_matrices": attention_matrix is not None, # Flag for frontend
|
| 2276 |
"pattern": {
|
| 2277 |
"type": pattern_type,
|
| 2278 |
-
"confidence": confidence
|
| 2279 |
-
} if pattern_type else None
|
| 2280 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2281 |
|
| 2282 |
# Sort by max_weight (return all heads, frontend will decide how many to display)
|
| 2283 |
critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
|
|
|
|
| 2193 |
entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy
|
| 2194 |
avg_entropy = 0.0 if math.isnan(avg_entropy) or math.isinf(avg_entropy) else avg_entropy
|
| 2195 |
|
| 2196 |
+
# Score-all-then-rank head classification
|
| 2197 |
+
# Two dimensions: behaviour type (attention geometry) + code cue (token relevance)
|
| 2198 |
seq_len_hw = head_weights.shape[0]
|
|
|
|
|
|
|
| 2199 |
|
| 2200 |
+
# --- Behaviour type scores (attention geometry) ---
|
| 2201 |
+
behaviour_scores = {}
|
| 2202 |
+
|
| 2203 |
+
# Attention sink: weight on positions 0-2
|
| 2204 |
sink_w = head_weights[:min(3, seq_len_hw)].sum().item()
|
| 2205 |
+
behaviour_scores["attention_sink"] = sink_w
|
| 2206 |
+
|
| 2207 |
+
# Previous token: weight on immediate predecessor
|
| 2208 |
+
prev_tok_w = head_weights[-2].item() if seq_len_hw >= 2 else 0.0
|
| 2209 |
+
behaviour_scores["previous_token"] = prev_tok_w
|
| 2210 |
+
|
| 2211 |
+
# Local: weight within last 5 positions
|
| 2212 |
+
local_w = head_weights[max(0, seq_len_hw - 5):].sum().item() if seq_len_hw > 5 else 0.0
|
| 2213 |
+
behaviour_scores["local"] = local_w
|
| 2214 |
+
|
| 2215 |
+
# Induction: weight on positions following previous occurrences of current token
|
| 2216 |
+
ind_w = 0.0
|
| 2217 |
+
if step > 0 and seq_len_hw > 1:
|
| 2218 |
current_tok = current_ids[0, -1]
|
| 2219 |
prev_occ = (current_ids[0, :-1] == current_tok).nonzero(as_tuple=True)[0]
|
| 2220 |
if len(prev_occ) > 0:
|
|
|
|
| 2222 |
foll = foll[foll < seq_len_hw]
|
| 2223 |
if len(foll) > 0:
|
| 2224 |
ind_w = head_weights[foll].sum().item()
|
| 2225 |
+
behaviour_scores["induction"] = min(1.0, ind_w)
|
| 2226 |
+
|
| 2227 |
+
# Focused: low entropy, concentrated attention (not captured by above)
|
| 2228 |
+
focused_score = max(0.0, 1.0 - entropy) if entropy < 1.5 else 0.0
|
| 2229 |
+
behaviour_scores["focused"] = focused_score
|
| 2230 |
+
|
| 2231 |
+
# Diffuse: high entropy, broad attention
|
| 2232 |
+
diffuse_score = min(1.0, max(0.0, (entropy - 1.0) / 2.0))
|
| 2233 |
+
behaviour_scores["diffuse"] = diffuse_score
|
| 2234 |
+
|
| 2235 |
+
# Pick primary behaviour (highest score, with minimum thresholds)
|
| 2236 |
+
behaviour_thresholds = {
|
| 2237 |
+
"attention_sink": 0.4,
|
| 2238 |
+
"previous_token": 0.7,
|
| 2239 |
+
"local": 0.5,
|
| 2240 |
+
"induction": 0.2,
|
| 2241 |
+
"focused": 0.3,
|
| 2242 |
+
"diffuse": 0.3,
|
| 2243 |
+
}
|
| 2244 |
+
qualified_behaviours = {
|
| 2245 |
+
k: v for k, v in behaviour_scores.items()
|
| 2246 |
+
if v >= behaviour_thresholds.get(k, 0.3)
|
| 2247 |
+
}
|
| 2248 |
+
sorted_behaviours = sorted(qualified_behaviours.items(), key=lambda x: x[1], reverse=True)
|
| 2249 |
+
primary_behaviour = sorted_behaviours[0] if sorted_behaviours else ("diffuse", diffuse_score)
|
| 2250 |
+
secondary_behaviour = sorted_behaviours[1] if len(sorted_behaviours) > 1 else None
|
| 2251 |
+
|
| 2252 |
+
pattern_type = primary_behaviour[0]
|
| 2253 |
+
confidence = primary_behaviour[1]
|
| 2254 |
+
|
| 2255 |
+
# --- Code cue scores (what code tokens are attended to) ---
|
| 2256 |
+
# Decode token texts for code-aware detection (cached per step)
|
| 2257 |
+
if not hasattr(self, '_step_token_texts') or self._step_token_texts_step != step:
|
| 2258 |
+
try:
|
| 2259 |
+
self._step_token_texts = [
|
| 2260 |
+
manager.tokenizer.decode([tid]) for tid in current_ids[0, :seq_len_hw].tolist()
|
| 2261 |
+
]
|
| 2262 |
+
except Exception:
|
| 2263 |
+
self._step_token_texts = []
|
| 2264 |
+
self._step_token_texts_step = step
|
| 2265 |
+
token_texts = self._step_token_texts
|
| 2266 |
+
|
| 2267 |
+
code_cues = {}
|
| 2268 |
+
if len(token_texts) == seq_len_hw:
|
| 2269 |
+
# Delimiter-sensitive: attention to brackets, braces, parens
|
| 2270 |
+
delimiters = {'(', ')', '{', '}', '[', ']', ':', ';', ','}
|
| 2271 |
+
delim_indices = [i for i, t in enumerate(token_texts) if t.strip() in delimiters]
|
| 2272 |
+
if delim_indices:
|
| 2273 |
+
delim_w = head_weights[delim_indices].sum().item()
|
| 2274 |
+
code_cues["delimiter_sensitive"] = delim_w
|
| 2275 |
+
|
| 2276 |
+
# Keyword-sensitive: attention to language keywords
|
| 2277 |
+
keywords = {'def', 'return', 'if', 'else', 'elif', 'for', 'while', 'class',
|
| 2278 |
+
'import', 'from', 'try', 'except', 'with', 'as', 'in', 'not',
|
| 2279 |
+
'and', 'or', 'True', 'False', 'None', 'self', 'yield', 'async',
|
| 2280 |
+
'await', 'lambda', 'raise', 'pass', 'break', 'continue',
|
| 2281 |
+
'function', 'const', 'let', 'var', 'new', 'this'}
|
| 2282 |
+
kw_indices = [i for i, t in enumerate(token_texts) if t.strip() in keywords]
|
| 2283 |
+
if kw_indices:
|
| 2284 |
+
kw_w = head_weights[kw_indices].sum().item()
|
| 2285 |
+
code_cues["keyword_sensitive"] = kw_w
|
| 2286 |
+
|
| 2287 |
+
# Pattern reuse: attention to a contiguous span that appeared earlier
|
| 2288 |
+
# (broader than induction — checks for repeated multi-token sequences)
|
| 2289 |
+
if ind_w > 0.15:
|
| 2290 |
+
code_cues["pattern_reuse"] = min(1.0, ind_w * 1.5)
|
| 2291 |
+
|
| 2292 |
+
# Filter code cues by minimum threshold
|
| 2293 |
+
code_cue_threshold = 0.15
|
| 2294 |
+
qualified_cues = {
|
| 2295 |
+
k: round(v, 4) for k, v in code_cues.items()
|
| 2296 |
+
if v >= code_cue_threshold
|
| 2297 |
+
}
|
| 2298 |
+
sorted_cues = sorted(qualified_cues.items(), key=lambda x: x[1], reverse=True)
|
| 2299 |
+
primary_cue = sorted_cues[0] if sorted_cues else None
|
| 2300 |
|
| 2301 |
# Sanitize confidence
|
| 2302 |
confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
|
|
|
|
| 2326 |
})
|
| 2327 |
|
| 2328 |
# Return only metadata (matrices fetched on-demand via /matrix endpoint)
|
| 2329 |
+
head_entry = {
|
| 2330 |
"head_idx": head_idx,
|
| 2331 |
"entropy": entropy,
|
| 2332 |
"avg_entropy": avg_entropy, # Averaged over all query positions
|
|
|
|
| 2334 |
"has_matrices": attention_matrix is not None, # Flag for frontend
|
| 2335 |
"pattern": {
|
| 2336 |
"type": pattern_type,
|
| 2337 |
+
"confidence": round(confidence, 4),
|
| 2338 |
+
} if pattern_type else None,
|
| 2339 |
+
}
|
| 2340 |
+
# Secondary behaviour (if present and distinct from primary)
|
| 2341 |
+
if secondary_behaviour:
|
| 2342 |
+
head_entry["secondary_behaviour"] = {
|
| 2343 |
+
"type": secondary_behaviour[0],
|
| 2344 |
+
"score": round(secondary_behaviour[1], 4),
|
| 2345 |
+
}
|
| 2346 |
+
# Code cue (separate dimension from behaviour type)
|
| 2347 |
+
if primary_cue:
|
| 2348 |
+
head_entry["code_cue"] = {
|
| 2349 |
+
"type": primary_cue[0],
|
| 2350 |
+
"score": round(primary_cue[1], 4),
|
| 2351 |
+
"evidence": f"{round(primary_cue[1] * 100)}% attention on {primary_cue[0].replace('_', ' ')} tokens",
|
| 2352 |
+
}
|
| 2353 |
+
# Secondary code cue
|
| 2354 |
+
if len(sorted_cues) > 1:
|
| 2355 |
+
head_entry["secondary_cue"] = {
|
| 2356 |
+
"type": sorted_cues[1][0],
|
| 2357 |
+
"score": round(sorted_cues[1][1], 4),
|
| 2358 |
+
}
|
| 2359 |
+
critical_heads.append(head_entry)
|
| 2360 |
|
| 2361 |
# Sort by max_weight (return all heads, frontend will decide how many to display)
|
| 2362 |
critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
|