gary-boon Claude Opus 4.6 commited on
Commit
d8d197a
·
1 Parent(s): 9978aec

Change default model to Devstral and optimise attention extraction

Browse files

Switch the default model fallback from codegen-350m to devstral-small,
matching the instruction-tuned model used for PhD research. Also includes
batched tensor operations for attention extraction (reduces GPU→CPU sync
points from ~4000 to ~40 per token).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. backend/model_service.py +111 -61
backend/model_service.py CHANGED
@@ -294,7 +294,7 @@ class ModelManager:
294
  self.trace_buffer: List[TraceData] = []
295
 
296
  # Read configuration from environment variables
297
- self.model_id = os.environ.get("DEFAULT_MODEL", "codegen-350m")
298
  self.max_context = int(os.environ.get("MAX_CONTEXT", "8192"))
299
  self.batch_size = int(os.environ.get("BATCH_SIZE", "1"))
300
 
@@ -2676,107 +2676,158 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2676
  "sampling": sampling_metadata
2677
  })
2678
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2679
  # === STAGE 3: EXTRACTING (per layer within each token) ===
 
 
2680
  layer_data_this_token = []
 
2681
 
2682
- for layer_idx in range(len(outputs.attentions)):
2683
  # Emit extraction progress (within generating stage for combined progress)
2684
  if step == max_tokens - 1: # Only emit detailed layer progress on last token
2685
- layer_progress = (layer_idx / len(outputs.attentions)) * 100
2686
- overall_progress = 30 + (layer_idx / len(outputs.attentions)) * 40 # 30-70%
2687
  yield sse_event('extracting', stage=3, totalStages=5, progress=overall_progress,
2688
  stageProgress=layer_progress,
2689
- detail=f'Processing layer {layer_idx + 1}/{len(outputs.attentions)}',
2690
- metadata={'layerIndex': layer_idx, 'totalLayers': len(outputs.attentions),
2691
  'headsPerLayer': n_heads, 'stepIndex': step, 'totalSteps': max_tokens})
2692
- if layer_idx % 5 == 0: # Yield every 5 layers to avoid too many events
2693
  await asyncio.sleep(0)
2694
 
2695
- layer_attn = outputs.attentions[layer_idx][0]
 
2696
  current_hidden = outputs.hidden_states[layer_idx + 1]
2697
  if current_hidden.dim() == 3:
2698
  current_hidden = current_hidden[0]
2699
 
 
 
2700
  if layer_idx > 0:
2701
  prev_hidden = outputs.hidden_states[layer_idx]
2702
  if prev_hidden.dim() == 3:
2703
  prev_hidden = prev_hidden[0]
2704
- delta_norm = torch.norm(current_hidden - prev_hidden).item()
 
 
 
 
 
 
2705
  else:
 
 
 
 
 
 
2706
  delta_norm = None
2707
 
2708
- activation_magnitude = torch.norm(current_hidden).item()
2709
- last_token_hidden = current_hidden[-1]
2710
- activation_entropy = torch.std(last_token_hidden).item()
2711
- hidden_state_norm = torch.norm(last_token_hidden).item()
2712
-
2713
- # Sanitize
2714
  activation_magnitude = 0.0 if math.isnan(activation_magnitude) or math.isinf(activation_magnitude) else activation_magnitude
2715
  activation_entropy = 0.0 if math.isnan(activation_entropy) or math.isinf(activation_entropy) else activation_entropy
2716
  hidden_state_norm = 0.0 if math.isnan(hidden_state_norm) or math.isinf(hidden_state_norm) else hidden_state_norm
2717
  if delta_norm is not None:
2718
  delta_norm = 0.0 if math.isnan(delta_norm) or math.isinf(delta_norm) else delta_norm
2719
 
2720
- # Process heads
2721
- critical_heads = []
2722
- for head_idx in range(layer_attn.shape[0]):
2723
- head_weights = layer_attn[head_idx, -1, :]
2724
- max_weight = head_weights.max().item()
2725
- entropy = -(head_weights * torch.log(head_weights + 1e-10)).sum().item()
 
 
2726
 
2727
- # Normalized attention entropy averaged over latter half of query positions
2728
- # Normalized by log(k_i) where k_i = number of keys position i can attend to
2729
- # This produces values in [0,1] with better spread across heads
2730
- head_attn = layer_attn[head_idx] # [q_len, k_len]
2731
- q_len = head_attn.shape[0]
2732
 
