|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
from typing import Dict, List, Union
|
|
|
|
import torch
|
|
from datasets import load_dataset
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
AutoModelForCausalLM,
|
|
BitsAndBytesConfig,
|
|
Trainer,
|
|
TrainingArguments,
|
|
)
|
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
|
|
|
|
|
def parse_args():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--base", type=str, required=True, help="Base model id or path")
|
|
ap.add_argument("--data", type=str, required=True, help="JSONL with chat messages")
|
|
ap.add_argument("--out", type=str, required=True, help="Output dir for adapter")
|
|
ap.add_argument("--epochs", type=int, default=2)
|
|
ap.add_argument("--bsz", type=int, default=8)
|
|
ap.add_argument("--grad_accum", type=int, default=1)
|
|
ap.add_argument("--cutoff_len", type=int, default=2048)
|
|
ap.add_argument("--lr", type=float, default=2e-4)
|
|
ap.add_argument("--lora_r", type=int, default=16)
|
|
ap.add_argument("--lora_alpha", type=int, default=32)
|
|
ap.add_argument("--lora_dropout", type=float, default=0.05)
|
|
ap.add_argument("--debug", action="store_true")
|
|
return ap.parse_args()
|
|
|
|
|
|
def device_supports_bf16() -> bool:
|
|
if not torch.cuda.is_available():
|
|
return False
|
|
major, _ = torch.cuda.get_device_capability(0)
|
|
return major >= 8
|
|
|
|
|
|
def build_tokenizer(base_id: str):
|
|
tok = AutoTokenizer.from_pretrained(base_id, use_fast=True)
|
|
if tok.pad_token is None:
|
|
tok.pad_token = tok.eos_token
|
|
tok.padding_side = "right"
|
|
return tok
|
|
|
|
|
|
def _to_ids(x: Union[torch.Tensor, List[int], Dict[str, List[int]]]) -> List[int]:
|
|
if isinstance(x, torch.Tensor):
|
|
return x.detach().cpu().tolist()[0] if x.ndim == 2 else x.detach().cpu().tolist()
|
|
if isinstance(x, dict) and "input_ids" in x:
|
|
return x["input_ids"]
|
|
if isinstance(x, (list, tuple)):
|
|
return list(x)
|
|
raise TypeError(f"Unsupported chat template return type: {type(x)}")
|
|
|
|
|
|
def chat_to_ids(tokenizer: AutoTokenizer, messages: List[Dict], max_len: int):
|
|
|
|
|
|
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
|
|
out = tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=True,
|
|
add_generation_prompt=False,
|
|
return_tensors="pt",
|
|
max_length=max_len,
|
|
truncation=True,
|
|
)
|
|
ids = _to_ids(out)
|
|
attn = [1] * len(ids)
|
|
return {"input_ids": ids, "attention_mask": attn}
|
|
|
|
|
|
lines = []
|
|
for m in messages:
|
|
role = m.get("role", "user")
|
|
content = m.get("content", "")
|
|
lines.append(f"{role}:\n{content}\n")
|
|
text = "\n".join(lines)
|
|
enc = tokenizer(text, max_length=max_len, truncation=True)
|
|
return {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
|
|
|
|
|
|
def collate_pad(tokenizer: AutoTokenizer):
|
|
pad_id = tokenizer.pad_token_id
|
|
|
|
def _fn(batch: List[Dict[str, List[int]]]):
|
|
max_len = max(len(x["input_ids"]) for x in batch)
|
|
input_ids, attn, labels = [], [], []
|
|
for x in batch:
|
|
ids = x["input_ids"]
|
|
am = x["attention_mask"]
|
|
pad_n = max_len - len(ids)
|
|
input_ids.append(ids + [pad_id] * pad_n)
|
|
attn.append(am + [0] * pad_n)
|
|
labels.append(ids + [-100] * pad_n)
|
|
return {
|
|
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
|
"attention_mask": torch.tensor(attn, dtype=torch.long),
|
|
"labels": torch.tensor(labels, dtype=torch.long),
|
|
}
|
|
|
|
return _fn
|
|
|
|
|
|
def guess_lora_targets(model: torch.nn.Module) -> List[str]:
|
|
prefs = [
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
"o_proj",
|
|
"gate_proj",
|
|
"up_proj",
|
|
"down_proj",
|
|
"wi",
|
|
"wo",
|
|
"w1",
|
|
"w2",
|
|
"w3",
|
|
"out_proj",
|
|
]
|
|
found = set()
|
|
for name, _ in model.named_modules():
|
|
for p in prefs:
|
|
if p in name:
|
|
found.add(p)
|
|
return sorted(found) if found else ["Linear"]
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
base_id = args.base
|
|
data_path = Path(args.data)
|
|
out_dir = Path(args.out)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
tokenizer = build_tokenizer(base_id)
|
|
|
|
ds = load_dataset("json", data_files=str(data_path), split="train")
|
|
|
|
def map_row(ex):
|
|
return chat_to_ids(tokenizer, ex["messages"], args.cutoff_len)
|
|
|
|
|
|
ds = ds.map(map_row, remove_columns=ds.column_names)
|
|
|
|
collate = collate_pad(tokenizer)
|
|
|
|
quant = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_use_double_quant=True,
|
|
)
|
|
|
|
use_bf16 = device_supports_bf16()
|
|
torch_dtype = torch.bfloat16 if use_bf16 else torch.float16
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
base_id,
|
|
device_map="auto",
|
|
quantization_config=quant,
|
|
torch_dtype=torch_dtype,
|
|
)
|
|
|
|
model = prepare_model_for_kbit_training(model)
|
|
lconf = LoraConfig(
|
|
r=args.lora_r,
|
|
lora_alpha=args.lora_alpha,
|
|
lora_dropout=args.lora_dropout,
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
target_modules=guess_lora_targets(model),
|
|
)
|
|
model = get_peft_model(model, lconf)
|
|
|
|
train_args = TrainingArguments(
|
|
output_dir=str(out_dir),
|
|
num_train_epochs=args.epochs,
|
|
per_device_train_batch_size=args.bsz,
|
|
gradient_accumulation_steps=args.grad_accum,
|
|
learning_rate=args.lr,
|
|
lr_scheduler_type="cosine",
|
|
warmup_ratio=0.03,
|
|
logging_steps=5,
|
|
save_steps=100,
|
|
bf16=use_bf16,
|
|
fp16=not use_bf16,
|
|
optim="paged_adamw_8bit",
|
|
remove_unused_columns=False,
|
|
dataloader_num_workers=2,
|
|
report_to=[],
|
|
)
|
|
|
|
tr = Trainer(
|
|
model=model,
|
|
args=train_args,
|
|
train_dataset=ds,
|
|
data_collator=collate,
|
|
tokenizer=tokenizer,
|
|
)
|
|
|
|
if args.debug:
|
|
batch = next(iter(tr.get_train_dataloader()))
|
|
print("[debug] batch keys:", list(batch.keys()))
|
|
for k, v in batch.items():
|
|
if isinstance(v, torch.Tensor):
|
|
print(f"[debug] {k}: shape={tuple(v.shape)} dtype={v.dtype}")
|
|
|
|
tr.train()
|
|
|
|
model.save_pretrained(str(out_dir))
|
|
tokenizer.save_pretrained(str(out_dir))
|
|
print("[ok] saved adapter to", out_dir.resolve())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|