Spaces:
Paused
Paused
gary-boon Claude Opus 4.6 commited on
Commit ·
54d9b6e
1
Parent(s): d8d197a
Add deep inspection: data-driven pattern classification, attention/MLP tracking, logit lens
Browse files- Replace hardcoded position-based head/layer pattern classification with
data-driven detection (sink, previous-token, local, induction, positional, semantic)
- Layer patterns computed as confidence-weighted majority vote of head patterns
- Add forward hooks for attention and MLP output norms per layer
- Add logit lens computation at sampled layers (every n_layers//5)
- Compute per-head sink weight, local weight, and induction weight metrics
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- backend/model_service.py +205 -43
backend/model_service.py
CHANGED
|
@@ -2065,26 +2065,51 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 2065 |
entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy
|
| 2066 |
avg_entropy = 0.0 if math.isnan(avg_entropy) or math.isinf(avg_entropy) else avg_entropy
|
| 2067 |
|
| 2068 |
-
#
|
|
|
|
| 2069 |
pattern_type = None
|
| 2070 |
confidence = 0.0
|
| 2071 |
|
| 2072 |
-
#
|
| 2073 |
-
|
| 2074 |
-
|
| 2075 |
-
|
| 2076 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2077 |
elif entropy < 1.0:
|
| 2078 |
pattern_type = "positional"
|
| 2079 |
confidence = 1.0 - entropy
|
| 2080 |
-
#
|
| 2081 |
-
elif
|
| 2082 |
pattern_type = "semantic"
|
| 2083 |
-
confidence = min(1.0,
|
| 2084 |
-
# Previous token pattern: sharp focus on immediate predecessor
|
| 2085 |
-
elif max_weight > 0.9 and head_weights[-2].item() > 0.85:
|
| 2086 |
-
pattern_type = "previous_token"
|
| 2087 |
-
confidence = head_weights[-2].item()
|
| 2088 |
|
| 2089 |
# Sanitize confidence
|
| 2090 |
confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
|
|
@@ -2129,17 +2154,21 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 2129 |
# Sort by max_weight (return all heads, frontend will decide how many to display)
|
| 2130 |
critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
|
| 2131 |
|
| 2132 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2133 |
layer_pattern = None
|
| 2134 |
-
|
| 2135 |
-
|
| 2136 |
-
|
| 2137 |
-
|
| 2138 |
-
|
| 2139 |
-
|
| 2140 |
-
|
| 2141 |
-
else:
|
| 2142 |
-
layer_pattern = {"type": "semantic", "confidence": 0.92}
|
| 2143 |
|
| 2144 |
layer_data_this_token.append({
|
| 2145 |
"layer_idx": layer_idx,
|
|
@@ -2506,6 +2535,55 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2506 |
except Exception as hook_error:
|
| 2507 |
logger.warning(f"Could not register QKV hooks: {hook_error}")
|
| 2508 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2509 |
with torch.no_grad():
|
| 2510 |
current_ids = inputs["input_ids"]
|
| 2511 |
|
|
@@ -2520,6 +2598,8 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2520 |
await asyncio.sleep(0)
|
| 2521 |
|
| 2522 |
qkv_captures.clear()
|
|
|
|
|
|
|
| 2523 |
|
| 2524 |
# Forward pass with full outputs
|
| 2525 |
outputs = manager.model(
|
|
@@ -2775,13 +2855,37 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2775 |
# Previous-token weights for pattern detection: [n_heads]
|
| 2776 |
all_prev_token_weights = all_last_row[:, -2] if all_last_row.shape[1] >= 2 else torch.zeros(num_heads_layer, device=layer_attn.device)
|
| 2777 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2778 |
# Single bulk transfer: all head metrics to CPU
|
| 2779 |
-
head_metrics_gpu = torch.stack([
|
|
|
|
|
|
|
|
|
|
| 2780 |
head_metrics_cpu = head_metrics_gpu.cpu().tolist() # one sync point
|
| 2781 |
max_weights_list = head_metrics_cpu[0]
|
| 2782 |
entropies_list = head_metrics_cpu[1]
|
| 2783 |
avg_entropies_list = head_metrics_cpu[2]
|
| 2784 |
prev_token_list = head_metrics_cpu[3]
|
|
|
|
|
|
|
|
|
|
| 2785 |
|
| 2786 |
# Bulk transfer attention matrices to CPU: one .cpu() for entire layer
|
| 2787 |
layer_attn_cpu = layer_attn.cpu().float().numpy() # [n_heads, seq_len, seq_len]
|
|
@@ -2796,26 +2900,42 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2796 |
ent = entropies_list[head_idx]
|
| 2797 |
avg_ent = avg_entropies_list[head_idx]
|
| 2798 |
ptw = prev_token_list[head_idx]
|
|
|
|
|
|
|
|
|
|
| 2799 |
|
| 2800 |
# Sanitize
|
| 2801 |
mw = 0.0 if math.isnan(mw) or math.isinf(mw) else mw
|
| 2802 |
ent = 0.0 if math.isnan(ent) or math.isinf(ent) else ent
|
| 2803 |
avg_ent = 0.0 if math.isnan(avg_ent) or math.isinf(avg_ent) else avg_ent
|
| 2804 |
|
|
|
|
| 2805 |
pattern_type = None
|
| 2806 |
confidence = 0.0
|
| 2807 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2808 |
pattern_type = "induction"
|
| 2809 |
-
confidence =
|
|
|
|
| 2810 |
elif ent < 1.0:
|
| 2811 |
pattern_type = "positional"
|
| 2812 |
confidence = 1.0 - ent
|
| 2813 |
-
|
|
|
|
| 2814 |
pattern_type = "semantic"
|
| 2815 |
-
confidence = min(1.0,
|
| 2816 |
-
elif mw > 0.9 and ptw > 0.85:
|
| 2817 |
-
pattern_type = "previous_token"
|
| 2818 |
-
confidence = ptw
|
| 2819 |
confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
|
| 2820 |
|
| 2821 |
attention_matrix = layer_attn_cpu[head_idx]
|
|
@@ -2846,18 +2966,23 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2846 |
|
| 2847 |
critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
|
| 2848 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2849 |
layer_pattern = None
|
| 2850 |
-
|
| 2851 |
-
|
| 2852 |
-
|
| 2853 |
-
|
| 2854 |
-
|
| 2855 |
-
|
| 2856 |
-
|
| 2857 |
-
|
| 2858 |
-
|
| 2859 |
-
|
| 2860 |
-
layer_data_this_token.append({
|
| 2861 |
"layer_idx": layer_idx,
|
| 2862 |
"pattern": layer_pattern,
|
| 2863 |
"critical_heads": critical_heads,
|
|
@@ -2865,7 +2990,44 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2865 |
"activation_entropy": activation_entropy,
|
| 2866 |
"hidden_state_norm": hidden_state_norm,
|
| 2867 |
"delta_norm": delta_norm
|
| 2868 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2869 |
|
| 2870 |
layer_data_by_token.append(layer_data_this_token)
|
| 2871 |
|
|
|
|
| 2065 |
entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy
|
| 2066 |
avg_entropy = 0.0 if math.isnan(avg_entropy) or math.isinf(avg_entropy) else avg_entropy
|
| 2067 |
|
| 2068 |
+
# Data-driven head pattern classification (priority order)
|
| 2069 |
+
seq_len_hw = head_weights.shape[0]
|
| 2070 |
pattern_type = None
|
| 2071 |
confidence = 0.0
|
| 2072 |
|
| 2073 |
+
# 1. Attention sink: >50% weight on positions 0-2
|
| 2074 |
+
sink_w = head_weights[:min(3, seq_len_hw)].sum().item()
|
| 2075 |
+
if sink_w > 0.5:
|
| 2076 |
+
pattern_type = "attention_sink"
|
| 2077 |
+
confidence = sink_w
|
| 2078 |
+
# 2. Previous token: sharp focus on immediate predecessor
|
| 2079 |
+
elif max_weight > 0.9 and head_weights[-2].item() > 0.85:
|
| 2080 |
+
pattern_type = "previous_token"
|
| 2081 |
+
confidence = head_weights[-2].item()
|
| 2082 |
+
# 3. Local: >80% weight within 5 positions of query
|
| 2083 |
+
elif seq_len_hw > 5 and head_weights[max(0, seq_len_hw - 5):].sum().item() > 0.8:
|
| 2084 |
+
pattern_type = "local"
|
| 2085 |
+
confidence = head_weights[max(0, seq_len_hw - 5):].sum().item()
|
| 2086 |
+
# 4. Induction: attends to positions following previous occurrences of current token
|
| 2087 |
+
elif step > 0:
|
| 2088 |
+
current_tok = current_ids[0, -1]
|
| 2089 |
+
prev_occ = (current_ids[0, :-1] == current_tok).nonzero(as_tuple=True)[0]
|
| 2090 |
+
if len(prev_occ) > 0:
|
| 2091 |
+
foll = prev_occ + 1
|
| 2092 |
+
foll = foll[foll < seq_len_hw]
|
| 2093 |
+
if len(foll) > 0:
|
| 2094 |
+
ind_w = head_weights[foll].sum().item()
|
| 2095 |
+
if ind_w > 0.3:
|
| 2096 |
+
pattern_type = "induction"
|
| 2097 |
+
confidence = min(1.0, ind_w)
|
| 2098 |
+
if pattern_type is None:
|
| 2099 |
+
if entropy < 1.0:
|
| 2100 |
+
pattern_type = "positional"
|
| 2101 |
+
confidence = 1.0 - entropy
|
| 2102 |
+
elif entropy >= 1.0:
|
| 2103 |
+
pattern_type = "semantic"
|
| 2104 |
+
confidence = min(1.0, 0.5)
|
| 2105 |
+
# 5. Positional: low entropy, focused attention
|
| 2106 |
elif entropy < 1.0:
|
| 2107 |
pattern_type = "positional"
|
| 2108 |
confidence = 1.0 - entropy
|
| 2109 |
+
# 6. Semantic: broad attention (fallback)
|
| 2110 |
+
elif entropy >= 1.0:
|
| 2111 |
pattern_type = "semantic"
|
| 2112 |
+
confidence = min(1.0, 0.5)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2113 |
|
| 2114 |
# Sanitize confidence
|
| 2115 |
confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
|
|
|
|
| 2154 |
# Sort by max_weight (return all heads, frontend will decide how many to display)
|
| 2155 |
critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
|
| 2156 |
|
| 2157 |
+
# Layer-level pattern: majority vote of head patterns, weighted by confidence
|
| 2158 |
+
pattern_votes = {}
|
| 2159 |
+
for h in critical_heads:
|
| 2160 |
+
if h["pattern"] and h["pattern"]["type"]:
|
| 2161 |
+
pt = h["pattern"]["type"]
|
| 2162 |
+
pc = h["pattern"]["confidence"]
|
| 2163 |
+
pattern_votes[pt] = pattern_votes.get(pt, 0.0) + pc
|
| 2164 |
layer_pattern = None
|
| 2165 |
+
if pattern_votes:
|
| 2166 |
+
best_type = max(pattern_votes, key=pattern_votes.get)
|
| 2167 |
+
total_conf = sum(pattern_votes.values())
|
| 2168 |
+
layer_pattern = {
|
| 2169 |
+
"type": best_type,
|
| 2170 |
+
"confidence": round(pattern_votes[best_type] / total_conf, 3) if total_conf > 0 else 0.0
|
| 2171 |
+
}
|
|
|
|
|
|
|
| 2172 |
|
| 2173 |
layer_data_this_token.append({
|
| 2174 |
"layer_idx": layer_idx,
|
|
|
|
| 2535 |
except Exception as hook_error:
|
| 2536 |
logger.warning(f"Could not register QKV hooks: {hook_error}")
|
| 2537 |
|
| 2538 |
+
# Phase 4: Hooks for attention and MLP output norms
|
| 2539 |
+
attn_output_norms = {}
|
| 2540 |
+
mlp_output_norms = {}
|
| 2541 |
+
|
| 2542 |
+
def make_attn_output_hook(layer_idx):
|
| 2543 |
+
def hook(module, input, output):
|
| 2544 |
+
try:
|
| 2545 |
+
out = output[0] if isinstance(output, tuple) else output
|
| 2546 |
+
if out.dim() == 3:
|
| 2547 |
+
attn_output_norms[layer_idx] = torch.norm(out[0, -1]).item()
|
| 2548 |
+
except Exception:
|
| 2549 |
+
pass
|
| 2550 |
+
return hook
|
| 2551 |
+
|
| 2552 |
+
def make_mlp_output_hook(layer_idx):
|
| 2553 |
+
def hook(module, input, output):
|
| 2554 |
+
try:
|
| 2555 |
+
out = output[0] if isinstance(output, tuple) else output
|
| 2556 |
+
if out.dim() == 3:
|
| 2557 |
+
mlp_output_norms[layer_idx] = torch.norm(out[0, -1]).item()
|
| 2558 |
+
elif out.dim() == 2:
|
| 2559 |
+
mlp_output_norms[layer_idx] = torch.norm(out[-1]).item()
|
| 2560 |
+
except Exception:
|
| 2561 |
+
pass
|
| 2562 |
+
return hook
|
| 2563 |
+
|
| 2564 |
+
try:
|
| 2565 |
+
# CodeGen style
|
| 2566 |
+
if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
|
| 2567 |
+
for layer_idx, layer in enumerate(manager.model.transformer.h):
|
| 2568 |
+
if hasattr(layer, 'attn'):
|
| 2569 |
+
hook = layer.attn.register_forward_hook(make_attn_output_hook(layer_idx))
|
| 2570 |
+
hooks.append(hook)
|
| 2571 |
+
if hasattr(layer, 'mlp'):
|
| 2572 |
+
hook = layer.mlp.register_forward_hook(make_mlp_output_hook(layer_idx))
|
| 2573 |
+
hooks.append(hook)
|
| 2574 |
+
# Mistral/LLaMA style
|
| 2575 |
+
elif hasattr(manager.model, 'model') and hasattr(manager.model.model, 'layers'):
|
| 2576 |
+
for layer_idx, layer in enumerate(manager.model.model.layers):
|
| 2577 |
+
if hasattr(layer, 'self_attn'):
|
| 2578 |
+
hook = layer.self_attn.register_forward_hook(make_attn_output_hook(layer_idx))
|
| 2579 |
+
hooks.append(hook)
|
| 2580 |
+
if hasattr(layer, 'mlp'):
|
| 2581 |
+
hook = layer.mlp.register_forward_hook(make_mlp_output_hook(layer_idx))
|
| 2582 |
+
hooks.append(hook)
|
| 2583 |
+
logger.info(f"Registered attn/MLP output hooks for contribution tracking")
|
| 2584 |
+
except Exception as hook_error:
|
| 2585 |
+
logger.warning(f"Could not register attn/MLP hooks: {hook_error}")
|
| 2586 |
+
|
| 2587 |
with torch.no_grad():
|
| 2588 |
current_ids = inputs["input_ids"]
|
| 2589 |
|
|
|
|
| 2598 |
await asyncio.sleep(0)
|
| 2599 |
|
| 2600 |
qkv_captures.clear()
|
| 2601 |
+
attn_output_norms.clear()
|
| 2602 |
+
mlp_output_norms.clear()
|
| 2603 |
|
| 2604 |
# Forward pass with full outputs
|
| 2605 |
outputs = manager.model(
|
|
|
|
| 2855 |
# Previous-token weights for pattern detection: [n_heads]
|
| 2856 |
all_prev_token_weights = all_last_row[:, -2] if all_last_row.shape[1] >= 2 else torch.zeros(num_heads_layer, device=layer_attn.device)
|
| 2857 |
|
| 2858 |
+
# Attention sink weights: sum of attention on positions 0-2 per head [n_heads]
|
| 2859 |
+
seq_len_attn = all_last_row.shape[1]
|
| 2860 |
+
all_sink_weights = all_last_row[:, :min(3, seq_len_attn)].sum(dim=-1)
|
| 2861 |
+
|
| 2862 |
+
# Local attention weights: sum within 5 positions of query per head [n_heads]
|
| 2863 |
+
all_local_weights = all_last_row[:, max(0, seq_len_attn - 5):].sum(dim=-1) if seq_len_attn > 5 else torch.ones(num_heads_layer, device=layer_attn.device)
|
| 2864 |
+
|
| 2865 |
+
# Induction detection: attention to positions following previous occurrences of current token
|
| 2866 |
+
all_induction_weights = torch.zeros(num_heads_layer, device=layer_attn.device)
|
| 2867 |
+
if step > 0:
|
| 2868 |
+
current_token = current_ids[0, -1]
|
| 2869 |
+
prev_occurrences = (current_ids[0, :-1] == current_token).nonzero(as_tuple=True)[0]
|
| 2870 |
+
if len(prev_occurrences) > 0:
|
| 2871 |
+
following_positions = prev_occurrences + 1
|
| 2872 |
+
following_positions = following_positions[following_positions < seq_len_attn]
|
| 2873 |
+
if len(following_positions) > 0:
|
| 2874 |
+
all_induction_weights = all_last_row[:, following_positions].sum(dim=-1)
|
| 2875 |
+
|
| 2876 |
# Single bulk transfer: all head metrics to CPU
|
| 2877 |
+
head_metrics_gpu = torch.stack([
|
| 2878 |
+
all_max_weights, all_entropies, all_avg_entropies, all_prev_token_weights,
|
| 2879 |
+
all_sink_weights, all_local_weights, all_induction_weights
|
| 2880 |
+
]) # [7, n_heads]
|
| 2881 |
head_metrics_cpu = head_metrics_gpu.cpu().tolist() # one sync point
|
| 2882 |
max_weights_list = head_metrics_cpu[0]
|
| 2883 |
entropies_list = head_metrics_cpu[1]
|
| 2884 |
avg_entropies_list = head_metrics_cpu[2]
|
| 2885 |
prev_token_list = head_metrics_cpu[3]
|
| 2886 |
+
sink_weights_list = head_metrics_cpu[4]
|
| 2887 |
+
local_weights_list = head_metrics_cpu[5]
|
| 2888 |
+
induction_weights_list = head_metrics_cpu[6]
|
| 2889 |
|
| 2890 |
# Bulk transfer attention matrices to CPU: one .cpu() for entire layer
|
| 2891 |
layer_attn_cpu = layer_attn.cpu().float().numpy() # [n_heads, seq_len, seq_len]
|
|
|
|
| 2900 |
ent = entropies_list[head_idx]
|
| 2901 |
avg_ent = avg_entropies_list[head_idx]
|
| 2902 |
ptw = prev_token_list[head_idx]
|
| 2903 |
+
skw = sink_weights_list[head_idx]
|
| 2904 |
+
lcw = local_weights_list[head_idx]
|
| 2905 |
+
idw = induction_weights_list[head_idx]
|
| 2906 |
|
| 2907 |
# Sanitize
|
| 2908 |
mw = 0.0 if math.isnan(mw) or math.isinf(mw) else mw
|
| 2909 |
ent = 0.0 if math.isnan(ent) or math.isinf(ent) else ent
|
| 2910 |
avg_ent = 0.0 if math.isnan(avg_ent) or math.isinf(avg_ent) else avg_ent
|
| 2911 |
|
| 2912 |
+
# Data-driven head pattern classification (priority order)
|
| 2913 |
pattern_type = None
|
| 2914 |
confidence = 0.0
|
| 2915 |
+
# 1. Attention sink: >50% weight on positions 0-2
|
| 2916 |
+
if skw > 0.5:
|
| 2917 |
+
pattern_type = "attention_sink"
|
| 2918 |
+
confidence = skw
|
| 2919 |
+
# 2. Previous token: sharp focus on immediate predecessor
|
| 2920 |
+
elif mw > 0.9 and ptw > 0.85:
|
| 2921 |
+
pattern_type = "previous_token"
|
| 2922 |
+
confidence = ptw
|
| 2923 |
+
# 3. Local: >80% weight within 5 positions of query
|
| 2924 |
+
elif seq_len_attn > 5 and lcw > 0.8:
|
| 2925 |
+
pattern_type = "local"
|
| 2926 |
+
confidence = lcw
|
| 2927 |
+
# 4. Induction: attends to positions following previous occurrences of current token
|
| 2928 |
+
elif step > 0 and idw > 0.3:
|
| 2929 |
pattern_type = "induction"
|
| 2930 |
+
confidence = min(1.0, idw)
|
| 2931 |
+
# 5. Positional: low entropy, focused attention
|
| 2932 |
elif ent < 1.0:
|
| 2933 |
pattern_type = "positional"
|
| 2934 |
confidence = 1.0 - ent
|
| 2935 |
+
# 6. Semantic: broad attention (fallback)
|
| 2936 |
+
elif ent >= 1.0:
|
| 2937 |
pattern_type = "semantic"
|
| 2938 |
+
confidence = min(1.0, 0.5)
|
|
|
|
|
|
|
|
|
|
| 2939 |
confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
|
| 2940 |
|
| 2941 |
attention_matrix = layer_attn_cpu[head_idx]
|
|
|
|
| 2966 |
|
| 2967 |
critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
|
| 2968 |
|
| 2969 |
+
# Layer-level pattern: majority vote of head patterns, weighted by confidence
|
| 2970 |
+
pattern_votes = {}
|
| 2971 |
+
for h in critical_heads:
|
| 2972 |
+
if h["pattern"] and h["pattern"]["type"]:
|
| 2973 |
+
pt = h["pattern"]["type"]
|
| 2974 |
+
pc = h["pattern"]["confidence"]
|
| 2975 |
+
pattern_votes[pt] = pattern_votes.get(pt, 0.0) + pc
|
| 2976 |
layer_pattern = None
|
| 2977 |
+
if pattern_votes:
|
| 2978 |
+
best_type = max(pattern_votes, key=pattern_votes.get)
|
| 2979 |
+
total_conf = sum(pattern_votes.values())
|
| 2980 |
+
layer_pattern = {
|
| 2981 |
+
"type": best_type,
|
| 2982 |
+
"confidence": round(pattern_votes[best_type] / total_conf, 3) if total_conf > 0 else 0.0
|
| 2983 |
+
}
|
| 2984 |
+
|
| 2985 |
+
layer_entry = {
|
|
|
|
|
|
|
| 2986 |
"layer_idx": layer_idx,
|
| 2987 |
"pattern": layer_pattern,
|
| 2988 |
"critical_heads": critical_heads,
|
|
|
|
| 2990 |
"activation_entropy": activation_entropy,
|
| 2991 |
"hidden_state_norm": hidden_state_norm,
|
| 2992 |
"delta_norm": delta_norm
|
| 2993 |
+
}
|
| 2994 |
+
# Phase 4: Attention and MLP output norms
|
| 2995 |
+
if layer_idx in attn_output_norms:
|
| 2996 |
+
layer_entry["attn_output_norm"] = attn_output_norms[layer_idx]
|
| 2997 |
+
if layer_idx in mlp_output_norms:
|
| 2998 |
+
layer_entry["mlp_output_norm"] = mlp_output_norms[layer_idx]
|
| 2999 |
+
|
| 3000 |
+
# Phase 5: Logit lens at sampled layers (every 8th layer)
|
| 3001 |
+
logit_lens_stride = max(1, n_layers // 5)
|
| 3002 |
+
if layer_idx % logit_lens_stride == 0 or layer_idx == n_layers - 1:
|
| 3003 |
+
try:
|
| 3004 |
+
hidden_for_lens = current_hidden[-1].unsqueeze(0) # [1, hidden_dim]
|
| 3005 |
+
# Apply final layer norm then project through lm_head
|
| 3006 |
+
if hasattr(manager.model, 'model') and hasattr(manager.model.model, 'norm'):
|
| 3007 |
+
normed = manager.model.model.norm(hidden_for_lens)
|
| 3008 |
+
lens_logits = manager.model.lm_head(normed)[0] # [vocab_size]
|
| 3009 |
+
elif hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'ln_f'):
|
| 3010 |
+
normed = manager.model.transformer.ln_f(hidden_for_lens)
|
| 3011 |
+
lens_logits = manager.model.lm_head(normed)[0]
|
| 3012 |
+
else:
|
| 3013 |
+
lens_logits = None
|
| 3014 |
+
|
| 3015 |
+
if lens_logits is not None:
|
| 3016 |
+
lens_probs = torch.softmax(lens_logits, dim=-1)
|
| 3017 |
+
top_probs, top_ids = torch.topk(lens_probs, k=5)
|
| 3018 |
+
top_probs_list = top_probs.cpu().tolist()
|
| 3019 |
+
top_ids_list = top_ids.cpu().tolist()
|
| 3020 |
+
lens_entries = []
|
| 3021 |
+
for tp, tid in zip(top_probs_list, top_ids_list):
|
| 3022 |
+
lens_entries.append({
|
| 3023 |
+
"token": manager.tokenizer.decode([tid], skip_special_tokens=False),
|
| 3024 |
+
"probability": tp
|
| 3025 |
+
})
|
| 3026 |
+
layer_entry["logit_lens_top"] = lens_entries
|
| 3027 |
+
except Exception as lens_err:
|
| 3028 |
+
logger.debug(f"Logit lens error at layer {layer_idx}: {lens_err}")
|
| 3029 |
+
|
| 3030 |
+
layer_data_this_token.append(layer_entry)
|
| 3031 |
|
| 3032 |
layer_data_by_token.append(layer_data_this_token)
|
| 3033 |
|