2733
- # Compute raw entropy per query position
2734
- token_entropies = -(head_attn * torch.log(head_attn + 1e-10)).sum(dim=-1) # [q_len]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2735
 
2736
- # Normalize by max possible entropy: log(k_i) where k_i = i + 1 (causal mask)
2737
- positions = torch.arange(1, q_len + 1, device=head_attn.device, dtype=head_attn.dtype)
2738
- max_entropies = torch.log(positions + 1e-10)
2739
- normalized_entropies = token_entropies / (max_entropies + 1e-10) # [0, 1] range
2740
 
2741
- # Average over latter half of positions
2742
- start_idx = q_len // 2
2743
- avg_entropy = normalized_entropies[start_idx:].mean().item() if start_idx < q_len else normalized_entropies.mean().item()
2744
 
2745
- max_weight = 0.0 if math.isnan(max_weight) or math.isinf(max_weight) else max_weight
2746
- entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy
2747
- avg_entropy = 0.0 if math.isnan(avg_entropy) or math.isinf(avg_entropy) else avg_entropy
 
 
 
 
 
 
 
 
 
2748
 
2749
  pattern_type = None
2750
  confidence = 0.0
2751
-
2752
- if step > 0 and max_weight > 0.8:
2753
  pattern_type = "induction"
2754
- confidence = max_weight
2755
- elif entropy < 1.0:
2756
  pattern_type = "positional"
2757
- confidence = 1.0 - entropy
2758
- elif 1.0 <= entropy < 2.5:
2759
  pattern_type = "semantic"
2760
- confidence = min(1.0, entropy / 2.5)
2761
- elif max_weight > 0.9 and head_weights[-2].item() > 0.85:
2762
  pattern_type = "previous_token"
2763
- confidence = head_weights[-2].item()
2764
-
2765
  confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
2766
 
2767
- # Store as numpy arrays (not Python lists) to save memory
2768
- # ~7x more memory efficient: 4 bytes/float vs 28 bytes/float
2769
- attention_matrix = layer_attn[head_idx].cpu().float().numpy()
2770
 
2771
  q_matrix = None
2772
  k_matrix = None
2773
  v_matrix = None
2774
- if layer_idx in qkv_captures:
2775
- q_matrix = qkv_captures[layer_idx]['q'][:, head_idx, :].float().numpy()
2776
- k_matrix = qkv_captures[layer_idx]['k'][:, head_idx, :].float().numpy()
2777
- v_matrix = qkv_captures[layer_idx]['v'][:, head_idx, :].float().numpy()
2778
 
2779
- # Store matrices in cache for lazy loading (reduces response size)
2780
  matrix_cache.store(request_id, step, layer_idx, head_idx, {
2781
  "attention_weights": attention_matrix,
2782
  "q_matrix": q_matrix,
@@ -2784,13 +2835,12 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2784
  "v_matrix": v_matrix
2785
  })
2786
 
2787
- # Return only metadata (matrices fetched on-demand via /matrix endpoint)
2788
  critical_heads.append({
2789
  "head_idx": head_idx,
2790
- "entropy": entropy,
2791
- "avg_entropy": avg_entropy, # Averaged over all query positions
2792
- "max_weight": max_weight,
2793
- "has_matrices": attention_matrix is not None, # Flag for frontend
2794
  "pattern": {"type": pattern_type, "confidence": confidence} if pattern_type else None
2795
  })
2796
 
 
294
  self.trace_buffer: List[TraceData] = []
295
 
296
  # Read configuration from environment variables
297
+ self.model_id = os.environ.get("DEFAULT_MODEL", "devstral-small")
298
  self.max_context = int(os.environ.get("MAX_CONTEXT", "8192"))
299
  self.batch_size = int(os.environ.get("BATCH_SIZE", "1"))
300
 
 
2676
  "sampling": sampling_metadata
2677
  })
2678
 
2679
+ # Emit generated token immediately so clients can show code progressively
2680
+ yield sse_event('generated_token', stage=2, totalStages=5,
2681
+ progress=10 + ((step + 1) / max_tokens) * 20,
2682
+ stageProgress=((step + 1) / max_tokens) * 100,
2683
+ detail=f'Generated token {step + 1}/{max_tokens}',
2684
+ metadata={
2685
+ 'stepIndex': step,
2686
+ 'totalSteps': max_tokens,
2687
+ 'token': next_token_text,
2688
+ 'tokenId': next_token_id,
2689
+ 'generatedTokens': generated_tokens.copy(),
2690
+ })
2691
+ await asyncio.sleep(0)
2692
+
2693
  # === STAGE 3: EXTRACTING (per layer within each token) ===
