gary-boon Claude Opus 4.6 commited on
Commit
6f48db0
·
1 Parent(s): 0d76811

Add tuned lens as supplementary projection mode for logit lens

Browse files

Introduces per-layer affine probes (trained to minimise KL divergence with
the final layer) that correct for subspace mismatch in early transformer
layers. Positioned alongside the existing raw logit lens with a frontend
toggle, allowing developers to compare projections and assess whether
observed commitment patterns are robust to projection choice.

New: backend/tuned_lens.py (runtime), scripts/train_tuned_lens.py (training)
Modified: model_service.py (load at startup, tuned lens computation, tuned
commitment summary, modelInfo.tunedLensAvailable, health endpoint)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

backend/model_service.py CHANGED
@@ -526,6 +526,10 @@ class ModelManager:
526
 
527
  logger.info("✅ Model loaded successfully")
528
 
 
 
 
 
529
  except Exception as e:
530
  logger.error(f"Failed to load model: {e}")
531
  raise
@@ -1238,11 +1242,13 @@ async def root():
1238
  @app.get("/health")
1239
  async def health():
1240
  """Detailed health check - always returns 200 for Docker healthcheck"""
 
1241
  return {
1242
  "status": "healthy" if manager.model else "initializing",
1243
  "model_loaded": manager.model is not None,
1244
  "device": str(manager.device) if manager.device else "not set",
1245
  "websocket_clients": len(manager.websocket_clients),
 
1246
  "timestamp": datetime.now().isoformat()
1247
  }
1248
 
@@ -3233,6 +3239,42 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
3233
  layer_entry["layer_margin"] = layer_margin_val
3234
  layer_entry["layer_winner"] = layer_winner_token
3235
  layer_entry["layer_runnerup"] = layer_runnerup_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3236
  except Exception as lens_err:
3237
  logger.debug(f"Logit lens error at layer {layer_idx}: {lens_err}")
3238
 
@@ -3397,6 +3439,39 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
3397
  "flip_count": flip_count,
3398
  }
3399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3400
  # Build response
3401
  response = {
3402
  "requestId": request_id, # For lazy-loading matrices via /matrix endpoint
@@ -3413,12 +3488,14 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
3413
  "numHeads": n_heads,
3414
  "modelDimension": d_model,
3415
  "headDim": head_dim,
3416
- "vocabSize": manager.model.config.vocab_size
 
3417
  },
3418
  "generationTime": generation_time,
3419
  "numTokensGenerated": len(generated_tokens),
3420
  "marginStats": margin_stats,
3421
  "commitmentSummary": commitment_summary,
 
3422
  }
3423
 
3424
  # Estimate response size
 
526
 
527
  logger.info("✅ Model loaded successfully")
528
 
529
+ # Load tuned lens probes (optional — falls back to raw logit lens if unavailable)
530
+ from .tuned_lens import tuned_lens_runtime
531
+ tuned_lens_runtime.load(self.model_id, self.device, self.dtype)
532
+
533
  except Exception as e:
534
  logger.error(f"Failed to load model: {e}")
535
  raise
 
1242
  @app.get("/health")
1243
  async def health():
1244
  """Detailed health check - always returns 200 for Docker healthcheck"""
1245
+ from .tuned_lens import tuned_lens_runtime
1246
  return {
1247
  "status": "healthy" if manager.model else "initializing",
1248
  "model_loaded": manager.model is not None,
1249
  "device": str(manager.device) if manager.device else "not set",
1250
  "websocket_clients": len(manager.websocket_clients),
1251
+ "tuned_lens_available": tuned_lens_runtime.available,
1252
  "timestamp": datetime.now().isoformat()
1253
  }
1254
 
 
3239
  layer_entry["layer_margin"] = layer_margin_val
3240
  layer_entry["layer_winner"] = layer_winner_token
3241
  layer_entry["layer_runnerup"] = layer_runnerup_token
