JiaqiXue commited on
Commit
bc1c255
·
verified ·
1 Parent(s): db29071

Initial release: R2-Router (#1 on RouterArena)

Browse files
README.md ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - llm-routing
5
+ - model-selection
6
+ - budget-optimization
7
+ - knn
8
+ language:
9
+ - en
10
+ library_name: sklearn
11
+ pipeline_tag: text-classification
12
+ ---
13
+
14
+ # R2-Router: LLM Router with Joint Model-Budget Optimization
15
+
16
+ **R2-Router** intelligently routes each query to the optimal (LLM, token budget) pair, jointly optimizing accuracy and inference cost. Ranked **#1** on the [RouterArena](https://routerarena.github.io/) leaderboard.
17
+
18
+ **Paper**: [R2-Router (arxiv)](https://arxiv.org/abs/TODO)
19
+
20
+ ## RouterArena Performance
21
+
22
+ Official leaderboard results on 8,400 queries:
23
+
24
+ | Metric | Value |
25
+ |--------|-------|
26
+ | Accuracy | 71.23% |
27
+ | Cost per 1K Queries | $0.061 |
28
+ | Arena Score (beta=0.1) | **71.60** |
29
+ | Robustness Score | 45.71% |
30
+ | Rank | **#1** |
31
+
32
+ ## Quick Start
33
+
34
+ ### Installation
35
+
36
+ ```bash
37
+ pip install scikit-learn numpy joblib huggingface_hub
38
+ ```
39
+
40
+ ### Load Pre-trained Checkpoints
41
+
42
+ ```python
43
+ from router import R2Router
44
+
45
+ # Load pre-trained KNN checkpoints (no training needed)
46
+ router = R2Router.from_pretrained("jqxue1999/r2-router")
47
+
48
+ # Route a query (requires 1024-dim embedding from Qwen3-0.6B)
49
+ result = router.route(embedding)
50
+ print(f"Model: {result['model_full_name']}")
51
+ print(f"Token Budget: {result['token_limit']}")
52
+ print(f"Predicted Quality: {result['predicted_quality']:.3f}")
53
+ ```
54
+
55
+ ### Train from Scratch
56
+
57
+ ```python
58
+ from router import R2Router
59
+
60
+ # Train KNN from the provided sub_10 training data
61
+ router = R2Router.from_training_data("jqxue1999/r2-router", k=80)
62
+
63
+ # Route a query
64
+ result = router.route(embedding)
65
+ ```
66
+
67
+ ### Get Query Embeddings
68
+
69
+ R2-Router uses [Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) embeddings (1024-dim). You can generate them with:
70
+
71
+ ```python
72
+ from sentence_transformers import SentenceTransformer
73
+
74
+ model = SentenceTransformer("Qwen/Qwen3-0.6B")
75
+ embedding = model.encode("What is the capital of France?")
76
+ ```
77
+
78
+ Or with vLLM for faster batch inference:
79
+
80
+ ```python
81
+ from vllm import LLM
82
+ llm = LLM(model="Qwen/Qwen3-0.6B", runner="pooling")
83
+ outputs = llm.embed(["What is the capital of France?"])
84
+ embedding = outputs[0].outputs.embedding
85
+ ```
86
+
87
+ ## Architecture
88
+
89
+ R2-Router jointly optimizes **which model** to use and **how many tokens** to allocate per query.
90
+
91
+ ### Routing Formula
92
+
93
+ ```
94
+ risk(M, b) = (1 - lambda) * predicted_quality(query, M, b) - lambda * predicted_tokens(query, M) * price_M / 1e6
95
+ (M*, b*) = argmax risk
96
+ ```
97
+
98
+ ### Pipeline
99
+
100
+ ```
101
+ Input Query
102
+ |
103
+ [1] Embed with Qwen3-0.6B -> 1024-dim vector
104
+ |
105
+ [2] For each (model, budget) pair:
106
+ - KNN predicts quality (accuracy)
107
+ - KNN predicts output token count
108
+ - Compute risk = (1-lambda) * quality - lambda * cost
109
+ |
110
+ [3] Select (model, budget) with highest risk
111
+ |
112
+ Output: (model_name, token_budget)
113
+ ```
114
+
115
+ ### Model Pool (6 LLMs)
116
+
117
+ | Model | Output $/M tokens |
118
+ |-------|------------------|
119
+ | Qwen3-235B-A22B | $0.463 |
120
+ | Qwen3-Next-80B-A3B | $1.10 |
121
+ | Qwen3-30B-A3B | $0.33 |
122
+ | Qwen3-Coder-Next | $0.30 |
123
+ | Gemini 2.5 Flash | $2.50 |
124
+ | Claude 3 Haiku | $1.25 |
125
+
126
+ ### Token Budgets
127
+
128
+ 4 output token limits: **100, 200, 400, 800** tokens.
129
+
130
+ ### Key Parameters
131
+
132
+ | Parameter | Value |
133
+ |-----------|-------|
134
+ | KNN K | 80 |
135
+ | Lambda | 0.999 |
136
+ | Distance Metric | Cosine |
137
+ | KNN Weights | Distance-weighted |
138
+ | Embedding Dim | 1024 |
139
+
140
+ ## Repository Contents
141
+
142
+ ```
143
+ config.json # Router configuration (models, budgets, prices, hyperparams)
144
+ router.py # Self-contained inference code
145
+ training_data/
146
+ embeddings.npy # Sub_10 training embeddings (809 x 1024)
147
+ labels.json # Per-(model, budget) accuracy & token labels
148
+ checkpoints/
149
+ quality_knn_*.joblib # Pre-fitted KNN quality predictors (18 total)
150
+ token_knn_*.joblib # Pre-fitted KNN token predictors (6 total)
151
+ ```
152
+
153
+ ### Two Ways to Use
154
+
155
+ 1. **Load checkpoints** (`from_pretrained`): Directly load pre-fitted KNN models. No training needed.
156
+ 2. **Train from data** (`from_training_data`): Use the provided training embeddings and labels to fit your own KNN with custom hyperparameters (e.g., different K, distance metric).
157
+
158
+ ## Training Details
159
+
160
+ - **Training Data**: RouterArena sub_10 split (809 queries, 10% of full 8,400)
161
+ - **Method**: KNeighborsRegressor with cosine distance, distance-weighted
162
+ - **Evaluation**: Full 8,400 RouterArena queries (no data leakage)
163
+ - **Training Time**: < 1 second (KNN fitting)
164
+
165
+ ## Citation
166
+
167
+ ```bibtex
168
+ @article{r2router2026,
169
+ title={R2-Router: A New Paradigm for LLM Routing with Reasoning},
170
+ author={TODO},
171
+ year={2026},
172
+ url={https://arxiv.org/abs/TODO}
173
+ }
174
+ ```
175
+
176
+ ## License
177
+
178
+ Apache 2.0
checkpoints/quality_knn_235b_budget_200.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f66789acf7675f2b4113c9ac173aaaa8199fdd270dfa6869609ed17b2feb7dbe
3
+ size 3317588
checkpoints/quality_knn_235b_budget_400.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c55dbaa6b9cda97db793b63471a95ff1b1389bc1e46acf8ee724f4deda1280b7
3
+ size 3317588
checkpoints/quality_knn_235b_budget_800.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51ef60d08cbb7e2c309af411054aaa1c2deee661f8dc0c1953228823fca56066
3
+ size 3317588
checkpoints/quality_knn_235b_concise.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76a9a2f93c64e1eb8243a35b7b94b1dcb66deb7e0d49e095866f4e3057bd7ea5
3
+ size 3317588
checkpoints/quality_knn_30b_budget_200.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:101348a979019d849d4dc56b1629844bd6c5d77adf1746b6cd28dbec6702aecb
3
+ size 3317588
checkpoints/quality_knn_30b_budget_400.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb2d64cca2cf69d56d72e9d1c700fbc3247d04831df2f1a154212f3cc98202fd
3
+ size 3317588
checkpoints/quality_knn_30b_budget_800.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:136bea24e03a67f7580e21e0c54306abeb883aab4ef804bdf9b1be238ad3bdbd
3
+ size 3317588
checkpoints/quality_knn_30b_concise.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7e69c75c0aeb15ae58298e34c42a834c1e69cce9b844ab25c38787384490550
3
+ size 3317588
checkpoints/quality_knn_80b_budget_200.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6f86c6981e31a9057ee7a2112db13eaa34c459471dfffb52f8aa9a3bdcce29e
3
+ size 3317588
checkpoints/quality_knn_80b_budget_400.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38cc8af776d03877dd0a513437e6637d9dd5a46b3a26ce47918548744de0f2e9
3
+ size 3317588
checkpoints/quality_knn_80b_budget_800.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6702cd8020a07189ab5943c20061776941bc3eae3611920a0d07500cf57daea
3
+ size 3317588
checkpoints/quality_knn_80b_concise.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:002f5021ea92a5d846244ba1ec76a03d773c93174281c789284bd63aecd54348
3
+ size 3317588
checkpoints/quality_knn_coder-next_budget_200.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3cea7b7f12e959e401319642d4aac688fb1a96f76f591466e5cfe4ad09b30ad
3
+ size 3317588
checkpoints/quality_knn_coder-next_budget_400.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:528b2e6e7ab9bde25d4d6639dd2b359d9e750a73ddb8a1695c496756e79cd8d3
3
+ size 3317588
checkpoints/quality_knn_coder-next_budget_800.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a4e599bf2490595a222792037329e97ef7630d917db67f05b4d3bfae17e60be
3
+ size 3317588
checkpoints/quality_knn_coder-next_concise.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4c2df7aabc3fee5ae12a05b3296eafd6f54053c6b1bf6a3cf5e6cec4cf4574c
3
+ size 3317588
checkpoints/quality_knn_gemini-flash_concise.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ffefebf08d5de8c2784ee3074a38e470fe307085bc2f8a62f463c855ab0e9c6
3
+ size 3317588
checkpoints/quality_knn_haiku_concise.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22633da4eecf1d0b1b3da5fc1bc2e20889ec1271816b949909863ce6c5a624b0
3
+ size 3317588
checkpoints/token_knn_235b.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d8e62844ac6f97e8d4cdf7d85af9a801c3850032600634a41378bd5ae436900
3
+ size 3317588
checkpoints/token_knn_30b.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f1a7972693a4b5727324c401dbe8beefa187be05e07fefdc0ff2f490ee256b5
3
+ size 3317588
checkpoints/token_knn_80b.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aea91620048b766624696d0dfe2ad0998fceb854183aeb011a7c41d723b2e217
3
+ size 3317588
checkpoints/token_knn_coder-next.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8374d3e3159cf9583c3a4dbb71fd7c3c2f69664bd49cf40e05851922d3e65f4a
3
+ size 3317588
checkpoints/token_knn_gemini-flash.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3bd50eb44333e5e7e8ff5cbcec995b673ac54224a3a86fe8dffdf8d4e1d129e
3
+ size 3317588
checkpoints/token_knn_haiku.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89578d06d0aa2640b720f0599bb91d66e4a52b39f475475db1aa33a833f53178
3
+ size 3317588
config.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "router_name": "R2-Router",
3
+ "method": "Global KNN (cosine, distance-weighted)",
4
+ "embedding_model": "Qwen/Qwen3-0.6B",
5
+ "embedding_dim": 1024,
6
+ "knn_k": 80,
7
+ "lambda": 0.999,
8
+ "models": {
9
+ "235b": {
10
+ "full_name": "qwen/qwen3-235b-a22b-2507",
11
+ "output_price_per_million": 0.463,
12
+ "input_price_per_million": 0.071,
13
+ "type": "vllm"
14
+ },
15
+ "80b": {
16
+ "full_name": "qwen/qwen3-next-80b-a3b-instruct",
17
+ "output_price_per_million": 1.1,
18
+ "input_price_per_million": 0.09,
19
+ "type": "vllm"
20
+ },
21
+ "30b": {
22
+ "full_name": "qwen/qwen3-30b-a3b-instruct-2507",
23
+ "output_price_per_million": 0.33,
24
+ "input_price_per_million": 0.08,
25
+ "type": "vllm"
26
+ },
27
+ "coder-next": {
28
+ "full_name": "Qwen/Qwen3-Coder-Next",
29
+ "output_price_per_million": 0.3,
30
+ "input_price_per_million": 0.07,
31
+ "type": "vllm"
32
+ },
33
+ "gemini-flash": {
34
+ "full_name": "gemini-2.5-flash",
35
+ "output_price_per_million": 2.5,
36
+ "input_price_per_million": 0.3,
37
+ "type": "api"
38
+ },
39
+ "haiku": {
40
+ "full_name": "claude-3-haiku-20240307",
41
+ "output_price_per_million": 1.25,
42
+ "input_price_per_million": 0.25,
43
+ "type": "api"
44
+ }
45
+ },
46
+ "budgets": {
47
+ "concise": 100,
48
+ "budget_200": 200,
49
+ "budget_400": 400,
50
+ "budget_800": 800
51
+ },
52
+ "training_size": 809,
53
+ "training_source": "RouterArena sub_10 (10% official split)"
54
+ }
router.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ R2-Router: LLM Router with Joint Model-Budget Optimization
3
+
4
+ Self-contained inference module. Routes queries to the optimal (model, token_budget)
5
+ pair by predicting per-query quality and cost using KNN.
6
+
7
+ Usage:
8
+ from router import R2Router
9
+ router = R2Router.from_pretrained("jqxue1999/r2-router")
10
+ result = router.route(embedding) # embedding: np.ndarray (1024,)
11
+
12
+ # Or train from scratch:
13
+ router = R2Router.from_training_data("jqxue1999/r2-router")
14
+ """
15
+
16
+ import os
17
+ import json
18
+ import numpy as np
19
+ import joblib
20
+ from typing import Dict, List, Optional, Union
21
+ from sklearn.neighbors import KNeighborsRegressor
22
+
23
+
24
+ class R2Router:
25
+ """
26
+ R2-Router: Routes queries to optimal (LLM, token_budget) pair.
27
+
28
+ Uses KNN to predict quality for each (model, budget) combination,
29
+ then selects the pair that maximizes:
30
+ risk = (1 - lambda) * quality - lambda * tokens * price / 1e6
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ quality_knns: Dict[str, Dict[str, KNeighborsRegressor]],
36
+ token_knns: Dict[str, KNeighborsRegressor],
37
+ model_prices: Dict[str, float],
38
+ model_names: Dict[str, str],
39
+ budgets: Dict[str, int],
40
+ lambda_val: float = 0.999,
41
+ ):
42
+ self.quality_knns = quality_knns # {model: {budget: KNN}}
43
+ self.token_knns = token_knns # {model: KNN}
44
+ self.model_prices = model_prices # {model: price_per_million_output_tokens}
45
+ self.model_names = model_names # {short_name: full_name}
46
+ self.budgets = budgets # {budget_name: token_limit}
47
+ self.lambda_val = lambda_val
48
+
49
+ @classmethod
50
+ def from_pretrained(cls, path: str, lambda_val: float = 0.999) -> "R2Router":
51
+ """
52
+ Load pre-trained KNN checkpoints.
53
+
54
+ Args:
55
+ path: Local directory or HuggingFace repo ID (e.g., "jqxue1999/r2-router")
56
+ lambda_val: Cost-accuracy tradeoff (higher = more cost-sensitive)
57
+ """
58
+ # If HF repo ID, download first
59
+ if not os.path.isdir(path):
60
+ path = cls._download_from_hf(path)
61
+
62
+ with open(os.path.join(path, "config.json")) as f:
63
+ config = json.load(f)
64
+
65
+ ckpt_dir = os.path.join(path, "checkpoints")
66
+ quality_knns = {}
67
+ token_knns = {}
68
+
69
+ for model_name in config["models"]:
70
+ quality_knns[model_name] = {}
71
+ for budget_name in config["budgets"]:
72
+ ckpt_path = os.path.join(ckpt_dir, f"quality_knn_{model_name}_{budget_name}.joblib")
73
+ if os.path.exists(ckpt_path):
74
+ quality_knns[model_name][budget_name] = joblib.load(ckpt_path)
75
+
76
+ tok_path = os.path.join(ckpt_dir, f"token_knn_{model_name}.joblib")
77
+ if os.path.exists(tok_path):
78
+ token_knns[model_name] = joblib.load(tok_path)
79
+
80
+ model_prices = {
81
+ mn: cfg["output_price_per_million"]
82
+ for mn, cfg in config["models"].items()
83
+ }
84
+ model_names = {
85
+ mn: cfg["full_name"]
86
+ for mn, cfg in config["models"].items()
87
+ }
88
+
89
+ return cls(
90
+ quality_knns=quality_knns,
91
+ token_knns=token_knns,
92
+ model_prices=model_prices,
93
+ model_names=model_names,
94
+ budgets=config["budgets"],
95
+ lambda_val=lambda_val,
96
+ )
97
+
98
+ @classmethod
99
+ def from_training_data(
100
+ cls,
101
+ path: str,
102
+ k: int = 80,
103
+ lambda_val: float = 0.999,
104
+ ) -> "R2Router":
105
+ """
106
+ Train KNN from scratch using the provided training data.
107
+
108
+ Args:
109
+ path: Local directory or HuggingFace repo ID
110
+ k: Number of KNN neighbors
111
+ lambda_val: Cost-accuracy tradeoff
112
+ """
113
+ if not os.path.isdir(path):
114
+ path = cls._download_from_hf(path)
115
+
116
+ with open(os.path.join(path, "config.json")) as f:
117
+ config = json.load(f)
118
+
119
+ X_train = np.load(os.path.join(path, "training_data", "embeddings.npy"))
120
+ with open(os.path.join(path, "training_data", "labels.json")) as f:
121
+ labels = json.load(f)
122
+
123
+ quality_knns = {}
124
+ token_knns = {}
125
+
126
+ for model_name, model_labels in labels.items():
127
+ quality_knns[model_name] = {}
128
+ for budget_name, bdata in model_labels.items():
129
+ acc = np.array([x if x is not None else np.nan for x in bdata["accuracy"]])
130
+ valid = ~np.isnan(acc)
131
+ if valid.sum() < 3:
132
+ continue
133
+ knn = KNeighborsRegressor(
134
+ n_neighbors=min(k, int(valid.sum()) - 1),
135
+ metric="cosine",
136
+ weights="distance",
137
+ )
138
+ knn.fit(X_train[valid], acc[valid])
139
+ quality_knns[model_name][budget_name] = knn
140
+
141
+ # Token predictor (use concise budget's output_tokens)
142
+ if "concise" in model_labels and "output_tokens" in model_labels["concise"]:
143
+ tok = np.array([x if x is not None else np.nan for x in model_labels["concise"]["output_tokens"]])
144
+ valid = ~np.isnan(tok)
145
+ if valid.sum() >= 3:
146
+ tknn = KNeighborsRegressor(
147
+ n_neighbors=min(k, int(valid.sum()) - 1),
148
+ metric="cosine",
149
+ weights="distance",
150
+ )
151
+ tknn.fit(X_train[valid], tok[valid])
152
+ token_knns[model_name] = tknn
153
+
154
+ model_prices = {
155
+ mn: cfg["output_price_per_million"]
156
+ for mn, cfg in config["models"].items()
157
+ }
158
+ model_names = {
159
+ mn: cfg["full_name"]
160
+ for mn, cfg in config["models"].items()
161
+ }
162
+
163
+ return cls(
164
+ quality_knns=quality_knns,
165
+ token_knns=token_knns,
166
+ model_prices=model_prices,
167
+ model_names=model_names,
168
+ budgets=config["budgets"],
169
+ lambda_val=lambda_val,
170
+ )
171
+
172
+ @staticmethod
173
+ def _download_from_hf(repo_id: str) -> str:
174
+ """Download model from Hugging Face Hub."""
175
+ try:
176
+ from huggingface_hub import snapshot_download
177
+ return snapshot_download(repo_id)
178
+ except ImportError:
179
+ raise ImportError(
180
+ "huggingface_hub is required to download from HF. "
181
+ "Install with: pip install huggingface_hub"
182
+ )
183
+
184
+ def route(
185
+ self,
186
+ embedding: np.ndarray,
187
+ lambda_val: Optional[float] = None,
188
+ ) -> Dict:
189
+ """
190
+ Route a query to the optimal (model, token_budget) pair.
191
+
192
+ Args:
193
+ embedding: Query embedding vector, shape (1024,) or (1, 1024)
194
+ lambda_val: Override default lambda (higher = more cost-sensitive)
195
+
196
+ Returns:
197
+ Dict with keys: model, model_full_name, budget, token_limit,
198
+ predicted_quality, predicted_cost, risk, all_options
199
+ """
200
+ if embedding.ndim == 1:
201
+ embedding = embedding.reshape(1, -1)
202
+
203
+ lam = lambda_val if lambda_val is not None else self.lambda_val
204
+ all_options = []
205
+
206
+ for mn in self.quality_knns:
207
+ price = self.model_prices.get(mn, 0)
208
+
209
+ # Predict output tokens
210
+ if mn in self.token_knns:
211
+ tok = max(1.0, float(self.token_knns[mn].predict(embedding)[0]))
212
+ else:
213
+ tok = 50.0
214
+
215
+ for budget_name, knn in self.quality_knns[mn].items():
216
+ q = float(knn.predict(embedding)[0])
217
+ risk = (1 - lam) * q - lam * tok * price / 1e6
218
+
219
+ all_options.append({
220
+ "model": mn,
221
+ "model_full_name": self.model_names.get(mn, mn),
222
+ "budget": budget_name,
223
+ "token_limit": self.budgets.get(budget_name, budget_name),
224
+ "predicted_quality": q,
225
+ "predicted_tokens": tok,
226
+ "predicted_cost": tok * price / 1e6,
227
+ "risk": risk,
228
+ })
229
+
230
+ if not all_options:
231
+ raise RuntimeError("No valid routing options")
232
+
233
+ best = max(all_options, key=lambda x: x["risk"])
234
+ best["all_options"] = all_options
235
+ return best
236
+
237
+ def route_batch(
238
+ self,
239
+ embeddings: np.ndarray,
240
+ lambda_val: Optional[float] = None,
241
+ ) -> List[Dict]:
242
+ """
243
+ Route a batch of queries.
244
+
245
+ Args:
246
+ embeddings: Query embeddings, shape (N, 1024)
247
+ lambda_val: Override default lambda
248
+
249
+ Returns:
250
+ List of routing decisions
251
+ """
252
+ return [self.route(embeddings[i], lambda_val) for i in range(len(embeddings))]
training_data/embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24f0f5116dc1d153739d6a30604dc4a6553bbd9d7d4708a1d14fa5b00041bd6c
3
+ size 3313792
training_data/labels.json ADDED
The diff for this file is too large to render. See raw diff