Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """SFT warm-start for the office-document task agent. | |
| Trains a small student model (default: Qwen/Qwen2.5-Coder-3B-Instruct) on | |
| teacher trajectories collected from a stronger model (Kimi-K2.5) and filtered | |
| by `data_pipeline/build_sft_corpus.py`. | |
| Output is a LoRA adapter saved to `--output-dir`, optionally pushed to HF | |
| Hub for use as the base of GRPO continued training. | |
| Designed for HF Jobs (1× A100 80GB, ~$2.50/hr, ~6 hours = ~$15) but runs | |
| locally too. | |
| Hardware sizing: | |
| - 3B base + LoRA r=32 + bf16 + 8K context: ~24 GB VRAM | |
| - Fits comfortably on A100 40GB / L40S 48GB / A100 80GB | |
| - For OOM, drop --max-seq-len to 4096 or --lora-r to 16 | |
| Example: | |
| pip install -U "trl>=0.11" "peft>=0.13" "transformers>=4.46" \ | |
| "datasets>=3.0" "accelerate>=1.0" "bitsandbytes>=0.43" | |
| python train_sft.py \ | |
| --dataset data/sft_kimi_k25.jsonl \ | |
| --base-model Qwen/Qwen2.5-Coder-3B-Instruct \ | |
| --output-dir checkpoints/qwen3b-sft-kimi \ | |
| --epochs 2 --lora-r 32 | |
| HF Jobs: | |
| hf jobs run \ | |
| --hardware "Nvidia A100 - large" \ | |
| --timeout 8h \ | |
| --image "huggingface/transformers-pytorch-gpu:latest" \ | |
| --secrets HF_TOKEN \ | |
| -- \ | |
| bash -c "pip install -U trl peft accelerate bitsandbytes && \ | |
| python train_sft.py --dataset data/sft_kimi_k25.jsonl \ | |
| --output-dir /tmp/qwen3b-sft \ | |
| --push-to-hub bpHigh/qwen3b-office-sft" | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| from pathlib import Path | |
| def parse_args(argv=None): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--dataset", required=True, | |
| help="path to the SFT corpus JSONL (built by " | |
| "data_pipeline/build_sft_corpus.py)") | |
| p.add_argument("--base-model", default="Qwen/Qwen2.5-Coder-3B-Instruct") | |
| p.add_argument("--output-dir", default="checkpoints/qwen3b-sft") | |
| p.add_argument("--epochs", type=float, default=2.0) | |
| p.add_argument("--lr", type=float, default=2e-4, | |
| help="learning rate (LoRA defaults are higher than full FT)") | |
| p.add_argument("--lora-r", type=int, default=32) | |
| p.add_argument("--lora-alpha", type=int, default=64) | |
| p.add_argument("--lora-dropout", type=float, default=0.05) | |
| p.add_argument("--target-modules", default="all-linear", | |
| help="LoRA target modules; 'all-linear' is the safe default") | |
| p.add_argument("--per-device-batch-size", type=int, default=1) | |
| p.add_argument("--gradient-accumulation", type=int, default=8, | |
| help="effective batch = per_device_bsz × grad_accum × n_gpus") | |
| p.add_argument("--max-seq-len", type=int, default=8192, | |
| help="drop to 4096 if OOM on smaller GPUs") | |
| p.add_argument("--logging-steps", type=int, default=2) | |
| p.add_argument("--save-steps", type=int, default=50) | |
| p.add_argument("--warmup-ratio", type=float, default=0.05) | |
| p.add_argument("--use-qlora", action="store_true", | |
| help="4-bit quantization (slower, much less memory)") | |
| p.add_argument("--no-assistant-only-loss", action="store_true", | |
| help="disable assistant-only loss masking; train on full " | |
| "conversation tokens (legacy behavior)") | |
| p.add_argument("--push-to-hub", default="", | |
| help="HF Hub repo to push the LoRA adapter to " | |
| "(e.g., 'username/repo-name'). Optional.") | |
| p.add_argument("--seed", type=int, default=42) | |
| p.add_argument("--report-to", default="none", | |
| help="'none', 'wandb', 'tensorboard', or comma-separated") | |
| return p.parse_args(argv) | |
| def main() -> int: | |
| args = parse_args() | |
| # Heavy imports inside main so --help is fast and import failures get | |
| # reported with context. | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import LoraConfig | |
| from trl import SFTConfig, SFTTrainer | |
| # ---- 1. Dataset ---- | |
| ds_path = Path(args.dataset) | |
| if not ds_path.exists(): | |
| print(f"ERROR: dataset {ds_path} not found", file=sys.stderr) | |
| print("Run data_pipeline/build_sft_corpus.py first.", file=sys.stderr) | |
| return 1 | |
| print(f"Loading SFT corpus from {ds_path}") | |
| raw = load_dataset("json", data_files=str(ds_path), split="train") | |
| print(f" rows: {len(raw)}") | |
| print(f" cols: {raw.column_names}") | |
| if "messages" not in raw.column_names: | |
| print(f"ERROR: dataset is missing 'messages' column", file=sys.stderr) | |
| return 1 | |
| # SFTTrainer wants ONLY the messages column (extra cols are tolerated but | |
| # cleaner to drop). Keep score/n_steps for inspection in logs. | |
| keep = [c for c in raw.column_names if c == "messages"] | |
| drop = [c for c in raw.column_names if c not in keep] | |
| train_ds = raw.remove_columns(drop) if drop else raw | |
| # ---- 2. Tokenizer ---- | |
| print(f"\nLoading tokenizer: {args.base_model}") | |
| tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # ---- 3. Precision detection ---- | |
| # bf16 is strictly better than fp16 when available (Ampere+ CUDA, M-series | |
| # MPS). fp16 requires a grad scaler that needs PyTorch >= 2.8 on MPS, so | |
| # we drop fp16 entirely — fall back to fp32 on old hardware. | |
| bf16_ok = ( | |
| (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) | |
| or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) | |
| ) | |
| compute_dtype = torch.bfloat16 if bf16_ok else torch.float32 | |
| # ---- 4. Model ---- | |
| print(f"Loading base model: {args.base_model}") | |
| print(f" precision: {'bf16' if bf16_ok else 'fp32'} " | |
| f"(cuda={torch.cuda.is_available()}, " | |
| f"mps={hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()})") | |
| model_kwargs = dict( | |
| torch_dtype=compute_dtype, | |
| attn_implementation="sdpa", | |
| ) | |
| if args.use_qlora: | |
| from transformers import BitsAndBytesConfig | |
| model_kwargs["quantization_config"] = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| print(" using 4-bit QLoRA") | |
| model = AutoModelForCausalLM.from_pretrained(args.base_model, **model_kwargs) | |
| if hasattr(model, "config"): | |
| model.config.use_cache = False # required for grad checkpointing | |
| # ---- 5. LoRA ---- | |
| target = args.target_modules | |
| if target != "all-linear" and "," in target: | |
| target = [t.strip() for t in target.split(",")] | |
| peft_config = LoraConfig( | |
| r=args.lora_r, | |
| lora_alpha=args.lora_alpha, | |
| lora_dropout=args.lora_dropout, | |
| target_modules=target, | |
| task_type="CAUSAL_LM", | |
| bias="none", | |
| ) | |
| # ---- 6. Trainer config ---- | |
| out_dir = Path(args.output_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| sft_config = SFTConfig( | |
| output_dir=str(out_dir), | |
| num_train_epochs=args.epochs, | |
| per_device_train_batch_size=args.per_device_batch_size, | |
| gradient_accumulation_steps=args.gradient_accumulation, | |
| learning_rate=args.lr, | |
| warmup_ratio=args.warmup_ratio, | |
| logging_steps=args.logging_steps, | |
| save_steps=args.save_steps, | |
| save_strategy="steps", | |
| save_total_limit=2, | |
| bf16=bf16_ok, | |
| fp16=False, # fp16 needs a grad scaler that doesn't play nice with MPS | |
| gradient_checkpointing=True, | |
| gradient_checkpointing_kwargs={"use_reentrant": False}, | |
| max_length=args.max_seq_len, | |
| # Assistant-only loss: only compute loss on assistant tokens, mask | |
| # everything else. This is the right behavior for multi-turn agent | |
| # SFT — we don't want to train on tool-feedback (which the env | |
| # generates, not the model). | |
| assistant_only_loss=not args.no_assistant_only_loss, | |
| # Don't pack — multi-turn examples are long enough on their own | |
| packing=False, | |
| report_to=args.report_to.split(",") if args.report_to != "none" else "none", | |
| seed=args.seed, | |
| push_to_hub=bool(args.push_to_hub), | |
| hub_model_id=args.push_to_hub or None, | |
| hub_strategy="end", | |
| hub_private_repo=False, | |
| dataset_kwargs={"skip_prepare_dataset": False}, | |
| ) | |
| # ---- 7. Train ---- | |
| print("\nStarting SFTTrainer...") | |
| trainer = SFTTrainer( | |
| model=model, | |
| args=sft_config, | |
| train_dataset=train_ds, | |
| processing_class=tokenizer, | |
| peft_config=peft_config, | |
| ) | |
| trainer.train() | |
| # ---- 8. Save ---- | |
| print(f"\nSaving final LoRA adapter to {out_dir}") | |
| trainer.save_model(str(out_dir)) | |
| tokenizer.save_pretrained(str(out_dir)) | |
| # Save the run args for reproducibility | |
| with open(out_dir / "train_args.json", "w") as f: | |
| json.dump(vars(args), f, indent=2) | |
| if args.push_to_hub: | |
| print(f"Pushing to HF Hub: {args.push_to_hub}") | |
| trainer.push_to_hub() | |
| print("\nDone.") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |