ComplexityRouter / README.md
RowRed's picture
Update README.md
7225bc2 verified
|
Raw
History Blame Contribute Delete
7.3 kB
---
license: apache-2.0
language:
- en
tags:
- prompt-routing
- complexity-classifier
- deberta-v3
- llm-router
- cost-optimization
datasets:
- RowRed/ComplexityRouter
- OpenAssistant/oasst2
base_model:
- microsoft/deberta-v3-base
---
# ComplexityRouter: A Complexity based LLM Router
Introducing ComplexityRouter, a lightweight prompt complexity classifier finetuned from **microsoft/deberta-v3-base**. Using prompts from [Open Assistant Conversations Dataset Release 2 (OASST2)](https://huggingface.co/datasets/OpenAssistant/oasst2), some from myself, and some made more complex by Qwen3.5-4B (non-thinking mode), I created a synthetic dataset classifying 4,400 of the prompts using [Qwen3.5-4B Non Thinking Mode](https://huggingface.co/Qwen/Qwen3.5-4B).
It assigns prompts to one of 4 complexity levels, making it useful for routing queries to the appropriate LLM tier.
## Model Details
### Model Description
- **Model type:** Text Classification (multi‑class)
- **Language:** English
- **Backbone:** microsoft/deberta-v3-base
- **License:** Apache‑2.0
- **Finetuned from model:** microsoft/deberta-v3-base
- **Training data:** OASST2 + synthetic augmentations + manually created prompts
Labels generated by **Qwen3.5‑4B** (non‑thinking mode).
### Model Sources
- **Dataset repository:** https://huggingface.co/datasets/RowRed/ComplexityRouter
## Uses
### Direct Use
Route prompts to appropriate LLM tiers based on predicted complexity:
| Level | Meaning | Suggested LLM Tier |
|-------|---------|-------------------|
| 0 (Trivial) | Simple lookups, basic Q&A (e.g., “What is 2+2?”) | Fast/cheap local model |
| 1 (Simple) | Moderate reasoning, basic domain knowledge | Mid‑tier model |
| 2 (Moderate) | Complex reasoning, deep knowledge required | Strong model |
| 3 (Complex) | Very complex reasoning, niche expertise | Frontier API model |
**Recommended routing strategy:** Group levels **0 and 1** together (fast/cheap tier), level 2 as standard, level 3 as premium. The model achieves **93.0% adjacent accuracy** on my test, meaning it rarely misroutes by more than one tier.
### Out‑of‑Scope Use
- Multi‑turn conversation routing (single prompts only).
- Non‑English prompts (training data was English‑only).
- Prompts requiring image or multimodal understanding.
## Bias, Risks, and Limitations
- Training data is synthetic and may not represent all real‑world prompt distributions.
- Level 1 (Simple) and Level 2 (Moderate) have lower per‑class F1 scores – boundary cases are inherently ambiguous.
- The model may struggle with very domain‑specific technical jargon.
- Performance may degrade on prompts that are very different from the training distribution.
## Notice
This is my first attempt making a widespread finetune. There are probably lots of issues, but thought the idea was sound. I might make a second (hopefully better) version eventually, but am not sure where to get lots of high-quality open source data.
## Training Details
### Training Data
| Split | Samples | Source File | Notes |
|-------------|---------|----------------------|-------|
| Training | 2,800 | TRAINING.jsonl | Used for model training |
| Validation | 600 | TRAINING.jsonl | Used for early stopping / hyperparameter tuning |
| Test (internal) | 600 | TRAINING.jsonl | Used for in‑distribution evaluation |
| Test (held‑out) | 400 | TEST.jsonl | Fully independent test set (reported results) |
**Total unique prompts:** 4,400
Class distribution (training):
Level 0: 762 (27.2%) • Level 1: 674 (24.1%) • Level 2: 795 (28.4%) • Level 3: 569 (20.3%)
## Training Procedure
- Hardware: NVIDIA T4 (16 GB VRAM, Google Colab)
- Framework: PyTorch 2.11 + Hugging Face Transformers
- Optimizer: AdamW (lr=2e-5, weight_decay=0.01)
- Scheduler: Linear warmup (10% of steps) → linear decay
- Loss: Weighted Cross‑Entropy (classification) + MSE (regression)
- Batch size: 16 (effective 32 with gradient accumulation)
- Epochs: 7 (early stopping patience = 3)
- Training time: ~18 minutes
- Class balancing: sqrt‑scaled class weights + weighted random sampler
## Evaluation Results
Reported on 600 held‑out samples from TRAINING.jsonl (internal test).
|Metric|Value|
|----|----|
|Exact Match Accuracy|64.5%|
|Adjacent (±1) Accuracy|93.0%|
|Macro F1|0.663|
|Weighted F1|0.653|
Per‑Class Performance (internal test, 600 samples)
|Level|Precision|Recall|F1|Support|
|----|----|----|----|----|
|L0 (Trivial)|0.658|0.626|0.642|163|
|L1 (Simple)|0.457|0.628|0.529|145|
|L2 (Moderate)|0.683|0.571|0.622|170|
|L3 (Complex)|0.933|0.795|0.858|122|
Confusion Matrix (internal test, 600 samples)
| |Pred L0|Pred L1|Pred L2|Pred L3|
|----|----|----|----|----|
|True L0|102|46|13|2|
|True L1|35|91|18|1|
|True L2|15|54|97|4|
|True L3|3|8|14|97|
## How to Get Started with the Model
```python
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn as nn
class PromptComplexityRouter(nn.Module):
def __init__(self, backbone="microsoft/deberta-v3-base", num_labels=4):
super().__init__()
self.backbone = AutoModel.from_pretrained(backbone)
hidden_size = self.backbone.config.hidden_size
self.classifier = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(hidden_size, 256),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(256, num_labels),
)
def forward(self, input_ids, attention_mask):
outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
cls_output = outputs.last_hidden_state[:, 0, :]
return self.classifier(cls_output)
# Load
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("RowRed/ComplexityRouter")
model = PromptComplexityRouter()
model.load_state_dict(
torch.load("pytorch_model.bin", map_location=device),
strict=False
)
model.to(device)
model.eval()
# Predict
prompts = ["What is 2+2?", "Explain quantum entanglement in detail."]
encoded = tokenizer(prompts, padding=True, truncation=True, return_tensors="pt").to(device)
with torch.no_grad():
logits = model(encoded["input_ids"], encoded["attention_mask"])
probs = torch.softmax(logits, dim=-1)
predictions = torch.argmax(probs, dim=-1)
for prompt, level in zip(prompts, predictions):
print(f"Level {level.item()}: {prompt}")
```
## Citation
If you use this model, please cite:
```bibtex
@software{ComplexityRouter,
author = {RowRed},
title = {ComplexityRouter},
year = {2026},
url = {https://huggingface.co/RowRed/ComplexityRouter}
}
```
Additionally, acknowledge the base dataset and labeling model:
```bibtex
@dataset{oasst2,
author = {OpenAssistant Contributors},
title = {Open Assistant Conversations Dataset Release 2},
year = {2023},
url = {https://huggingface.co/datasets/OpenAssistant/oasst2}
}
@software{qwen3.5-4b,
author = {Qwen Team},
title = {Qwen3.5-4B},
year = {2026},
url = {https://huggingface.co/Qwen/Qwen3.5-4B}
}
```
## License
This model is released under Apache‑2.0.
The backbone (microsoft/deberta-v3-base) is MIT‑licensed.
The training dataset is derived from OASST2 (Apache‑2.0) and Qwen3.5‑4B outputs (Apache‑2.0).