codi-trace / code /train /train_codi.py
sirui6011's picture
add code/ loader snapshot
aedd6ab verified
Raw
History Blame Contribute Delete
15.5 kB
"""Stage 2b: per-frame CODI self-distillation (multi-span).
Shared-weight teacher+student initialized from the Stage-1 SFT model.
- Teacher reads the full explicit trace (prompt+trace), CE = L_teacher.
- Student replaces each LINE frame's $LOCALS with a latent block (latent_start +
`latent_steps` recurrent latents + latent_end; last hidden -> prj -> next embed)
and teacher-forces the rest, CE = L_student over the emitted (non-locals) text.
- KD aligns the hidden at each frame's `<|action_sep|>` (student after latents vs
teacher after locals), teacher detached. L = a*Lt + b*Ls + g*Lkd.
"""
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, Trainer, TrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import WEIGHTS_NAME
from data.dataset import IGNORE_INDEX, build_codi_dataset
from tokens import add_trace_tokens, token_ids
from wb import wandb_init
class CodiModel(nn.Module):
def __init__(self, base, *, latent_start_id, latent_end_id, latent_steps,
a=1.0, b=1.0, g=1.0, kd_layers=None, single_anchor=False,
ss_prob=0.0, ss_ramp_frac=0.5, teacher=None, kd_target="hidden", kd_temp=2.0,
line_sep_id=None, recon_w=0.0):
super().__init__()
self.model = base
h = base.config.hidden_size
# CODI thought projector (last hidden -> next latent input).
self.prj = nn.Sequential(
nn.Linear(h, h, bias=False), nn.GELU(),
nn.Linear(h, h, bias=False), nn.LayerNorm(h),
)
ref = base.get_input_embeddings().weight
self.prj.to(device=ref.device, dtype=ref.dtype)
self.latent_steps, self.a, self.b, self.g = latent_steps, a, b, g
self.teacher = [teacher] if teacher is not None else None # list -> hidden from state_dict/DDP/optim
self.kd_target, self.kd_temp = kd_target, kd_temp # hidden: smooth_l1 on kd_layers; logit: KL on lm_head
if kd_target == "logit" or (teacher is not None and kd_layers is None):
kd_layers = [-1] # logit KD is defined on the last layer only; frozen default = key (last) hidden
self.kd_layers = kd_layers # None -> all layers
self.single_anchor = single_anchor # KD at last span only (vanilla-CODI ablation)
# scheduled sampling: ss_p (ramped per step) of post-latent lines feed the student's own argmax
self.ss_prob, self.ss_ramp_frac, self.ss_p = ss_prob, ss_ramp_frac, 0.0
self.register_buffer("_ls_tok", torch.tensor([[latent_start_id]], dtype=torch.long), persistent=False)
self.register_buffer("_le_tok", torch.tensor([[latent_end_id]], dtype=torch.long), persistent=False)
self.body = base.model
self.head = base.lm_head
def _kd(self, hs):
return hs[1:] if self.kd_layers is None else tuple(hs[l] for l in self.kd_layers)
def _emb(self, ids):
return self.model.get_input_embeddings()(ids)
def _teacher(self, full_ids, labels, kd_pos):
pos = torch.tensor(kd_pos, device=full_ids.device)
if self.teacher is not None: # frozen teacher: KD targets only, no teacher CE
tch, dev = self.teacher[0], full_ids.device
if next(tch.parameters()).device != dev:
tch.to(dev)
with torch.no_grad():
if self.kd_target == "logit": # target = teacher's own next-token logits
return None, [tch(input_ids=full_ids[None], use_cache=False).logits[0, pos]]
hs = tch(input_ids=full_ids[None], use_cache=False, output_hidden_states=True).hidden_states
return None, [l[0, pos] for l in self._kd(hs)]
with torch.no_grad(): # KD targets are detached; take hiddens without a backward graph
hs = self.model(input_ids=full_ids[None], use_cache=False, output_hidden_states=True).hidden_states
kd = [l[0, pos] for l in self._kd(hs)]
# CE forward without output_hidden_states so grad-checkpointing actually frees layer acts.
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
logits = self.model(input_ids=full_ids[None], use_cache=False).logits
self.model.gradient_checkpointing_disable() # teacher-only; student keeps KV cache
ce = F.cross_entropy(logits[0, :-1], labels[1:], ignore_index=IGNORE_INDEX)
return ce, kd
def _latent_block(self, cache):
"""latent_start + `latent_steps` recurrent latents + latent_end on top of
`cache`. Returns (new cache, logits predicting the next real token)."""
o = self.body(inputs_embeds=self._emb(self._ls_tok), past_key_values=cache, use_cache=True)
cache, h = o.past_key_values, o.last_hidden_state[:, -1:]
for _ in range(self.latent_steps):
o = self.body(inputs_embeds=self.prj(h), past_key_values=cache, use_cache=True)
cache, h = o.past_key_values, o.last_hidden_state[:, -1:]
o = self.body(inputs_embeds=self._emb(self._le_tok), past_key_values=cache, use_cache=True)
return o.past_key_values, self.head(o.last_hidden_state[:, -1])
def _student(self, prompt_ids, trace_ids, spans):
# Segments cover trace_ids in order; locals (trace_ids[i+1:j]) are dropped
# and replaced by a latent block. kd=True marks a frame's <|action_sep|>.
segs, prev, kd = [], 0, False
for i, j in spans:
segs.append(("text", trace_ids[prev:i + 1], kd))
segs.append(("latent", None, False))
prev, kd = j, True
segs.append(("text", trace_ids[prev:], kd))
last = len(segs) - 1
out = self.model(inputs_embeds=self._emb(prompt_ids[None]), use_cache=True)
cache, prev_logits = out.past_key_values, out.logits[:, -1] # predicts trace_ids[0]
ce_logits, ce_targets, kd_vecs = [], [], []
for s, (kind, ids, kd) in enumerate(segs):
if kind == "latent": # prev_logits predicted dropped locals; overwrite, no CE
cache, prev_logits = self._latent_block(cache)
continue
inp = ids
if kd and 0 < self.ss_p and random.random() < self.ss_p:
# scheduled sampling: replace the code (not action_sep / line_sep) with the student's own
# argmax via a no-grad pass on a detached cache clone; CE targets below stay GT.
end = ids.numel() if s == last else ids.numel() - 1
c = DynamicCache()
for i, ly in enumerate(cache.layers):
c.update(ly.keys.detach(), ly.values.detach(), i)
with torch.no_grad():
pred = self.model(inputs_embeds=self._emb(ids[None]), past_key_values=c, use_cache=True).logits[0].argmax(-1)
inp = ids.clone(); inp[1:end] = pred[:end - 1]
ce_logits.append(prev_logits); ce_targets.append(ids[:1])
out = self.model(inputs_embeds=self._emb(inp[None]), past_key_values=cache,
use_cache=True, output_hidden_states=kd) # hiddens only for KD anchors
cache, logits = out.past_key_values, out.logits[0]
if ids.numel() > 1:
ce_logits.append(logits[:-1]); ce_targets.append(ids[1:])
prev_logits = logits[-1:]
if kd: # action_sep is this segment's first token
kd_vecs.append([hs[0, 0] for hs in self._kd(out.hidden_states)])
ce = F.cross_entropy(torch.cat(ce_logits), torch.cat(ce_targets))
s_kd = [torch.stack([v[l] for v in kd_vecs]) for l in range(len(kd_vecs[0]))]
return ce, s_kd
def _kd_loss(self, s_kd, t_kd):
s, t = torch.stack(s_kd), torch.stack(t_kd).detach()
if self.kd_target == "logit": # s=student hidden, t=frozen-teacher logits; KL on distributions
T = self.kd_temp
sl, tl = self.head(s).flatten(0, -2) / T, t.flatten(0, -2) / T
return F.kl_div(F.log_softmax(sl, -1), F.softmax(tl, -1), reduction="batchmean") * T * T
return F.smooth_l1_loss(s, t)
def forward(self, examples):
dev = self.model.get_input_embeddings().weight.device
tl = sl = kl = 0.0
for ex in examples:
prompt = torch.tensor(ex["prompt_ids"], device=dev)
trace = torch.tensor(ex["trace_ids"], device=dev)
spans = ex["spans"]
full = torch.cat([prompt, trace])
labels = None if self.teacher else torch.cat([full.new_full((len(prompt),), IGNORE_INDEX), trace])
kd_pos = [len(prompt) + j for _, j in spans]
t_ce, t_kd = self._teacher(full, labels, kd_pos)
s_ce, s_kd = self._student(prompt, trace, spans)
if self.single_anchor: # keep only the last frame's anchor (per layer)
t_kd, s_kd = [t[-1:] for t in t_kd], [s[-1:] for s in s_kd]
tl = tl + (t_ce if t_ce is not None else 0.0) # frozen teacher -> no teacher CE
sl, kl = sl + s_ce, kl + self._kd_loss(s_kd, t_kd)
n = len(examples)
loss = self.a * tl / n + self.b * sl / n + self.g * kl / n
t_log = (tl / n).detach() if torch.is_tensor(tl) else torch.tensor(0.0) # 0 under frozen teacher
return {"loss": loss, "teacher_loss": t_log,
"student_loss": (sl / n).detach(), "kd_loss": (kl / n).detach()}
class CodiTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, **kw):
core = model.module if hasattr(model, "module") else model
if core.ss_prob: # linear ramp 0 -> ss_prob over the first ss_ramp_frac of training
core.ss_p = self._ss = core.ss_prob * min(1.0, self.state.global_step / max(1.0, core.ss_ramp_frac * self.state.max_steps))
out = model(inputs["examples"])
self._sub = {k: out[k].detach() for k in ("teacher_loss", "student_loss", "kd_loss")}
return (out["loss"], out) if return_outputs else out["loss"]
def log(self, logs, *a, **k): # surface sub-losses to console + wandb
if hasattr(self, "_sub"):
logs.update({k: v.item() for k, v in self._sub.items()})
if hasattr(self, "_ss"):
logs["ss_p"] = self._ss
super().log(logs, *a, **k)
def _save(self, output_dir=None, state_dict=None):
# tied backbone weights -> safetensors (5.x default) rejects shared tensors; torch.save instead.
output_dir = output_dir or self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
torch.save(state_dict or self.model.state_dict(), os.path.join(output_dir, WEIGHTS_NAME))
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
# also write config/tokenizer/projector so each ckpt is eval-loadable (small, no weight dup).
self.model.model.config.save_pretrained(output_dir)
self.tok.save_pretrained(output_dir)
torch.save(self.model.prj.state_dict(), os.path.join(output_dir, "thought_projector.pt"))
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model", required=True) # Stage-1 SFT dir
ap.add_argument("--output_dir", required=True)
ap.add_argument("--sources", nargs="+", default=["mbpp", "humaneval", "pyx"])
ap.add_argument("--cache_dir", default="data/cache/codi_train") # offline tokenized examples from precompute.py
ap.add_argument("--n_samples", type=int, default=-1)
ap.add_argument("--max_seq_len", type=int, default=4096)
ap.add_argument("--max_frames", type=int, default=-1)
ap.add_argument("--latent_steps", type=int, default=1)
ap.add_argument("--epochs", type=float, default=10.0)
ap.add_argument("--lr", type=float, default=1e-5)
ap.add_argument("--batch_size", type=int, default=1)
ap.add_argument("--grad_accum", type=int, default=4)
ap.add_argument("--max_steps", type=int, default=-1)
ap.add_argument("--save_steps", type=int, default=500)
ap.add_argument("--alpha", type=float, default=1.0)
ap.add_argument("--beta", type=float, default=1.0)
ap.add_argument("--gamma", type=float, default=1.0)
ap.add_argument("--kd_layers", nargs="+", type=int, default=None) # default: all layers (frozen -> last)
ap.add_argument("--frozen_teacher", default="") # path to frozen SFT teacher; "" -> shared-weight (legacy)
ap.add_argument("--kd_target", default="hidden", choices=["hidden", "logit"]) # key-hidden align: smooth_l1 vs KL
ap.add_argument("--kd_temp", type=float, default=2.0) # logit-KD temperature
ap.add_argument("--single_anchor", action="store_true") # KD at last frame only (vanilla CODI)
ap.add_argument("--ss_prob", type=float, default=0.0) # scheduled-sampling max prob (0 = off)
ap.add_argument("--ss_ramp_frac", type=float, default=0.5) # ramp ss_prob over this frac of steps
args = ap.parse_args()
tok = AutoTokenizer.from_pretrained(args.model, use_fast=True)
add_trace_tokens(tok) # idempotent
ids = token_ids(tok)
base = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.bfloat16)
base.config.use_cache = True
teacher = None
if args.frozen_teacher:
teacher = AutoModelForCausalLM.from_pretrained(args.frozen_teacher, torch_dtype=torch.bfloat16)
teacher.config.use_cache = False
teacher.eval().requires_grad_(False)
model = CodiModel(base, latent_start_id=ids["<|latent_start|>"], latent_end_id=ids["<|latent_end|>"],
latent_steps=args.latent_steps, a=args.alpha, b=args.beta, g=args.gamma,
kd_layers=args.kd_layers, single_anchor=args.single_anchor,
ss_prob=args.ss_prob, ss_ramp_frac=args.ss_ramp_frac,
teacher=teacher, kd_target=args.kd_target, kd_temp=args.kd_temp)
ds = build_codi_dataset(tok, sources=args.sources, cache_dir=args.cache_dir,
n_samples=args.n_samples, max_seq_len=args.max_seq_len, max_frames=args.max_frames)
print(f"{len(ds)} codi examples, latent_steps={args.latent_steps}")
report_to = wandb_init(args, "codi")
targs = TrainingArguments(
output_dir=args.output_dir,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
num_train_epochs=args.epochs,
max_steps=args.max_steps,
learning_rate=args.lr,
lr_scheduler_type="cosine",
warmup_ratio=0.03,
weight_decay=0.1,
max_grad_norm=1.0,
bf16=True,
optim="paged_adamw_8bit",
ddp_find_unused_parameters=False,
logging_steps=5,
save_strategy="steps",
save_steps=args.save_steps,
save_total_limit=None,
report_to=report_to,
remove_unused_columns=False,
label_names=[],
)
trainer = CodiTrainer(
model=model, args=targs, train_dataset=ds,
data_collator=lambda b: {"examples": b},
)
trainer.tok = tok
# Native checkpoints (CodiModel wrapper + optimizer) auto-resume if interrupted.
ckpt = get_last_checkpoint(args.output_dir) if os.path.isdir(args.output_dir) else None
trainer.train(resume_from_checkpoint=ckpt)
trainer._save_checkpoint(trainer.model, trial=None) # final step as a resumable, eval-loadable checkpoint-<step>
if __name__ == "__main__":
main()