ASTERIZER commited on
Commit
1dec56b
Β·
verified Β·
1 Parent(s): 8f79c18

Upload sft_train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sft_train.py +778 -0
sft_train.py ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LUNA 100M β€” SFT Fine-Tuning Script
3
+ ====================================
4
+ Fine-tunes the pretrained LUNA-100M on instruction-following (SFT) data.
5
+
6
+ Features:
7
+ - Loads pretrained checkpoint (latest.pt from pretraining)
8
+ - JSON-based SFT dataset (instruction/input/output format)
9
+ - Prompt masking: loss computed only on the output portion
10
+ - Checkpoint eval: runs identity + knowledge prompts after each save
11
+ - Cosine LR with warmup
12
+ - Auto hardware detection (same as train.py)
13
+
14
+ Usage:
15
+ python sft_train.py # uses sft_config.yaml
16
+ python sft_train.py --config sft_config.yaml # explicit config
17
+ python sft_train.py --train_json /data/train.json # override data path
18
+ """
19
+
20
+ import os
21
+ import gc
22
+ import sys
23
+ import math
24
+ import time
25
+ import json
26
+ import argparse
27
+ import yaml
28
+ import psutil
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ from torch.amp import autocast, GradScaler
33
+ from pathlib import Path
34
+
35
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
36
+
37
+
38
+ # ─── Model (identical to train.py) ───────────────────────────────────────────
39
+
40
+ class RotaryEmbedding(nn.Module):
41
+ def __init__(self, dim, max_seq_len=1024):
42
+ super().__init__()
43
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
44
+ self.register_buffer("inv_freq", inv_freq)
45
+ t = torch.arange(max_seq_len).float()
46
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
47
+ emb = torch.cat([freqs, freqs], dim=-1)
48
+ self.register_buffer("cos_cached", emb.cos())
49
+ self.register_buffer("sin_cached", emb.sin())
50
+
51
+ def forward(self, seq_len):
52
+ return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
53
+
54
+
55
+ def rotate_half(x):
56
+ x1, x2 = x.chunk(2, dim=-1)
57
+ return torch.cat([-x2, x1], dim=-1)
58
+
59
+
60
+ def apply_rotary(x, cos, sin):
61
+ c = cos.unsqueeze(0).unsqueeze(0)
62
+ s = sin.unsqueeze(0).unsqueeze(0)
63
+ return x * c + rotate_half(x) * s
64
+
65
+
66
+ class CausalSelfAttention(nn.Module):
67
+ def __init__(self, n_embd, n_head, block_size, rotary_pct=0.25):
68
+ super().__init__()
69
+ self.n_head = n_head
70
+ self.head_dim = n_embd // n_head
71
+ self.rot_dim = int(self.head_dim * rotary_pct)
72
+ self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True)
73
+ self.c_proj = nn.Linear(n_embd, n_embd, bias=True)
74
+ self.rotary = RotaryEmbedding(self.rot_dim, block_size)
75
+
76
+ def forward(self, x):
77
+ B, T, C = x.size()
78
+ qkv = self.c_attn(x).reshape(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4)
79
+ q, k, v = qkv.unbind(0)
80
+ cos, sin = self.rotary(T)
81
+ q = torch.cat([apply_rotary(q[..., :self.rot_dim], cos, sin), q[..., self.rot_dim:]], dim=-1)
82
+ k = torch.cat([apply_rotary(k[..., :self.rot_dim], cos, sin), k[..., self.rot_dim:]], dim=-1)
83
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
84
+ return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C))
85
+
86
+
87
+ class MLP(nn.Module):
88
+ def __init__(self, n_embd):
89
+ super().__init__()
90
+ self.fc = nn.Linear(n_embd, 4 * n_embd, bias=True)
91
+ self.gelu = nn.GELU()
92
+ self.proj = nn.Linear(4 * n_embd, n_embd, bias=True)
93
+
94
+ def forward(self, x):
95
+ return self.proj(self.gelu(self.fc(x)))
96
+
97
+
98
+ class Block(nn.Module):
99
+ def __init__(self, n_embd, n_head, block_size):
100
+ super().__init__()
101
+ self.ln1 = nn.LayerNorm(n_embd)
102
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size)
103
+ self.ln2 = nn.LayerNorm(n_embd)
104
+ self.mlp = MLP(n_embd)
105
+
106
+ def forward(self, x):
107
+ x = x + self.attn(self.ln1(x))
108
+ x = x + self.mlp(self.ln2(x))
109
+ return x
110
+
111
+
112
+ class LUNAModel(nn.Module):
113
+ def __init__(self, vocab_size, block_size, n_layer, n_embd, n_head):
114
+ super().__init__()
115
+ self.block_size = block_size
116
+ self.wte = nn.Embedding(vocab_size, n_embd)
117
+ self.blocks = nn.ModuleList([Block(n_embd, n_head, block_size) for _ in range(n_layer)])
118
+ self.ln_f = nn.LayerNorm(n_embd)
119
+ self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
120
+ self.lm_head.weight = self.wte.weight # tied
121
+
122
+ def _init_weights(self, m):
123
+ if isinstance(m, (nn.Linear, nn.Embedding)):
124
+ m.weight.data.normal_(mean=0.0, std=0.02)
125
+ if isinstance(m, nn.Linear) and m.bias is not None:
126
+ m.bias.data.zero_()
127
+
128
+ def forward(self, idx, targets=None, loss_mask=None, return_logits=True):
129
+ x = self.wte(idx)
130
+ for block in self.blocks:
131
+ x = block(x)
132
+ x = self.ln_f(x)
133
+ logits = self.lm_head(x)
134
+ loss = None
135
+ if targets is not None:
136
+ shift_logits = logits[:, :-1, :].contiguous()
137
+ shift_targets = targets[:, 1:].contiguous()
138
+ if loss_mask is not None:
139
+ shift_mask = loss_mask[:, 1:].contiguous()
140
+ # Only compute loss on output tokens
141
+ flat_logits = shift_logits.view(-1, shift_logits.size(-1))
142
+ flat_targets = shift_targets.view(-1)
143
+ flat_mask = shift_mask.view(-1).float()
144
+ per_token_loss = F.cross_entropy(flat_logits, flat_targets, reduction='none')
145
+ loss = (per_token_loss * flat_mask).sum() / flat_mask.sum().clamp(min=1)
146
+ else:
147
+ loss = F.cross_entropy(
148
+ shift_logits.view(-1, shift_logits.size(-1)),
149
+ shift_targets.view(-1)
150
+ )
151
+ if not return_logits:
152
+ logits = None
153
+ return logits, loss
154
+
155
+ @property
156
+ def num_params(self):
157
+ return sum(p.numel() for p in self.parameters()) - self.wte.weight.numel()
158
+
159
+
160
+ # ─── SFT Dataset ──────────────────────────────────────────────────────────────
161
+
162
+ class SFTDataset(torch.utils.data.Dataset):
163
+ """
164
+ Loads JSON SFT data (instruction/input/output) and tokenizes with prompt masking.
165
+ Format per entry: {"instruction": "...", "input": "...", "output": "..."}
166
+
167
+ Prompt template (Alpaca-style):
168
+ ### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n{output}<|endoftext|>
169
+
170
+ Loss mask: 0 for prompt tokens, 1 for response tokens (including EOS).
171
+ """
172
+
173
+ def __init__(self, json_path, tokenizer, max_len=1024):
174
+ with open(json_path, "r", encoding="utf-8") as f:
175
+ self.data = json.load(f)
176
+ self.tokenizer = tokenizer
177
+ self.max_len = max_len
178
+ self.eos_id = tokenizer.eos_token_id or 0
179
+
180
+ def __len__(self):
181
+ return len(self.data)
182
+
183
+ def _format_prompt(self, entry):
184
+ inst = entry.get("instruction", "").strip()
185
+ inp = entry.get("input", "").strip()
186
+ out = entry.get("output", "").strip()
187
+
188
+ if inst and inp:
189
+ prompt = f"### Instruction:\n{inst}\n\n### Input:\n{inp}\n\n### Response:\n"
190
+ elif inst:
191
+ prompt = f"### Instruction:\n{inst}\n\n### Response:\n"
192
+ else:
193
+ # input-only format
194
+ prompt = f"### Input:\n{inp}\n\n### Response:\n"
195
+
196
+ return prompt, out
197
+
198
+ def __getitem__(self, idx):
199
+ entry = self.data[idx]
200
+ prompt, response = self._format_prompt(entry)
201
+
202
+ prompt_ids = self.tokenizer.encode(prompt)
203
+ response_ids = self.tokenizer.encode(response) + [self.eos_id]
204
+
205
+ total_ids = prompt_ids + response_ids
206
+
207
+ # Truncate to max_len
208
+ if len(total_ids) > self.max_len:
209
+ total_ids = total_ids[:self.max_len]
210
+ # Ensure EOS at end
211
+ total_ids[-1] = self.eos_id
212
+ # Recalculate prompt boundary
213
+ prompt_len = min(len(prompt_ids), self.max_len)
214
+ else:
215
+ prompt_len = len(prompt_ids)
216
+
217
+ # Build loss mask: 0 for prompt, 1 for response
218
+ loss_mask = [0] * prompt_len + [1] * (len(total_ids) - prompt_len)
219
+
220
+ # Pad to max_len
221
+ pad_len = self.max_len - len(total_ids)
222
+ total_ids = total_ids + [self.eos_id] * pad_len
223
+ loss_mask = loss_mask + [0] * pad_len # don't compute loss on padding
224
+
225
+ input_ids = torch.tensor(total_ids, dtype=torch.long)
226
+ loss_mask = torch.tensor(loss_mask, dtype=torch.long)
227
+
228
+ return input_ids, loss_mask
229
+
230
+
231
+ # ─── Generation (for eval) ───────────────────────────────────────────────────
232
+
233
+ @torch.no_grad()
234
+ def generate(model, input_ids, max_new=150, temperature=0.7,
235
+ top_p=0.9, top_k=40, device="cpu"):
236
+ model.eval()
237
+ ids = input_ids.clone().to(device)
238
+ for _ in range(max_new):
239
+ ctx = ids[:, -model.block_size:]
240
+ logits, _ = model(ctx)
241
+ logits = logits[:, -1, :] / max(temperature, 1e-8)
242
+ if top_k > 0:
243
+ vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
244
+ logits[logits < vals[:, -1:]] = -float("inf")
245
+ probs = torch.softmax(logits, dim=-1)
246
+ if top_p < 1.0:
247
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True)
248
+ cum = torch.cumsum(sorted_probs, dim=-1)
249
+ mask = cum - sorted_probs > top_p
250
+ sorted_probs[mask] = 0.0
251
+ sorted_probs /= sorted_probs.sum()
252
+ next_token = sorted_idx[0, torch.multinomial(sorted_probs[0], 1)]
253
+ else:
254
+ next_token = torch.multinomial(probs[0], 1)
255
+ ids = torch.cat([ids, next_token.view(1, 1)], dim=1)
256
+ if next_token.item() == 0: # EOS
257
+ break
258
+ model.train()
259
+ return ids[0, input_ids.size(1):]
260
+
261
+
262
+ def run_eval_prompts(model, tokenizer, prompts, device, step, out_dir):
263
+ """Run eval prompts and print + log results."""
264
+ model.eval()
265
+ results = []
266
+ sep = "─" * 60
267
+
268
+ print(f"\n{sep}")
269
+ print(f" EVAL @ step {step}")
270
+ print(sep)
271
+
272
+ for prompt_text in prompts:
273
+ # Format as instruction
274
+ formatted = f"### Instruction:\n{prompt_text}\n\n### Response:\n"
275
+ ids = tokenizer.encode(formatted, return_tensors="pt").to(device)
276
+ out_ids = generate(model, ids, max_new=150, temperature=0.7, device=device)
277
+ response = tokenizer.decode(out_ids.tolist(), skip_special_tokens=True).strip()
278
+
279
+ print(f" Q: {prompt_text}")
280
+ print(f" A: {response[:200]}")
281
+ print()
282
+ results.append({"prompt": prompt_text, "response": response[:500]})
283
+
284
+ print(sep)
285
+
286
+ # Save eval log
287
+ eval_dir = Path(out_dir) / "evals"
288
+ eval_dir.mkdir(parents=True, exist_ok=True)
289
+ with open(eval_dir / f"eval_step_{step:06d}.json", "w", encoding="utf-8") as f:
290
+ json.dump(results, f, indent=2, ensure_ascii=False)
291
+
292
+ model.train()
293
+ return results
294
+
295
+
296
+ # ─── Hardware Detection (same as train.py) ────────────────────────────────────
297
+
298
+ def probe_hardware():
299
+ info = {
300
+ "cpu_cores": os.cpu_count() or 4,
301
+ "ram_gb": psutil.virtual_memory().total / 1024**3,
302
+ }
303
+ if torch.cuda.is_available():
304
+ props = torch.cuda.get_device_properties(0)
305
+ info.update({
306
+ "device": "cuda",
307
+ "gpu_name": props.name,
308
+ "vram_gb": props.total_memory / 1024**3,
309
+ "sm_major": props.major,
310
+ })
311
+ if props.major >= 8:
312
+ torch.backends.cuda.matmul.allow_tf32 = True
313
+ torch.backends.cudnn.allow_tf32 = True
314
+ info["precision"] = "bf16"
315
+ info["dtype"] = torch.bfloat16
316
+ else:
317
+ info["precision"] = "fp16"
318
+ info["dtype"] = torch.float16
319
+ else:
320
+ info.update({
321
+ "device": "cpu", "gpu_name": "CPU", "vram_gb": 0,
322
+ "sm_major": 0, "precision": "fp32", "dtype": torch.float32,
323
+ })
324
+ return info
325
+
326
+
327
+ def probe_max_batch(model, device, dtype, seq_len, vocab_size, grad_accum_sim=4):
328
+ """Binary search for max micro_batch. Safety: x0.70."""
329
+ tmp_opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
330
+ lo, hi, best = 1, 512, 1
331
+ while lo <= hi:
332
+ mid = (lo + hi) // 2
333
+ try:
334
+ torch.cuda.empty_cache(); gc.collect()
335
+ tmp_opt.zero_grad(set_to_none=True)
336
+ for _ in range(grad_accum_sim):
337
+ x = torch.randint(0, vocab_size, (mid, seq_len), device=device)
338
+ mask = torch.ones_like(x)
339
+ with autocast(device_type="cuda", dtype=dtype):
340
+ _, loss = model(x, x, loss_mask=mask, return_logits=False)
341
+ loss = loss / grad_accum_sim
342
+ loss.backward()
343
+ del x, mask, loss
344
+ tmp_opt.step()
345
+ tmp_opt.zero_grad(set_to_none=True)
346
+ best = mid; lo = mid + 1
347
+ torch.cuda.empty_cache()
348
+ except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
349
+ if "out of memory" in str(e).lower() or isinstance(e, torch.cuda.OutOfMemoryError):
350
+ try: del x, mask, loss
351
+ except: pass
352
+ torch.cuda.empty_cache()
353
+ tmp_opt.zero_grad(set_to_none=True)
354
+ hi = mid - 1
355
+ else:
356
+ raise
357
+ del tmp_opt; torch.cuda.empty_cache(); gc.collect()
358
+ safe = max(1, int(best * 0.70))
359
+ print(f" Probe: max_batch={best}, using {safe} (70% safety)")
360
+ return safe
361
+
362
+
363
+ # ─── LR Schedule ──────────────────────────────────────────────────────────────
364
+
365
+ def cosine_lr(step, warmup, total, lr_max, lr_min):
366
+ if step < warmup:
367
+ return lr_max * (step + 1) / warmup
368
+ p = (step - warmup) / max(1, total - warmup)
369
+ return lr_min + 0.5 * (1 + math.cos(math.pi * p)) * (lr_max - lr_min)
370
+
371
+
372
+ # ─── Config ───────────────────────────────────────────────────────────────────
373
+
374
+ def load_sft_config(config_path):
375
+ with open(config_path, encoding="utf-8") as f:
376
+ raw = yaml.safe_load(f)
377
+
378
+ cfg = {
379
+ "auto_config": raw.get("auto_config", True),
380
+ "hf_model_repo": raw.get("hf_model_repo", "ASTERIZER/LUNA-100M"),
381
+ "hf_dataset_repo": raw.get("hf_dataset_repo", "ASTERIZER/Luna_Dataset"),
382
+ "pretrained_ckpt": raw.get("pretrained_ckpt", "Base/out/pretrain/luna_100m/latest.pt"),
383
+ "train_json": raw.get("train_json", "Base/Datasets/sft_clean/train.json"),
384
+ "val_json": raw.get("val_json", "Base/Datasets/sft_clean/val.json"),
385
+ "out_dir": raw.get("out_dir", "Base/out/sft/luna_100m_sft"),
386
+ "tokenizer_dir": raw.get("tokenizer_dir", "Base/checkpoints/EleutherAI/pythia-160m"),
387
+ # model
388
+ "vocab_size": raw["model"]["vocab_size"],
389
+ "seq_len": raw["model"]["seq_len"],
390
+ "n_layer": raw["model"]["n_layer"],
391
+ "n_embd": raw["model"]["n_embd"],
392
+ "n_head": raw["model"]["n_head"],
393
+ # train
394
+ "epochs": raw["train"]["epochs"],
395
+ "max_tokens": raw["train"].get("max_tokens", 0),
396
+ "lr_warmup_steps": raw["train"]["lr_warmup_steps"],
397
+ "save_interval": raw["train"]["save_interval"],
398
+ "log_interval": raw["train"]["log_interval"],
399
+ "eval_interval": raw["train"]["eval_interval"],
400
+ "max_norm": raw["train"]["max_norm"],
401
+ # optimizer
402
+ "lr": raw["optimizer"]["lr"],
403
+ "min_lr": raw["optimizer"]["min_lr"],
404
+ "weight_decay": raw["optimizer"]["weight_decay"],
405
+ "betas": tuple(raw["optimizer"]["betas"]),
406
+ "eps": raw["optimizer"]["eps"],
407
+ # batch
408
+ "global_batch": raw["batch"]["global_batch"],
409
+ "micro_batch": raw["batch"]["micro_batch"],
410
+ "grad_accum": raw["batch"]["grad_accum"],
411
+ # dataloader
412
+ "num_workers": raw["dataloader"]["num_workers"],
413
+ "pin_memory": raw["dataloader"]["pin_memory"],
414
+ # hardware
415
+ "precision": raw["hardware"]["precision"],
416
+ # eval prompts
417
+ "eval_prompts": raw.get("eval_prompts", []),
418
+ }
419
+ return cfg
420
+
421
+
422
+ # ─── Training ─────────────────────────────────────────────────────────────────
423
+
424
+ SEP = "=" * 72
425
+
426
+ def sft_train(cfg):
427
+ hw = probe_hardware()
428
+ device = torch.device(hw["device"])
429
+
430
+ if device.type == "cuda":
431
+ torch.cuda.empty_cache(); gc.collect()
432
+
433
+ # Precision
434
+ if cfg["auto_config"]:
435
+ dtype = hw.get("dtype", torch.float32)
436
+ cfg["precision"] = hw["precision"]
437
+ else:
438
+ dtype = {"bf16": torch.bfloat16, "fp16": torch.float16,
439
+ "fp32": torch.float32}.get(cfg["precision"], torch.float32)
440
+
441
+ print(SEP)
442
+ print(" LUNA 100M - SFT Fine-Tuning")
443
+ print(SEP)
444
+ print(f" GPU : {hw['gpu_name']} ({hw['vram_gb']:.1f} GB)")
445
+ print(f" RAM : {hw['ram_gb']:.1f} GB CPU: {hw['cpu_cores']} cores")
446
+ print(f" Precision : {cfg['precision']} dtype={dtype}")
447
+ print(f" Pretrained : {cfg['pretrained_ckpt']}")
448
+
449
+ # ── Tokenizer ─────────────────────────────────────────────────────────────
450
+ from transformers import AutoTokenizer
451
+ tokenizer = AutoTokenizer.from_pretrained(cfg["tokenizer_dir"])
452
+ print(f" Tokenizer : {cfg['tokenizer_dir']} (vocab={tokenizer.vocab_size})")
453
+
454
+ # ── Model ─────────────────────────────────────────────────────────────────
455
+ print(f"\n Building LUNA-100M...")
456
+ model = LUNAModel(
457
+ vocab_size=cfg["vocab_size"],
458
+ block_size=cfg["seq_len"],
459
+ n_layer=cfg["n_layer"],
460
+ n_embd=cfg["n_embd"],
461
+ n_head=cfg["n_head"],
462
+ ).to(device)
463
+ print(f" Parameters: {model.num_params:,} (unique)")
464
+
465
+ # ── Load pretrained weights ───────────────────────────────────────────────
466
+ ckpt_path = Path(cfg["pretrained_ckpt"])
467
+ if not ckpt_path.exists() and cfg.get("hf_model_repo"):
468
+ # Auto-download from HuggingFace model repo
469
+ print(f"\n Pretrained checkpoint not found locally.")
470
+ print(f" Downloading from HuggingFace: {cfg['hf_model_repo']}")
471
+ from huggingface_hub import hf_hub_download
472
+ ckpt_path.parent.mkdir(parents=True, exist_ok=True)
473
+ hf_hub_download(
474
+ repo_id=cfg["hf_model_repo"],
475
+ filename="latest.pt",
476
+ local_dir=str(ckpt_path.parent),
477
+ token=os.environ.get("HF_TOKEN"),
478
+ )
479
+ print(f" Downloaded to: {ckpt_path}")
480
+
481
+ if ckpt_path.exists():
482
+ print(f"\n Loading pretrained checkpoint: {ckpt_path}")
483
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
484
+ state = ckpt["model"] if "model" in ckpt else ckpt
485
+ model.load_state_dict(state, strict=True)
486
+ pretrain_step = ckpt.get("step", "?")
487
+ pretrain_tokens = ckpt.get("tokens_seen", 0)
488
+ print(f" Pretrained @ step {pretrain_step}, tokens seen: {pretrain_tokens:,}")
489
+ # Do NOT load optimizer state β€” we start fresh for SFT
490
+ else:
491
+ print(f"\n WARNING: No pretrained checkpoint at {ckpt_path}")
492
+ print(f" Training from scratch (not recommended for SFT)!")
493
+
494
+ # ── Dataset (auto-download from HF if missing) ─────────────────────────────
495
+ train_path = Path(cfg["train_json"])
496
+ val_path = Path(cfg["val_json"]) if cfg["val_json"] else None
497
+
498
+ if not train_path.exists() and cfg.get("hf_dataset_repo"):
499
+ print(f"\n SFT dataset not found locally.")
500
+ print(f" Downloading from HuggingFace: {cfg['hf_dataset_repo']}")
501
+ from huggingface_hub import hf_hub_download
502
+ train_path.parent.mkdir(parents=True, exist_ok=True)
503
+ hf_hub_download(
504
+ repo_id=cfg["hf_dataset_repo"],
505
+ repo_type="dataset",
506
+ filename="train.json",
507
+ local_dir=str(train_path.parent),
508
+ token=os.environ.get("HF_TOKEN"),
509
+ )
510
+ print(f" Downloaded train.json")
511
+ if val_path:
512
+ hf_hub_download(
513
+ repo_id=cfg["hf_dataset_repo"],
514
+ repo_type="dataset",
515
+ filename="val.json",
516
+ local_dir=str(val_path.parent),
517
+ token=os.environ.get("HF_TOKEN"),
518
+ )
519
+ print(f" Downloaded val.json")
520
+
521
+ print(f"\n Train data: {cfg['train_json']}")
522
+ train_dataset = SFTDataset(cfg["train_json"], tokenizer, max_len=cfg["seq_len"])
523
+ print(f" Train entries: {len(train_dataset):,}")
524
+
525
+ val_dataset = None
526
+ if cfg["val_json"] and Path(cfg["val_json"]).exists():
527
+ val_dataset = SFTDataset(cfg["val_json"], tokenizer, max_len=cfg["seq_len"])
528
+ print(f" Val entries: {len(val_dataset):,}")
529
+
530
+ # ── Batch sizing ──────────────────────────────────────────────────────────
531
+ if cfg["auto_config"] and device.type == "cuda":
532
+ print(f"\n Probing max micro_batch_size...")
533
+ max_mbs = probe_max_batch(model, device, dtype, cfg["seq_len"], cfg["vocab_size"])
534
+ model.load_state_dict(state, strict=True) # reinit after probe
535
+ torch.cuda.empty_cache(); gc.collect()
536
+ grad_accum = max(1, math.ceil(cfg["global_batch"] / max_mbs))
537
+ effective_batch = max_mbs * grad_accum
538
+ else:
539
+ max_mbs = cfg["micro_batch"]
540
+ grad_accum = cfg["grad_accum"]
541
+ effective_batch = max_mbs * grad_accum
542
+
543
+ print(f" micro_batch={max_mbs}, grad_accum={grad_accum}, effective={effective_batch}")
544
+
545
+ # ── DataLoader ────────────────────────────────────────────────────────────
546
+ train_loader = torch.utils.data.DataLoader(
547
+ train_dataset,
548
+ batch_size=max_mbs,
549
+ shuffle=True,
550
+ num_workers=cfg["num_workers"],
551
+ pin_memory=cfg["pin_memory"],
552
+ drop_last=True,
553
+ prefetch_factor=4 if cfg["num_workers"] > 0 else None,
554
+ persistent_workers=cfg["num_workers"] > 0,
555
+ )
556
+
557
+ val_loader = None
558
+ if val_dataset:
559
+ val_loader = torch.utils.data.DataLoader(
560
+ val_dataset, batch_size=max_mbs, shuffle=False,
561
+ num_workers=min(2, cfg["num_workers"]),
562
+ pin_memory=cfg["pin_memory"], drop_last=False,
563
+ )
564
+
565
+ # ── Optimizer ─────────────────────────────────────────────────────────────
566
+ try:
567
+ optimizer = torch.optim.AdamW(
568
+ model.parameters(), lr=cfg["lr"],
569
+ weight_decay=cfg["weight_decay"],
570
+ betas=cfg["betas"], eps=cfg["eps"], fused=True,
571
+ )
572
+ except TypeError:
573
+ optimizer = torch.optim.AdamW(
574
+ model.parameters(), lr=cfg["lr"],
575
+ weight_decay=cfg["weight_decay"],
576
+ betas=cfg["betas"], eps=cfg["eps"],
577
+ )
578
+
579
+ use_scaler = dtype == torch.float16
580
+ scaler = GradScaler(enabled=use_scaler)
581
+
582
+ # ── Schedule ──────────────────────────────────────────────────────────────
583
+ steps_per_epoch = len(train_loader) // grad_accum
584
+ total_steps = steps_per_epoch * cfg["epochs"]
585
+ warmup_steps = min(cfg["lr_warmup_steps"], total_steps // 5)
586
+
587
+ out_dir = Path(cfg["out_dir"])
588
+ out_dir.mkdir(parents=True, exist_ok=True)
589
+
590
+ print(f"\n Epochs : {cfg['epochs']}")
591
+ print(f" Steps/epoch : {steps_per_epoch:,}")
592
+ print(f" Total steps : {total_steps:,}")
593
+ print(f" Warmup steps : {warmup_steps}")
594
+ print(f" LR : {cfg['lr']:.2e} -> {cfg['min_lr']:.2e}")
595
+ print(f" Save every : {cfg['save_interval']} steps")
596
+ print(f" Eval every : {cfg['eval_interval']} steps")
597
+ print(f" Eval prompts : {len(cfg['eval_prompts'])}")
598
+ print(f" Out dir : {out_dir}")
599
+ print(SEP)
600
+
601
+ # ── Resume SFT ────────────────────────────────────────────────────────────
602
+ start_step = 0
603
+ sft_ckpt_path = out_dir / "latest.pt"
604
+ if sft_ckpt_path.exists():
605
+ print(f"\n Resuming SFT from {sft_ckpt_path}...")
606
+ sft_ckpt = torch.load(sft_ckpt_path, map_location=device, weights_only=True)
607
+ model.load_state_dict(sft_ckpt["model"])
608
+ optimizer.load_state_dict(sft_ckpt["optimizer"])
609
+ start_step = sft_ckpt["step"]
610
+ print(f" Resumed at SFT step {start_step}")
611
+
612
+ # ── Eval at start ─────────────────────────────────────────────────────────
613
+ if cfg["eval_prompts"] and start_step == 0:
614
+ print("\n Running initial eval (before SFT)...")
615
+ run_eval_prompts(model, tokenizer, cfg["eval_prompts"], device, 0, out_dir)
616
+
617
+ # ── Training loop ─────────────────────────────────────────────────────────
618
+ model.train()
619
+ run_t0 = time.perf_counter()
620
+ step = start_step
621
+ best_val_loss = float("inf")
622
+
623
+ print(f"\n Starting SFT training (step {start_step} -> {total_steps})...")
624
+
625
+ for epoch in range(cfg["epochs"]):
626
+ data_iter = iter(train_loader)
627
+ micro_step = 0
628
+
629
+ for batch_idx, (input_ids, loss_mask) in enumerate(data_iter):
630
+ # Skip already-done steps on resume
631
+ current_global_step = epoch * steps_per_epoch + (micro_step // grad_accum)
632
+ if current_global_step < start_step and (micro_step % grad_accum == grad_accum - 1):
633
+ micro_step += 1
634
+ continue
635
+ if current_global_step >= total_steps:
636
+ break
637
+
638
+ input_ids = input_ids.to(device, non_blocking=True)
639
+ loss_mask = loss_mask.to(device, non_blocking=True)
640
+
641
+ t0 = time.perf_counter()
642
+
643
+ # Accumulation step
644
+ with autocast(device_type=device.type, dtype=dtype, enabled=(device.type == "cuda")):
645
+ _, loss = model(input_ids, targets=input_ids, loss_mask=loss_mask, return_logits=False)
646
+ loss = loss / grad_accum
647
+
648
+ scaler.scale(loss).backward()
649
+ micro_step += 1
650
+
651
+ # Optimizer step after grad_accum micro-batches
652
+ if micro_step % grad_accum == 0:
653
+ scaler.unscale_(optimizer)
654
+ torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["max_norm"])
655
+
656
+ # LR schedule
657
+ lr_now = cosine_lr(step, warmup_steps, total_steps, cfg["lr"], cfg["min_lr"])
658
+ for pg in optimizer.param_groups:
659
+ pg["lr"] = lr_now
660
+
661
+ scaler.step(optimizer)
662
+ scaler.update()
663
+ optimizer.zero_grad(set_to_none=True)
664
+
665
+ if device.type == "cuda":
666
+ torch.cuda.synchronize()
667
+
668
+ dt = time.perf_counter() - t0
669
+ step += 1
670
+
671
+ # ── Log ───────────────────────────────────────────────────────
672
+ if step % cfg["log_interval"] == 0 or step <= 3:
673
+ tokens_step = effective_batch * cfg["seq_len"]
674
+ tps = tokens_step / dt if dt > 0 else 0
675
+ vram = torch.cuda.max_memory_allocated() / 1024**3 if device.type == "cuda" else 0
676
+ eta_h = (total_steps - step) * dt / 3600
677
+ print(f" step {step:6d}/{total_steps} | epoch {epoch+1}/{cfg['epochs']} | "
678
+ f"loss {loss.item()*grad_accum:.4f} | lr {lr_now:.2e} | "
679
+ f"{tps:,.0f} tok/s | VRAM {vram:.1f}GB | ETA {eta_h:.1f}h")
680
+
681
+ # ── Save checkpoint ───────────────────────────────────────────
682
+ if step % cfg["save_interval"] == 0 or step == total_steps:
683
+ raw_model = model._orig_mod if hasattr(model, "_orig_mod") else model
684
+ step_dir = out_dir / f"step-{step:06d}"
685
+ step_dir.mkdir(parents=True, exist_ok=True)
686
+ torch.save(raw_model.state_dict(), step_dir / "model.pth")
687
+ torch.save({
688
+ "step": step,
689
+ "model": raw_model.state_dict(),
690
+ "optimizer": optimizer.state_dict(),
691
+ "epoch": epoch,
692
+ "sft_loss": loss.item() * grad_accum,
693
+ }, out_dir / "latest.pt")
694
+ print(f" Saved -> {step_dir}")
695
+
696
+ # ── Eval ──────────────────────────────────────────────────────
697
+ if step % cfg["eval_interval"] == 0 or step == total_steps:
698
+ # Validation loss
699
+ if val_loader:
700
+ model.eval()
701
+ val_loss_sum = 0.0
702
+ val_count = 0
703
+ with torch.no_grad():
704
+ for val_ids, val_mask in val_loader:
705
+ val_ids = val_ids.to(device, non_blocking=True)
706
+ val_mask = val_mask.to(device, non_blocking=True)
707
+ with autocast(device_type=device.type, dtype=dtype, enabled=(device.type == "cuda")):
708
+ _, vl = model(val_ids, targets=val_ids, loss_mask=val_mask, return_logits=False)
709
+ val_loss_sum += vl.item()
710
+ val_count += 1
711
+ if val_count >= 50: # cap eval to 50 batches
712
+ break
713
+ avg_val = val_loss_sum / max(val_count, 1)
714
+ print(f" Val loss: {avg_val:.4f}")
715
+ if avg_val < best_val_loss:
716
+ best_val_loss = avg_val
717
+ raw_model = model._orig_mod if hasattr(model, "_orig_mod") else model
718
+ torch.save(raw_model.state_dict(), out_dir / "best_model.pth")
719
+ print(f" New best! Saved best_model.pth")
720
+ model.train()
721
+
722
+ # Run eval prompts
723
+ if cfg["eval_prompts"]:
724
+ run_eval_prompts(model, tokenizer, cfg["eval_prompts"],
725
+ device, step, out_dir)
726
+
727
+ # ── Final ─────────────────────────────────────────────────────────────────
728
+ final_dir = out_dir / "final"
729
+ final_dir.mkdir(parents=True, exist_ok=True)
730
+ raw_model = model._orig_mod if hasattr(model, "_orig_mod") else model
731
+ torch.save(raw_model.state_dict(), final_dir / "model.pth")
732
+ torch.save({
733
+ "step": step,
734
+ "model": raw_model.state_dict(),
735
+ "sft_complete": True,
736
+ }, out_dir / "latest.pt")
737
+
738
+ # Copy tokenizer
739
+ import shutil
740
+ tok_src = Path(cfg["tokenizer_dir"])
741
+ if tok_src.exists():
742
+ shutil.copytree(tok_src, final_dir / "tokenizer", dirs_exist_ok=True)
743
+
744
+ total_h = (time.perf_counter() - run_t0) / 3600
745
+ print(SEP)
746
+ print(f" SFT Complete! {total_h:.2f}h -> {final_dir}")
747
+ print(f" Best val loss: {best_val_loss:.4f}")
748
+ print(SEP)
749
+
750
+
751
+ # ─── Entry ────────────────────────────────────────────────────────────────────
752
+
753
+ def parse_args():
754
+ p = argparse.ArgumentParser(description="LUNA 100M β€” SFT Fine-Tuning")
755
+ p.add_argument("--config", default="sft_config.yaml")
756
+ p.add_argument("--pretrained_ckpt", default=None)
757
+ p.add_argument("--train_json", default=None)
758
+ p.add_argument("--val_json", default=None)
759
+ p.add_argument("--out_dir", default=None)
760
+ p.add_argument("--epochs", type=int, default=None)
761
+ p.add_argument("--lr", type=float, default=None)
762
+ p.add_argument("--micro_batch",type=int, default=None)
763
+ p.add_argument("--global_batch",type=int, default=None)
764
+ p.add_argument("--save_interval", type=int, default=None)
765
+ p.add_argument("--eval_interval", type=int, default=None)
766
+ p.add_argument("--auto_config", type=lambda x: x.lower() in ("1","true","yes"),
767
+ default=None)
768
+ return p.parse_args()
769
+
770
+
771
+ if __name__ == "__main__":
772
+ args = parse_args()
773
+ cfg = load_sft_config(args.config)
774
+ # CLI overrides
775
+ for key, val in vars(args).items():
776
+ if key != "config" and val is not None:
777
+ cfg[key] = val
778
+ sft_train(cfg)