2694
+ # Optimised: batched tensor ops per layer instead of per-head Python loops
2695
+ # Reduces GPU→CPU sync points from ~4000 to ~40 per token
2696
  layer_data_this_token = []
2697
+ n_total_layers = len(outputs.attentions)
2698
 
2699
+ for layer_idx in range(n_total_layers):
2700
  # Emit extraction progress (within generating stage for combined progress)
2701
  if step == max_tokens - 1: # Only emit detailed layer progress on last token
2702
+ layer_progress = (layer_idx / n_total_layers) * 100
2703
+ overall_progress = 30 + (layer_idx / n_total_layers) * 40 # 30-70%
2704
  yield sse_event('extracting', stage=3, totalStages=5, progress=overall_progress,
2705
  stageProgress=layer_progress,
2706
+ detail=f'Processing layer {layer_idx + 1}/{n_total_layers}',
2707
+ metadata={'layerIndex': layer_idx, 'totalLayers': n_total_layers,
2708
  'headsPerLayer': n_heads, 'stepIndex': step, 'totalSteps': max_tokens})
2709
+ if layer_idx % 5 == 0:
2710
  await asyncio.sleep(0)
2711
 
2712
+ # --- Per-layer: bulk GPU ops then single CPU transfer ---
2713
+ layer_attn = outputs.attentions[layer_idx][0] # [n_heads, seq_len, seq_len]
2714
  current_hidden = outputs.hidden_states[layer_idx + 1]
2715
  if current_hidden.dim() == 3:
2716
  current_hidden = current_hidden[0]
2717
 
2718
+ # Hidden state metrics — 4 values, one .cpu() call
2719
+ last_token_hidden = current_hidden[-1]
2720
  if layer_idx > 0:
2721
  prev_hidden = outputs.hidden_states[layer_idx]
2722
  if prev_hidden.dim() == 3:
2723
  prev_hidden = prev_hidden[0]
2724
+ hidden_metrics = torch.stack([
2725
+ torch.norm(current_hidden - prev_hidden),
2726
+ torch.norm(current_hidden),
2727
+ torch.std(last_token_hidden),
2728
+ torch.norm(last_token_hidden),
2729
+ ]).cpu().tolist()
2730
+ delta_norm, activation_magnitude, activation_entropy, hidden_state_norm = hidden_metrics
2731
  else:
2732
+ hidden_metrics = torch.stack([
2733
+ torch.norm(current_hidden),
2734
+ torch.std(last_token_hidden),
2735
+ torch.norm(last_token_hidden),
2736
+ ]).cpu().tolist()
2737
+ activation_magnitude, activation_entropy, hidden_state_norm = hidden_metrics
2738
  delta_norm = None
2739
 
2740
+ # Sanitize hidden state metrics
 
 
 
 
 
2741
  activation_magnitude = 0.0 if math.isnan(activation_magnitude) or math.isinf(activation_magnitude) else activation_magnitude
2742
  activation_entropy = 0.0 if math.isnan(activation_entropy) or math.isinf(activation_entropy) else activation_entropy
2743
  hidden_state_norm = 0.0 if math.isnan(hidden_state_norm) or math.isinf(hidden_state_norm) else hidden_state_norm
2744
  if delta_norm is not None:
2745
  delta_norm = 0.0 if math.isnan(delta_norm) or math.isinf(delta_norm) else delta_norm
2746
 
2747
+ # --- Batched head processing: all heads at once on GPU ---
2748
+ num_heads_layer = layer_attn.shape[0]
2749
+
2750
+ # Last-row attention weights for all heads: [n_heads, seq_len]
2751
+ all_last_row = layer_attn[:, -1, :]
2752
+
2753
+ # Max weight per head: [n_heads] — single GPU op
2754
+ all_max_weights = all_last_row.max(dim=-1).values
2755
 
2756
+ # Entropy of last-row per head: [n_heads] single GPU op
2757
+ all_entropies = -(all_last_row * torch.log(all_last_row + 1e-10)).sum(dim=-1)
 
 
 
2758
 
