financial-task-env / train_sft.py
bpHigh's picture
train_sft: drop fp16, prefer bf16 (MPS-compatible without grad scaler)
35fa944
#!/usr/bin/env python3
"""SFT warm-start for the office-document task agent.
Trains a small student model (default: Qwen/Qwen2.5-Coder-3B-Instruct) on
teacher trajectories collected from a stronger model (Kimi-K2.5) and filtered
by `data_pipeline/build_sft_corpus.py`.
Output is a LoRA adapter saved to `--output-dir`, optionally pushed to HF
Hub for use as the base of GRPO continued training.
Designed for HF Jobs (1× A100 80GB, ~$2.50/hr, ~6 hours = ~$15) but runs
locally too.
Hardware sizing:
- 3B base + LoRA r=32 + bf16 + 8K context: ~24 GB VRAM
- Fits comfortably on A100 40GB / L40S 48GB / A100 80GB
- For OOM, drop --max-seq-len to 4096 or --lora-r to 16
Example:
pip install -U "trl>=0.11" "peft>=0.13" "transformers>=4.46" \
"datasets>=3.0" "accelerate>=1.0" "bitsandbytes>=0.43"
python train_sft.py \
--dataset data/sft_kimi_k25.jsonl \
--base-model Qwen/Qwen2.5-Coder-3B-Instruct \
--output-dir checkpoints/qwen3b-sft-kimi \
--epochs 2 --lora-r 32
HF Jobs:
hf jobs run \
--hardware "Nvidia A100 - large" \
--timeout 8h \
--image "huggingface/transformers-pytorch-gpu:latest" \
--secrets HF_TOKEN \
-- \
bash -c "pip install -U trl peft accelerate bitsandbytes && \
python train_sft.py --dataset data/sft_kimi_k25.jsonl \
--output-dir /tmp/qwen3b-sft \
--push-to-hub bpHigh/qwen3b-office-sft"
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from pathlib import Path
def parse_args(argv=None):
p = argparse.ArgumentParser()
p.add_argument("--dataset", required=True,
help="path to the SFT corpus JSONL (built by "
"data_pipeline/build_sft_corpus.py)")
p.add_argument("--base-model", default="Qwen/Qwen2.5-Coder-3B-Instruct")
p.add_argument("--output-dir", default="checkpoints/qwen3b-sft")
p.add_argument("--epochs", type=float, default=2.0)
p.add_argument("--lr", type=float, default=2e-4,
help="learning rate (LoRA defaults are higher than full FT)")
p.add_argument("--lora-r", type=int, default=32)
p.add_argument("--lora-alpha", type=int, default=64)
p.add_argument("--lora-dropout", type=float, default=0.05)
p.add_argument("--target-modules", default="all-linear",
help="LoRA target modules; 'all-linear' is the safe default")
p.add_argument("--per-device-batch-size", type=int, default=1)
p.add_argument("--gradient-accumulation", type=int, default=8,
help="effective batch = per_device_bsz × grad_accum × n_gpus")
p.add_argument("--max-seq-len", type=int, default=8192,
help="drop to 4096 if OOM on smaller GPUs")
p.add_argument("--logging-steps", type=int, default=2)
p.add_argument("--save-steps", type=int, default=50)
p.add_argument("--warmup-ratio", type=float, default=0.05)
p.add_argument("--use-qlora", action="store_true",
help="4-bit quantization (slower, much less memory)")
p.add_argument("--no-assistant-only-loss", action="store_true",
help="disable assistant-only loss masking; train on full "
"conversation tokens (legacy behavior)")
p.add_argument("--push-to-hub", default="",
help="HF Hub repo to push the LoRA adapter to "
"(e.g., 'username/repo-name'). Optional.")
p.add_argument("--seed", type=int, default=42)
p.add_argument("--report-to", default="none",
help="'none', 'wandb', 'tensorboard', or comma-separated")
return p.parse_args(argv)
def main() -> int:
args = parse_args()
# Heavy imports inside main so --help is fast and import failures get
# reported with context.
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer
# ---- 1. Dataset ----
ds_path = Path(args.dataset)
if not ds_path.exists():
print(f"ERROR: dataset {ds_path} not found", file=sys.stderr)
print("Run data_pipeline/build_sft_corpus.py first.", file=sys.stderr)
return 1
print(f"Loading SFT corpus from {ds_path}")
raw = load_dataset("json", data_files=str(ds_path), split="train")
print(f" rows: {len(raw)}")
print(f" cols: {raw.column_names}")
if "messages" not in raw.column_names:
print(f"ERROR: dataset is missing 'messages' column", file=sys.stderr)
return 1
# SFTTrainer wants ONLY the messages column (extra cols are tolerated but
# cleaner to drop). Keep score/n_steps for inspection in logs.
keep = [c for c in raw.column_names if c == "messages"]
drop = [c for c in raw.column_names if c not in keep]
train_ds = raw.remove_columns(drop) if drop else raw
# ---- 2. Tokenizer ----
print(f"\nLoading tokenizer: {args.base_model}")
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# ---- 3. Precision detection ----
# bf16 is strictly better than fp16 when available (Ampere+ CUDA, M-series
# MPS). fp16 requires a grad scaler that needs PyTorch >= 2.8 on MPS, so
# we drop fp16 entirely — fall back to fp32 on old hardware.
bf16_ok = (
(torch.cuda.is_available() and torch.cuda.is_bf16_supported())
or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available())
)
compute_dtype = torch.bfloat16 if bf16_ok else torch.float32
# ---- 4. Model ----
print(f"Loading base model: {args.base_model}")
print(f" precision: {'bf16' if bf16_ok else 'fp32'} "
f"(cuda={torch.cuda.is_available()}, "
f"mps={hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()})")
model_kwargs = dict(
torch_dtype=compute_dtype,
attn_implementation="sdpa",
)
if args.use_qlora:
from transformers import BitsAndBytesConfig
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
print(" using 4-bit QLoRA")
model = AutoModelForCausalLM.from_pretrained(args.base_model, **model_kwargs)
if hasattr(model, "config"):
model.config.use_cache = False # required for grad checkpointing
# ---- 5. LoRA ----
target = args.target_modules
if target != "all-linear" and "," in target:
target = [t.strip() for t in target.split(",")]
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=target,
task_type="CAUSAL_LM",
bias="none",
)
# ---- 6. Trainer config ----
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
sft_config = SFTConfig(
output_dir=str(out_dir),
num_train_epochs=args.epochs,
per_device_train_batch_size=args.per_device_batch_size,
gradient_accumulation_steps=args.gradient_accumulation,
learning_rate=args.lr,
warmup_ratio=args.warmup_ratio,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
save_strategy="steps",
save_total_limit=2,
bf16=bf16_ok,
fp16=False, # fp16 needs a grad scaler that doesn't play nice with MPS
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
max_length=args.max_seq_len,
# Assistant-only loss: only compute loss on assistant tokens, mask
# everything else. This is the right behavior for multi-turn agent
# SFT — we don't want to train on tool-feedback (which the env
# generates, not the model).
assistant_only_loss=not args.no_assistant_only_loss,
# Don't pack — multi-turn examples are long enough on their own
packing=False,
report_to=args.report_to.split(",") if args.report_to != "none" else "none",
seed=args.seed,
push_to_hub=bool(args.push_to_hub),
hub_model_id=args.push_to_hub or None,
hub_strategy="end",
hub_private_repo=False,
dataset_kwargs={"skip_prepare_dataset": False},
)
# ---- 7. Train ----
print("\nStarting SFTTrainer...")
trainer = SFTTrainer(
model=model,
args=sft_config,
train_dataset=train_ds,
processing_class=tokenizer,
peft_config=peft_config,
)
trainer.train()
# ---- 8. Save ----
print(f"\nSaving final LoRA adapter to {out_dir}")
trainer.save_model(str(out_dir))
tokenizer.save_pretrained(str(out_dir))
# Save the run args for reproducibility
with open(out_dir / "train_args.json", "w") as f:
json.dump(vars(args), f, indent=2)
if args.push_to_hub:
print(f"Pushing to HF Hub: {args.push_to_hub}")
trainer.push_to_hub()
print("\nDone.")
return 0
if __name__ == "__main__":
raise SystemExit(main())