"""Train CodeBERT cross-encoder for SQL error classification with HF Trainer.""" from __future__ import annotations import argparse import json from pathlib import Path import numpy as np import pandas as pd import torch from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, EarlyStoppingCallback, Trainer, TrainingArguments, ) from src.device_utils import get_device from src.codebert_dataset import ( SQLCodeBERTDataCollator, prepare_datasets, ) from src.codebert_labels import load_codebert_labels from src.hf_metrics import build_compute_metrics, compute_multilabel_metrics PROJECT_ROOT = Path(__file__).resolve().parent.parent DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_1m.parquet" DEFAULT_OUTPUT = PROJECT_ROOT / "models" / "codebert-cross-encoder" DEFAULT_MODEL = "microsoft/codebert-base" def train( data_path: Path | None = DEFAULT_DATA, dataframe: pd.DataFrame | None = None, output_dir: Path = DEFAULT_OUTPUT, model_name: str = DEFAULT_MODEL, epochs: float = 3.0, batch_size: int = 16, eval_batch_size: int = 32, learning_rate: float = 2e-5, weight_decay: float = 0.01, warmup_ratio: float = 0.06, max_length: int = 512, max_samples: int | None = None, test_size: float = 0.1, val_size: float = 0.1, threshold: float = 0.5, seed: int = 42, push_to_hub: bool = False, hub_model_id: str | None = None, fp16: bool = False, save_strategy: str = "no", hub_token: str | None = None, ) -> dict: if dataframe is not None: df = dataframe.copy() print(f"Loaded dataframe with {len(df):,} rows") elif data_path is not None: print(f"Loading dataset from {data_path}...") df = pd.read_parquet(data_path) else: raise ValueError("Either data_path or dataframe must be provided") if max_samples and len(df) > max_samples: df = df.sample(n=max_samples, random_state=seed) label_list = load_codebert_labels() num_labels = len(label_list) print(f"Labels ({num_labels}): {label_list}") print(f"Samples: {len(df):,}") device = get_device() use_fp16 = fp16 and device == "cuda" print(f"Device: {device} | fp16: {use_fp16}") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=num_labels, problem_type="multi_label_classification", id2label={i: name for i, name in enumerate(label_list)}, label2id={name: i for i, name in enumerate(label_list)}, ) train_ds, val_ds, test_ds = prepare_datasets( df, tokenizer, test_size=test_size, val_size=val_size, max_length=max_length, seed=seed, ) print(f"Train: {len(train_ds):,} | Val: {len(val_ds):,} | Test: {len(test_ds):,}") output_dir.mkdir(parents=True, exist_ok=True) label_info = { "labels": label_list, "model_name": model_name, "architecture": "codebert-cross-encoder", "input_format": "QUESTION + SCHEMA + STUDENT_SQL + CORRECT_SQL", "max_length": max_length, "threshold": threshold, } with open(output_dir / "label_config.json", "w") as f: json.dump(label_info, f, indent=2) training_args = TrainingArguments( output_dir=str(output_dir), num_train_epochs=epochs, per_device_train_batch_size=batch_size, per_device_eval_batch_size=eval_batch_size, learning_rate=learning_rate, weight_decay=weight_decay, warmup_ratio=warmup_ratio, eval_strategy="epoch", save_strategy=save_strategy, logging_strategy="steps", logging_steps=50, load_best_model_at_end=save_strategy == "epoch", metric_for_best_model="f1_macro", greater_is_better=True, save_total_limit=1, seed=seed, report_to="none", fp16=use_fp16, use_mps_device=(device == "mps"), push_to_hub=push_to_hub, hub_model_id=hub_model_id, hub_token=hub_token, ) callbacks = [] if save_strategy == "epoch": callbacks.append(EarlyStoppingCallback(early_stopping_patience=2)) trainer_kwargs = dict( model=model, args=training_args, train_dataset=train_ds, eval_dataset=val_ds, data_collator=SQLCodeBERTDataCollator(tokenizer), compute_metrics=build_compute_metrics(threshold=threshold), callbacks=callbacks, ) try: trainer = Trainer(processing_class=tokenizer, **trainer_kwargs) except TypeError: trainer = Trainer(tokenizer=tokenizer, **trainer_kwargs) print("Starting CodeBERT cross-encoder training...") train_result = trainer.train() print("Evaluating on validation set...") val_metrics = trainer.evaluate() print("Evaluating on held-out test set...") test_output = trainer.predict(test_ds) test_metrics = compute_multilabel_metrics( test_output.predictions, test_output.label_ids, threshold=threshold, ) trainer.save_model(str(output_dir)) tokenizer.save_pretrained(str(output_dir)) metrics = { "train_samples": len(train_ds), "val_samples": len(val_ds), "test_samples": len(test_ds), "train_runtime": train_result.metrics.get("train_runtime"), "validation": val_metrics, "test": test_metrics, } with open(output_dir / "metrics.json", "w") as f: json.dump(metrics, f, indent=2, default=float) print(f"\nValidation F1 (macro): {val_metrics.get('eval_f1_macro', 0):.4f}") print(f"Test F1 (macro): {test_metrics['f1_macro']:.4f}") print(f"Test subset accuracy: {test_metrics['subset_accuracy']:.4f}") print(f"Model saved to {output_dir}") return metrics def main() -> None: parser = argparse.ArgumentParser( description="Train CodeBERT cross-encoder with Hugging Face Trainer" ) parser.add_argument("--data", type=Path, default=DEFAULT_DATA) parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT) parser.add_argument("--model-name", type=str, default=DEFAULT_MODEL) parser.add_argument("--epochs", type=float, default=3.0) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--eval-batch-size", type=int, default=32) parser.add_argument("--learning-rate", type=float, default=2e-5) parser.add_argument("--max-length", type=int, default=512) parser.add_argument("--max-samples", type=int, default=None) parser.add_argument("--threshold", type=float, default=0.5) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--push-to-hub", action="store_true") parser.add_argument("--hub-model-id", type=str, default=None) parser.add_argument("--fp16", action="store_true") parser.add_argument( "--save-strategy", choices=["no", "epoch"], default="no", help="Use 'no' to save only final model (saves disk space)", ) args = parser.parse_args() train( data_path=args.data, output_dir=args.output_dir, model_name=args.model_name, epochs=args.epochs, batch_size=args.batch_size, eval_batch_size=args.eval_batch_size, learning_rate=args.learning_rate, max_length=args.max_length, max_samples=args.max_samples, threshold=args.threshold, seed=args.seed, push_to_hub=args.push_to_hub, hub_model_id=args.hub_model_id, fp16=args.fp16, save_strategy=args.save_strategy, ) if __name__ == "__main__": main()