gary-boon Claude Opus 4.6 commited on
Commit
54d9b6e
·
1 Parent(s): d8d197a

Add deep inspection: data-driven pattern classification, attention/MLP tracking, logit lens

Browse files

- Replace hardcoded position-based head/layer pattern classification with
data-driven detection (sink, previous-token, local, induction, positional, semantic)
- Layer patterns computed as confidence-weighted majority vote of head patterns
- Add forward hooks for attention and MLP output norms per layer
- Add logit lens computation at sampled layers (every n_layers//5)
- Compute per-head sink weight, local weight, and induction weight metrics

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. backend/model_service.py +205 -43
backend/model_service.py CHANGED
@@ -2065,26 +2065,51 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
2065
  entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy
2066
  avg_entropy = 0.0 if math.isnan(avg_entropy) or math.isinf(avg_entropy) else avg_entropy
2067
 
2068
- # Classify pattern
 
2069
  pattern_type = None
2070
  confidence = 0.0
2071
 
2072
- # Induction pattern: high attention to previous similar tokens
2073
- if step > 0 and max_weight > 0.8:
2074
- pattern_type = "induction"
2075
- confidence = max_weight
2076
- # Positional pattern: attention focused on nearby tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2077
  elif entropy < 1.0:
2078
  pattern_type = "positional"
2079
  confidence = 1.0 - entropy
2080
- # Semantic pattern: broader attention with moderate entropy
2081
- elif 1.0 <= entropy < 2.5:
2082
  pattern_type = "semantic"
2083
- confidence = min(1.0, entropy / 2.5)
2084
- # Previous token pattern: sharp focus on immediate predecessor
2085
- elif max_weight > 0.9 and head_weights[-2].item() > 0.85:
2086
- pattern_type = "previous_token"
2087
- confidence = head_weights[-2].item()
2088
 
2089
  # Sanitize confidence
2090
  confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
@@ -2129,17 +2154,21 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
2129
  # Sort by max_weight (return all heads, frontend will decide how many to display)
2130
  critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
2131
 
2132
- # Detect layer-level pattern (percentage-based for any layer count)
 
 
 
 
 
 
2133
  layer_pattern = None
2134
- layer_fraction = (layer_idx + 1) / n_layers # 1-indexed fraction
2135
- if layer_idx == 0:
2136
- layer_pattern = {"type": "positional", "confidence": 0.78}
2137
- elif layer_fraction <= 0.25 and step > 0:
2138
- layer_pattern = {"type": "previous_token", "confidence": 0.65}
2139
- elif layer_fraction <= 0.75:
2140
- layer_pattern = {"type": "induction", "confidence": 0.87}
2141
- else:
2142
- layer_pattern = {"type": "semantic", "confidence": 0.92}
2143
 
2144
  layer_data_this_token.append({
2145
  "layer_idx": layer_idx,
@@ -2506,6 +2535,55 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2506
  except Exception as hook_error:
2507
  logger.warning(f"Could not register QKV hooks: {hook_error}")
2508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2509
  with torch.no_grad():
2510
  current_ids = inputs["input_ids"]
2511
 
@@ -2520,6 +2598,8 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2520
  await asyncio.sleep(0)
2521
 
2522
  qkv_captures.clear()
 
 
2523
 
2524
  # Forward pass with full outputs
2525
  outputs = manager.model(
@@ -2775,13 +2855,37 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2775
  # Previous-token weights for pattern detection: [n_heads]
2776
  all_prev_token_weights = all_last_row[:, -2] if all_last_row.shape[1] >= 2 else torch.zeros(num_heads_layer, device=layer_attn.device)
2777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2778
  # Single bulk transfer: all head metrics to CPU
2779
- head_metrics_gpu = torch.stack([all_max_weights, all_entropies, all_avg_entropies, all_prev_token_weights]) # [4, n_heads]
 
 
 
2780
  head_metrics_cpu = head_metrics_gpu.cpu().tolist() # one sync point
2781
  max_weights_list = head_metrics_cpu[0]
2782
  entropies_list = head_metrics_cpu[1]
2783
  avg_entropies_list = head_metrics_cpu[2]
2784
  prev_token_list = head_metrics_cpu[3]
 
 
 
2785
 
2786
  # Bulk transfer attention matrices to CPU: one .cpu() for entire layer
2787
  layer_attn_cpu = layer_attn.cpu().float().numpy() # [n_heads, seq_len, seq_len]
@@ -2796,26 +2900,42 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2796
  ent = entropies_list[head_idx]
2797
  avg_ent = avg_entropies_list[head_idx]
2798
  ptw = prev_token_list[head_idx]
 
 
 
2799
 
2800
  # Sanitize
2801
  mw = 0.0 if math.isnan(mw) or math.isinf(mw) else mw
2802
  ent = 0.0 if math.isnan(ent) or math.isinf(ent) else ent
2803
  avg_ent = 0.0 if math.isnan(avg_ent) or math.isinf(avg_ent) else avg_ent
2804
 
 
2805
  pattern_type = None
2806
  confidence = 0.0
2807
- if step > 0 and mw > 0.8:
 
 
 
 
 
 
 
 
 
 
 
 
 
2808
  pattern_type = "induction"
2809
- confidence = mw
 
2810
  elif ent < 1.0:
2811
  pattern_type = "positional"
2812
  confidence = 1.0 - ent
2813
- elif 1.0 <= ent < 2.5:
 
2814
  pattern_type = "semantic"
2815
- confidence = min(1.0, ent / 2.5)
2816
- elif mw > 0.9 and ptw > 0.85:
2817
- pattern_type = "previous_token"
2818
- confidence = ptw
2819
  confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
2820
 
2821
  attention_matrix = layer_attn_cpu[head_idx]
@@ -2846,18 +2966,23 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2846
 
2847
  critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
2848
 
 
 
 
 
 
 
 
2849
  layer_pattern = None
2850
- layer_fraction = (layer_idx + 1) / n_layers
2851
- if layer_idx == 0:
2852
- layer_pattern = {"type": "positional", "confidence": 0.78}
2853
- elif layer_fraction <= 0.25 and step > 0:
2854
- layer_pattern = {"type": "previous_token", "confidence": 0.65}
2855
- elif layer_fraction <= 0.75:
2856
- layer_pattern = {"type": "induction", "confidence": 0.87}
2857
- else:
2858
- layer_pattern = {"type": "semantic", "confidence": 0.92}
2859
-
2860
- layer_data_this_token.append({
2861
  "layer_idx": layer_idx,
2862
  "pattern": layer_pattern,
2863
  "critical_heads": critical_heads,
@@ -2865,7 +2990,44 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2865
  "activation_entropy": activation_entropy,
2866
  "hidden_state_norm": hidden_state_norm,
2867
  "delta_norm": delta_norm
2868
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2869
 
2870
  layer_data_by_token.append(layer_data_this_token)
2871
 
 
2065
  entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy
2066
  avg_entropy = 0.0 if math.isnan(avg_entropy) or math.isinf(avg_entropy) else avg_entropy
2067
 
2068
+ # Data-driven head pattern classification (priority order)
2069
+ seq_len_hw = head_weights.shape[0]
2070
  pattern_type = None
2071
  confidence = 0.0
2072
 
2073
+ # 1. Attention sink: >50% weight on positions 0-2
2074
+ sink_w = head_weights[:min(3, seq_len_hw)].sum().item()
2075
+ if sink_w > 0.5:
2076
+ pattern_type = "attention_sink"
2077
+ confidence = sink_w
2078
+ # 2. Previous token: sharp focus on immediate predecessor
2079
+ elif max_weight > 0.9 and head_weights[-2].item() > 0.85:
2080
+ pattern_type = "previous_token"
2081
+ confidence = head_weights[-2].item()
2082
+ # 3. Local: >80% weight within 5 positions of query
2083
+ elif seq_len_hw > 5 and head_weights[max(0, seq_len_hw - 5):].sum().item() > 0.8:
2084
+ pattern_type = "local"
2085
+ confidence = head_weights[max(0, seq_len_hw - 5):].sum().item()
2086
+ # 4. Induction: attends to positions following previous occurrences of current token
2087
+ elif step > 0:
2088
+ current_tok = current_ids[0, -1]
2089
+ prev_occ = (current_ids[0, :-1] == current_tok).nonzero(as_tuple=True)[0]
2090
+ if len(prev_occ) > 0:
2091
+ foll = prev_occ + 1
2092
+ foll = foll[foll < seq_len_hw]
2093
+ if len(foll) > 0:
2094
+ ind_w = head_weights[foll].sum().item()
2095
+ if ind_w > 0.3:
2096
+ pattern_type = "induction"
2097
+ confidence = min(1.0, ind_w)
2098
+ if pattern_type is None:
2099
+ if entropy < 1.0:
2100
+ pattern_type = "positional"
2101
+ confidence = 1.0 - entropy
2102
+ elif entropy >= 1.0:
2103
+ pattern_type = "semantic"
2104
+ confidence = min(1.0, 0.5)
2105
+ # 5. Positional: low entropy, focused attention
2106
  elif entropy < 1.0:
2107
  pattern_type = "positional"
2108
  confidence = 1.0 - entropy
2109
+ # 6. Semantic: broad attention (fallback)
2110
+ elif entropy >= 1.0:
2111
  pattern_type = "semantic"
2112
+ confidence = min(1.0, 0.5)
 
 
 
 
2113
 
2114
  # Sanitize confidence
2115
  confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
 
2154
  # Sort by max_weight (return all heads, frontend will decide how many to display)
2155
  critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
2156
 
2157
+ # Layer-level pattern: majority vote of head patterns, weighted by confidence
2158
+ pattern_votes = {}
2159
+ for h in critical_heads:
2160
+ if h["pattern"] and h["pattern"]["type"]:
2161
+ pt = h["pattern"]["type"]
2162
+ pc = h["pattern"]["confidence"]
2163
+ pattern_votes[pt] = pattern_votes.get(pt, 0.0) + pc
2164
  layer_pattern = None
2165
+ if pattern_votes:
2166
+ best_type = max(pattern_votes, key=pattern_votes.get)
2167
+ total_conf = sum(pattern_votes.values())
2168
+ layer_pattern = {
2169
+ "type": best_type,
2170
+ "confidence": round(pattern_votes[best_type] / total_conf, 3) if total_conf > 0 else 0.0
2171
+ }
 
 
2172
 
2173
  layer_data_this_token.append({
2174
  "layer_idx": layer_idx,
 
2535
  except Exception as hook_error:
2536
  logger.warning(f"Could not register QKV hooks: {hook_error}")
2537
 
2538
+ # Phase 4: Hooks for attention and MLP output norms
2539
+ attn_output_norms = {}
2540
+ mlp_output_norms = {}
2541
+
2542
+ def make_attn_output_hook(layer_idx):
2543
+ def hook(module, input, output):
2544
+ try:
2545
+ out = output[0] if isinstance(output, tuple) else output
2546
+ if out.dim() == 3:
2547
+ attn_output_norms[layer_idx] = torch.norm(out[0, -1]).item()
2548
+ except Exception:
2549
+ pass
2550
+ return hook
2551
+
2552
+ def make_mlp_output_hook(layer_idx):
2553
+ def hook(module, input, output):
2554
+ try:
2555
+ out = output[0] if isinstance(output, tuple) else output
2556
+ if out.dim() == 3:
2557
+ mlp_output_norms[layer_idx] = torch.norm(out[0, -1]).item()
2558
+ elif out.dim() == 2:
2559
+ mlp_output_norms[layer_idx] = torch.norm(out[-1]).item()
2560
+ except Exception:
2561
+ pass
2562
+ return hook
2563
+
2564
+ try:
2565
+ # CodeGen style
2566
+ if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
2567
+ for layer_idx, layer in enumerate(manager.model.transformer.h):
2568
+ if hasattr(layer, 'attn'):
2569
+ hook = layer.attn.register_forward_hook(make_attn_output_hook(layer_idx))
2570
+ hooks.append(hook)
2571
+ if hasattr(layer, 'mlp'):
2572
+ hook = layer.mlp.register_forward_hook(make_mlp_output_hook(layer_idx))
2573
+ hooks.append(hook)
2574
+ # Mistral/LLaMA style
2575
+ elif hasattr(manager.model, 'model') and hasattr(manager.model.model, 'layers'):
2576
+ for layer_idx, layer in enumerate(manager.model.model.layers):
2577
+ if hasattr(layer, 'self_attn'):
2578
+ hook = layer.self_attn.register_forward_hook(make_attn_output_hook(layer_idx))
2579
+ hooks.append(hook)
2580
+ if hasattr(layer, 'mlp'):
2581
+ hook = layer.mlp.register_forward_hook(make_mlp_output_hook(layer_idx))
2582
+ hooks.append(hook)
2583
+ logger.info(f"Registered attn/MLP output hooks for contribution tracking")
2584
+ except Exception as hook_error:
2585
+ logger.warning(f"Could not register attn/MLP hooks: {hook_error}")
2586
+
2587
  with torch.no_grad():
2588
  current_ids = inputs["input_ids"]
2589
 
 
2598
  await asyncio.sleep(0)
2599
 
2600
  qkv_captures.clear()
2601
+ attn_output_norms.clear()
2602
+ mlp_output_norms.clear()
2603
 
2604
  # Forward pass with full outputs
2605
  outputs = manager.model(
 
2855
  # Previous-token weights for pattern detection: [n_heads]
2856
  all_prev_token_weights = all_last_row[:, -2] if all_last_row.shape[1] >= 2 else torch.zeros(num_heads_layer, device=layer_attn.device)
2857
 
2858
+ # Attention sink weights: sum of attention on positions 0-2 per head [n_heads]
2859
+ seq_len_attn = all_last_row.shape[1]
2860
+ all_sink_weights = all_last_row[:, :min(3, seq_len_attn)].sum(dim=-1)
2861
+
2862
+ # Local attention weights: sum within 5 positions of query per head [n_heads]
2863
+ all_local_weights = all_last_row[:, max(0, seq_len_attn - 5):].sum(dim=-1) if seq_len_attn > 5 else torch.ones(num_heads_layer, device=layer_attn.device)
2864
+
2865
+ # Induction detection: attention to positions following previous occurrences of current token
2866
+ all_induction_weights = torch.zeros(num_heads_layer, device=layer_attn.device)
2867
+ if step > 0:
2868
+ current_token = current_ids[0, -1]
2869
+ prev_occurrences = (current_ids[0, :-1] == current_token).nonzero(as_tuple=True)[0]
2870
+ if len(prev_occurrences) > 0:
2871
+ following_positions = prev_occurrences + 1
2872
+ following_positions = following_positions[following_positions < seq_len_attn]
2873
+ if len(following_positions) > 0:
2874
+ all_induction_weights = all_last_row[:, following_positions].sum(dim=-1)
2875
+
2876
  # Single bulk transfer: all head metrics to CPU
2877
+ head_metrics_gpu = torch.stack([
2878
+ all_max_weights, all_entropies, all_avg_entropies, all_prev_token_weights,
2879
+ all_sink_weights, all_local_weights, all_induction_weights
2880
+ ]) # [7, n_heads]
2881
  head_metrics_cpu = head_metrics_gpu.cpu().tolist() # one sync point
2882
  max_weights_list = head_metrics_cpu[0]
2883
  entropies_list = head_metrics_cpu[1]
2884
  avg_entropies_list = head_metrics_cpu[2]
2885
  prev_token_list = head_metrics_cpu[3]
2886
+ sink_weights_list = head_metrics_cpu[4]
2887
+ local_weights_list = head_metrics_cpu[5]
2888
+ induction_weights_list = head_metrics_cpu[6]
2889
 
2890
  # Bulk transfer attention matrices to CPU: one .cpu() for entire layer
2891
  layer_attn_cpu = layer_attn.cpu().float().numpy() # [n_heads, seq_len, seq_len]
 
2900
  ent = entropies_list[head_idx]
2901
  avg_ent = avg_entropies_list[head_idx]
2902
  ptw = prev_token_list[head_idx]
2903
+ skw = sink_weights_list[head_idx]
2904
+ lcw = local_weights_list[head_idx]
2905
+ idw = induction_weights_list[head_idx]
2906
 
2907
  # Sanitize
2908
  mw = 0.0 if math.isnan(mw) or math.isinf(mw) else mw
2909
  ent = 0.0 if math.isnan(ent) or math.isinf(ent) else ent
2910
  avg_ent = 0.0 if math.isnan(avg_ent) or math.isinf(avg_ent) else avg_ent
2911
 
2912
+ # Data-driven head pattern classification (priority order)
2913
  pattern_type = None
2914
  confidence = 0.0
2915
+ # 1. Attention sink: >50% weight on positions 0-2
2916
+ if skw > 0.5:
2917
+ pattern_type = "attention_sink"
2918
+ confidence = skw
2919
+ # 2. Previous token: sharp focus on immediate predecessor
2920
+ elif mw > 0.9 and ptw > 0.85:
2921
+ pattern_type = "previous_token"
2922
+ confidence = ptw
2923
+ # 3. Local: >80% weight within 5 positions of query
2924
+ elif seq_len_attn > 5 and lcw > 0.8:
2925
+ pattern_type = "local"
2926
+ confidence = lcw
2927
+ # 4. Induction: attends to positions following previous occurrences of current token
2928
+ elif step > 0 and idw > 0.3:
2929
  pattern_type = "induction"
2930
+ confidence = min(1.0, idw)
2931
+ # 5. Positional: low entropy, focused attention
2932
  elif ent < 1.0:
2933
  pattern_type = "positional"
2934
  confidence = 1.0 - ent
2935
+ # 6. Semantic: broad attention (fallback)
2936
+ elif ent >= 1.0:
2937
  pattern_type = "semantic"
2938
+ confidence = min(1.0, 0.5)
 
 
 
2939
  confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
2940
 
2941
  attention_matrix = layer_attn_cpu[head_idx]
 
2966
 
2967
  critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
2968
 
2969
+ # Layer-level pattern: majority vote of head patterns, weighted by confidence
2970
+ pattern_votes = {}
2971
+ for h in critical_heads:
2972
+ if h["pattern"] and h["pattern"]["type"]:
2973
+ pt = h["pattern"]["type"]
2974
+ pc = h["pattern"]["confidence"]
2975
+ pattern_votes[pt] = pattern_votes.get(pt, 0.0) + pc
2976
  layer_pattern = None
2977
+ if pattern_votes:
2978
+ best_type = max(pattern_votes, key=pattern_votes.get)
2979
+ total_conf = sum(pattern_votes.values())
2980
+ layer_pattern = {
2981
+ "type": best_type,
2982
+ "confidence": round(pattern_votes[best_type] / total_conf, 3) if total_conf > 0 else 0.0
2983
+ }
2984
+
2985
+ layer_entry = {
 
 
2986
  "layer_idx": layer_idx,
2987
  "pattern": layer_pattern,
2988
  "critical_heads": critical_heads,
 
2990
  "activation_entropy": activation_entropy,
2991
  "hidden_state_norm": hidden_state_norm,
2992
  "delta_norm": delta_norm
2993
+ }
2994
+ # Phase 4: Attention and MLP output norms
2995
+ if layer_idx in attn_output_norms:
2996
+ layer_entry["attn_output_norm"] = attn_output_norms[layer_idx]
2997
+ if layer_idx in mlp_output_norms:
2998
+ layer_entry["mlp_output_norm"] = mlp_output_norms[layer_idx]
2999
+
3000
+ # Phase 5: Logit lens at sampled layers (every 8th layer)
3001
+ logit_lens_stride = max(1, n_layers // 5)
3002
+ if layer_idx % logit_lens_stride == 0 or layer_idx == n_layers - 1:
3003
+ try:
3004
+ hidden_for_lens = current_hidden[-1].unsqueeze(0) # [1, hidden_dim]
3005
+ # Apply final layer norm then project through lm_head
3006
+ if hasattr(manager.model, 'model') and hasattr(manager.model.model, 'norm'):
3007
+ normed = manager.model.model.norm(hidden_for_lens)
3008
+ lens_logits = manager.model.lm_head(normed)[0] # [vocab_size]
3009
+ elif hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'ln_f'):
3010
+ normed = manager.model.transformer.ln_f(hidden_for_lens)
3011
+ lens_logits = manager.model.lm_head(normed)[0]
3012
+ else:
3013
+ lens_logits = None
3014
+
3015
+ if lens_logits is not None:
3016
+ lens_probs = torch.softmax(lens_logits, dim=-1)
3017
+ top_probs, top_ids = torch.topk(lens_probs, k=5)
3018
+ top_probs_list = top_probs.cpu().tolist()
3019
+ top_ids_list = top_ids.cpu().tolist()
3020
+ lens_entries = []
3021
+ for tp, tid in zip(top_probs_list, top_ids_list):
3022
+ lens_entries.append({
3023
+ "token": manager.tokenizer.decode([tid], skip_special_tokens=False),
3024
+ "probability": tp
3025
+ })
3026
+ layer_entry["logit_lens_top"] = lens_entries
3027
+ except Exception as lens_err:
3028
+ logger.debug(f"Logit lens error at layer {layer_idx}: {lens_err}")
3029
+
3030
+ layer_data_this_token.append(layer_entry)
3031
 
3032
  layer_data_by_token.append(layer_data_this_token)
3033