Delta-Vector commited on
Commit
a1e45db
·
verified ·
1 Parent(s): 93462da

Upload distill_sharded.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. distill_sharded.py +807 -0
distill_sharded.py ADDED
@@ -0,0 +1,807 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Single-process KL distillation with a sharded frozen teacher and one trainable
4
+ student GPU.
5
+
6
+ This is a derivative of distill.py tailored for large-teacher / smaller-student
7
+ setups where replicating the teacher per process is wasteful or infeasible.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ import gc
14
+ import json
15
+ import logging
16
+ import random
17
+ import re
18
+ import shutil
19
+ import time
20
+ import tomllib
21
+ from collections import OrderedDict
22
+ from pathlib import Path
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint as checkpoint_utils
27
+ from torch.optim import AdamW
28
+
29
+
30
+ logging.basicConfig(
31
+ level=logging.INFO,
32
+ format="%(asctime)s [%(levelname)s] %(message)s",
33
+ datefmt="%H:%M:%S",
34
+ )
35
+ log = logging.getLogger("distill_sharded")
36
+
37
+
38
+ REQUIRED_SECTIONS = ("model", "data", "train", "eval", "log", "init")
39
+ REQUIRED_KEYS = {
40
+ "model": ("teacher", "student", "tokenizer", "student_device", "teacher_devices", "teacher_max_memory_gb"),
41
+ "data": ("min_chars", "max_seq_len", "kl_start_pos", "seed", "shuffle_buffer"),
42
+ "train": (
43
+ "seed",
44
+ "lr",
45
+ "schedule",
46
+ "warmup_steps",
47
+ "weight_decay",
48
+ "grad_clip",
49
+ "betas",
50
+ "eps",
51
+ "samples_per_step",
52
+ "max_steps",
53
+ "grad_checkpointing",
54
+ "attn_implementation",
55
+ "student_dtype",
56
+ "teacher_dtype",
57
+ "kl_chunk_size",
58
+ "micro_batch_size",
59
+ "new_layer_lr_mul",
60
+ ),
61
+ "eval": ("every_steps", "samples", "seed", "cache_path"),
62
+ "log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir", "experiment_log"),
63
+ "init": ("zero_layers", "target_num_layers"),
64
+ }
65
+
66
+ DTYPE_MAP = {
67
+ "float32": torch.float32,
68
+ "bfloat16": torch.bfloat16,
69
+ }
70
+
71
+
72
+ def parse_dtype(s: str) -> torch.dtype:
73
+ if s not in DTYPE_MAP:
74
+ raise ValueError(f"unknown dtype {s!r}; must be one of {list(DTYPE_MAP)}")
75
+ return DTYPE_MAP[s]
76
+
77
+
78
+ def load_config(path: str) -> dict:
79
+ with open(path, "rb") as f:
80
+ cfg = tomllib.load(f)
81
+ for sec in REQUIRED_SECTIONS:
82
+ if sec not in cfg:
83
+ raise KeyError(f"config missing required section [{sec}]")
84
+ for key in REQUIRED_KEYS[sec]:
85
+ if key not in cfg[sec]:
86
+ raise KeyError(f"config missing required key [{sec}].{key}")
87
+ return cfg
88
+
89
+
90
+ def get_inner_with_layers(model):
91
+ seen = set()
92
+ stack = [model]
93
+ while stack:
94
+ m = stack.pop()
95
+ if id(m) in seen:
96
+ continue
97
+ seen.add(id(m))
98
+ if hasattr(m, "layers"):
99
+ return m
100
+ for attr in ("model", "language_model", "transformer", "base_model"):
101
+ child = getattr(m, attr, None)
102
+ if child is not None:
103
+ stack.append(child)
104
+ raise RuntimeError(f"Could not locate `.layers` inside {type(model).__name__}")
105
+
106
+
107
+ def zero_layers(model, layer_indices):
108
+ inner = get_inner_with_layers(model)
109
+ layers = inner.layers
110
+ n = len(layers)
111
+ for idx in layer_indices:
112
+ if idx < 0 or idx >= n:
113
+ raise IndexError(f"layer {idx} out of range (0..{n - 1})")
114
+ with torch.no_grad():
115
+ for p in layers[idx].parameters():
116
+ p.zero_()
117
+ return n
118
+
119
+
120
+ def _zero_output_projections(layer):
121
+ zeroed = []
122
+ with torch.no_grad():
123
+ if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "o_proj"):
124
+ layer.self_attn.o_proj.weight.zero_()
125
+ zeroed.append("self_attn.o_proj")
126
+ if hasattr(layer, "linear_attn") and hasattr(layer.linear_attn, "out_proj"):
127
+ layer.linear_attn.out_proj.weight.zero_()
128
+ zeroed.append("linear_attn.out_proj")
129
+ if hasattr(layer, "mlp") and hasattr(layer.mlp, "down_proj"):
130
+ layer.mlp.down_proj.weight.zero_()
131
+ zeroed.append("mlp.down_proj")
132
+ return zeroed
133
+
134
+
135
+ def grow_layers(model, target_n):
136
+ inner = get_inner_with_layers(model)
137
+ cur_n = len(inner.layers)
138
+ if target_n == cur_n:
139
+ return cur_n, []
140
+ if target_n < cur_n:
141
+ raise ValueError(f"target_num_layers={target_n} < current {cur_n}; cannot shrink")
142
+
143
+ cfg = model.config
144
+ text_cfg = getattr(cfg, "text_config", cfg)
145
+ if not hasattr(text_cfg, "layer_types") or not text_cfg.layer_types:
146
+ raise RuntimeError("text config has no layer_types; cannot extend pattern")
147
+
148
+ period = getattr(text_cfg, "full_attention_interval", 4)
149
+ new_types = list(text_cfg.layer_types)
150
+ while len(new_types) < target_n:
151
+ new_types.append(new_types[len(new_types) % period])
152
+ text_cfg.layer_types = new_types
153
+ text_cfg.num_hidden_layers = target_n
154
+ if hasattr(cfg, "num_hidden_layers") and cfg is not text_cfg:
155
+ cfg.num_hidden_layers = target_n
156
+
157
+ layer_cls = type(inner.layers[0])
158
+ device = next(inner.parameters()).device
159
+ dtype = next(inner.parameters()).dtype
160
+
161
+ new_layer_zeroed = []
162
+ for i in range(cur_n, target_n):
163
+ new_layer = layer_cls(text_cfg, layer_idx=i)
164
+ new_layer.apply(model._init_weights)
165
+ new_layer.to(device=device, dtype=dtype)
166
+ zeroed = _zero_output_projections(new_layer)
167
+ new_layer_zeroed.append((i, zeroed))
168
+ inner.layers.append(new_layer)
169
+
170
+ return target_n, new_layer_zeroed
171
+
172
+
173
+ def detect_model_kind(model_id: str) -> str:
174
+ from transformers import AutoConfig
175
+
176
+ cfg = AutoConfig.from_pretrained(model_id)
177
+ archs = list(getattr(cfg, "architectures", []) or [])
178
+ arch = archs[0] if archs else ""
179
+ if "ConditionalGeneration" in arch or "ImageText" in arch:
180
+ return "image_text"
181
+ return "causal_lm"
182
+
183
+
184
+ def load_student(model_id: str, dtype: torch.dtype, grad_ckpt: bool, attn_impl: str):
185
+ kind = detect_model_kind(model_id)
186
+ if kind == "image_text":
187
+ from transformers import AutoModelForImageTextToText
188
+
189
+ model = AutoModelForImageTextToText.from_pretrained(
190
+ model_id,
191
+ dtype=dtype,
192
+ low_cpu_mem_usage=True,
193
+ attn_implementation=attn_impl,
194
+ )
195
+ else:
196
+ from transformers import AutoModelForCausalLM
197
+
198
+ model = AutoModelForCausalLM.from_pretrained(
199
+ model_id,
200
+ dtype=dtype,
201
+ low_cpu_mem_usage=True,
202
+ attn_implementation=attn_impl,
203
+ )
204
+ model.config.use_cache = False
205
+ if grad_ckpt:
206
+ model.gradient_checkpointing_enable(
207
+ gradient_checkpointing_kwargs={"use_reentrant": False}
208
+ )
209
+ return model
210
+
211
+
212
+ def load_teacher(model_id: str, dtype: torch.dtype, attn_impl: str, devices: list[int], max_mem_gb: int):
213
+ kind = detect_model_kind(model_id)
214
+ max_memory = {idx: f"{max_mem_gb}GiB" for idx in devices}
215
+ max_memory["cpu"] = "256GiB"
216
+ common = dict(
217
+ dtype=dtype,
218
+ low_cpu_mem_usage=True,
219
+ attn_implementation=attn_impl,
220
+ device_map="auto",
221
+ max_memory=max_memory,
222
+ )
223
+
224
+ if kind == "image_text":
225
+ from transformers import AutoModelForImageTextToText
226
+
227
+ model = AutoModelForImageTextToText.from_pretrained(model_id, **common)
228
+ else:
229
+ from transformers import AutoModelForCausalLM
230
+
231
+ model = AutoModelForCausalLM.from_pretrained(model_id, **common)
232
+ model.config.use_cache = False
233
+ model.eval()
234
+ for p in model.parameters():
235
+ p.requires_grad_(False)
236
+ return model
237
+
238
+
239
+ def get_teacher_devices(model) -> tuple[torch.device, torch.device]:
240
+ device_map = getattr(model, "hf_device_map", None) or {}
241
+ ordered = OrderedDict()
242
+ for _, dev in device_map.items():
243
+ if isinstance(dev, int):
244
+ ordered.setdefault(f"cuda:{dev}", None)
245
+ elif isinstance(dev, str) and dev.startswith("cuda:"):
246
+ ordered.setdefault(dev, None)
247
+ if not ordered:
248
+ first = next(model.parameters()).device
249
+ return first, first
250
+ keys = list(ordered.keys())
251
+ return torch.device(keys[0]), torch.device(keys[-1])
252
+
253
+
254
+ def teacher_forward(teacher, input_ids, attention_mask, out_device):
255
+ out = teacher(input_ids=input_ids, attention_mask=attention_mask)
256
+ logits = getattr(out, "logits", None)
257
+ if logits is None:
258
+ raise RuntimeError("teacher forward did not return .logits")
259
+ if logits.device != out_device:
260
+ logits = logits.to(out_device, non_blocking=True)
261
+ return logits
262
+
263
+
264
+ class StreamingTextLoader:
265
+ def __init__(
266
+ self,
267
+ name,
268
+ text_field,
269
+ min_chars,
270
+ max_seq_len,
271
+ kl_start_pos,
272
+ tokenizer,
273
+ seed,
274
+ shuffle_buffer,
275
+ ):
276
+ from datasets import load_dataset
277
+
278
+ last_err = None
279
+ for attempt in range(8):
280
+ try:
281
+ ds = load_dataset(name, split="train", streaming=True)
282
+ break
283
+ except Exception as e:
284
+ last_err = e
285
+ wait = min(2 ** attempt, 30)
286
+ log.warning(
287
+ f"load_dataset({name!r}) failed (attempt {attempt + 1}/8): "
288
+ f"{type(e).__name__}: {e}; sleeping {wait}s"
289
+ )
290
+ time.sleep(wait)
291
+ else:
292
+ raise RuntimeError(f"load_dataset failed after 8 retries") from last_err
293
+ ds = ds.shuffle(seed=seed, buffer_size=shuffle_buffer)
294
+ self._ds = iter(ds)
295
+ self._text_field = text_field
296
+ self._min_chars = min_chars
297
+ self._max_seq_len = max_seq_len
298
+ self._min_tokens = kl_start_pos + 16
299
+ self._tokenizer = tokenizer
300
+ self._name = name
301
+
302
+ def next_sample(self):
303
+ scanned = 0
304
+ while scanned < 100:
305
+ try:
306
+ item = next(self._ds)
307
+ except StopIteration:
308
+ return None
309
+ scanned += 1
310
+ text = item.get(self._text_field, "") or ""
311
+ if len(text) < self._min_chars:
312
+ continue
313
+ ids = self._tokenizer(
314
+ text,
315
+ return_tensors="pt",
316
+ truncation=True,
317
+ max_length=self._max_seq_len,
318
+ ).input_ids.squeeze(0)
319
+ if ids.shape[0] < self._min_tokens:
320
+ continue
321
+ return ids
322
+ return None
323
+
324
+
325
+ class MixedStreamingLoader:
326
+ def __init__(self, specs, tokenizer, min_chars, max_seq_len, kl_start_pos, seed, shuffle_buffer):
327
+ self._rng = random.Random(seed)
328
+ self._weights = []
329
+ self._loaders = []
330
+ for spec in specs:
331
+ self._weights.append(spec["weight"])
332
+ self._loaders.append(
333
+ StreamingTextLoader(
334
+ name=spec["name"],
335
+ text_field=spec["text_field"],
336
+ min_chars=min_chars,
337
+ max_seq_len=max_seq_len,
338
+ kl_start_pos=kl_start_pos,
339
+ tokenizer=tokenizer,
340
+ seed=seed + len(self._loaders),
341
+ shuffle_buffer=shuffle_buffer,
342
+ )
343
+ )
344
+
345
+ def next_batch(self, n):
346
+ out = []
347
+ while len(out) < n:
348
+ idx = self._rng.choices(range(len(self._loaders)), weights=self._weights, k=1)[0]
349
+ sample = self._loaders[idx].next_sample()
350
+ if sample is None:
351
+ continue
352
+ out.append(sample)
353
+ return out
354
+
355
+
356
+ def collate_pad(token_lists, pad_id):
357
+ max_len = max(t.shape[0] for t in token_lists)
358
+ B = len(token_lists)
359
+ input_ids = torch.full((B, max_len), pad_id, dtype=torch.long)
360
+ attention_mask = torch.zeros((B, max_len), dtype=torch.long)
361
+ for i, t in enumerate(token_lists):
362
+ L = t.shape[0]
363
+ input_ids[i, :L] = t
364
+ attention_mask[i, :L] = 1
365
+ return input_ids, attention_mask
366
+
367
+
368
+ def _kl_chunk_sum(s_chunk, t_chunk, m_chunk):
369
+ s = s_chunk.float()
370
+ t = t_chunk.float()
371
+ t_log_p = F.log_softmax(t, dim=-1)
372
+ s_log_p = F.log_softmax(s, dim=-1)
373
+ t_p = t_log_p.exp()
374
+ per_token = (t_p * (t_log_p - s_log_p)).sum(-1)
375
+ return (per_token * m_chunk).sum()
376
+
377
+
378
+ def kl_loss_masked(student_logits, teacher_logits, attention_mask, start_pos, chunk_size):
379
+ s_full = student_logits[:, start_pos:, :]
380
+ t_full = teacher_logits[:, start_pos:, :].detach()
381
+ m_full = attention_mask[:, start_pos:].float()
382
+
383
+ T = s_full.shape[1]
384
+ if chunk_size <= 0 or chunk_size >= T:
385
+ return _kl_chunk_sum(s_full, t_full, m_full) / m_full.sum().clamp_min(1.0)
386
+
387
+ total_kl = torch.zeros((), device=s_full.device, dtype=torch.float32)
388
+ for i in range(0, T, chunk_size):
389
+ end = min(i + chunk_size, T)
390
+ s_c = s_full[:, i:end, :]
391
+ t_c = t_full[:, i:end, :]
392
+ m_c = m_full[:, i:end]
393
+ chunk_kl = checkpoint_utils.checkpoint(
394
+ _kl_chunk_sum, s_c, t_c, m_c, use_reentrant=False
395
+ )
396
+ total_kl = total_kl + chunk_kl
397
+ return total_kl / m_full.sum().clamp_min(1.0)
398
+
399
+
400
+ def apply_trainable_masks(model, train_cfg):
401
+ trainable = train_cfg.get("trainable_patterns", [])
402
+ frozen = train_cfg.get("freeze_patterns", [])
403
+ if not trainable and not frozen:
404
+ return
405
+
406
+ trainable_re = [re.compile(p) for p in trainable]
407
+ frozen_re = [re.compile(p) for p in frozen]
408
+ for name, p in model.named_parameters():
409
+ keep = True
410
+ if trainable_re:
411
+ keep = any(r.search(name) for r in trainable_re)
412
+ if keep and frozen_re and any(r.search(name) for r in frozen_re):
413
+ keep = False
414
+ p.requires_grad_(keep)
415
+
416
+
417
+ def make_optimizer(model, train_cfg, new_layer_indices=None):
418
+ base_lr = train_cfg["lr"]
419
+ mul = train_cfg["new_layer_lr_mul"]
420
+ common = dict(
421
+ weight_decay=train_cfg["weight_decay"],
422
+ betas=tuple(train_cfg["betas"]),
423
+ eps=train_cfg["eps"],
424
+ )
425
+
426
+ if not new_layer_indices or mul == 1.0:
427
+ return AdamW(
428
+ [p for p in model.parameters() if p.requires_grad],
429
+ lr=base_lr,
430
+ **common,
431
+ )
432
+
433
+ inner = get_inner_with_layers(model)
434
+ new_pids = set()
435
+ for idx in new_layer_indices:
436
+ for p in inner.layers[idx].parameters():
437
+ if p.requires_grad:
438
+ new_pids.add(id(p))
439
+
440
+ new_params = []
441
+ rest_params = []
442
+ for p in model.parameters():
443
+ if not p.requires_grad:
444
+ continue
445
+ (new_params if id(p) in new_pids else rest_params).append(p)
446
+
447
+ return AdamW(
448
+ [
449
+ {"params": rest_params, "lr": base_lr},
450
+ {"params": new_params, "lr": base_lr * mul},
451
+ ],
452
+ **common,
453
+ )
454
+
455
+
456
+ def make_scheduler(optimizer, train_cfg):
457
+ schedule = train_cfg["schedule"]
458
+ warmup = train_cfg["warmup_steps"]
459
+ total = train_cfg["max_steps"]
460
+
461
+ if schedule == "constant":
462
+ from transformers import get_constant_schedule_with_warmup
463
+
464
+ return get_constant_schedule_with_warmup(optimizer, warmup)
465
+ if schedule == "cosine":
466
+ from transformers import get_cosine_schedule_with_warmup
467
+
468
+ return get_cosine_schedule_with_warmup(optimizer, warmup, total)
469
+ if schedule == "linear":
470
+ from transformers import get_linear_schedule_with_warmup
471
+
472
+ return get_linear_schedule_with_warmup(optimizer, warmup, total)
473
+ raise ValueError(f"unknown schedule: {schedule!r}")
474
+
475
+
476
+ def build_dataset_specs(data_cfg):
477
+ if "datasets" in data_cfg:
478
+ names = data_cfg["datasets"]
479
+ text_fields = data_cfg.get("text_fields", [data_cfg.get("text_field", "text")] * len(names))
480
+ weights = data_cfg.get("dataset_weights", [1.0] * len(names))
481
+ if not (len(names) == len(text_fields) == len(weights)):
482
+ raise ValueError("datasets/text_fields/dataset_weights length mismatch")
483
+ return [
484
+ {"name": name, "text_field": field, "weight": weight}
485
+ for name, field, weight in zip(names, text_fields, weights)
486
+ ]
487
+ return [
488
+ {
489
+ "name": data_cfg["dataset"],
490
+ "text_field": data_cfg["text_field"],
491
+ "weight": 1.0,
492
+ }
493
+ ]
494
+
495
+
496
+ def build_or_load_eval_cache(path, loader=None, samples=None):
497
+ path = Path(path)
498
+ if path.exists():
499
+ log.info(f"Loading eval cache from {path}")
500
+ raw = torch.load(path)
501
+ return [torch.tensor(x, dtype=torch.long) for x in raw]
502
+ if loader is None or samples is None:
503
+ raise ValueError("loader and samples are required when building a new eval cache")
504
+ path.parent.mkdir(parents=True, exist_ok=True)
505
+ log.info(f"Building eval cache at {path}")
506
+ batches = loader.next_batch(samples)
507
+ torch.save([x.tolist() for x in batches], path)
508
+ return batches
509
+
510
+
511
+ def log_jsonl(path: Path, record: dict):
512
+ path.parent.mkdir(parents=True, exist_ok=True)
513
+ with path.open("a") as f:
514
+ f.write(json.dumps(record, sort_keys=True) + "\n")
515
+
516
+
517
+ @torch.no_grad()
518
+ def evaluate(student, teacher, eval_batches, pad_id, kl_start_pos, kl_chunk_size, student_device, teacher_input_device):
519
+ student.eval()
520
+ total = 0.0
521
+ n = 0
522
+ for sample in eval_batches:
523
+ ids, mask = collate_pad([sample], pad_id)
524
+ teacher_ids = ids.to(teacher_input_device, non_blocking=True)
525
+ teacher_mask = mask.to(teacher_input_device, non_blocking=True)
526
+ student_ids = ids.to(student_device, non_blocking=True)
527
+ student_mask = mask.to(student_device, non_blocking=True)
528
+ t_logits = teacher_forward(teacher, teacher_ids, teacher_mask, student_device)
529
+ s_logits = student(input_ids=student_ids, attention_mask=student_mask).logits
530
+ loss = kl_loss_masked(
531
+ s_logits,
532
+ t_logits,
533
+ student_mask,
534
+ start_pos=kl_start_pos,
535
+ chunk_size=kl_chunk_size,
536
+ )
537
+ total += loss.item()
538
+ n += 1
539
+ del t_logits, s_logits, loss, teacher_ids, teacher_mask, student_ids, student_mask
540
+ student.train()
541
+ return total / max(n, 1)
542
+
543
+
544
+ def save_best(student, tokenizer, output_dir, step, eval_kl):
545
+ out_dir = Path(output_dir) / "best"
546
+ if out_dir.exists():
547
+ shutil.rmtree(out_dir)
548
+ out_dir.mkdir(parents=True, exist_ok=True)
549
+ student.save_pretrained(out_dir, safe_serialization=True)
550
+ tokenizer.save_pretrained(out_dir)
551
+ with (out_dir / "best.json").open("w") as f:
552
+ json.dump({"step": step, "eval_kl": eval_kl}, f, indent=2)
553
+ log.info(f"saved best @ step {step}: eval_kl={eval_kl:.6f} -> {out_dir}")
554
+
555
+
556
+ def main():
557
+ parser = argparse.ArgumentParser()
558
+ parser.add_argument("--config", required=True)
559
+ args = parser.parse_args()
560
+
561
+ cfg = load_config(args.config)
562
+ torch.manual_seed(cfg["train"]["seed"])
563
+ random.seed(cfg["train"]["seed"])
564
+
565
+ student_device = torch.device(cfg["model"]["student_device"])
566
+ teacher_devices = list(cfg["model"]["teacher_devices"])
567
+
568
+ from transformers import AutoTokenizer
569
+
570
+ tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["tokenizer"], trust_remote_code=True)
571
+ if tokenizer.pad_token is None:
572
+ tokenizer.pad_token = tokenizer.eos_token
573
+ pad_id = tokenizer.pad_token_id
574
+
575
+ student = load_student(
576
+ cfg["model"]["student"],
577
+ parse_dtype(cfg["train"]["student_dtype"]),
578
+ grad_ckpt=cfg["train"]["grad_checkpointing"],
579
+ attn_impl=cfg["train"]["attn_implementation"],
580
+ )
581
+ student.to(student_device)
582
+ student.config.use_cache = False
583
+
584
+ target_n = cfg["init"]["target_num_layers"]
585
+ cur_n = len(get_inner_with_layers(student).layers)
586
+ new_layer_indices = []
587
+ if target_n != cur_n:
588
+ new_n, new_zeroed = grow_layers(student, target_n)
589
+ new_layer_indices = [idx for idx, _ in new_zeroed]
590
+ log.info(f"Grew student from {cur_n} -> {new_n} layers")
591
+ for idx, names in new_zeroed:
592
+ log.info(f" layer {idx}: zeroed {names}")
593
+
594
+ zero_idx = cfg["init"]["zero_layers"]
595
+ if zero_idx:
596
+ n = zero_layers(student, zero_idx)
597
+ log.info(f"Zeroed student layers {zero_idx} (model has {n} layers)")
598
+
599
+ apply_trainable_masks(student, cfg["train"])
600
+ trainable_params = sum(p.numel() for p in student.parameters() if p.requires_grad)
601
+ total_params = sum(p.numel() for p in student.parameters())
602
+ if trainable_params == 0:
603
+ raise RuntimeError("No trainable parameters remain after applying trainable/freeze patterns")
604
+ log.info(f"Student params: total={total_params/1e9:.3f}B trainable={trainable_params/1e9:.3f}B")
605
+
606
+ teacher = load_teacher(
607
+ cfg["model"]["teacher"],
608
+ parse_dtype(cfg["train"]["teacher_dtype"]),
609
+ attn_impl=cfg["train"]["attn_implementation"],
610
+ devices=teacher_devices,
611
+ max_mem_gb=cfg["model"]["teacher_max_memory_gb"],
612
+ )
613
+ teacher_input_device, _ = get_teacher_devices(teacher)
614
+ log.info(f"Teacher input device: {teacher_input_device}")
615
+
616
+ optimizer = make_optimizer(student, cfg["train"], new_layer_indices=new_layer_indices)
617
+ scheduler = make_scheduler(optimizer, cfg["train"])
618
+
619
+ output_dir = Path(cfg["log"]["output_dir"])
620
+ output_dir.mkdir(parents=True, exist_ok=True)
621
+ shutil.copy2(args.config, output_dir / "config.snapshot.toml")
622
+ metrics_path = output_dir / "metrics.jsonl"
623
+ experiment_log = Path(cfg["log"]["experiment_log"])
624
+
625
+ use_wandb = cfg["log"]["wandb"]
626
+ if use_wandb:
627
+ import wandb
628
+
629
+ wandb.init(
630
+ project=cfg["log"]["wandb_project"],
631
+ name=cfg["log"]["wandb_run"],
632
+ config=cfg,
633
+ )
634
+
635
+ specs = build_dataset_specs(cfg["data"])
636
+ train_loader = MixedStreamingLoader(
637
+ specs=specs,
638
+ tokenizer=tokenizer,
639
+ min_chars=cfg["data"]["min_chars"],
640
+ max_seq_len=cfg["data"]["max_seq_len"],
641
+ kl_start_pos=cfg["data"]["kl_start_pos"],
642
+ seed=cfg["data"]["seed"],
643
+ shuffle_buffer=cfg["data"]["shuffle_buffer"],
644
+ )
645
+ eval_cache_path = Path(cfg["eval"]["cache_path"])
646
+ if eval_cache_path.exists():
647
+ eval_batches = build_or_load_eval_cache(eval_cache_path)
648
+ else:
649
+ eval_loader = MixedStreamingLoader(
650
+ specs=specs,
651
+ tokenizer=tokenizer,
652
+ min_chars=cfg["data"]["min_chars"],
653
+ max_seq_len=cfg["data"]["max_seq_len"],
654
+ kl_start_pos=cfg["data"]["kl_start_pos"],
655
+ seed=cfg["eval"]["seed"],
656
+ shuffle_buffer=cfg["data"]["shuffle_buffer"],
657
+ )
658
+ eval_batches = build_or_load_eval_cache(eval_cache_path, eval_loader, cfg["eval"]["samples"])
659
+ log.info(f"Eval samples: {len(eval_batches)}")
660
+
661
+ samples_per_step = cfg["train"]["samples_per_step"]
662
+ micro_batch_size = cfg["train"]["micro_batch_size"]
663
+ grad_clip = cfg["train"]["grad_clip"]
664
+ kl_start_pos = cfg["data"]["kl_start_pos"]
665
+ kl_chunk_size = cfg["train"]["kl_chunk_size"]
666
+ max_steps = cfg["train"]["max_steps"]
667
+ eval_every = cfg["eval"]["every_steps"]
668
+ log_every = cfg["log"]["log_every"]
669
+
670
+ student.train()
671
+ best_kl = float("inf")
672
+ global_step = 0
673
+ run_summary = {
674
+ "config": args.config,
675
+ "run_name": cfg["log"]["wandb_run"],
676
+ "student": cfg["model"]["student"],
677
+ "teacher": cfg["model"]["teacher"],
678
+ "start_time": int(time.time()),
679
+ }
680
+
681
+ while global_step < max_steps:
682
+ t0 = time.time()
683
+ batch = train_loader.next_batch(samples_per_step)
684
+ optimizer.zero_grad(set_to_none=True)
685
+ batch_n = len(batch)
686
+ kl_sum = 0.0
687
+
688
+ for mb_start in range(0, batch_n, micro_batch_size):
689
+ micro = batch[mb_start : mb_start + micro_batch_size]
690
+ mb_n = len(micro)
691
+ ids, mask = collate_pad(micro, pad_id)
692
+ teacher_ids = ids.to(teacher_input_device, non_blocking=True)
693
+ teacher_mask = mask.to(teacher_input_device, non_blocking=True)
694
+ student_ids = ids.to(student_device, non_blocking=True)
695
+ student_mask = mask.to(student_device, non_blocking=True)
696
+
697
+ with torch.no_grad():
698
+ t_logits = teacher_forward(teacher, teacher_ids, teacher_mask, student_device)
699
+ s_logits = student(input_ids=student_ids, attention_mask=student_mask).logits
700
+ loss = kl_loss_masked(
701
+ s_logits,
702
+ t_logits,
703
+ student_mask,
704
+ start_pos=kl_start_pos,
705
+ chunk_size=kl_chunk_size,
706
+ )
707
+ scaled = loss * (mb_n / batch_n)
708
+ scaled.backward()
709
+ kl_sum += loss.item() * mb_n
710
+ del teacher_ids, teacher_mask, student_ids, student_mask, t_logits, s_logits, loss, scaled
711
+
712
+ if grad_clip > 0:
713
+ torch.nn.utils.clip_grad_norm_(student.parameters(), grad_clip)
714
+ optimizer.step()
715
+ scheduler.step()
716
+ global_step += 1
717
+
718
+ elapsed = time.time() - t0
719
+ kl_avg = kl_sum / batch_n
720
+ lr_now = scheduler.get_last_lr()[0]
721
+ record = {
722
+ "step": global_step,
723
+ "train_kl": kl_avg,
724
+ "lr": lr_now,
725
+ "step_time_s": elapsed,
726
+ }
727
+ log_jsonl(metrics_path, record)
728
+
729
+ if global_step % log_every == 0:
730
+ log.info(
731
+ f"step {global_step}/{max_steps} | kl {kl_avg:.6f} | "
732
+ f"lr {lr_now:.2e} | {elapsed:.2f}s"
733
+ )
734
+ if use_wandb:
735
+ import wandb
736
+
737
+ wandb.log(
738
+ {
739
+ "train/kl": kl_avg,
740
+ "train/lr": lr_now,
741
+ "perf/step_time_s": elapsed,
742
+ },
743
+ step=global_step,
744
+ )
745
+
746
+ if global_step % eval_every == 0:
747
+ eval_kl = evaluate(
748
+ student,
749
+ teacher,
750
+ eval_batches,
751
+ pad_id,
752
+ kl_start_pos,
753
+ kl_chunk_size,
754
+ student_device,
755
+ teacher_input_device,
756
+ )
757
+ log.info(f"eval @ step {global_step}: kl={eval_kl:.6f} (best={best_kl:.6f})")
758
+ log_jsonl(metrics_path, {"step": global_step, "eval_kl": eval_kl})
759
+ if use_wandb:
760
+ import wandb
761
+
762
+ wandb.log({"eval/kl": eval_kl}, step=global_step)
763
+ if eval_kl < best_kl:
764
+ best_kl = eval_kl
765
+ save_best(student, tokenizer, output_dir, global_step, eval_kl)
766
+ student.train()
767
+
768
+ if global_step % 10 == 0:
769
+ gc.collect()
770
+ torch.cuda.empty_cache()
771
+
772
+ final_eval = evaluate(
773
+ student,
774
+ teacher,
775
+ eval_batches,
776
+ pad_id,
777
+ kl_start_pos,
778
+ kl_chunk_size,
779
+ student_device,
780
+ teacher_input_device,
781
+ )
782
+ log.info(f"final eval: kl={final_eval:.6f} (best={best_kl:.6f})")
783
+ if final_eval < best_kl:
784
+ best_kl = final_eval
785
+ save_best(student, tokenizer, output_dir, global_step, final_eval)
786
+
787
+ run_summary.update(
788
+ {
789
+ "end_time": int(time.time()),
790
+ "best_eval_kl": best_kl,
791
+ "final_eval_kl": final_eval,
792
+ "max_steps": max_steps,
793
+ "student_total_params": total_params,
794
+ "student_trainable_params": trainable_params,
795
+ }
796
+ )
797
+ log_jsonl(experiment_log, run_summary)
798
+
799
+ if use_wandb:
800
+ import wandb
801
+
802
+ wandb.log({"eval/final_kl": final_eval, "eval/best_kl": best_kl}, step=global_step)
803
+ wandb.finish()
804
+
805
+
806
+ if __name__ == "__main__":
807
+ main()