3242
+
3243
+ # Tuned lens: apply per-layer affine correction
3244
+ from .tuned_lens import tuned_lens_runtime
3245
+ if tuned_lens_runtime.available:
3246
+ try:
3247
+ corrected = tuned_lens_runtime.apply(layer_idx, hidden_for_lens)
3248
+ tuned_normed = normed.__class__ # reuse same LN path
3249
+ # Re-apply final LN + lm_head on corrected hidden
3250
+ if hasattr(manager.model, 'model') and hasattr(manager.model.model, 'norm'):
3251
+ tuned_normed = manager.model.model.norm(corrected)
3252
+ tuned_logits = manager.model.lm_head(tuned_normed)[0]
3253
+ elif hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'ln_f'):
3254
+ tuned_normed = manager.model.transformer.ln_f(corrected)
3255
+ tuned_logits = manager.model.lm_head(tuned_normed)[0]
3256
+ else:
3257
+ tuned_logits = None
3258
+
3259
+ if tuned_logits is not None:
3260
+ tuned_probs = torch.softmax(tuned_logits, dim=-1)
3261
+ tuned_top_probs, tuned_top_ids = torch.topk(tuned_probs, k=5)
3262
+ tuned_entries = []
3263
+ for tp, tid in zip(tuned_top_probs.cpu().tolist(), tuned_top_ids.cpu().tolist()):
3264
+ tuned_entries.append({
3265
+ "token": manager.tokenizer.decode([tid], skip_special_tokens=False),
3266
+ "probability": tp
3267
+ })
3268
+ layer_entry["tuned_lens_top"] = tuned_entries
3269
+
3270
+ tuned_top2_logits, tuned_top2_ids = torch.topk(tuned_logits, k=min(2, len(tuned_logits)))
3271
+ tuned_top2_logits_list = tuned_top2_logits.cpu().tolist()
3272
+ tuned_top2_ids_list = tuned_top2_ids.cpu().tolist()
3273
+ layer_entry["tuned_layer_winner"] = manager.tokenizer.decode([tuned_top2_ids_list[0]], skip_special_tokens=False)
3274
+ layer_entry["tuned_layer_runnerup"] = manager.tokenizer.decode([tuned_top2_ids_list[1]], skip_special_tokens=False) if len(tuned_top2_ids_list) > 1 else ""
3275
+ layer_entry["tuned_layer_margin"] = (tuned_top2_logits_list[0] - tuned_top2_logits_list[1]) if len(tuned_top2_logits_list) > 1 else 0.0
3276
+ except Exception as tuned_err:
3277
+ logger.debug(f"Tuned lens error at layer {layer_idx}: {tuned_err}")
3278
  except Exception as lens_err:
3279
  logger.debug(f"Logit lens error at layer {layer_idx}: {lens_err}")
3280
 
 
3439
  "flip_count": flip_count,
3440
  }
3441
 
3442
+ # Tuned lens commitment summary (parallel to raw)
3443
+ tuned_commitment_summary = None
3444
+ from .tuned_lens import tuned_lens_runtime
3445
+ if tuned_lens_runtime.available:
3446
+ tuned_commitment_layers = []
3447
+ tuned_flip_count = 0
3448
+ for step_idx, step_layers in enumerate(layer_data_by_token):
3449
+ tuned_lens_layers = [l for l in step_layers if l.get("tuned_layer_margin") is not None]
3450
+ if not tuned_lens_layers:
3451
+ continue
3452
+ step_commitment = None
3453
+ for i, ll in enumerate(tuned_lens_layers):
3454
+ if ll["tuned_layer_margin"] > 0.3:
3455
+ stays_positive = all(tuned_lens_layers[j]["tuned_layer_margin"] > 0 for j in range(i, len(tuned_lens_layers)))
3456
+ if stays_positive:
3457
+ step_commitment = ll["layer_idx"]
3458
+ break
3459
+ if step_commitment is not None:
3460
+ tuned_commitment_layers.append(step_commitment)
3461
+ for i in range(1, len(tuned_lens_layers)):
3462
+ prev_w = (tuned_lens_layers[i-1].get("tuned_layer_winner") or "").strip()
3463
+ curr_w = (tuned_lens_layers[i].get("tuned_layer_winner") or "").strip()
3464
+ if prev_w and curr_w and prev_w != curr_w:
3465
+ tuned_flip_count += 1
3466
+
3467
+ tuned_avg = sum(tuned_commitment_layers) / len(tuned_commitment_layers) if tuned_commitment_layers else n_layers
3468
+ tuned_late = sum(1 for cl in tuned_commitment_layers if cl > late_threshold)
3469
+ tuned_commitment_summary = {
3470
+ "avg_commitment_layer": round(tuned_avg, 1),
3471
+ "late_commitment_count": tuned_late,
3472
+ "flip_count": tuned_flip_count,
3473
+ }
3474
+
3475
  # Build response
