| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Soft-Label Cross-Encoder Reranker Training |
| |
| Trains a reranker using continuous relevance scores (soft labels). |
| Dataset format: {"query": "...", "text": "...", "score": 0.0-1.0} |
| """ |
|
|
| import logging |
| import os |
| import math |
| from collections import defaultdict |
| import trackio |
| import numpy as np |
| from datasets import load_dataset |
| from sentence_transformers.cross_encoder import ( |
| CrossEncoder, |
| CrossEncoderTrainer, |
| CrossEncoderTrainingArguments, |
| ) |
| from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator |
| from scipy.stats import spearmanr |
| from transformers import TrainerCallback |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| DATASET_NAME = os.environ.get("DATASET_NAME", "amanwithaplan/arcade-reranker-data") |
| HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "amanwithaplan/arcade-reranker") |
| BASE_MODEL = os.environ.get("BASE_MODEL", "Alibaba-NLP/gte-reranker-modernbert-base") |
| NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "10")) |
| BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "16")) |
| LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "2e-5")) |
| MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "512")) |
| RUN_NAME = os.environ.get("RUN_NAME", "reranker-03130903") |
| SPACE_ID = os.environ.get("TRACKIO_SPACE_ID", "amanwithaplan/trackio") |
|
|
|
|
| def dcg_at_k(relevances, k): |
| """Compute DCG@k.""" |
| relevances = np.array(relevances)[:k] |
| if len(relevances) == 0: |
| return 0.0 |
| |
| discounts = np.log2(np.arange(len(relevances)) + 2) |
| return np.sum(relevances / discounts) |
|
|
|
|
| def ndcg_at_k(predicted_order, true_relevances, k): |
| """ |
| Compute NDCG@k. |
| |
| predicted_order: indices of docs sorted by model score (descending) |
| true_relevances: ground truth relevance scores for each doc |
| """ |
| |
| predicted_relevances = [true_relevances[i] for i in predicted_order] |
|
|
| |
| ideal_relevances = sorted(true_relevances, reverse=True) |
|
|
| dcg = dcg_at_k(predicted_relevances, k) |
| idcg = dcg_at_k(ideal_relevances, k) |
|
|
| if idcg == 0: |
| return 0.0 |
| return dcg / idcg |
|
|
|
|
| def mrr(predicted_order, true_relevances, threshold=0.5): |
| """ |
| Compute MRR (Mean Reciprocal Rank). |
| |
| Returns 1/rank of first relevant doc (relevance > threshold). |
| """ |
| for rank, idx in enumerate(predicted_order, start=1): |
| if true_relevances[idx] > threshold: |
| return 1.0 / rank |
| return 0.0 |
|
|
|
|
| def evaluate_ranking(model, eval_dataset): |
| """ |
| Proper ranking evaluation: group by query, compute NDCG and MRR. |
| |
| This measures what we actually care about: |
| "Given a query with multiple docs, does the model rank them correctly?" |
| """ |
| |
| query_groups = defaultdict(list) |
| for item in eval_dataset: |
| query_groups[item["sentence1"]].append({ |
| "text": item["sentence2"], |
| "label": item["label"] |
| }) |
|
|
| |
| query_groups = {q: docs for q, docs in query_groups.items() if len(docs) >= 2} |
|
|
| if not query_groups: |
| return {"ndcg@3": 0.0, "ndcg@5": 0.0, "mrr": 0.0, "n_queries": 0} |
|
|
| ndcg_3_scores = [] |
| ndcg_5_scores = [] |
| mrr_scores = [] |
| rank_correlations = [] |
|
|
| for query, docs in query_groups.items(): |
| |
| pairs = [(query, d["text"]) for d in docs] |
| predictions = model.predict(pairs, show_progress_bar=False) |
|
|
| true_relevances = [d["label"] for d in docs] |
|
|
| |
| predicted_order = np.argsort(predictions)[::-1].tolist() |
|
|
| |
| ndcg_3_scores.append(ndcg_at_k(predicted_order, true_relevances, k=3)) |
| ndcg_5_scores.append(ndcg_at_k(predicted_order, true_relevances, k=5)) |
| mrr_scores.append(mrr(predicted_order, true_relevances, threshold=0.5)) |
|
|
| |
| if len(set(true_relevances)) > 1: |
| corr = spearmanr(predictions, true_relevances).correlation |
| if not math.isnan(corr): |
| rank_correlations.append(corr) |
|
|
| return { |
| "ndcg@3": np.mean(ndcg_3_scores), |
| "ndcg@5": np.mean(ndcg_5_scores), |
| "mrr": np.mean(mrr_scores), |
| "rank_corr": np.mean(rank_correlations) if rank_correlations else 0.0, |
| "n_queries": len(query_groups), |
| } |
|
|
|
|
| class DomainEvalCallback(TrainerCallback): |
| """Callback to log proper ranking metrics during training.""" |
|
|
| def __init__(self, model, eval_dataset_full): |
| self.model = model |
| self.eval_dataset_full = eval_dataset_full |
|
|
| def on_evaluate(self, args, state, control, **kwargs): |
| """Run after each evaluation step.""" |
| metrics = evaluate_ranking(self.model, self.eval_dataset_full) |
|
|
| |
| trackio.log({ |
| "domain/ndcg@3": metrics["ndcg@3"], |
| "domain/ndcg@5": metrics["ndcg@5"], |
| "domain/mrr": metrics["mrr"], |
| "domain/rank_corr": metrics["rank_corr"], |
| }) |
|
|
| logger.info( |
| f"Domain eval - NDCG@3: {metrics['ndcg@3']:.4f}, " |
| f"NDCG@5: {metrics['ndcg@5']:.4f}, " |
| f"MRR: {metrics['mrr']:.4f}, " |
| f"RankCorr: {metrics['rank_corr']:.4f} " |
| f"(n={metrics['n_queries']} queries)" |
| ) |
|
|
|
|
| def evaluate_by_type(model, eval_dataset, type_column="type"): |
| """Evaluate ranking metrics per content type.""" |
| if type_column not in eval_dataset.column_names: |
| return {} |
|
|
| |
| by_type = defaultdict(list) |
| for item in eval_dataset: |
| by_type[item[type_column]].append(item) |
|
|
| results = {} |
| for content_type, items in by_type.items(): |
| |
| class TypeDataset: |
| def __init__(self, items): |
| self.items = items |
| def __iter__(self): |
| return iter(self.items) |
| @property |
| def column_names(self): |
| return ["sentence1", "sentence2", "label"] |
|
|
| type_metrics = evaluate_ranking(model, TypeDataset(items)) |
|
|
| if type_metrics["n_queries"] >= 2: |
| results[f"{content_type}_ndcg@5"] = type_metrics["ndcg@5"] |
| results[f"{content_type}_mrr"] = type_metrics["mrr"] |
| results[f"{content_type}_n_queries"] = type_metrics["n_queries"] |
|
|
| return results |
|
|
|
|
| def main(): |
| |
| trackio.init( |
| project="arcade-reranker", |
| name=RUN_NAME, |
| space_id=SPACE_ID, |
| config={ |
| "model": BASE_MODEL, |
| "dataset": DATASET_NAME, |
| "learning_rate": LEARNING_RATE, |
| "num_epochs": NUM_EPOCHS, |
| "batch_size": BATCH_SIZE, |
| "max_seq_length": MAX_SEQ_LENGTH, |
| } |
| ) |
|
|
| logger.info(f"Configuration:") |
| logger.info(f" Dataset: {DATASET_NAME}") |
| logger.info(f" Base model: {BASE_MODEL}") |
| logger.info(f" Epochs: {NUM_EPOCHS}") |
| logger.info(f" Run name: {RUN_NAME}") |
| logger.info(f" Trackio space: {SPACE_ID}") |
|
|
| model = CrossEncoder(BASE_MODEL, max_length=MAX_SEQ_LENGTH) |
|
|
| logger.info(f"Loading dataset: {DATASET_NAME}") |
| dataset = load_dataset(DATASET_NAME, split="train") |
|
|
| |
| type_counts = defaultdict(int) |
| if "type" in dataset.column_names: |
| for item in dataset: |
| type_counts[item["type"]] += 1 |
| logger.info(f"Dataset composition: {dict(type_counts)}") |
|
|
| |
| for content_type, count in type_counts.items(): |
| trackio.log({f"data/{content_type}_count": count}) |
|
|
| trackio.log({"data/total_examples": len(dataset)}) |
| logger.info(f"Total examples: {len(dataset)}") |
|
|
| |
| dataset = dataset.rename_columns({ |
| "query": "sentence1", |
| "text": "sentence2", |
| "score": "label" |
| }) |
|
|
| |
| eval_size = min(400, int(len(dataset) * 0.15)) |
| splits = dataset.train_test_split(test_size=eval_size, seed=42) |
|
|
| |
| eval_dataset_full = splits["test"] |
|
|
| |
| train_dataset = splits["train"].select_columns(["sentence1", "sentence2", "label"]) |
| eval_dataset = splits["test"].select_columns(["sentence1", "sentence2", "label"]) |
|
|
| trackio.log({ |
| "data/train_size": len(train_dataset), |
| "data/eval_size": len(eval_dataset), |
| }) |
| logger.info(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}") |
|
|
| |
| logger.info("Evaluating base model on eval set...") |
| base_metrics = evaluate_ranking(model, eval_dataset_full) |
| for key, value in base_metrics.items(): |
| trackio.log({f"base_model/{key}": value}) |
| logger.info(f"Base model metrics: {base_metrics}") |
|
|
| |
| evaluator = CrossEncoderNanoBEIREvaluator( |
| dataset_names=["msmarco", "nfcorpus", "nq"], |
| batch_size=BATCH_SIZE, |
| ) |
|
|
| args = CrossEncoderTrainingArguments( |
| output_dir="models/reranker", |
| num_train_epochs=NUM_EPOCHS, |
| per_device_train_batch_size=BATCH_SIZE, |
| per_device_eval_batch_size=BATCH_SIZE, |
| learning_rate=LEARNING_RATE, |
| warmup_ratio=0.1, |
| bf16=True, |
| eval_strategy="steps", |
| eval_steps=25, |
| save_strategy="steps", |
| save_steps=25, |
| save_total_limit=5, |
| logging_steps=25, |
| logging_first_step=True, |
| load_best_model_at_end=True, |
| metric_for_best_model="eval_loss", |
| greater_is_better=False, |
| push_to_hub=True, |
| hub_model_id=HUB_MODEL_ID, |
| hub_strategy="every_save", |
| report_to="trackio", |
| run_name=RUN_NAME, |
| ) |
|
|
| |
| domain_callback = DomainEvalCallback(model, eval_dataset_full) |
|
|
| trainer = CrossEncoderTrainer( |
| model=model, |
| args=args, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| evaluator=evaluator, |
| callbacks=[domain_callback], |
| ) |
|
|
| logger.info("Starting training...") |
| trainer.train() |
|
|
| |
| logger.info("Running final ranking evaluation...") |
| final_metrics = evaluate_ranking(model, eval_dataset_full) |
| for key, value in final_metrics.items(): |
| trackio.log({f"final/{key}": value}) |
| logger.info(f"Final metrics: {final_metrics}") |
|
|
| |
| logger.info("Evaluating by content type...") |
| type_metrics = evaluate_by_type(model, eval_dataset_full) |
| for key, value in type_metrics.items(): |
| trackio.log({f"final/by_type/{key}": value}) |
| logger.info(f"Per-type metrics: {type_metrics}") |
|
|
| |
| trackio.log({ |
| "improvement/ndcg5_delta": final_metrics["ndcg@5"] - base_metrics["ndcg@5"], |
| "improvement/mrr_delta": final_metrics["mrr"] - base_metrics["mrr"], |
| }) |
|
|
| logger.info(f"Pushing final model to {HUB_MODEL_ID}") |
| model.push_to_hub(HUB_MODEL_ID, exist_ok=True) |
|
|
| trackio.finish() |
| logger.info("Done!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|