| """ |
| R2-Router: LLM Router with Joint Model-Budget Optimization |
| |
| Self-contained inference module. Routes queries to the optimal (model, token_budget) |
| pair by predicting per-query quality and cost using KNN. |
| |
| Usage: |
| from router import R2Router |
| router = R2Router.from_pretrained("jqxue1999/r2-router") |
| result = router.route(embedding) # embedding: np.ndarray (1024,) |
| |
| # Or train from scratch: |
| router = R2Router.from_training_data("jqxue1999/r2-router") |
| """ |
|
|
| import os |
| import json |
| import numpy as np |
| import joblib |
| from typing import Dict, List, Optional, Union |
| from sklearn.neighbors import KNeighborsRegressor |
|
|
|
|
| class R2Router: |
| """ |
| R2-Router: Routes queries to optimal (LLM, token_budget) pair. |
| |
| Uses KNN to predict quality for each (model, budget) combination, |
| then selects the pair that maximizes: |
| risk = (1 - lambda) * quality - lambda * tokens * price / 1e6 |
| """ |
|
|
| def __init__( |
| self, |
| quality_knns: Dict[str, Dict[str, KNeighborsRegressor]], |
| token_knns: Dict[str, KNeighborsRegressor], |
| model_prices: Dict[str, float], |
| model_names: Dict[str, str], |
| budgets: Dict[str, int], |
| lambda_val: float = 0.999, |
| ): |
| self.quality_knns = quality_knns |
| self.token_knns = token_knns |
| self.model_prices = model_prices |
| self.model_names = model_names |
| self.budgets = budgets |
| self.lambda_val = lambda_val |
|
|
| @classmethod |
| def from_pretrained(cls, path: str, lambda_val: float = 0.999) -> "R2Router": |
| """ |
| Load pre-trained KNN checkpoints. |
| |
| Args: |
| path: Local directory or HuggingFace repo ID (e.g., "jqxue1999/r2-router") |
| lambda_val: Cost-accuracy tradeoff (higher = more cost-sensitive) |
| """ |
| |
| if not os.path.isdir(path): |
| path = cls._download_from_hf(path) |
|
|
| with open(os.path.join(path, "config.json")) as f: |
| config = json.load(f) |
|
|
| ckpt_dir = os.path.join(path, "checkpoints") |
| quality_knns = {} |
| token_knns = {} |
|
|
| for model_name in config["models"]: |
| quality_knns[model_name] = {} |
| for budget_name in config["budgets"]: |
| ckpt_path = os.path.join(ckpt_dir, f"quality_knn_{model_name}_{budget_name}.joblib") |
| if os.path.exists(ckpt_path): |
| quality_knns[model_name][budget_name] = joblib.load(ckpt_path) |
|
|
| tok_path = os.path.join(ckpt_dir, f"token_knn_{model_name}.joblib") |
| if os.path.exists(tok_path): |
| token_knns[model_name] = joblib.load(tok_path) |
|
|
| model_prices = { |
| mn: cfg["output_price_per_million"] |
| for mn, cfg in config["models"].items() |
| } |
| model_names = { |
| mn: cfg["full_name"] |
| for mn, cfg in config["models"].items() |
| } |
|
|
| return cls( |
| quality_knns=quality_knns, |
| token_knns=token_knns, |
| model_prices=model_prices, |
| model_names=model_names, |
| budgets=config["budgets"], |
| lambda_val=lambda_val, |
| ) |
|
|
| @classmethod |
| def from_training_data( |
| cls, |
| path: str, |
| k: int = 80, |
| lambda_val: float = 0.999, |
| ) -> "R2Router": |
| """ |
| Train KNN from scratch using the provided training data. |
| |
| Args: |
| path: Local directory or HuggingFace repo ID |
| k: Number of KNN neighbors |
| lambda_val: Cost-accuracy tradeoff |
| """ |
| if not os.path.isdir(path): |
| path = cls._download_from_hf(path) |
|
|
| with open(os.path.join(path, "config.json")) as f: |
| config = json.load(f) |
|
|
| X_train = np.load(os.path.join(path, "training_data", "embeddings.npy")) |
| with open(os.path.join(path, "training_data", "labels.json")) as f: |
| labels = json.load(f) |
|
|
| quality_knns = {} |
| token_knns = {} |
|
|
| for model_name, model_labels in labels.items(): |
| quality_knns[model_name] = {} |
| for budget_name, bdata in model_labels.items(): |
| acc = np.array([x if x is not None else np.nan for x in bdata["accuracy"]]) |
| valid = ~np.isnan(acc) |
| if valid.sum() < 3: |
| continue |
| knn = KNeighborsRegressor( |
| n_neighbors=min(k, int(valid.sum()) - 1), |
| metric="cosine", |
| weights="distance", |
| ) |
| knn.fit(X_train[valid], acc[valid]) |
| quality_knns[model_name][budget_name] = knn |
|
|
| |
| if "concise" in model_labels and "output_tokens" in model_labels["concise"]: |
| tok = np.array([x if x is not None else np.nan for x in model_labels["concise"]["output_tokens"]]) |
| valid = ~np.isnan(tok) |
| if valid.sum() >= 3: |
| tknn = KNeighborsRegressor( |
| n_neighbors=min(k, int(valid.sum()) - 1), |
| metric="cosine", |
| weights="distance", |
| ) |
| tknn.fit(X_train[valid], tok[valid]) |
| token_knns[model_name] = tknn |
|
|
| model_prices = { |
| mn: cfg["output_price_per_million"] |
| for mn, cfg in config["models"].items() |
| } |
| model_names = { |
| mn: cfg["full_name"] |
| for mn, cfg in config["models"].items() |
| } |
|
|
| return cls( |
| quality_knns=quality_knns, |
| token_knns=token_knns, |
| model_prices=model_prices, |
| model_names=model_names, |
| budgets=config["budgets"], |
| lambda_val=lambda_val, |
| ) |
|
|
| @staticmethod |
| def _download_from_hf(repo_id: str) -> str: |
| """Download model from Hugging Face Hub.""" |
| try: |
| from huggingface_hub import snapshot_download |
| return snapshot_download(repo_id) |
| except ImportError: |
| raise ImportError( |
| "huggingface_hub is required to download from HF. " |
| "Install with: pip install huggingface_hub" |
| ) |
|
|
| def route( |
| self, |
| embedding: np.ndarray, |
| lambda_val: Optional[float] = None, |
| ) -> Dict: |
| """ |
| Route a query to the optimal (model, token_budget) pair. |
| |
| Args: |
| embedding: Query embedding vector, shape (1024,) or (1, 1024) |
| lambda_val: Override default lambda (higher = more cost-sensitive) |
| |
| Returns: |
| Dict with keys: model, model_full_name, budget, token_limit, |
| predicted_quality, predicted_cost, risk, all_options |
| """ |
| if embedding.ndim == 1: |
| embedding = embedding.reshape(1, -1) |
|
|
| lam = lambda_val if lambda_val is not None else self.lambda_val |
| all_options = [] |
|
|
| for mn in self.quality_knns: |
| price = self.model_prices.get(mn, 0) |
|
|
| |
| if mn in self.token_knns: |
| tok = max(1.0, float(self.token_knns[mn].predict(embedding)[0])) |
| else: |
| tok = 50.0 |
|
|
| for budget_name, knn in self.quality_knns[mn].items(): |
| q = float(knn.predict(embedding)[0]) |
| risk = (1 - lam) * q - lam * tok * price / 1e6 |
|
|
| all_options.append({ |
| "model": mn, |
| "model_full_name": self.model_names.get(mn, mn), |
| "budget": budget_name, |
| "token_limit": self.budgets.get(budget_name, budget_name), |
| "predicted_quality": q, |
| "predicted_tokens": tok, |
| "predicted_cost": tok * price / 1e6, |
| "risk": risk, |
| }) |
|
|
| if not all_options: |
| raise RuntimeError("No valid routing options") |
|
|
| best = max(all_options, key=lambda x: x["risk"]) |
| best["all_options"] = all_options |
| return best |
|
|
| def route_batch( |
| self, |
| embeddings: np.ndarray, |
| lambda_val: Optional[float] = None, |
| ) -> List[Dict]: |
| """ |
| Route a batch of queries. |
| |
| Args: |
| embeddings: Query embeddings, shape (N, 1024) |
| lambda_val: Override default lambda |
| |
| Returns: |
| List of routing decisions |
| """ |
| return [self.route(embeddings[i], lambda_val) for i in range(len(embeddings))] |
|
|