text2sql-demo / src /train_sft_bart.py
tjhalanigrid's picture
Add src folder
dc59b01
from __future__ import annotations
import os
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
)
from prompting import clean_gold_sql, get_schema_text, build_prompt
# =====================================================
# SETTINGS
# =====================================================
BASE_MODEL = os.environ.get("BASE_MODEL", "facebook/bart-base")
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
OUT_DIR = os.path.join(PROJECT_ROOT, "checkpoints", "sft_best_bart_2")
TRAIN_SPLIT = "train[:7000]"
EPOCHS = 12
LR = 3e-4
PER_DEVICE_BATCH = 16
GRAD_ACCUM = 4
MAX_INPUT = 512
MAX_OUTPUT = 128
# =====================================================
# DEVICE
# =====================================================
os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device("mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu"))
print("Using device:", device)
# =====================================================
# TOKENIZER
# =====================================================
print("Loading tokenizer/model:", BASE_MODEL)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
# =====================================================
# PREPROCESS FUNCTION
# =====================================================
def preprocess_function(example):
question = example["question"]
db_id = example["db_id"]
gold_sql = clean_gold_sql(example["query"])
# ---- Build Prompt ----
schema_text = get_schema_text(db_id)
prompt = build_prompt(question, db_id, schema_text=schema_text, training_sql=None)
model_inputs = tokenizer(
prompt,
max_length=MAX_INPUT,
truncation=True,
padding="max_length",
)
# ---- Target SQL ----
labels = tokenizer(
gold_sql,
max_length=MAX_OUTPUT,
truncation=True,
padding="max_length",
)["input_ids"]
# IMPORTANT: ignore padding in loss
labels = [
(tok if tok != tokenizer.pad_token_id else -400)
for tok in labels
]
model_inputs["labels"] = labels
return model_inputs
# =====================================================
# DATASET
# =====================================================
print("Loading Spider subset:", TRAIN_SPLIT)
dataset = load_dataset("spider", split=TRAIN_SPLIT)
dataset = dataset.train_test_split(test_size=0.1, seed=42)
train_ds = dataset["train"]
eval_ds = dataset["test"]
print("Tokenizing dataset (single process, stable)...")
train_tok = train_ds.map(
preprocess_function,
batched=False,
num_proc=1,
remove_columns=train_ds.column_names,
load_from_cache_file=False,
)
eval_tok = eval_ds.map(
preprocess_function,
batched=False,
num_proc=1,
remove_columns=eval_ds.column_names,
load_from_cache_file=False,
)
print("Train dataset size:", len(train_tok))
print("Eval dataset size:", len(eval_tok))
# =====================================================
# MODEL + LoRA
# =====================================================
base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
base_model.config.use_cache = False
# πŸš€ UPGRADE 1: Expanded LoRA brainpower
lora_config = LoraConfig(
r=16, # Increased rank for more learning capacity
lora_alpha=32, # Alpha is typically 2x the rank
lora_dropout=0.1,
bias="none",
task_type="SEQ_2_SEQ_LM",
# Target all attention and dense layers in BART
target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
)
model = get_peft_model(base_model, lora_config)
model.to(device)
# =====================================================
# TRAINER
# =====================================================
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
model=model,
padding=True,
)
args = Seq2SeqTrainingArguments(
output_dir=os.path.join(PROJECT_ROOT, "checkpoints", "sft_bart_runs"),
num_train_epochs=EPOCHS,
learning_rate=LR,
per_device_train_batch_size=PER_DEVICE_BATCH,
per_device_eval_batch_size=PER_DEVICE_BATCH,
gradient_accumulation_steps=GRAD_ACCUM,
dataloader_num_workers=0,
dataloader_pin_memory=False,
# πŸš€ UPGRADE 2 & 3: Better optimization & generalization
warmup_ratio=0.05, # Slowly ramp up learning rate
weight_decay=0.01, # Penalize over-reliance on single tokens
label_smoothing_factor=0.1, # Prevent overconfidence in SQL token matching
evaluation_strategy="epoch",
save_strategy="epoch",
save_total_limit=1,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
logging_steps=50,
report_to=[],
fp16=False,
bf16=False,
predict_with_generate=True,
)
trainer = Seq2SeqTrainer(
model=model,
args=args,
train_dataset=train_tok,
eval_dataset=eval_tok,
tokenizer=tokenizer,
data_collator=data_collator,
)
# =====================================================
# TRAIN
# =====================================================
trainer.train()
# =====================================================
# SAVE BEST MODEL
# =====================================================
print("Saving best BART LoRA adapter to:", OUT_DIR)
os.makedirs(OUT_DIR, exist_ok=True)
trainer.model.save_pretrained(OUT_DIR)
tokenizer.save_pretrained(OUT_DIR)
print("DONE βœ” SFT BART finished")