dream-s1k-demo / train_lora.py
况兑
eval: greedy decode + numeric strict; system: force full decimals; regressions: A/B/C/noisy
e45d7fc
import os, json, random, math
from dataclasses import dataclass
from typing import Dict, List, Optional
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
import torch
from peft import LoraConfig, get_peft_model
# --------------------
# Config via env
# --------------------
BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
DATA_PATH = os.environ.get("DATA_PATH", "s1k_chat_1.1_small.jsonl")
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "./runs/qwen25-0p5b-lora")
SEED = int(os.environ.get("SEED", "42"))
EPOCHS = float(os.environ.get("EPOCHS", "3"))
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "4"))
GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "8"))
LR = float(os.environ.get("LR", "2e-4"))
MAX_LEN = int(os.environ.get("MAX_LEN", "1024"))
WARMUP_RATIO = float(os.environ.get("WARMUP_RATIO", "0.05"))
LORA_R = int(os.environ.get("LORA_R", "16"))
LORA_ALPHA = int(os.environ.get("LORA_ALPHA", "32"))
LORA_DROPOUT = float(os.environ.get("LORA_DROPOUT", "0.05"))
SAVE_STEPS = int(os.environ.get("SAVE_STEPS", "0")) # 0 -> only end
VAL_RATIO = float(os.environ.get("VAL_RATIO", "0.1"))
random.seed(SEED)
# --------------------
# Load & split dataset
# --------------------
def load_jsonl_messages(path: str) -> List[Dict]:
rows = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
obj = json.loads(line)
rows.append(obj)
return rows
raw = load_jsonl_messages(DATA_PATH)
# Basic shuffle & split
random.shuffle(raw)
val_n = max(1, int(len(raw) * VAL_RATIO))
val_list = raw[:val_n]
train_list = raw[val_n:]
def messages_to_pairs(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""
将多轮 messages 压成若干 (prompt -> response) 对:
连续 user 合并成一个 prompt,遇到 assistant 产出一对。
"""
pairs = []
last_user = []
for m in messages:
role = m.get("role", "")
content = m.get("content", "")
if role == "user":
last_user.append(content)
elif role == "assistant" and last_user:
prompt = "\n\n".join(last_user)
pairs.append({"prompt": prompt, "response": content})
last_user = []
return pairs
def flatten_jsonl_to_pairs(jsonl_rows: List[Dict]) -> List[Dict]:
pairs_all = []
for r in jsonl_rows:
msgs = r.get("messages", [])
pairs = messages_to_pairs(msgs)
pairs_all.extend(pairs)
return pairs_all
train_pairs = flatten_jsonl_to_pairs(train_list)
val_pairs = flatten_jsonl_to_pairs(val_list)
train_ds = Dataset.from_list(train_pairs)
val_ds = Dataset.from_list(val_pairs)
ds = DatasetDict({"train": train_ds, "validation": val_ds})
# --------------------
# Tokenizer & chat template
# --------------------
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
try:
tokenizer.padding_side = "right"
except Exception:
pass
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# === chat_template tokenization ===
def _sft_tokenize_with_chat_template(example):
# 将 (prompt, response) 还原成 messages,用 chat_template 构造
ctx_msgs = [{"role":"user","content": example["prompt"]}]
tgt = example["response"]
# 仅上下文,要求“准备生成”
prompt_text = tokenizer.apply_chat_template(
ctx_msgs, tokenize=False, add_generation_prompt=True
)
# 仅答案 + eos
target_text = tgt + (tokenizer.eos_token or "")
prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
target_ids = tokenizer(target_text, add_special_tokens=False)["input_ids"]
# 截断:尽量保住答案(右截断)
max_p = MAX_LEN - len(target_ids)
if max_p <= 0:
target_ids = target_ids[-(MAX_LEN-1):]
prompt_ids = []
else:
prompt_ids = prompt_ids[-max_p:]
ids = prompt_ids + target_ids
labels = [-100]*len(prompt_ids) + target_ids[:] # 只对答案段计损失
attn = [1]*len(ids)
return {"input_ids": ids, "labels": labels, "attention_mask": attn}
IGNORE_INDEX = -100
def tokenize(example: Dict) -> Dict:
# 仅对 assistant 段计算损失
prompt_text = build_chat_prompt(example["prompt"], None)
full_text = build_chat_prompt(example["prompt"], example["response"])
prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
full = tokenizer(
full_text,
max_length=MAX_LEN,
truncation=True,
padding=False,
add_special_tokens=False,
)["input_ids"]
labels = [IGNORE_INDEX] * len(full)
start = len(prompt_ids)
for i in range(start, len(full)):
labels[i] = full[i]
return {
"input_ids": full,
"labels": labels,
"attention_mask": [1] * len(full),
}
tokenized = ds.map(_sft_tokenize_with_chat_template, remove_columns=ds["train"].column_names, desc="Tokenizing with chat_template")
# --------------------
# Model & LoRA —— 适配 Mac (MPS):禁用混合精度,用 fp32
# --------------------
use_mps = torch.backends.mps.is_available()
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
compute_dtype = torch.bfloat16
elif torch.cuda.is_available():
compute_dtype = torch.float16
else:
compute_dtype = torch.float32 # MPS/CPU 用全精度
device_map = "auto" if torch.cuda.is_available() else None
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=compute_dtype,
device_map=device_map,
trust_remote_code=True,
)
if use_mps:
model.to("mps")
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
lora_cfg = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
lora_dropout=LORA_DROPOUT,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
)
model = get_peft_model(model, lora_cfg)
def print_trainable(model):
trainable = 0
total = 0
for n,p in model.named_parameters():
c = p.numel()
total += c
if p.requires_grad:
trainable += c
print(f"[PARAMS] total={total} trainable={trainable} ratio={trainable/max(total,1):.6f}")
print_trainable(model)
# --------------------
# Collator(不打乱 labels)
# --------------------
@dataclass
class DataCollatorForCausalLM:
tokenizer: AutoTokenizer
pad_to_multiple_of: Optional[int] = 8
def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
max_len = max(len(f["input_ids"]) for f in features)
if self.pad_to_multiple_of:
max_len = int(math.ceil(max_len / self.pad_to_multiple_of) * self.pad_to_multiple_of)
input_ids, labels, attention_mask = [], [], []
for f in features:
ids = f["input_ids"]
labs = f["labels"]
mask = f["attention_mask"]
pad_len = max_len - len(ids)
input_ids.append(ids + [tokenizer.pad_token_id] * pad_len)
attention_mask.append(mask + [0] * pad_len)
labels.append(labs + [IGNORE_INDEX] * pad_len)
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long)
}
collator = DataCollatorForCausalLM(tokenizer)
# --------------------
# Training(在 Mac 上强制不用 bf16/fp16)
# --------------------
steps_per_epoch = max(1, len(tokenized["train"]) // (BATCH_SIZE * GRAD_ACCUM))
save_strategy = "steps" if SAVE_STEPS > 0 else "epoch"
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
num_train_epochs=EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=LR,
warmup_ratio=WARMUP_RATIO,
logging_steps=max(1, steps_per_epoch // 5),
evaluation_strategy="epoch",
save_strategy=save_strategy,
save_steps=SAVE_STEPS if SAVE_STEPS > 0 else None,
save_total_limit=2,
bf16=False,
fp16=False,
weight_decay=0.0,
lr_scheduler_type="cosine",
seed=SEED,
max_grad_norm=1.0,
remove_unused_columns=False,
report_to=["none"],
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["validation"],
data_collator=collator,
tokenizer=tokenizer,
)
trainer.train()
metrics = trainer.evaluate()
trainer.save_model()
tokenizer.save_pretrained(OUTPUT_DIR)
with open(os.path.join(OUTPUT_DIR, "eval_metrics.json"), "w", encoding="utf-8") as f:
json.dump(metrics, f, indent=2, ensure_ascii=False)
print("==> Training done. Eval metrics:", metrics)
def _build_sft_examples(examples, tokenizer, max_len=1024):
# 期望每行是 {"messages":[{"role":"user"/"system"/"assistant","content":...}, ...]}
texts=[]
for msgs in examples["messages"]:
# 找最后一条 assistant 作为监督目标;其余作为上下文
if not isinstance(msgs, list) or not msgs:
continue
# 拆出上下文(user/system等,不含最后assistant)
ctx = [m for m in msgs if m.get("role")!="assistant"]
# 目标:最后一个 assistant(若没有则跳过)
tgt = None
for m in reversed(msgs):
if m.get("role")=="assistant":
tgt = m["content"]
break
if tgt is None:
continue
# 构造:上下文 + 目标
prompt = tokenizer.apply_chat_template(ctx + [{"role":"assistant","content":tgt}],
tokenize=False, add_generation_prompt=False)
texts.append(prompt)
tokenized = tokenizer(texts, truncation=True, max_length=max_len)
return tokenized
IGNORE_INDEX = -100
def _sft_tokenize_with_chat_template(example):
# 将 (prompt, response) 还原成 messages,用 chat_template 构造
ctx_msgs = [{"role":"user","content": example["prompt"]}]
tgt = example["response"]
# 仅上下文,要求“准备生成”
prompt_text = tokenizer.apply_chat_template(
ctx_msgs, tokenize=False, add_generation_prompt=True
)
# 仅答案 + eos
target_text = tgt + (tokenizer.eos_token or "")
prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
target_ids = tokenizer(target_text, add_special_tokens=False)["input_ids"]
# 截断:尽量保住答案(右截断)
max_p = MAX_LEN - len(target_ids)
if max_p <= 0:
target_ids = target_ids[-(MAX_LEN-1):]
prompt_ids = []
else:
prompt_ids = prompt_ids[-max_p:]
ids = prompt_ids + target_ids
labels = [-100]*len(prompt_ids) + target_ids[:] # 只对答案段计损失
attn = [1]*len(ids)
return {"input_ids": ids, "labels": labels, "attention_mask": attn}