AvocadoMuffin's picture
Update train.py
bad80b7 verified
#!/usr/bin/env python
# train_cuad_lora_efficient.py - FIXED VERSION
"""
CUAD fine-tune with LoRA - Fixed for realistic training times
"""
import os, json, random, gc, time
from collections import defaultdict
from pathlib import Path
import torch, numpy as np
from datasets import load_dataset, Dataset, disable_caching
from transformers import (
AutoTokenizer, AutoModelForQuestionAnswering,
TrainingArguments, default_data_collator, Trainer
)
from peft import LoraConfig, get_peft_model, TaskType
import evaluate
from huggingface_hub import login
disable_caching()
# Set tokenizers parallelism to avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# ─────────────────────────────────────────────────────────────── config ──
MAX_LEN = 512 # Slightly longer context
DOC_STRIDE = 256 # Larger stride = fewer chunks = faster training
SEED = 42
BATCH_SIZE = 1000 # Process in larger, more efficient batches
# Back to reasonable subset size since you've trained 5k before
USE_SUBSET = True
SUBSET_SIZE = 7000 # Good middle ground - more than your 5k success
def set_seed(seed):
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def balance_has_answer(dataset, ratio=2.0, max_samples=None):
"""Keep all has-answer rows, down-sample no-answer rows to `ratio`."""
has, no = [], []
for ex in dataset:
(has if ex["answers"]["text"] else no).append(ex)
print(f"πŸ“Š Original: {len(has)} has-answer, {len(no)} no-answer")
# FIXED: Apply max_samples FIRST, then balance
if max_samples:
total_available = len(has) + len(no)
if total_available > max_samples:
# Sample proportionally from original distribution
has_ratio = len(has) / total_available
target_has = int(max_samples * has_ratio)
target_no = max_samples - target_has
has = random.sample(has, min(target_has, len(has)))
no = random.sample(no, min(target_no, len(no)))
print(f"πŸ“‰ Pre-balance subset: {len(has)} has-answer, {len(no)} no-answer")
# Now balance within the subset
k = int(len(has) * ratio)
if len(no) > k:
no = random.sample(no, k)
balanced = has + no
random.shuffle(balanced) # Shuffle the final dataset
print(f"πŸ“Š Final balanced: {len([x for x in balanced if x['answers']['text']])} has-answer, {len([x for x in balanced if not x['answers']['text']])} no-answer")
print(f"πŸ“Š Total examples: {len(balanced)}")
return Dataset.from_list(balanced)
# ────────────────────────────────────────────────────────────── postproc ──
metric = evaluate.load("squad")
def postprocess_qa(examples, features, raw_predictions, tokenizer):
"""HF-style span extraction + n-best, returns SQuAD format dict."""
all_start, all_end = raw_predictions
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
features_per_example = defaultdict(list)
for i, feat_id in enumerate(features["example_id"]):
features_per_example[example_id_to_index[feat_id]].append(i)
predictions = []
for example_idx, example in enumerate(examples):
best_score = -1e9
best_span = ""
context = example["context"]
for feat_idx in features_per_example[example_idx]:
start_logit = all_start[feat_idx]
end_logit = all_end[feat_idx]
offset = features["offset_mapping"][feat_idx]
start_idx = int(np.argmax(start_logit))
end_idx = int(np.argmax(end_logit))
if start_idx <= end_idx < len(offset):
start_char, _ = offset[start_idx]
_, end_char = offset[end_idx]
span = context[start_char:end_char].strip()
score = start_logit[start_idx] + end_logit[end_idx]
if score > best_score and span:
best_score, best_span = score, span
predictions.append(
{"id": example["id"], "prediction_text": best_span}
)
return predictions
# ───────────────────────────────────────────────────────────── preprocessing ──
def preprocess_training_batch(examples, tokenizer):
"""Training preprocessing - NO offset_mapping included"""
questions = examples["question"]
contexts = examples["context"]
tokenized_examples = tokenizer(
questions,
contexts,
truncation="only_second",
max_length=MAX_LEN,
stride=DOC_STRIDE,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
offset_mapping = tokenized_examples.pop("offset_mapping")
start_positions = []
end_positions = []
for i, offsets in enumerate(offset_mapping):
cls_index = 0
sample_index = sample_mapping[i]
answers = examples["answers"][sample_index]
if not answers["text"] or not answers["text"][0]:
start_positions.append(cls_index)
end_positions.append(cls_index)
continue
answer_start_char = answers["answer_start"][0]
answer_text = answers["text"][0]
answer_end_char = answer_start_char + len(answer_text)
token_start_index = cls_index
token_end_index = cls_index
for token_index, (start_char, end_char) in enumerate(offsets):
if start_char <= answer_start_char < end_char:
token_start_index = token_index
if start_char < answer_end_char <= end_char:
token_end_index = token_index
break
if token_start_index <= token_end_index and token_start_index > 0:
start_positions.append(token_start_index)
end_positions.append(token_end_index)
else:
start_positions.append(cls_index)
end_positions.append(cls_index)
tokenized_examples["start_positions"] = start_positions
tokenized_examples["end_positions"] = end_positions
return tokenized_examples
def preprocess_validation_batch(examples, tokenizer):
"""Validation preprocessing - INCLUDES offset_mapping and example_id"""
questions = examples["question"]
contexts = examples["context"]
tokenized_examples = tokenizer(
questions,
contexts,
truncation="only_second",
max_length=MAX_LEN,
stride=DOC_STRIDE,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
tokenized_examples["example_id"] = [
examples["id"][sample_mapping[i]] for i in range(len(tokenized_examples["input_ids"]))
]
return tokenized_examples
def preprocess_dataset_streaming(dataset, tokenizer, desc="Processing", is_training=True):
"""Process dataset in batches using HuggingFace's map function with batching."""
print(f"πŸ”„ {desc} dataset with batch processing...")
if is_training:
preprocess_fn = preprocess_training_batch
else:
preprocess_fn = preprocess_validation_batch
processed = dataset.map(
lambda examples: preprocess_fn(examples, tokenizer),
batched=True,
batch_size=BATCH_SIZE,
remove_columns=dataset.column_names,
desc=desc,
num_proc=1,
)
print(f"βœ… {desc} completed: {len(processed)} features")
return processed
# ───────────────────────────────────────────────────────────────── main ──
def main():
set_seed(SEED)
model_repo = os.getenv("MODEL_NAME", "AvocadoMuffin/roberta-cuad-qa-v4")
if (tokn := os.getenv("roberta_token")):
try:
login(tokn)
print("πŸ”‘ HuggingFace Hub login OK")
except Exception as e:
print(f"⚠️ Hub login failed: {e}")
tokn = None
print("πŸ“š Loading CUAD…")
try:
cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True)
print(f"βœ… Loaded {len(cuad)} examples")
except Exception as e:
print(f"❌ Dataset loading failed: {e}")
cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True, download_mode="force_redownload")
cuad = cuad.shuffle(seed=SEED)
# FIXED: Apply subset reduction more aggressively
subset_size = SUBSET_SIZE if USE_SUBSET else None
cuad = balance_has_answer(cuad, ratio=1.5, max_samples=subset_size) # Reduced ratio too
print(f"πŸ“Š Final dataset size: {len(cuad)} examples")
# Estimate features after preprocessing
avg_features_per_example = 2.5 # Conservative estimate with stride
estimated_features = len(cuad) * avg_features_per_example
print(f"πŸ“Š Estimated training features: ~{int(estimated_features)}")
ds = cuad.train_test_split(test_size=0.1, seed=SEED)
train_raw, val_raw = ds["train"], ds["test"]
# ── tokeniser & model ──
base_ckpt = "deepset/roberta-base-squad2"
tok = AutoTokenizer.from_pretrained(base_ckpt, use_fast=True)
model = AutoModelForQuestionAnswering.from_pretrained(base_ckpt)
# FIXED: Lighter LoRA config for faster training
lora = LoraConfig(
task_type=TaskType.QUESTION_ANS,
r=16, # Reduced from 32
lora_alpha=32, # Reduced from 64
lora_dropout=0.1,
target_modules=["query", "value"], # Fewer modules
)
model = get_peft_model(model, lora)
model.print_trainable_parameters()
# ── preprocessing ─────────────────────────────────────────
print("πŸ”„ Starting preprocessing...")
train_feats = preprocess_dataset_streaming(train_raw, tok, "Training", is_training=True)
val_feats = preprocess_dataset_streaming(val_raw, tok, "Validation", is_training=False)
print(f"βœ… Preprocessing completed!")
print(f" Training features: {len(train_feats)}")
print(f" Validation features: {len(val_feats)}")
# ── training args - FIXED for reasonable training time ──
batch_size = 16 # Good balance
gradient_accumulation_steps = 2
effective_batch_size = batch_size * gradient_accumulation_steps
num_epochs = 3 # Keep it reasonable
steps_per_epoch = len(train_feats) // effective_batch_size
total_steps = steps_per_epoch * num_epochs
eval_steps = max(25, steps_per_epoch // 8) # More frequent eval
save_steps = eval_steps * 3
print(f"πŸ“Š Training configuration:")
print(f" Effective batch size: {effective_batch_size}")
print(f" Steps per epoch: {steps_per_epoch}")
print(f" Total steps: {total_steps}")
print(f" Estimated time: ~{total_steps/2.4/60:.1f} minutes")
print(f" Eval every: {eval_steps} steps")
args = TrainingArguments(
output_dir="./cuad_lora_out",
learning_rate=3e-5, # Slightly lower LR
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=8,
gradient_accumulation_steps=gradient_accumulation_steps,
fp16=False, bf16=True,
eval_strategy="steps",
eval_steps=eval_steps,
save_steps=save_steps,
save_total_limit=2,
weight_decay=0.01,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
load_best_model_at_end=False,
logging_steps=10, # More frequent logging
report_to="none",
dataloader_num_workers=2,
dataloader_pin_memory=True,
remove_unused_columns=True,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_feats,
eval_dataset=val_feats,
tokenizer=tok,
data_collator=default_data_collator,
compute_metrics=None,
)
print("πŸš€ Training…")
try:
trainer.train()
print("βœ… Training completed successfully!")
except Exception as e:
print(f"❌ Training failed: {e}")
try:
trainer.save_model("./cuad_lora_out_partial")
tok.save_pretrained("./cuad_lora_out_partial")
print("πŸ’Ύ Partial model saved")
except:
print("❌ Could not save partial model")
raise e
print("βœ… Done. Best eval_loss:", trainer.state.best_metric)
trainer.save_model("./cuad_lora_out")
tok.save_pretrained("./cuad_lora_out")
# Push to hub
if tokn:
for attempt in range(3):
try:
print(f"⬆️ Pushing to Hub (attempt {attempt + 1}/3)...")
trainer.push_to_hub(model_repo, private=False)
tok.push_to_hub(model_repo, private=False)
print("πŸš€ Pushed to:", f"https://huggingface.co/{model_repo}")
break
except Exception as e:
print(f"⚠️ Hub push failed: {e}")
if attempt < 2:
time.sleep(30)
else:
print("πŸ’Ύ Model saved locally (push failed)")
if __name__ == "__main__":
main()