Spaces:
Paused
Paused
gary-boon Claude Opus 4.6 commited on
Commit ·
d8d197a
1
Parent(s): 9978aec
Change default model to Devstral and optimise attention extraction
Browse filesSwitch the default model fallback from codegen-350m to devstral-small,
matching the instruction-tuned model used for PhD research. Also includes
batched tensor operations for attention extraction (reduces GPU→CPU sync
points from ~4000 to ~40 per token).
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- backend/model_service.py +111 -61
backend/model_service.py
CHANGED
|
@@ -294,7 +294,7 @@ class ModelManager:
|
|
| 294 |
self.trace_buffer: List[TraceData] = []
|
| 295 |
|
| 296 |
# Read configuration from environment variables
|
| 297 |
-
self.model_id = os.environ.get("DEFAULT_MODEL", "
|
| 298 |
self.max_context = int(os.environ.get("MAX_CONTEXT", "8192"))
|
| 299 |
self.batch_size = int(os.environ.get("BATCH_SIZE", "1"))
|
| 300 |
|
|
@@ -2676,107 +2676,158 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2676 |
"sampling": sampling_metadata
|
| 2677 |
})
|
| 2678 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2679 |
# === STAGE 3: EXTRACTING (per layer within each token) ===
|
|
|
|
|
|
|
| 2680 |
layer_data_this_token = []
|
|
|
|
| 2681 |
|
| 2682 |
-
for layer_idx in range(
|
| 2683 |
# Emit extraction progress (within generating stage for combined progress)
|
| 2684 |
if step == max_tokens - 1: # Only emit detailed layer progress on last token
|
| 2685 |
-
layer_progress = (layer_idx /
|
| 2686 |
-
overall_progress = 30 + (layer_idx /
|
| 2687 |
yield sse_event('extracting', stage=3, totalStages=5, progress=overall_progress,
|
| 2688 |
stageProgress=layer_progress,
|
| 2689 |
-
detail=f'Processing layer {layer_idx + 1}/{
|
| 2690 |
-
metadata={'layerIndex': layer_idx, 'totalLayers':
|
| 2691 |
'headsPerLayer': n_heads, 'stepIndex': step, 'totalSteps': max_tokens})
|
| 2692 |
-
if layer_idx % 5 == 0:
|
| 2693 |
await asyncio.sleep(0)
|
| 2694 |
|
| 2695 |
-
|
|
|
|
| 2696 |
current_hidden = outputs.hidden_states[layer_idx + 1]
|
| 2697 |
if current_hidden.dim() == 3:
|
| 2698 |
current_hidden = current_hidden[0]
|
| 2699 |
|
|
|
|
|
|
|
| 2700 |
if layer_idx > 0:
|
| 2701 |
prev_hidden = outputs.hidden_states[layer_idx]
|
| 2702 |
if prev_hidden.dim() == 3:
|
| 2703 |
prev_hidden = prev_hidden[0]
|
| 2704 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2705 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2706 |
delta_norm = None
|
| 2707 |
|
| 2708 |
-
|
| 2709 |
-
last_token_hidden = current_hidden[-1]
|
| 2710 |
-
activation_entropy = torch.std(last_token_hidden).item()
|
| 2711 |
-
hidden_state_norm = torch.norm(last_token_hidden).item()
|
| 2712 |
-
|
| 2713 |
-
# Sanitize
|
| 2714 |
activation_magnitude = 0.0 if math.isnan(activation_magnitude) or math.isinf(activation_magnitude) else activation_magnitude
|
| 2715 |
activation_entropy = 0.0 if math.isnan(activation_entropy) or math.isinf(activation_entropy) else activation_entropy
|
| 2716 |
hidden_state_norm = 0.0 if math.isnan(hidden_state_norm) or math.isinf(hidden_state_norm) else hidden_state_norm
|
| 2717 |
if delta_norm is not None:
|
| 2718 |
delta_norm = 0.0 if math.isnan(delta_norm) or math.isinf(delta_norm) else delta_norm
|
| 2719 |
|
| 2720 |
-
#
|
| 2721 |
-
|
| 2722 |
-
|
| 2723 |
-
|
| 2724 |
-
|
| 2725 |
-
|
|
|
|
|
|
|
| 2726 |
|
| 2727 |
-
|
| 2728 |
-
|
| 2729 |
-
# This produces values in [0,1] with better spread across heads
|
| 2730 |
-
head_attn = layer_attn[head_idx] # [q_len, k_len]
|
| 2731 |
-
q_len = head_attn.shape[0]
|
| 2732 |
|
| 2733 |
-
|
| 2734 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2735 |
|
| 2736 |
-
|
| 2737 |
-
|
| 2738 |
-
max_entropies = torch.log(positions + 1e-10)
|
| 2739 |
-
normalized_entropies = token_entropies / (max_entropies + 1e-10) # [0, 1] range
|
| 2740 |
|
| 2741 |
-
|
| 2742 |
-
|
| 2743 |
-
avg_entropy = normalized_entropies[start_idx:].mean().item() if start_idx < q_len else normalized_entropies.mean().item()
|
| 2744 |
|
| 2745 |
-
|
| 2746 |
-
|
| 2747 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2748 |
|
| 2749 |
pattern_type = None
|
| 2750 |
confidence = 0.0
|
| 2751 |
-
|
| 2752 |
-
if step > 0 and max_weight > 0.8:
|
| 2753 |
pattern_type = "induction"
|
| 2754 |
-
confidence =
|
| 2755 |
-
elif
|
| 2756 |
pattern_type = "positional"
|
| 2757 |
-
confidence = 1.0 -
|
| 2758 |
-
elif 1.0 <=
|
| 2759 |
pattern_type = "semantic"
|
| 2760 |
-
confidence = min(1.0,
|
| 2761 |
-
elif
|
| 2762 |
pattern_type = "previous_token"
|
| 2763 |
-
confidence =
|
| 2764 |
-
|
| 2765 |
confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
|
| 2766 |
|
| 2767 |
-
|
| 2768 |
-
# ~7x more memory efficient: 4 bytes/float vs 28 bytes/float
|
| 2769 |
-
attention_matrix = layer_attn[head_idx].cpu().float().numpy()
|
| 2770 |
|
| 2771 |
q_matrix = None
|
| 2772 |
k_matrix = None
|
| 2773 |
v_matrix = None
|
| 2774 |
-
if
|
| 2775 |
-
q_matrix =
|
| 2776 |
-
k_matrix =
|
| 2777 |
-
v_matrix =
|
| 2778 |
|
| 2779 |
-
# Store matrices in cache for lazy loading (reduces response size)
|
| 2780 |
matrix_cache.store(request_id, step, layer_idx, head_idx, {
|
| 2781 |
"attention_weights": attention_matrix,
|
| 2782 |
"q_matrix": q_matrix,
|
|
@@ -2784,13 +2835,12 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2784 |
"v_matrix": v_matrix
|
| 2785 |
})
|
| 2786 |
|
| 2787 |
-
# Return only metadata (matrices fetched on-demand via /matrix endpoint)
|
| 2788 |
critical_heads.append({
|
| 2789 |
"head_idx": head_idx,
|
| 2790 |
-
"entropy":
|
| 2791 |
-
"avg_entropy":
|
| 2792 |
-
"max_weight":
|
| 2793 |
-
"has_matrices": attention_matrix is not None,
|
| 2794 |
"pattern": {"type": pattern_type, "confidence": confidence} if pattern_type else None
|
| 2795 |
})
|
| 2796 |
|
|
|
|
| 294 |
self.trace_buffer: List[TraceData] = []
|
| 295 |
|
| 296 |
# Read configuration from environment variables
|
| 297 |
+
self.model_id = os.environ.get("DEFAULT_MODEL", "devstral-small")
|
| 298 |
self.max_context = int(os.environ.get("MAX_CONTEXT", "8192"))
|
| 299 |
self.batch_size = int(os.environ.get("BATCH_SIZE", "1"))
|
| 300 |
|
|
|
|
| 2676 |
"sampling": sampling_metadata
|
| 2677 |
})
|
| 2678 |
|
| 2679 |
+
# Emit generated token immediately so clients can show code progressively
|
| 2680 |
+
yield sse_event('generated_token', stage=2, totalStages=5,
|
| 2681 |
+
progress=10 + ((step + 1) / max_tokens) * 20,
|
| 2682 |
+
stageProgress=((step + 1) / max_tokens) * 100,
|
| 2683 |
+
detail=f'Generated token {step + 1}/{max_tokens}',
|
| 2684 |
+
metadata={
|
| 2685 |
+
'stepIndex': step,
|
| 2686 |
+
'totalSteps': max_tokens,
|
| 2687 |
+
'token': next_token_text,
|
| 2688 |
+
'tokenId': next_token_id,
|
| 2689 |
+
'generatedTokens': generated_tokens.copy(),
|
| 2690 |
+
})
|
| 2691 |
+
await asyncio.sleep(0)
|
| 2692 |
+
|
| 2693 |
# === STAGE 3: EXTRACTING (per layer within each token) ===
|
| 2694 |
+
# Optimised: batched tensor ops per layer instead of per-head Python loops
|
| 2695 |
+
# Reduces GPU→CPU sync points from ~4000 to ~40 per token
|
| 2696 |
layer_data_this_token = []
|
| 2697 |
+
n_total_layers = len(outputs.attentions)
|
| 2698 |
|
| 2699 |
+
for layer_idx in range(n_total_layers):
|
| 2700 |
# Emit extraction progress (within generating stage for combined progress)
|
| 2701 |
if step == max_tokens - 1: # Only emit detailed layer progress on last token
|
| 2702 |
+
layer_progress = (layer_idx / n_total_layers) * 100
|
| 2703 |
+
overall_progress = 30 + (layer_idx / n_total_layers) * 40 # 30-70%
|
| 2704 |
yield sse_event('extracting', stage=3, totalStages=5, progress=overall_progress,
|
| 2705 |
stageProgress=layer_progress,
|
| 2706 |
+
detail=f'Processing layer {layer_idx + 1}/{n_total_layers}',
|
| 2707 |
+
metadata={'layerIndex': layer_idx, 'totalLayers': n_total_layers,
|
| 2708 |
'headsPerLayer': n_heads, 'stepIndex': step, 'totalSteps': max_tokens})
|
| 2709 |
+
if layer_idx % 5 == 0:
|
| 2710 |
await asyncio.sleep(0)
|
| 2711 |
|
| 2712 |
+
# --- Per-layer: bulk GPU ops then single CPU transfer ---
|
| 2713 |
+
layer_attn = outputs.attentions[layer_idx][0] # [n_heads, seq_len, seq_len]
|
| 2714 |
current_hidden = outputs.hidden_states[layer_idx + 1]
|
| 2715 |
if current_hidden.dim() == 3:
|
| 2716 |
current_hidden = current_hidden[0]
|
| 2717 |
|
| 2718 |
+
# Hidden state metrics — 4 values, one .cpu() call
|
| 2719 |
+
last_token_hidden = current_hidden[-1]
|
| 2720 |
if layer_idx > 0:
|
| 2721 |
prev_hidden = outputs.hidden_states[layer_idx]
|
| 2722 |
if prev_hidden.dim() == 3:
|
| 2723 |
prev_hidden = prev_hidden[0]
|
| 2724 |
+
hidden_metrics = torch.stack([
|
| 2725 |
+
torch.norm(current_hidden - prev_hidden),
|
| 2726 |
+
torch.norm(current_hidden),
|
| 2727 |
+
torch.std(last_token_hidden),
|
| 2728 |
+
torch.norm(last_token_hidden),
|
| 2729 |
+
]).cpu().tolist()
|
| 2730 |
+
delta_norm, activation_magnitude, activation_entropy, hidden_state_norm = hidden_metrics
|
| 2731 |
else:
|
| 2732 |
+
hidden_metrics = torch.stack([
|
| 2733 |
+
torch.norm(current_hidden),
|
| 2734 |
+
torch.std(last_token_hidden),
|
| 2735 |
+
torch.norm(last_token_hidden),
|
| 2736 |
+
]).cpu().tolist()
|
| 2737 |
+
activation_magnitude, activation_entropy, hidden_state_norm = hidden_metrics
|
| 2738 |
delta_norm = None
|
| 2739 |
|
| 2740 |
+
# Sanitize hidden state metrics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2741 |
activation_magnitude = 0.0 if math.isnan(activation_magnitude) or math.isinf(activation_magnitude) else activation_magnitude
|
| 2742 |
activation_entropy = 0.0 if math.isnan(activation_entropy) or math.isinf(activation_entropy) else activation_entropy
|
| 2743 |
hidden_state_norm = 0.0 if math.isnan(hidden_state_norm) or math.isinf(hidden_state_norm) else hidden_state_norm
|
| 2744 |
if delta_norm is not None:
|
| 2745 |
delta_norm = 0.0 if math.isnan(delta_norm) or math.isinf(delta_norm) else delta_norm
|
| 2746 |
|
| 2747 |
+
# --- Batched head processing: all heads at once on GPU ---
|
| 2748 |
+
num_heads_layer = layer_attn.shape[0]
|
| 2749 |
+
|
| 2750 |
+
# Last-row attention weights for all heads: [n_heads, seq_len]
|
| 2751 |
+
all_last_row = layer_attn[:, -1, :]
|
| 2752 |
+
|
| 2753 |
+
# Max weight per head: [n_heads] — single GPU op
|
| 2754 |
+
all_max_weights = all_last_row.max(dim=-1).values
|
| 2755 |
|
| 2756 |
+
# Entropy of last-row per head: [n_heads] — single GPU op
|
| 2757 |
+
all_entropies = -(all_last_row * torch.log(all_last_row + 1e-10)).sum(dim=-1)
|
|
|
|
|
|
|
|
|
|
| 2758 |
|
| 2759 |
+
# Normalized average entropy per head (latter half of query positions)
|
| 2760 |
+
# layer_attn: [n_heads, q_len, k_len]
|
| 2761 |
+
q_len = layer_attn.shape[1]
|
| 2762 |
+
# Raw entropy per query position per head: [n_heads, q_len]
|
| 2763 |
+
all_token_entropies = -(layer_attn * torch.log(layer_attn + 1e-10)).sum(dim=-1)
|
| 2764 |
+
# Normalize by log(position): [q_len]
|
| 2765 |
+
positions = torch.arange(1, q_len + 1, device=layer_attn.device, dtype=layer_attn.dtype)
|
| 2766 |
+
max_ents = torch.log(positions + 1e-10) # [q_len]
|
| 2767 |
+
all_normalized = all_token_entropies / (max_ents.unsqueeze(0) + 1e-10) # [n_heads, q_len]
|
| 2768 |
+
# Average over latter half: [n_heads]
|
| 2769 |
+
start_idx = q_len // 2
|
| 2770 |
+
if start_idx < q_len:
|
| 2771 |
+
all_avg_entropies = all_normalized[:, start_idx:].mean(dim=-1)
|
| 2772 |
+
else:
|
| 2773 |
+
all_avg_entropies = all_normalized.mean(dim=-1)
|
| 2774 |
+
|
| 2775 |
+
# Previous-token weights for pattern detection: [n_heads]
|
| 2776 |
+
all_prev_token_weights = all_last_row[:, -2] if all_last_row.shape[1] >= 2 else torch.zeros(num_heads_layer, device=layer_attn.device)
|
| 2777 |
+
|
| 2778 |
+
# Single bulk transfer: all head metrics to CPU
|
| 2779 |
+
head_metrics_gpu = torch.stack([all_max_weights, all_entropies, all_avg_entropies, all_prev_token_weights]) # [4, n_heads]
|
| 2780 |
+
head_metrics_cpu = head_metrics_gpu.cpu().tolist() # one sync point
|
| 2781 |
+
max_weights_list = head_metrics_cpu[0]
|
| 2782 |
+
entropies_list = head_metrics_cpu[1]
|
| 2783 |
+
avg_entropies_list = head_metrics_cpu[2]
|
| 2784 |
+
prev_token_list = head_metrics_cpu[3]
|
| 2785 |
|
| 2786 |
+
# Bulk transfer attention matrices to CPU: one .cpu() for entire layer
|
| 2787 |
+
layer_attn_cpu = layer_attn.cpu().float().numpy() # [n_heads, seq_len, seq_len]
|
|
|
|
|
|
|
| 2788 |
|
| 2789 |
+
# QKV matrices (already on CPU from hooks)
|
| 2790 |
+
qkv_layer = qkv_captures.get(layer_idx)
|
|
|
|
| 2791 |
|
| 2792 |
+
# Build per-head metadata from CPU-side data (no more GPU calls)
|
| 2793 |
+
critical_heads = []
|
| 2794 |
+
for head_idx in range(num_heads_layer):
|
| 2795 |
+
mw = max_weights_list[head_idx]
|
| 2796 |
+
ent = entropies_list[head_idx]
|
| 2797 |
+
avg_ent = avg_entropies_list[head_idx]
|
| 2798 |
+
ptw = prev_token_list[head_idx]
|
| 2799 |
+
|
| 2800 |
+
# Sanitize
|
| 2801 |
+
mw = 0.0 if math.isnan(mw) or math.isinf(mw) else mw
|
| 2802 |
+
ent = 0.0 if math.isnan(ent) or math.isinf(ent) else ent
|
| 2803 |
+
avg_ent = 0.0 if math.isnan(avg_ent) or math.isinf(avg_ent) else avg_ent
|
| 2804 |
|
| 2805 |
pattern_type = None
|
| 2806 |
confidence = 0.0
|
| 2807 |
+
if step > 0 and mw > 0.8:
|
|
|
|
| 2808 |
pattern_type = "induction"
|
| 2809 |
+
confidence = mw
|
| 2810 |
+
elif ent < 1.0:
|
| 2811 |
pattern_type = "positional"
|
| 2812 |
+
confidence = 1.0 - ent
|
| 2813 |
+
elif 1.0 <= ent < 2.5:
|
| 2814 |
pattern_type = "semantic"
|
| 2815 |
+
confidence = min(1.0, ent / 2.5)
|
| 2816 |
+
elif mw > 0.9 and ptw > 0.85:
|
| 2817 |
pattern_type = "previous_token"
|
| 2818 |
+
confidence = ptw
|
|
|
|
| 2819 |
confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
|
| 2820 |
|
| 2821 |
+
attention_matrix = layer_attn_cpu[head_idx]
|
|
|
|
|
|
|
| 2822 |
|
| 2823 |
q_matrix = None
|
| 2824 |
k_matrix = None
|
| 2825 |
v_matrix = None
|
| 2826 |
+
if qkv_layer is not None:
|
| 2827 |
+
q_matrix = qkv_layer['q'][:, head_idx, :].float().numpy()
|
| 2828 |
+
k_matrix = qkv_layer['k'][:, head_idx, :].float().numpy()
|
| 2829 |
+
v_matrix = qkv_layer['v'][:, head_idx, :].float().numpy()
|
| 2830 |
|
|
|
|
| 2831 |
matrix_cache.store(request_id, step, layer_idx, head_idx, {
|
| 2832 |
"attention_weights": attention_matrix,
|
| 2833 |
"q_matrix": q_matrix,
|
|
|
|
| 2835 |
"v_matrix": v_matrix
|
| 2836 |
})
|
| 2837 |
|
|
|
|
| 2838 |
critical_heads.append({
|
| 2839 |
"head_idx": head_idx,
|
| 2840 |
+
"entropy": ent,
|
| 2841 |
+
"avg_entropy": avg_ent,
|
| 2842 |
+
"max_weight": mw,
|
| 2843 |
+
"has_matrices": attention_matrix is not None,
|
| 2844 |
"pattern": {"type": pattern_type, "confidence": confidence} if pattern_type else None
|
| 2845 |
})
|
| 2846 |
|