Spaces:
Running
Running
| """Train CodeBERT cross-encoder for SQL error classification with HF Trainer.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from transformers import ( | |
| AutoModelForSequenceClassification, | |
| AutoTokenizer, | |
| EarlyStoppingCallback, | |
| Trainer, | |
| TrainingArguments, | |
| ) | |
| from src.device_utils import get_device | |
| from src.codebert_dataset import ( | |
| SQLCodeBERTDataCollator, | |
| prepare_datasets, | |
| ) | |
| from src.codebert_labels import load_codebert_labels | |
| from src.hf_metrics import build_compute_metrics, compute_multilabel_metrics | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_1m.parquet" | |
| DEFAULT_OUTPUT = PROJECT_ROOT / "models" / "codebert-cross-encoder" | |
| DEFAULT_MODEL = "microsoft/codebert-base" | |
| def train( | |
| data_path: Path | None = DEFAULT_DATA, | |
| dataframe: pd.DataFrame | None = None, | |
| output_dir: Path = DEFAULT_OUTPUT, | |
| model_name: str = DEFAULT_MODEL, | |
| epochs: float = 3.0, | |
| batch_size: int = 16, | |
| eval_batch_size: int = 32, | |
| learning_rate: float = 2e-5, | |
| weight_decay: float = 0.01, | |
| warmup_ratio: float = 0.06, | |
| max_length: int = 512, | |
| max_samples: int | None = None, | |
| test_size: float = 0.1, | |
| val_size: float = 0.1, | |
| threshold: float = 0.5, | |
| seed: int = 42, | |
| push_to_hub: bool = False, | |
| hub_model_id: str | None = None, | |
| fp16: bool = False, | |
| save_strategy: str = "no", | |
| hub_token: str | None = None, | |
| ) -> dict: | |
| if dataframe is not None: | |
| df = dataframe.copy() | |
| print(f"Loaded dataframe with {len(df):,} rows") | |
| elif data_path is not None: | |
| print(f"Loading dataset from {data_path}...") | |
| df = pd.read_parquet(data_path) | |
| else: | |
| raise ValueError("Either data_path or dataframe must be provided") | |
| if max_samples and len(df) > max_samples: | |
| df = df.sample(n=max_samples, random_state=seed) | |
| label_list = load_codebert_labels() | |
| num_labels = len(label_list) | |
| print(f"Labels ({num_labels}): {label_list}") | |
| print(f"Samples: {len(df):,}") | |
| device = get_device() | |
| use_fp16 = fp16 and device == "cuda" | |
| print(f"Device: {device} | fp16: {use_fp16}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_name, | |
| num_labels=num_labels, | |
| problem_type="multi_label_classification", | |
| id2label={i: name for i, name in enumerate(label_list)}, | |
| label2id={name: i for i, name in enumerate(label_list)}, | |
| ) | |
| train_ds, val_ds, test_ds = prepare_datasets( | |
| df, | |
| tokenizer, | |
| test_size=test_size, | |
| val_size=val_size, | |
| max_length=max_length, | |
| seed=seed, | |
| ) | |
| print(f"Train: {len(train_ds):,} | Val: {len(val_ds):,} | Test: {len(test_ds):,}") | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| label_info = { | |
| "labels": label_list, | |
| "model_name": model_name, | |
| "architecture": "codebert-cross-encoder", | |
| "input_format": "QUESTION + SCHEMA + STUDENT_SQL + CORRECT_SQL", | |
| "max_length": max_length, | |
| "threshold": threshold, | |
| } | |
| with open(output_dir / "label_config.json", "w") as f: | |
| json.dump(label_info, f, indent=2) | |
| training_args = TrainingArguments( | |
| output_dir=str(output_dir), | |
| num_train_epochs=epochs, | |
| per_device_train_batch_size=batch_size, | |
| per_device_eval_batch_size=eval_batch_size, | |
| learning_rate=learning_rate, | |
| weight_decay=weight_decay, | |
| warmup_ratio=warmup_ratio, | |
| eval_strategy="epoch", | |
| save_strategy=save_strategy, | |
| logging_strategy="steps", | |
| logging_steps=50, | |
| load_best_model_at_end=save_strategy == "epoch", | |
| metric_for_best_model="f1_macro", | |
| greater_is_better=True, | |
| save_total_limit=1, | |
| seed=seed, | |
| report_to="none", | |
| fp16=use_fp16, | |
| use_mps_device=(device == "mps"), | |
| push_to_hub=push_to_hub, | |
| hub_model_id=hub_model_id, | |
| hub_token=hub_token, | |
| ) | |
| callbacks = [] | |
| if save_strategy == "epoch": | |
| callbacks.append(EarlyStoppingCallback(early_stopping_patience=2)) | |
| trainer_kwargs = dict( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_ds, | |
| eval_dataset=val_ds, | |
| data_collator=SQLCodeBERTDataCollator(tokenizer), | |
| compute_metrics=build_compute_metrics(threshold=threshold), | |
| callbacks=callbacks, | |
| ) | |
| try: | |
| trainer = Trainer(processing_class=tokenizer, **trainer_kwargs) | |
| except TypeError: | |
| trainer = Trainer(tokenizer=tokenizer, **trainer_kwargs) | |
| print("Starting CodeBERT cross-encoder training...") | |
| train_result = trainer.train() | |
| print("Evaluating on validation set...") | |
| val_metrics = trainer.evaluate() | |
| print("Evaluating on held-out test set...") | |
| test_output = trainer.predict(test_ds) | |
| test_metrics = compute_multilabel_metrics( | |
| test_output.predictions, | |
| test_output.label_ids, | |
| threshold=threshold, | |
| ) | |
| trainer.save_model(str(output_dir)) | |
| tokenizer.save_pretrained(str(output_dir)) | |
| metrics = { | |
| "train_samples": len(train_ds), | |
| "val_samples": len(val_ds), | |
| "test_samples": len(test_ds), | |
| "train_runtime": train_result.metrics.get("train_runtime"), | |
| "validation": val_metrics, | |
| "test": test_metrics, | |
| } | |
| with open(output_dir / "metrics.json", "w") as f: | |
| json.dump(metrics, f, indent=2, default=float) | |
| print(f"\nValidation F1 (macro): {val_metrics.get('eval_f1_macro', 0):.4f}") | |
| print(f"Test F1 (macro): {test_metrics['f1_macro']:.4f}") | |
| print(f"Test subset accuracy: {test_metrics['subset_accuracy']:.4f}") | |
| print(f"Model saved to {output_dir}") | |
| return metrics | |
| def main() -> None: | |
| parser = argparse.ArgumentParser( | |
| description="Train CodeBERT cross-encoder with Hugging Face Trainer" | |
| ) | |
| parser.add_argument("--data", type=Path, default=DEFAULT_DATA) | |
| parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT) | |
| parser.add_argument("--model-name", type=str, default=DEFAULT_MODEL) | |
| parser.add_argument("--epochs", type=float, default=3.0) | |
| parser.add_argument("--batch-size", type=int, default=16) | |
| parser.add_argument("--eval-batch-size", type=int, default=32) | |
| parser.add_argument("--learning-rate", type=float, default=2e-5) | |
| parser.add_argument("--max-length", type=int, default=512) | |
| parser.add_argument("--max-samples", type=int, default=None) | |
| parser.add_argument("--threshold", type=float, default=0.5) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--push-to-hub", action="store_true") | |
| parser.add_argument("--hub-model-id", type=str, default=None) | |
| parser.add_argument("--fp16", action="store_true") | |
| parser.add_argument( | |
| "--save-strategy", | |
| choices=["no", "epoch"], | |
| default="no", | |
| help="Use 'no' to save only final model (saves disk space)", | |
| ) | |
| args = parser.parse_args() | |
| train( | |
| data_path=args.data, | |
| output_dir=args.output_dir, | |
| model_name=args.model_name, | |
| epochs=args.epochs, | |
| batch_size=args.batch_size, | |
| eval_batch_size=args.eval_batch_size, | |
| learning_rate=args.learning_rate, | |
| max_length=args.max_length, | |
| max_samples=args.max_samples, | |
| threshold=args.threshold, | |
| seed=args.seed, | |
| push_to_hub=args.push_to_hub, | |
| hub_model_id=args.hub_model_id, | |
| fp16=args.fp16, | |
| save_strategy=args.save_strategy, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |