forecast-extractor / train.py
philippotiger's picture
Upload train.py with huggingface_hub
a987f51 verified
"""
Fine-tuning Qwen2.5-3B-Instruct for football prediction extraction
Fixes from original: target_modules, validation split, scheduler, checkpoint saving
"""
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
import torch
# ─────────────────────────────────────────────
# CONFIG
# ─────────────────────────────────────────────
MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"
OUTPUT_DIR = "./football-extractor"
TRAIN_FILE = "train_dataset.jsonl"
VAL_FILE = "val_dataset.jsonl"
# ─────────────────────────────────────────────
# LOAD DATA
# ─────────────────────────────────────────────
dataset = load_dataset("json", data_files={"train": TRAIN_FILE, "validation": VAL_FILE})
print(f"Train: {len(dataset['train'])} | Val: {len(dataset['validation'])}")
# ─────────────────────────────────────────────
# TOKENIZER
# ─────────────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # important for causal LM training
# ─────────────────────────────────────────────
# QUANTIZATION (4-bit QLoRA)
# ─────────────────────────────────────────────
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, # bfloat16 is more stable than float16
bnb_4bit_use_double_quant=True, # saves a bit more VRAM
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
attn_implementation="eager", # avoids flash-attn issues on Colab
)
model.config.use_cache = False # required for gradient checkpointing
# ─────────────────────────────────────────────
# LORA CONFIG
# ─────────────────────────────────────────────
lora_config = LoraConfig(
r=8, # smaller r is fine for simple extraction
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
# Explicitly target attention + MLP layers for Qwen2.5
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
)
# ─────────────────────────────────────────────
# FORMAT FUNCTION
# ─────────────────────────────────────────────
def format_example(example):
"""Apply Qwen2.5 chat template to each training example."""
return tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
add_generation_prompt=False
)
# ─────────────────────────────────────────────
# TRAINING ARGS
# ─────────────────────────────────────────────
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=1,
gradient_accumulation_steps=4, # effective batch = 4
gradient_checkpointing=True, # saves VRAM
learning_rate=2e-4,
num_train_epochs=3,
lr_scheduler_type="cosine", # smooth decay
warmup_ratio=0.05, # 5% warmup steps
logging_steps=10,
eval_strategy="epoch", # evaluate after each epoch
save_strategy="epoch", # save checkpoint each epoch
save_total_limit=2, # keep only last 2 checkpoints
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
fp16=False,
bf16=True, # use bfloat16 if your GPU supports it
report_to="none", # set to "wandb" if you want tracking
)
# ─────────────────────────────────────────────
# TRAINER
# ─────────────────────────────────────────────
trainer = SFTTrainer(
model=model,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
peft_config=lora_config,
args=training_args,
formatting_func=format_example,
max_seq_length=512, # extraction tasks are short
)
trainer.train()
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"βœ… Adapter saved to {OUTPUT_DIR}")