3476
  response = {
3477
  "requestId": request_id, # For lazy-loading matrices via /matrix endpoint
 
3488
  "numHeads": n_heads,
3489
  "modelDimension": d_model,
3490
  "headDim": head_dim,
3491
+ "vocabSize": manager.model.config.vocab_size,
3492
+ "tunedLensAvailable": tuned_lens_runtime.available,
3493
  },
3494
  "generationTime": generation_time,
3495
  "numTokensGenerated": len(generated_tokens),
3496
  "marginStats": margin_stats,
3497
  "commitmentSummary": commitment_summary,
3498
+ **({"tunedCommitmentSummary": tuned_commitment_summary} if tuned_commitment_summary else {}),
3499
  }
3500
 
3501
  # Estimate response size
backend/tuned_lens.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tuned Lens Runtime — load and apply per-layer affine probes for improved
3
+ intermediate-layer predictions.
4
+
5
+ Each probe applies a learned linear correction A_l(x) = x @ W_l^T + b_l
6
+ (initialised to identity + zero during training) that is trained to minimise
7
+ KL divergence between the corrected layer's predictions and the model's
8
+ final-layer predictions.
9
+
10
+ See scripts/train_tuned_lens.py for the training pipeline.
11
+ """
12
+
13
+ import json
14
+ import logging
15
+ import os
16
+ from pathlib import Path
17
+ from typing import Dict, List, Optional, Tuple
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ TUNED_LENS_DIR = os.environ.get("TUNED_LENS_DIR", "./tuned_lens_weights")
25
+
26
+
27
+ class TunedLensRuntime:
28
+ """Load, cache, and apply per-layer affine probes at inference time."""
29
+
30
+ def __init__(self):
31
+ self._probes: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {}
32
+ self._metadata: Optional[dict] = None
33
+ self._available = False
34
+
35
+ @property
36
+ def available(self) -> bool:
37
+ return self._available
38
+
39
+ def load(self, model_id: str, device: torch.device, dtype: torch.dtype,
40
+ weights_dir: Optional[str] = None) -> bool:
41
+ """Load tuned lens checkpoint for *model_id*.
42
+
43
+ Returns True if weights were loaded successfully, False otherwise.
44
+ Failure is non-fatal — the system falls back to raw logit lens.
45
+ """
46
+ base_dir = Path(weights_dir or TUNED_LENS_DIR)
47
+ model_dir = base_dir / model_id
48
+
49
+ if not model_dir.exists():
50
+ logger.info(f"Tuned lens: no weights directory for {model_id} at {model_dir}")
51
+ return False
52
+
53
+ # Find the checkpoint — pick the first .pt file
54
+ pt_files = sorted(model_dir.glob("tuned_lens_*.pt"))
55
+ if not pt_files:
56
+ logger.info(f"Tuned lens: no .pt checkpoint found in {model_dir}")
57
+ return False
58
+
59
+ checkpoint_path = pt_files[0]
60
+ metadata_path = model_dir / "metadata.json"
61
+
62
+ try:
63
+ # Load metadata
64
+ if metadata_path.exists():
65
+ with open(metadata_path, "r") as f:
66
+ self._metadata = json.load(f)
67
+ logger.info(f"Tuned lens: metadata loaded — {self._metadata.get('n_layers')} layers, "
68
+ f"d_model={self._metadata.get('d_model')}")
69
+ else:
70
+ self._metadata = {}
71
+
72
+ # Load state dict
73
+ state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
74
+
75
+ # Parse layer_N.weight / layer_N.bias entries
76
+ self._probes = {}
77
+ layer_indices = set()
78
+ for key in state_dict:
79
+ parts = key.split(".")
80
+ if len(parts) == 2 and parts[0].startswith("layer_"):
81
+ idx = int(parts[0].split("_")[1])
82
+ layer_indices.add(idx)
83
+
84
+ for idx in sorted(layer_indices):
85
+ w_key = f"layer_{idx}.weight"
86
+ b_key = f"layer_{idx}.bias"
87
+ if w_key in state_dict and b_key in state_dict:
88
+ weight = state_dict[w_key].to(device=device, dtype=dtype)
89
+ bias = state_dict[b_key].to(device=device, dtype=dtype)
90
+ self._probes[idx] = (weight, bias)
91
+
92
+ if not self._probes:
93
+ logger.warning(f"Tuned lens: checkpoint loaded but no layer probes found")
94
+ return False
95
+
96
+ self._available = True
97
+ logger.info(f"Tuned lens: loaded {len(self._probes)} layer probes from {checkpoint_path} "
98
+ f"(device={device}, dtype={dtype})")
99
+ return True
100
+
101
+ except Exception as e:
102
+ logger.warning(f"Tuned lens: failed to load checkpoint — {e}")
103
+ self._probes = {}
104
+ self._metadata = None
105
+ self._available = False
106
+ return False
107
+
108
+ def apply(self, layer_idx: int, hidden_state: torch.Tensor) -> torch.Tensor:
109
+ """Apply the affine probe for *layer_idx*: hidden @ W^T + b.
110
+
111
+ If no probe exists for this layer, returns the hidden state unchanged
112
+ (identity fallback).
113
+ """
114
+ if layer_idx not in self._probes:
115
+ return hidden_state
116
+ weight, bias = self._probes[layer_idx]
117
+ return hidden_state @ weight.T + bias
118
+
119
+ def get_info(self) -> dict:
120
+ """Return metadata dict for health/debug endpoints."""
121
+ return {
122
+ "available": self._available,
123
+ "num_probes": len(self._probes),
124
+ "layer_indices": sorted(self._probes.keys()),
125
+ "metadata": self._metadata or {},
126
+ }
127
+
128
+
129
+ # Global singleton
130
+ tuned_lens_runtime = TunedLensRuntime()
scripts/__init__.py ADDED
File without changes
scripts/train_tuned_lens.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train Tuned Lens probes — per-layer affine corrections that minimise
3
+ KL divergence between an intermediate layer's corrected predictions and
4
+ the model's final-layer predictions.
5
+
6
+ Usage:
7
+ python -m scripts.train_tuned_lens \
8
+ --model-id codegen-350m \
9
+ --corpus-file calibration_data.txt \
10
+ --output-dir ./tuned_lens_weights/ \
11
+ --max-samples 2000 --epochs 5
12
+
13
+ Each probe is a simple affine map A_l(x) = x @ W_l^T + b_l
14
+ initialised to identity + zero so that the untrained probe reproduces
15
+ the raw logit lens exactly.
16
+ """
17
+
18
+ import argparse
19
+ import hashlib
20
+ import json
21
+ import logging
22
+ import os
23
+ import sys
24
+ import time
25
+ from pathlib import Path
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ from transformers import AutoModelForCausalLM, AutoTokenizer
31
+
32
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # AffineProbe
38
+ # ---------------------------------------------------------------------------
39
+
40
+ class AffineProbe(nn.Module):
41
+ """Per-layer affine correction initialised to identity."""
42
+
43
+ def __init__(self, d_model: int):
44
+ super().__init__()
45
+ self.weight = nn.Parameter(torch.eye(d_model))
46
+ self.bias = nn.Parameter(torch.zeros(d_model))
47
+
48
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
49
+ return x @ self.weight.T + self.bias
50
+
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # Architecture detection — mirrors model_service.py
54
+ # ---------------------------------------------------------------------------
55
+
56
+ def get_final_ln_and_lm_head(model):
57
+ """Return (final_layer_norm, lm_head) for the loaded model."""
58
+ # Mistral / LLaMA / CodeGen-style
59
+ if hasattr(model, "model") and hasattr(model.model, "norm"):
60
+ return model.model.norm, model.lm_head
61
+ # GPT-style
62
+ if hasattr(model, "transformer") and hasattr(model.transformer, "ln_f"):
63
+ return model.transformer.ln_f, model.lm_head
64
+ raise RuntimeError(
65
+ "Cannot detect final layer norm — model architecture not recognised. "
66
+ "Supported: Mistral/LLaMA (.model.norm), GPT (.transformer.ln_f)"
67
+ )
68
+
69
+
70
+ # ---------------------------------------------------------------------------
71
+ # Model hash — ties checkpoint to exact model weights
72
+ # ---------------------------------------------------------------------------
73
+
74
+ def compute_model_hash(model, n_tensors: int = 20) -> str:
75
+ """SHA-256 of the first *n_tensors* parameter tensors' bytes."""
76
+ h = hashlib.sha256()
77
+ for i, (_, param) in enumerate(model.named_parameters()):
78
+ if i >= n_tensors:
79
+ break
80
+ h.update(param.data.cpu().numpy().tobytes())
81
+ return h.hexdigest()
82
+
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # Corpus loader
86
+ # ---------------------------------------------------------------------------
87
+
88
+ def load_corpus(path: str, max_samples: int, max_seq_len: int, tokenizer) -> list:
89
+ """Load and tokenize a plain-text corpus (one sample per line or paragraph)."""
90
+ texts = []
91
+ with open(path, "r", encoding="utf-8") as f:
92
+ buf = []
93
+ for line in f:
94
+ line = line.rstrip("\n")
95
+ if line.strip() == "" and buf:
96
+ texts.append("\n".join(buf))
97
+ buf = []
98
+ if len(texts) >= max_samples:
99
+ break
100
+ else:
101
+ buf.append(line)
102
+ if buf and len(texts) < max_samples:
103
+ texts.append("\n".join(buf))
104
+
105
+ # Tokenize
106
+ samples = []
107
+ for text in texts[:max_samples]:
108
+ ids = tokenizer.encode(text, add_special_tokens=False, truncation=True,
109
+ max_length=max_seq_len)
110
+ if len(ids) >= 8: # skip very short sequences
111
+ samples.append(torch.tensor(ids, dtype=torch.long))
112
+ logger.info(f"Loaded {len(samples)} samples from {path} (max_seq_len={max_seq_len})")
113
+ return samples
114
+
115
+
116
+ # ---------------------------------------------------------------------------
117
+ # Training
118
+ # ---------------------------------------------------------------------------
119
+
120
+ def train_tuned_lens(
121
+ model,
122
+ tokenizer,
123
+ samples: list,
124
+ device: torch.device,
125
+ lr: float = 1e-3,
126
+ l2_weight: float = 1e-4,
127
+ epochs: int = 5,
128
+ ):
129
+ """Train one AffineProbe per layer, streaming hidden states (no disk storage)."""
130
+ final_ln, lm_head = get_final_ln_and_lm_head(model)
131
+ config = model.config
132
+ d_model = getattr(config, "hidden_size", None) or getattr(config, "n_embd")
133
+ n_layers = getattr(config, "num_hidden_layers", None) or getattr(config, "n_layer")
134
+
135
+ # Create probes + optimizers
136
+ probes = {}
137
+ optimizers = {}
138
+ for l in range(n_layers):
139
+ probe = AffineProbe(d_model).to(device)
140
+ probes[l] = probe
141
+ optimizers[l] = torch.optim.AdamW(probe.parameters(), lr=lr, weight_decay=0.0)
142
+
143
+ logger.info(f"Training {n_layers} probes (d_model={d_model}, {len(samples)} samples, {epochs} epochs)")
144
+
145
+ for epoch in range(epochs):
146
+ epoch_losses = {l: 0.0 for l in range(n_layers)}
147
+ epoch_count = 0
148
+
149
+ for si, sample_ids in enumerate(samples):
150
+ input_ids = sample_ids.unsqueeze(0).to(device)
151
+
152
+ with torch.no_grad():
153
+ outputs = model(input_ids, output_hidden_states=True)
154
+ hidden_states = outputs.hidden_states # tuple of (n_layers+1) tensors
155
+
156
+ # Reference distribution from final layer
157
+ ref_hidden = hidden_states[-1]
158
+ ref_normed = final_ln(ref_hidden)
159
+ ref_logits = lm_head(ref_normed)
160
+ ref_log_probs = F.log_softmax(ref_logits, dim=-1).detach()
161
+
162
+ # Train each layer's probe independently
163
+ for l in range(n_layers):
164
+ probe = probes[l]
165
+ optimizer = optimizers[l]
166
+
167
+ # hidden_states[0] = embedding, hidden_states[l+1] = after layer l
168
+ h = hidden_states[l + 1].detach()
169
+
170
+ corrected = probe(h)
171
+ corrected_normed = final_ln(corrected)
172
+ probe_logits = lm_head(corrected_normed)
173
+ probe_log_probs = F.log_softmax(probe_logits, dim=-1)
174
+
175
+ # KL(ref || probe) — ref is the target distribution
176
+ kl = F.kl_div(probe_log_probs, ref_log_probs.exp(), reduction="batchmean", log_target=False)
177
+
178
+ # L2 regularisation toward identity: ||W - I||^2 + ||b||^2
179
+ identity = torch.eye(d_model, device=device, dtype=probe.weight.dtype)
180
+ l2_reg = ((probe.weight - identity) ** 2).sum() + (probe.bias ** 2).sum()
181
+
182
+ loss = kl + l2_weight * l2_reg
183
+
184
+ optimizer.zero_grad()
185
+ loss.backward()
186
+ optimizer.step()
187
+
188
+ epoch_losses[l] += loss.item()
189
+
190
+ epoch_count += 1
191
+
192
+ # Free memory
193
+ del outputs, hidden_states, ref_hidden, ref_normed, ref_logits, ref_log_probs
194
+
195
+ if (si + 1) % 100 == 0:
196
+ avg_loss = sum(epoch_losses[l] for l in range(n_layers)) / (n_layers * epoch_count)
197
+ logger.info(f" Epoch {epoch+1}, sample {si+1}/{len(samples)}, avg loss: {avg_loss:.4f}")
198
+
199
+ avg_epoch_loss = sum(epoch_losses[l] for l in range(n_layers)) / (n_layers * max(epoch_count, 1))
200
+ logger.info(f"Epoch {epoch+1}/{epochs} complete — avg loss: {avg_epoch_loss:.4f}")
201
+
202
+ return probes
203
+
204
+
205
+ # ---------------------------------------------------------------------------
206
+ # Checkpoint saving
207
+ # ---------------------------------------------------------------------------
208
+
209
+ def save_checkpoint(probes: dict, model, model_id: str, output_dir: str,
210
+ training_config: dict):
211
+ """Save probe state dicts and metadata."""
212
+ model_hash = compute_model_hash(model)
213
+ config = model.config
214
+ d_model = getattr(config, "hidden_size", None) or getattr(config, "n_embd")
215
+ n_layers = getattr(config, "num_hidden_layers", None) or getattr(config, "n_layer")
216
+
217
+ save_dir = Path(output_dir) / model_id
218
+ save_dir.mkdir(parents=True, exist_ok=True)
219
+
220
+ # Build combined state dict
221
+ state_dict = {}
222
+ for layer_idx, probe in probes.items():
223
+ state_dict[f"layer_{layer_idx}.weight"] = probe.weight.data.cpu()
224
+ state_dict[f"layer_{layer_idx}.bias"] = probe.bias.data.cpu()
225
+
226
+ checkpoint_path = save_dir / f"tuned_lens_{model_hash[:16]}.pt"
227
+ torch.save(state_dict, checkpoint_path)
228
+
229
+ metadata = {
230
+ "model_id": model_id,
231
+ "model_hash": model_hash,
232
+ "n_layers": n_layers,
233
+ "d_model": d_model,
234
+ "training_config": training_config,
235
+ "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
236
+ }
237
+ metadata_path = save_dir / "metadata.json"
238
+ with open(metadata_path, "w") as f:
239
+ json.dump(metadata, f, indent=2)
240
+
241
+ logger.info(f"Saved checkpoint to {checkpoint_path} ({checkpoint_path.stat().st_size / 1024 / 1024:.1f}MB)")
242
+ logger.info(f"Saved metadata to {metadata_path}")
243
+ return checkpoint_path
244
+
245
+
246
+ # ---------------------------------------------------------------------------
247
+ # CLI
248
+ # ---------------------------------------------------------------------------
249
+
250
+ def main():
251
+ parser = argparse.ArgumentParser(description="Train tuned lens probes for a model")
252
+ parser.add_argument("--model-id", required=True, help="Model identifier (e.g. codegen-350m)")
253
+ parser.add_argument("--model-name", default=None,
254
+ help="HuggingFace model name (defaults to model-id)")
255
+ parser.add_argument("--corpus-file", required=True, help="Plain-text calibration corpus")
256
+ parser.add_argument("--output-dir", default="./tuned_lens_weights/",
257
+ help="Output directory for checkpoints")
258
+ parser.add_argument("--max-samples", type=int, default=2000)
259
+ parser.add_argument("--max-seq-len", type=int, default=512)
260
+ parser.add_argument("--epochs", type=int, default=5)
261
+ parser.add_argument("--lr", type=float, default=1e-3)
262
+ parser.add_argument("--l2-weight", type=float, default=1e-4)
263
+ parser.add_argument("--device", default=None, help="Device (auto-detected if omitted)")
264
+ parser.add_argument("--dtype", default="float16", choices=["float16", "bfloat16", "float32"])
265
+ args = parser.parse_args()
266
+
267
+ model_name = args.model_name or args.model_id
268
+ dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
269
+ dtype = dtype_map[args.dtype]
270
+
271
+ if args.device:
272
+ device = torch.device(args.device)
273
+ elif torch.cuda.is_available():
274
+ device = torch.device("cuda")
275
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
276
+ device = torch.device("mps")
277
+ else:
278
+ device = torch.device("cpu")
279
+
280
+ logger.info(f"Device: {device}, dtype: {dtype}")
281
+ logger.info(f"Loading model: {model_name}")
282
+
283
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
284
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device)
285
+ model.eval()
286
+
287
+ samples = load_corpus(args.corpus_file, args.max_samples, args.max_seq_len, tokenizer)
288
+ if not samples:
289
+ logger.error("No valid samples loaded — aborting")
290
+ sys.exit(1)
291
+
292
+ training_config = {
293
+ "lr": args.lr,
294
+ "l2_weight": args.l2_weight,
295
+ "epochs": args.epochs,
296
+ "max_samples": args.max_samples,
297
+ "max_seq_len": args.max_seq_len,
298
+ "dtype": args.dtype,
299
+ "num_samples_used": len(samples),
300
+ }
301
+
302
+ probes = train_tuned_lens(
303
+ model, tokenizer, samples, device,
304
+ lr=args.lr, l2_weight=args.l2_weight, epochs=args.epochs,
305
+ )
306
+
307
+ save_checkpoint(probes, model, args.model_id, args.output_dir, training_config)
308
+ logger.info("Done.")
309
+
310
+
311
+ if __name__ == "__main__":
312
+ main()