Adaptive RAG - Query Router Model

Model Description

This is a fine-tuned distilbert-base-uncased model trained to classify incoming user queries into one of three information-need categories:

  • factual: Queries that are simple entity-centric lookups (suited for sparse/BM25 retrieval).
  • abstractive: Queries requiring conceptual or thematic context (suited for dense/FAISS retrieval).
  • multi-hop: Queries requiring synthesizing information across multiple documents (suited for iterative hybrid retrieval).

This router is the core decision engine for an Adaptive RAG pipeline. Instead of using a one-size-fits-all retrieval strategy (which hurts latency or accuracy), this lightweight router dynamically dispatches the query to the optimal strategy.

Architecture & Training

  • Base Model: distilbert-base-uncased
  • Training Data: ~10,000 queries sampled from 7 benchmarks (HotpotQA, SciFact, NFCorpus, TREC-COVID, ArguAna, FiQA, TriviaQA), silver-labeled via Gemini 1.5 Pro.
  • Performance: Achieved ~87% accuracy on a held-out test split (213 queries).

Intended Use

This model is intended to be used as a preprocessing step in a RAG pipeline. It outputs the optimal retrieval strategy, which can be mapped to downstream data stores.

Example Usage:

from transformers import pipeline

router = pipeline("text-classification", model="your-username/adaptive-rag-router")
prediction = router("What is the role of Wingless signaling in Drosophila hematopoiesis?")
print(prediction) 
# [{'label': 'abstractive', 'score': 0.98}]

Known Limitations and Caveats

When deploying or evaluating this model, please note the following limitations recorded during our evaluation phase:

  1. 3-Class Scope: The model was constrained to 3 classes (factual, abstractive, multi-hop) due to project scope. It does not predict time-sensitive (recency-biased) or no-retrieval-needed (direct LLM answering) classes.
  2. Multi-Hop Recall Limitation: On multi-hop datasets (like HotpotQA), the router exhibited approximately a 50% recall for multi-hop queries (escalating 47% of eligible non-multi-hop queries as a fallback).
  3. Out-of-Domain Misclassification: The training dataset's label skew heavily impacts performance on domain-specific corpora. For instance, on FiQA (financial reasoning), the router systematically misclassified complex financial steps as simple factual queries (0% multi-hop prediction rate).
  4. Hardware Inconsistency in Baselines: If reviewing the original project's latency metrics, note that runs for scifact and arguana were on Colab GPUs, while others (nfcorpus, hotpotqa) ran locally. Cross-dataset latency comparisons are invalid.
  5. Evaluation Scope:
    • The HotpotQA evaluation corpus was heavily downsampled to 5,000 documents + gold docs to accommodate local RAM limits, artificially inflating absolute NDCG metrics to ~0.86-0.90 across all strategies.
    • The TREC-COVID evaluation only used a 5-query test set, making its metrics statistically noisy.
Downloads last month
30
Safetensors
Model size
67M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support