2759
+ # Normalized average entropy per head (latter half of query positions)
2760
+ # layer_attn: [n_heads, q_len, k_len]
2761
+ q_len = layer_attn.shape[1]
2762
+ # Raw entropy per query position per head: [n_heads, q_len]
2763
+ all_token_entropies = -(layer_attn * torch.log(layer_attn + 1e-10)).sum(dim=-1)
2764
+ # Normalize by log(position): [q_len]
2765
+ positions = torch.arange(1, q_len + 1, device=layer_attn.device, dtype=layer_attn.dtype)
2766
+ max_ents = torch.log(positions + 1e-10) # [q_len]
2767
+ all_normalized = all_token_entropies / (max_ents.unsqueeze(0) + 1e-10) # [n_heads, q_len]
2768
+ # Average over latter half: [n_heads]
2769
+ start_idx = q_len // 2
2770
+ if start_idx < q_len:
2771
+ all_avg_entropies = all_normalized[:, start_idx:].mean(dim=-1)
2772
+ else:
2773
+ all_avg_entropies = all_normalized.mean(dim=-1)
2774
+
2775
+ # Previous-token weights for pattern detection: [n_heads]
2776
+ all_prev_token_weights = all_last_row[:, -2] if all_last_row.shape[1] >= 2 else torch.zeros(num_heads_layer, device=layer_attn.device)
2777
+
2778
+ # Single bulk transfer: all head metrics to CPU
2779
+ head_metrics_gpu = torch.stack([all_max_weights, all_entropies, all_avg_entropies, all_prev_token_weights]) # [4, n_heads]
2780
+ head_metrics_cpu = head_metrics_gpu.cpu().tolist() # one sync point
2781
+ max_weights_list = head_metrics_cpu[0]
2782
+ entropies_list = head_metrics_cpu[1]
2783
+ avg_entropies_list = head_metrics_cpu[2]
2784
+ prev_token_list = head_metrics_cpu[3]
2785
 
2786
+ # Bulk transfer attention matrices to CPU: one .cpu() for entire layer
2787
+ layer_attn_cpu = layer_attn.cpu().float().numpy() # [n_heads, seq_len, seq_len]
 
 
2788
 
2789
+ # QKV matrices (already on CPU from hooks)
2790
+ qkv_layer = qkv_captures.get(layer_idx)
 
2791
 
2792
+ # Build per-head metadata from CPU-side data (no more GPU calls)
2793
+ critical_heads = []
2794
+ for head_idx in range(num_heads_layer):
2795
+ mw = max_weights_list[head_idx]
2796
+ ent = entropies_list[head_idx]
2797
+ avg_ent = avg_entropies_list[head_idx]
2798
+ ptw = prev_token_list[head_idx]
2799
+
2800
+ # Sanitize
2801
+ mw = 0.0 if math.isnan(mw) or math.isinf(mw) else mw
2802
+ ent = 0.0 if math.isnan(ent) or math.isinf(ent) else ent
2803
+ avg_ent = 0.0 if math.isnan(avg_ent) or math.isinf(avg_ent) else avg_ent
2804
 
2805
  pattern_type = None
2806
  confidence = 0.0
2807
+ if step > 0 and mw > 0.8:
 
2808
  pattern_type = "induction"
2809
+ confidence = mw
2810
+ elif ent < 1.0:
2811
  pattern_type = "positional"
2812
+ confidence = 1.0 - ent
2813
+ elif 1.0 <= ent < 2.5:
2814
  pattern_type = "semantic"
2815
+ confidence = min(1.0, ent / 2.5)
2816
+ elif mw > 0.9 and ptw > 0.85:
2817
  pattern_type = "previous_token"
2818
+ confidence = ptw
 
2819
  confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
2820
 
2821
+ attention_matrix = layer_attn_cpu[head_idx]
 
 
2822
 
2823
  q_matrix = None
2824
  k_matrix = None
2825
  v_matrix = None
2826
+ if qkv_layer is not None:
2827
+ q_matrix = qkv_layer['q'][:, head_idx, :].float().numpy()
2828
+ k_matrix = qkv_layer['k'][:, head_idx, :].float().numpy()
2829
+ v_matrix = qkv_layer['v'][:, head_idx, :].float().numpy()
2830
 
 
2831
  matrix_cache.store(request_id, step, layer_idx, head_idx, {
2832
  "attention_weights": attention_matrix,
2833
  "q_matrix": q_matrix,
 
2835
  "v_matrix": v_matrix
2836
  })
2837
 
 
2838
  critical_heads.append({
2839
  "head_idx": head_idx,
2840
+ "entropy": ent,
2841
+ "avg_entropy": avg_ent,
2842
+ "max_weight": mw,
2843
+ "has_matrices": attention_matrix is not None,
2844
  "pattern": {"type": pattern_type, "confidence": confidence} if pattern_type else None
2845
  })
2846