Spaces:
Sleeping
Sleeping
Initial deploy: classifier + FastAPI router
Browse files- .gitignore +14 -0
- Dockerfile +28 -0
- README.md +29 -5
- app.py +164 -0
- greenrouting/__init__.py +1 -0
- greenrouting/classifier/__init__.py +0 -0
- greenrouting/classifier/calibration.py +41 -0
- greenrouting/classifier/infer.py +205 -0
- greenrouting/classifier/model.py +94 -0
- greenrouting/classifier/ood.py +90 -0
- greenrouting/classifier/train.py +269 -0
- greenrouting/classifier/trained_predictor.py +129 -0
- greenrouting/data/__init__.py +0 -0
- greenrouting/data/builder.py +260 -0
- greenrouting/data/capability_labeler.py +237 -0
- greenrouting/data/cascade.py +292 -0
- greenrouting/data/graders.py +158 -0
- greenrouting/data/schema.py +80 -0
- greenrouting/data/seed_dataset.py +545 -0
- greenrouting/data/sources.py +343 -0
- greenrouting/demo/__init__.py +0 -0
- greenrouting/demo/app.py +215 -0
- greenrouting/energy/__init__.py +0 -0
- greenrouting/energy/estimator.py +19 -0
- greenrouting/routing/__init__.py +0 -0
- greenrouting/routing/decision.py +191 -0
- greenrouting/routing/registry.py +440 -0
- greenrouting/routing/scorer.py +93 -0
- mapper.py +175 -0
- models/classifier_v1/calibration.json +3 -0
- models/classifier_v1/encoder_name.txt +1 -0
- models/classifier_v1/head.pt +3 -0
- models/classifier_v1/metadata.json +21 -0
- models/classifier_v1/ood_stats.npz +3 -0
- models/classifier_v1/training_history.json +182 -0
- partner_registry.py +115 -0
- requirements.txt +6 -0
.gitignore
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[codz]
|
| 3 |
+
*.egg-info/
|
| 4 |
+
.venv/
|
| 5 |
+
venv/
|
| 6 |
+
.env
|
| 7 |
+
|
| 8 |
+
# Partner config: never commit
|
| 9 |
+
data/partner_registry.json
|
| 10 |
+
data/*.json
|
| 11 |
+
|
| 12 |
+
# IDE
|
| 13 |
+
.idea/
|
| 14 |
+
.vscode/
|
Dockerfile
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 4 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 5 |
+
PIP_NO_CACHE_DIR=1 \
|
| 6 |
+
HF_HOME=/tmp/hf_cache \
|
| 7 |
+
TRANSFORMERS_CACHE=/tmp/hf_cache/transformers \
|
| 8 |
+
SENTENCE_TRANSFORMERS_HOME=/tmp/hf_cache/sentence-transformers
|
| 9 |
+
|
| 10 |
+
WORKDIR /app
|
| 11 |
+
|
| 12 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 13 |
+
build-essential \
|
| 14 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 15 |
+
|
| 16 |
+
COPY requirements.txt .
|
| 17 |
+
RUN pip install --upgrade pip && \
|
| 18 |
+
pip install torch --index-url https://download.pytorch.org/whl/cpu && \
|
| 19 |
+
pip install -r requirements.txt
|
| 20 |
+
|
| 21 |
+
COPY greenrouting /app/greenrouting
|
| 22 |
+
COPY models /app/models
|
| 23 |
+
COPY partner_registry.py mapper.py app.py /app/
|
| 24 |
+
|
| 25 |
+
RUN mkdir -p /tmp/hf_cache && chmod -R 777 /tmp/hf_cache
|
| 26 |
+
|
| 27 |
+
EXPOSE 7860
|
| 28 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,10 +1,34 @@
|
|
| 1 |
---
|
| 2 |
-
title: Router
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Router Classify API
|
| 3 |
+
emoji: 🛰️
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
license: other
|
| 9 |
+
license_name: polyform-noncommercial-1.0.0
|
| 10 |
+
license_link: https://polyformproject.org/licenses/noncommercial/1.0.0
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# Router Classify API
|
| 14 |
+
|
| 15 |
+
REST endpoint that runs the GreenRouting classifier and returns a routing decision against a configurable downstream model registry.
|
| 16 |
+
|
| 17 |
+
## Endpoints
|
| 18 |
+
|
| 19 |
+
- `POST /classify` — Run the classifier and pick a model.
|
| 20 |
+
- Request: `{ "message": "...", "recentMessages": [{"role": "...", "content": "..."}] }`
|
| 21 |
+
- Response: `{ "category", "complexity", "model_id", "capability_weights", "difficulty", "energy_savings_pct", "method", "reason" }`
|
| 22 |
+
- `GET /health` — Liveness probe.
|
| 23 |
+
|
| 24 |
+
## Configuration
|
| 25 |
+
|
| 26 |
+
The registry of candidate models is supplied at runtime via a Space secret. Set one of:
|
| 27 |
+
|
| 28 |
+
- `PARTNER_REGISTRY_JSON` — the registry as raw JSON (preferred)
|
| 29 |
+
- `PARTNER_REGISTRY_PATH` — a file path inside the container
|
| 30 |
+
|
| 31 |
+
Other env vars:
|
| 32 |
+
|
| 33 |
+
- `CLASSIFIER_ARTIFACT_DIR` — defaults to `models/classifier_v1`
|
| 34 |
+
- `INCLUDE_REASON` — `1` (default) to include the `reason` string in responses, `0` to omit
|
app.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI service that wraps the GreenRouting classifier behind the partner-
|
| 2 |
+
specific response schema.
|
| 3 |
+
|
| 4 |
+
Endpoints:
|
| 5 |
+
POST /classify - classify a query and pick a model from the partner registry
|
| 6 |
+
GET /health - liveness probe used by the partner edge function
|
| 7 |
+
|
| 8 |
+
Auth: none. Stateless. CORS open. Single-process. Designed for a HF Spaces
|
| 9 |
+
Docker deployment with periodic /health pings keeping the container warm.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
import time
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Optional
|
| 19 |
+
|
| 20 |
+
from fastapi import FastAPI, HTTPException
|
| 21 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 22 |
+
from pydantic import BaseModel, Field
|
| 23 |
+
|
| 24 |
+
from greenrouting.classifier.trained_predictor import TrainedPredictor
|
| 25 |
+
|
| 26 |
+
from mapper import (
|
| 27 |
+
build_reason,
|
| 28 |
+
fold_recent_context,
|
| 29 |
+
energy_savings_pct,
|
| 30 |
+
pick_category,
|
| 31 |
+
pick_complexity,
|
| 32 |
+
pick_difficulty_int,
|
| 33 |
+
rebucket_capabilities,
|
| 34 |
+
select_model,
|
| 35 |
+
)
|
| 36 |
+
from partner_registry import load_registry
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.getLogger("router-api")
|
| 40 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
ARTIFACT_DIR = os.environ.get("CLASSIFIER_ARTIFACT_DIR", "models/classifier_v1")
|
| 44 |
+
INCLUDE_REASON = os.environ.get("INCLUDE_REASON", "1") not in ("0", "false", "False")
|
| 45 |
+
|
| 46 |
+
app = FastAPI(title="GreenRouting Partner Router", version="0.1.0")
|
| 47 |
+
app.add_middleware(
|
| 48 |
+
CORSMiddleware,
|
| 49 |
+
allow_origins=["*"],
|
| 50 |
+
allow_credentials=False,
|
| 51 |
+
allow_methods=["*"],
|
| 52 |
+
allow_headers=["*"],
|
| 53 |
+
expose_headers=["*"],
|
| 54 |
+
max_age=3600,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
_predictor: Optional[TrainedPredictor] = None
|
| 59 |
+
_registry = None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class RecentMessage(BaseModel):
|
| 63 |
+
role: str
|
| 64 |
+
content: str
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ClassifyRequest(BaseModel):
|
| 68 |
+
message: str = Field(min_length=1, max_length=8000)
|
| 69 |
+
recentMessages: Optional[list[RecentMessage]] = None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class ClassifyResponse(BaseModel):
|
| 73 |
+
category: str
|
| 74 |
+
complexity: str
|
| 75 |
+
model_id: str
|
| 76 |
+
capability_weights: dict[str, float]
|
| 77 |
+
difficulty: int
|
| 78 |
+
energy_savings_pct: Optional[float] = None
|
| 79 |
+
method: str
|
| 80 |
+
reason: Optional[str] = None
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _ensure_loaded() -> None:
|
| 84 |
+
global _predictor, _registry
|
| 85 |
+
if _predictor is None:
|
| 86 |
+
artifact_path = Path(ARTIFACT_DIR)
|
| 87 |
+
if not (artifact_path / "head.pt").exists():
|
| 88 |
+
raise RuntimeError(f"trained classifier not found at {artifact_path}")
|
| 89 |
+
_predictor = TrainedPredictor(artifact_path)
|
| 90 |
+
_predictor.predict("warm up")
|
| 91 |
+
logger.info("classifier loaded and warmed")
|
| 92 |
+
if _registry is None:
|
| 93 |
+
_registry = load_registry()
|
| 94 |
+
logger.info("partner registry loaded with %d models", len(_registry))
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@app.on_event("startup")
|
| 98 |
+
def _startup() -> None:
|
| 99 |
+
try:
|
| 100 |
+
_ensure_loaded()
|
| 101 |
+
except Exception as exc:
|
| 102 |
+
logger.warning("startup warm load failed: %s (will retry on first request)", exc)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@app.get("/health")
|
| 106 |
+
def health() -> dict:
|
| 107 |
+
try:
|
| 108 |
+
_ensure_loaded()
|
| 109 |
+
return {"status": "ok"}
|
| 110 |
+
except Exception as exc:
|
| 111 |
+
logger.exception("health check failed")
|
| 112 |
+
raise HTTPException(status_code=503, detail=f"unhealthy: {exc}")
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@app.post("/classify", response_model=ClassifyResponse)
|
| 116 |
+
def classify(req: ClassifyRequest) -> ClassifyResponse:
|
| 117 |
+
_ensure_loaded()
|
| 118 |
+
started = time.time()
|
| 119 |
+
|
| 120 |
+
folded = fold_recent_context(
|
| 121 |
+
req.message,
|
| 122 |
+
[m.dict() for m in req.recentMessages] if req.recentMessages else None,
|
| 123 |
+
)
|
| 124 |
+
profile = _predictor.predict(folded)
|
| 125 |
+
|
| 126 |
+
weights = rebucket_capabilities(profile)
|
| 127 |
+
category = pick_category(weights)
|
| 128 |
+
complexity = pick_complexity(profile)
|
| 129 |
+
difficulty = pick_difficulty_int(profile)
|
| 130 |
+
|
| 131 |
+
chosen, escalated = select_model(_registry, weights, difficulty, is_ood=profile.is_ood)
|
| 132 |
+
savings: Optional[float]
|
| 133 |
+
if profile.is_ood or escalated:
|
| 134 |
+
savings = None
|
| 135 |
+
else:
|
| 136 |
+
savings = round(energy_savings_pct(chosen), 1)
|
| 137 |
+
reason = (
|
| 138 |
+
build_reason(weights, complexity, chosen, escalated, is_ood=profile.is_ood)
|
| 139 |
+
if INCLUDE_REASON
|
| 140 |
+
else None
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
elapsed_ms = (time.time() - started) * 1000.0
|
| 144 |
+
logger.info(
|
| 145 |
+
"classify model=%s tier=%s difficulty=%d category=%s ood=%s escalated=%s elapsed_ms=%.1f",
|
| 146 |
+
chosen.id, chosen.tier, difficulty, category, profile.is_ood, escalated, elapsed_ms,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
return ClassifyResponse(
|
| 150 |
+
category=category,
|
| 151 |
+
complexity=complexity,
|
| 152 |
+
model_id=chosen.id,
|
| 153 |
+
capability_weights=weights,
|
| 154 |
+
difficulty=difficulty,
|
| 155 |
+
energy_savings_pct=savings,
|
| 156 |
+
method="greenrouting",
|
| 157 |
+
reason=reason,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
import uvicorn
|
| 163 |
+
port = int(os.environ.get("PORT", 7860))
|
| 164 |
+
uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info")
|
greenrouting/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__version__ = "0.1.0"
|
greenrouting/classifier/__init__.py
ADDED
|
File without changes
|
greenrouting/classifier/calibration.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Temperature scaling for the multi-label capability head.
|
| 2 |
+
|
| 3 |
+
Fits a single positive scalar T such that BCE(logits / T, targets) is minimized
|
| 4 |
+
on a held-out set. T < 1 sharpens, T > 1 softens.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def fit_temperature(val_logits, val_targets, max_iter: int = 200) -> float:
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
|
| 17 |
+
if val_logits.size == 0:
|
| 18 |
+
return 1.0
|
| 19 |
+
|
| 20 |
+
logits = torch.tensor(np.asarray(val_logits), dtype=torch.float32)
|
| 21 |
+
targets = torch.tensor(np.asarray(val_targets), dtype=torch.float32)
|
| 22 |
+
log_t = torch.zeros((), dtype=torch.float32, requires_grad=True)
|
| 23 |
+
optimizer = torch.optim.LBFGS([log_t], lr=0.1, max_iter=max_iter)
|
| 24 |
+
bce = nn.BCEWithLogitsLoss()
|
| 25 |
+
|
| 26 |
+
def closure():
|
| 27 |
+
optimizer.zero_grad()
|
| 28 |
+
loss = bce(logits / log_t.exp(), targets)
|
| 29 |
+
loss.backward()
|
| 30 |
+
return loss
|
| 31 |
+
|
| 32 |
+
optimizer.step(closure)
|
| 33 |
+
return float(math.exp(log_t.item()))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def apply_temperature(logits, temperature: float):
|
| 37 |
+
import numpy as np
|
| 38 |
+
arr = np.asarray(logits)
|
| 39 |
+
if temperature <= 0:
|
| 40 |
+
return arr
|
| 41 |
+
return arr / temperature
|
greenrouting/classifier/infer.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference-side types and predictors. The trained predictor lives in a sibling
|
| 2 |
+
module; this file defines the contract the router consumes."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import re
|
| 8 |
+
from dataclasses import dataclass, asdict, field
|
| 9 |
+
from typing import Protocol
|
| 10 |
+
|
| 11 |
+
from greenrouting.routing.registry import CAPABILITY_KEYS
|
| 12 |
+
|
| 13 |
+
LENGTH_BUCKETS: tuple[str, str, str] = ("short", "medium", "long")
|
| 14 |
+
LENGTH_TOKEN_TARGETS: dict[str, int] = {"short": 60, "medium": 220, "long": 700}
|
| 15 |
+
LENGTH_P90_MULTIPLIER: float = 1.6
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class CapabilityProfile:
|
| 20 |
+
code: float = 0.0
|
| 21 |
+
math: float = 0.0
|
| 22 |
+
reasoning: float = 0.0
|
| 23 |
+
knowledge: float = 0.0
|
| 24 |
+
instruction: float = 0.0
|
| 25 |
+
creative: float = 0.0
|
| 26 |
+
multilingual: float = 0.0
|
| 27 |
+
simple_chat: float = 0.0
|
| 28 |
+
|
| 29 |
+
def as_dict(self) -> dict[str, float]:
|
| 30 |
+
return asdict(self)
|
| 31 |
+
|
| 32 |
+
def top(self, k: int = 3) -> list[tuple[str, float]]:
|
| 33 |
+
items = sorted(self.as_dict().items(), key=lambda kv: kv[1], reverse=True)
|
| 34 |
+
return [(k_, v) for k_, v in items[:k] if v > 0.05]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class QueryProfile:
|
| 39 |
+
capabilities: CapabilityProfile
|
| 40 |
+
difficulty_log_params: float
|
| 41 |
+
length_dist: dict[str, float]
|
| 42 |
+
expected_input_tokens: int
|
| 43 |
+
expected_output_tokens_p50: int
|
| 44 |
+
expected_output_tokens_p90: int
|
| 45 |
+
confidence: float
|
| 46 |
+
is_ood: bool = False
|
| 47 |
+
raw_query: str = ""
|
| 48 |
+
debug: dict = field(default_factory=dict)
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def difficulty_params_b(self) -> float:
|
| 52 |
+
return math.exp(self.difficulty_log_params) / 1e9
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Predictor(Protocol):
|
| 56 |
+
def predict(self, query: str) -> QueryProfile: ...
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _tokens_from_text(text: str) -> int:
|
| 60 |
+
words = max(1, len(text.split()))
|
| 61 |
+
return int(words * 1.3) + 4
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _length_dist_for_target(bucket: str) -> dict[str, float]:
|
| 65 |
+
if bucket == "short":
|
| 66 |
+
return {"short": 0.75, "medium": 0.20, "long": 0.05}
|
| 67 |
+
if bucket == "medium":
|
| 68 |
+
return {"short": 0.15, "medium": 0.65, "long": 0.20}
|
| 69 |
+
return {"short": 0.05, "medium": 0.25, "long": 0.70}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _expected_output_tokens(length_dist: dict[str, float]) -> tuple[int, int]:
|
| 73 |
+
p50 = sum(length_dist[b] * LENGTH_TOKEN_TARGETS[b] for b in LENGTH_BUCKETS)
|
| 74 |
+
long_weight = length_dist.get("long", 0.0)
|
| 75 |
+
p90 = p50 * LENGTH_P90_MULTIPLIER + long_weight * LENGTH_TOKEN_TARGETS["long"] * 0.3
|
| 76 |
+
return int(round(p50)), int(round(p90))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
_KEYWORD_RULES: dict[str, tuple[float, list[str]]] = {
|
| 80 |
+
"code": (0.85, [
|
| 81 |
+
r"\b(code|function|class|def |algorithm|implement|debug|compile|stack trace|api|sdk)\b",
|
| 82 |
+
r"\b(python|javascript|typescript|rust|go|c\+\+|java|sql|html|css|react|kotlin|swift)\b",
|
| 83 |
+
r"\b(refactor|unit test|regex|linter)\b",
|
| 84 |
+
]),
|
| 85 |
+
"math": (0.80, [
|
| 86 |
+
r"\b(calculate|compute|solve|equation|integral|derivative|matrix|vector|probability|theorem)\b",
|
| 87 |
+
r"\b(sum|product|mean|median|variance|standard deviation|percentage)\b",
|
| 88 |
+
r"\d+\s*[+\-*/×÷=]\s*\d+",
|
| 89 |
+
]),
|
| 90 |
+
"reasoning": (0.70, [
|
| 91 |
+
r"\b(why|how does|explain|reason|because|therefore|thus|argue|justify|implication)\b",
|
| 92 |
+
r"\b(compare|contrast|analyze|evaluate|trade-?off|implication)\b",
|
| 93 |
+
]),
|
| 94 |
+
"knowledge": (0.65, [
|
| 95 |
+
r"\b(who|what is|when did|where is|history|definition|capital|population|founded)\b",
|
| 96 |
+
]),
|
| 97 |
+
"instruction": (0.60, [
|
| 98 |
+
r"\b(write|draft|create|generate|produce|format|list|outline|step.?by.?step)\b",
|
| 99 |
+
]),
|
| 100 |
+
"creative": (0.75, [
|
| 101 |
+
r"\b(story|poem|novel|character|plot|scene|metaphor|fictional)\b",
|
| 102 |
+
r"\b(write a (?:short )?(?:story|poem|haiku|song|essay))\b",
|
| 103 |
+
]),
|
| 104 |
+
"multilingual": (0.85, [
|
| 105 |
+
r"\b(translate|translation|en español|en français|auf deutsch|на русском|中文|日本語|한국어)\b",
|
| 106 |
+
r"[Ѐ-ӿ一-鿿-ゟ゠-ヿ]",
|
| 107 |
+
]),
|
| 108 |
+
"simple_chat": (0.70, [
|
| 109 |
+
r"^\s*(hi|hello|hey|thanks|thank you|good morning|good evening|sup|yo)\b",
|
| 110 |
+
r"^\s*\S{1,40}\?\s*$",
|
| 111 |
+
]),
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class MockPredictor:
|
| 116 |
+
"""Heuristic predictor used to drive the demo before a trained checkpoint exists.
|
| 117 |
+
|
| 118 |
+
The interface and output shape match the trained predictor that replaces it later.
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
def __init__(self, default_difficulty_log_params: float = math.log(8e9)) -> None:
|
| 122 |
+
self.default_difficulty = default_difficulty_log_params
|
| 123 |
+
|
| 124 |
+
def predict(self, query: str) -> QueryProfile:
|
| 125 |
+
q = query.strip()
|
| 126 |
+
scores = {k: 0.0 for k in CAPABILITY_KEYS}
|
| 127 |
+
for cap, (weight, patterns) in _KEYWORD_RULES.items():
|
| 128 |
+
for pat in patterns:
|
| 129 |
+
if re.search(pat, q, flags=re.IGNORECASE | re.MULTILINE):
|
| 130 |
+
scores[cap] = max(scores[cap], weight)
|
| 131 |
+
|
| 132 |
+
if not any(v > 0 for v in scores.values()):
|
| 133 |
+
scores["simple_chat"] = 0.55
|
| 134 |
+
scores["instruction"] = 0.30
|
| 135 |
+
|
| 136 |
+
length_bucket = self._length_bucket(q, scores)
|
| 137 |
+
length_dist = _length_dist_for_target(length_bucket)
|
| 138 |
+
difficulty = self._difficulty(q, scores)
|
| 139 |
+
confidence = max(scores.values())
|
| 140 |
+
in_tokens = _tokens_from_text(q)
|
| 141 |
+
out_p50, out_p90 = _expected_output_tokens(length_dist)
|
| 142 |
+
is_ood = self._ood(q)
|
| 143 |
+
|
| 144 |
+
return QueryProfile(
|
| 145 |
+
capabilities=CapabilityProfile(**scores),
|
| 146 |
+
difficulty_log_params=difficulty,
|
| 147 |
+
length_dist=length_dist,
|
| 148 |
+
expected_input_tokens=in_tokens,
|
| 149 |
+
expected_output_tokens_p50=out_p50,
|
| 150 |
+
expected_output_tokens_p90=out_p90,
|
| 151 |
+
confidence=confidence,
|
| 152 |
+
is_ood=is_ood,
|
| 153 |
+
raw_query=q,
|
| 154 |
+
debug={"source": "mock", "length_bucket": length_bucket},
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
@staticmethod
|
| 158 |
+
def _length_bucket(query: str, scores: dict[str, float]) -> str:
|
| 159 |
+
if scores.get("simple_chat", 0) > 0.5:
|
| 160 |
+
return "short"
|
| 161 |
+
if scores.get("creative", 0) > 0.5 or scores.get("code", 0) > 0.5:
|
| 162 |
+
return "long"
|
| 163 |
+
if len(query) < 80:
|
| 164 |
+
return "short"
|
| 165 |
+
if len(query) < 240:
|
| 166 |
+
return "medium"
|
| 167 |
+
return "long"
|
| 168 |
+
|
| 169 |
+
@staticmethod
|
| 170 |
+
def _difficulty(query: str, scores: dict[str, float]) -> float:
|
| 171 |
+
base = math.log(7e9)
|
| 172 |
+
bumps = 0.0
|
| 173 |
+
if scores.get("math", 0) > 0.5 and re.search(r"\b(prove|theorem|integral|differential)\b", query, re.IGNORECASE):
|
| 174 |
+
bumps += math.log(10)
|
| 175 |
+
if scores.get("reasoning", 0) > 0.5 and len(query) > 200:
|
| 176 |
+
bumps += math.log(5)
|
| 177 |
+
if scores.get("code", 0) > 0.5 and re.search(r"\b(distributed|concurrency|kernel|cuda|optimize)\b", query, re.IGNORECASE):
|
| 178 |
+
bumps += math.log(8)
|
| 179 |
+
if scores.get("simple_chat", 0) > 0.5:
|
| 180 |
+
bumps -= math.log(3)
|
| 181 |
+
return base + bumps
|
| 182 |
+
|
| 183 |
+
@staticmethod
|
| 184 |
+
def _ood(query: str) -> bool:
|
| 185 |
+
q = query.strip()
|
| 186 |
+
if len(q) < 2:
|
| 187 |
+
return True
|
| 188 |
+
alnum = sum(c.isalnum() for c in q)
|
| 189 |
+
if alnum and (sum(c.isalpha() for c in q) / max(alnum, 1)) < 0.3:
|
| 190 |
+
return True
|
| 191 |
+
if re.fullmatch(r"[\W\d_]+", q):
|
| 192 |
+
return True
|
| 193 |
+
words = re.findall(r"[A-Za-z]{4,}", q)
|
| 194 |
+
if words:
|
| 195 |
+
gibberish = 0
|
| 196 |
+
for w in words:
|
| 197 |
+
longest_run = max(
|
| 198 |
+
(len(m.group()) for m in re.finditer(r"[^aeiouyAEIOUY]+", w)),
|
| 199 |
+
default=0,
|
| 200 |
+
)
|
| 201 |
+
if longest_run >= 5:
|
| 202 |
+
gibberish += 1
|
| 203 |
+
if gibberish / len(words) >= 0.5:
|
| 204 |
+
return True
|
| 205 |
+
return False
|
greenrouting/classifier/model.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Frozen sentence encoder + three task heads (capability, difficulty, length)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
DEFAULT_ENCODER = "BAAI/bge-small-en-v1.5"
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class ModelSpec:
|
| 12 |
+
encoder_name: str = DEFAULT_ENCODER
|
| 13 |
+
embedding_dim: int = 384
|
| 14 |
+
hidden_dim: int = 256
|
| 15 |
+
n_capabilities: int = 8
|
| 16 |
+
n_length_buckets: int = 3
|
| 17 |
+
dropout: float = 0.1
|
| 18 |
+
max_seq_len: int = 256
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def build_head(spec: ModelSpec):
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
|
| 24 |
+
class HeadStack(nn.Module):
|
| 25 |
+
def __init__(self, s: ModelSpec):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.shared = nn.Sequential(
|
| 28 |
+
nn.Linear(s.embedding_dim, s.hidden_dim),
|
| 29 |
+
nn.GELU(),
|
| 30 |
+
nn.Dropout(s.dropout),
|
| 31 |
+
nn.Linear(s.hidden_dim, s.hidden_dim),
|
| 32 |
+
nn.GELU(),
|
| 33 |
+
nn.Dropout(s.dropout),
|
| 34 |
+
)
|
| 35 |
+
self.cap_head = nn.Linear(s.hidden_dim, s.n_capabilities)
|
| 36 |
+
self.diff_head = nn.Linear(s.hidden_dim, 1)
|
| 37 |
+
self.len_head = nn.Linear(s.hidden_dim, s.n_length_buckets)
|
| 38 |
+
|
| 39 |
+
def forward(self, embeddings):
|
| 40 |
+
h = self.shared(embeddings)
|
| 41 |
+
return {
|
| 42 |
+
"cap_logits": self.cap_head(h),
|
| 43 |
+
"diff": self.diff_head(h).squeeze(-1),
|
| 44 |
+
"len_logits": self.len_head(h),
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
return HeadStack(spec)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Encoder:
|
| 51 |
+
"""Lazy wrapper around a HuggingFace sentence encoder, mean-pooled and L2-normalized."""
|
| 52 |
+
|
| 53 |
+
def __init__(self, encoder_name: str = DEFAULT_ENCODER, max_seq_len: int = 256):
|
| 54 |
+
self.encoder_name = encoder_name
|
| 55 |
+
self.max_seq_len = max_seq_len
|
| 56 |
+
self._tokenizer = None
|
| 57 |
+
self._model = None
|
| 58 |
+
self._device = None
|
| 59 |
+
|
| 60 |
+
def _ensure_loaded(self):
|
| 61 |
+
if self._model is not None:
|
| 62 |
+
return
|
| 63 |
+
import torch
|
| 64 |
+
from transformers import AutoModel, AutoTokenizer
|
| 65 |
+
|
| 66 |
+
self._tokenizer = AutoTokenizer.from_pretrained(self.encoder_name)
|
| 67 |
+
self._model = AutoModel.from_pretrained(self.encoder_name)
|
| 68 |
+
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 69 |
+
self._model.to(self._device).eval()
|
| 70 |
+
for p in self._model.parameters():
|
| 71 |
+
p.requires_grad = False
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def device(self) -> str:
|
| 75 |
+
self._ensure_loaded()
|
| 76 |
+
return self._device
|
| 77 |
+
|
| 78 |
+
def embed(self, texts: list[str]):
|
| 79 |
+
import torch
|
| 80 |
+
import torch.nn.functional as F
|
| 81 |
+
|
| 82 |
+
self._ensure_loaded()
|
| 83 |
+
enc = self._tokenizer(
|
| 84 |
+
texts,
|
| 85 |
+
padding=True,
|
| 86 |
+
truncation=True,
|
| 87 |
+
max_length=self.max_seq_len,
|
| 88 |
+
return_tensors="pt",
|
| 89 |
+
).to(self._device)
|
| 90 |
+
with torch.no_grad():
|
| 91 |
+
out = self._model(**enc)
|
| 92 |
+
mask = enc["attention_mask"].unsqueeze(-1).float()
|
| 93 |
+
pooled = (out.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
|
| 94 |
+
return F.normalize(pooled, dim=-1)
|
greenrouting/classifier/ood.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OOD detection on L2-normalized encoder embeddings.
|
| 2 |
+
|
| 3 |
+
Uses centroid cosine distance + k-nearest-neighbor distance. Both are robust
|
| 4 |
+
when the number of training examples is smaller than the embedding dimension,
|
| 5 |
+
which is typical for our seed-scale datasets.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def fit_ood_stats(train_embeddings, k: int = 5):
|
| 12 |
+
"""Returns a dict with: centroid (unit-norm), full reference embeddings,
|
| 13 |
+
and per-source distances used for threshold calibration."""
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
arr = np.asarray(train_embeddings, dtype=np.float32)
|
| 17 |
+
if arr.size == 0:
|
| 18 |
+
return {"centroid": np.zeros((arr.shape[1] if arr.ndim == 2 else 384,), dtype=np.float32),
|
| 19 |
+
"reference": arr,
|
| 20 |
+
"k": k}
|
| 21 |
+
norms = np.linalg.norm(arr, axis=1, keepdims=True)
|
| 22 |
+
norms[norms == 0] = 1.0
|
| 23 |
+
normalized = arr / norms
|
| 24 |
+
centroid = normalized.mean(axis=0)
|
| 25 |
+
centroid_norm = float(np.linalg.norm(centroid))
|
| 26 |
+
if centroid_norm == 0:
|
| 27 |
+
centroid_norm = 1.0
|
| 28 |
+
centroid = centroid / centroid_norm
|
| 29 |
+
return {"centroid": centroid.astype(np.float32),
|
| 30 |
+
"reference": normalized.astype(np.float32),
|
| 31 |
+
"k": k}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _cosine_distance(a, b):
|
| 35 |
+
import numpy as np
|
| 36 |
+
a = np.asarray(a, dtype=np.float32)
|
| 37 |
+
b = np.asarray(b, dtype=np.float32)
|
| 38 |
+
na = np.linalg.norm(a)
|
| 39 |
+
nb = np.linalg.norm(b)
|
| 40 |
+
if na == 0 or nb == 0:
|
| 41 |
+
return 1.0
|
| 42 |
+
return 1.0 - float(np.dot(a, b) / (na * nb))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def centroid_distance(embedding, stats) -> float:
|
| 46 |
+
return _cosine_distance(embedding, stats["centroid"])
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def knn_distance(embedding, stats, k: int = None) -> float:
|
| 50 |
+
import numpy as np
|
| 51 |
+
ref = stats["reference"]
|
| 52 |
+
if ref.size == 0:
|
| 53 |
+
return 1.0
|
| 54 |
+
k = k or stats.get("k", 5)
|
| 55 |
+
emb = np.asarray(embedding, dtype=np.float32)
|
| 56 |
+
n = np.linalg.norm(emb)
|
| 57 |
+
if n == 0:
|
| 58 |
+
return 1.0
|
| 59 |
+
emb = emb / n
|
| 60 |
+
sims = ref @ emb
|
| 61 |
+
distances = 1.0 - sims
|
| 62 |
+
distances.sort()
|
| 63 |
+
return float(distances[: max(1, min(k, len(distances)))].mean())
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def calibrate_thresholds(
|
| 67 |
+
train_embeddings,
|
| 68 |
+
stats,
|
| 69 |
+
percentile: float = 99.9,
|
| 70 |
+
safety_multiplier: float = 1.25,
|
| 71 |
+
) -> dict:
|
| 72 |
+
"""Threshold = percentile × safety_multiplier. The multiplier gives headroom
|
| 73 |
+
for natural rephrasings that aren't truly OOD."""
|
| 74 |
+
import numpy as np
|
| 75 |
+
centroid_dists = [centroid_distance(e, stats) for e in train_embeddings]
|
| 76 |
+
knn_dists = [knn_distance(e, stats) for e in train_embeddings]
|
| 77 |
+
centroid_t = float(np.percentile(centroid_dists, percentile)) if centroid_dists else 1.0
|
| 78 |
+
knn_t = float(np.percentile(knn_dists, percentile)) if knn_dists else 1.0
|
| 79 |
+
return {
|
| 80 |
+
"centroid_threshold": centroid_t * safety_multiplier,
|
| 81 |
+
"knn_threshold": knn_t * safety_multiplier,
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def is_ood(embedding, stats, thresholds) -> bool:
|
| 86 |
+
"""Either signal is sufficient to flag OOD. AND semantics let too many
|
| 87 |
+
obvious cases slip when one signal happens to look in-distribution."""
|
| 88 |
+
cd = centroid_distance(embedding, stats)
|
| 89 |
+
kd = knn_distance(embedding, stats)
|
| 90 |
+
return cd > thresholds["centroid_threshold"] or kd > thresholds["knn_threshold"]
|
greenrouting/classifier/train.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training loop. Frozen encoder, head-only optimization, multi-task loss."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import math
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from greenrouting.classifier.model import DEFAULT_ENCODER, Encoder, ModelSpec, build_head
|
| 11 |
+
from greenrouting.data.schema import LENGTH_BUCKETS
|
| 12 |
+
from greenrouting.routing.registry import CAPABILITY_KEYS
|
| 13 |
+
|
| 14 |
+
LENGTH_TO_INDEX: dict[str, int] = {b: i for i, b in enumerate(LENGTH_BUCKETS)}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class TrainConfig:
|
| 19 |
+
encoder_name: str = DEFAULT_ENCODER
|
| 20 |
+
hidden_dim: int = 256
|
| 21 |
+
dropout: float = 0.1
|
| 22 |
+
max_seq_len: int = 256
|
| 23 |
+
epochs: int = 8
|
| 24 |
+
batch_size: int = 32
|
| 25 |
+
learning_rate: float = 1e-3
|
| 26 |
+
weight_decay: float = 1e-4
|
| 27 |
+
cap_weight: float = 1.0
|
| 28 |
+
diff_weight: float = 0.5
|
| 29 |
+
len_weight: float = 0.3
|
| 30 |
+
val_split: float = 0.15
|
| 31 |
+
seed: int = 42
|
| 32 |
+
huber_delta: float = 1.0
|
| 33 |
+
cap_pos_weight: float = 2.0
|
| 34 |
+
diff_target_center: float = field(default_factory=lambda: math.log(8e9))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _load_split(parquet_path: str | Path):
|
| 38 |
+
import pandas as pd
|
| 39 |
+
return pd.read_parquet(parquet_path)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _build_targets(df, cfg: TrainConfig):
|
| 43 |
+
import numpy as np
|
| 44 |
+
cap_cols = [f"cap_{k}" for k in CAPABILITY_KEYS]
|
| 45 |
+
caps = df[cap_cols].fillna(0.0).to_numpy(dtype=np.float32)
|
| 46 |
+
diff = (df["difficulty_log_params"].fillna(cfg.diff_target_center).to_numpy(dtype=np.float32))
|
| 47 |
+
diff_centered = diff - cfg.diff_target_center
|
| 48 |
+
lens = df["length_bucket"].fillna("medium").map(LENGTH_TO_INDEX).fillna(1).to_numpy(dtype=np.int64)
|
| 49 |
+
texts = df["text"].astype(str).tolist()
|
| 50 |
+
return texts, caps, diff_centered, lens
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _split_train_val(texts, caps, diff, lens, val_split: float, seed: int):
|
| 54 |
+
import numpy as np
|
| 55 |
+
rng = np.random.default_rng(seed)
|
| 56 |
+
n = len(texts)
|
| 57 |
+
indices = np.arange(n)
|
| 58 |
+
rng.shuffle(indices)
|
| 59 |
+
n_val = max(1, int(n * val_split))
|
| 60 |
+
val_idx = indices[:n_val]
|
| 61 |
+
train_idx = indices[n_val:]
|
| 62 |
+
return (
|
| 63 |
+
([texts[i] for i in train_idx], caps[train_idx], diff[train_idx], lens[train_idx]),
|
| 64 |
+
([texts[i] for i in val_idx], caps[val_idx], diff[val_idx], lens[val_idx]),
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _iterate_batches(texts, caps, diff, lens, batch_size: int, encoder: Encoder, shuffle: bool, seed: int):
|
| 69 |
+
import numpy as np
|
| 70 |
+
import torch
|
| 71 |
+
|
| 72 |
+
n = len(texts)
|
| 73 |
+
indices = np.arange(n)
|
| 74 |
+
if shuffle:
|
| 75 |
+
np.random.default_rng(seed).shuffle(indices)
|
| 76 |
+
|
| 77 |
+
for start in range(0, n, batch_size):
|
| 78 |
+
idx = indices[start:start + batch_size]
|
| 79 |
+
batch_texts = [texts[i] for i in idx]
|
| 80 |
+
emb = encoder.embed(batch_texts)
|
| 81 |
+
cap_t = torch.tensor(caps[idx], dtype=torch.float32, device=emb.device)
|
| 82 |
+
diff_t = torch.tensor(diff[idx], dtype=torch.float32, device=emb.device)
|
| 83 |
+
len_t = torch.tensor(lens[idx], dtype=torch.long, device=emb.device)
|
| 84 |
+
yield emb, cap_t, diff_t, len_t
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def train(
|
| 88 |
+
train_parquet: str | Path,
|
| 89 |
+
output_dir: str | Path,
|
| 90 |
+
cfg: TrainConfig | None = None,
|
| 91 |
+
) -> dict:
|
| 92 |
+
import torch
|
| 93 |
+
import torch.nn as nn
|
| 94 |
+
from torch.optim import AdamW
|
| 95 |
+
|
| 96 |
+
cfg = cfg or TrainConfig()
|
| 97 |
+
out_dir = Path(output_dir)
|
| 98 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 99 |
+
|
| 100 |
+
df = _load_split(train_parquet)
|
| 101 |
+
texts, caps, diff, lens = _build_targets(df, cfg)
|
| 102 |
+
train_set, val_set = _split_train_val(texts, caps, diff, lens, cfg.val_split, cfg.seed)
|
| 103 |
+
|
| 104 |
+
encoder = Encoder(cfg.encoder_name, cfg.max_seq_len)
|
| 105 |
+
embed_dim = encoder.embed(["probe"]).shape[-1]
|
| 106 |
+
spec = ModelSpec(
|
| 107 |
+
encoder_name=cfg.encoder_name,
|
| 108 |
+
embedding_dim=embed_dim,
|
| 109 |
+
hidden_dim=cfg.hidden_dim,
|
| 110 |
+
dropout=cfg.dropout,
|
| 111 |
+
max_seq_len=cfg.max_seq_len,
|
| 112 |
+
)
|
| 113 |
+
head = build_head(spec).to(encoder.device)
|
| 114 |
+
|
| 115 |
+
pos_weight = torch.full((spec.n_capabilities,), cfg.cap_pos_weight, device=encoder.device)
|
| 116 |
+
cap_loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
| 117 |
+
diff_loss_fn = nn.HuberLoss(delta=cfg.huber_delta)
|
| 118 |
+
len_loss_fn = nn.CrossEntropyLoss()
|
| 119 |
+
optimizer = AdamW(head.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
|
| 120 |
+
|
| 121 |
+
history = []
|
| 122 |
+
for epoch in range(cfg.epochs):
|
| 123 |
+
head.train()
|
| 124 |
+
train_loss_sum = 0.0
|
| 125 |
+
n_train = 0
|
| 126 |
+
for emb, cap_t, diff_t, len_t in _iterate_batches(
|
| 127 |
+
*train_set, batch_size=cfg.batch_size, encoder=encoder,
|
| 128 |
+
shuffle=True, seed=cfg.seed + epoch,
|
| 129 |
+
):
|
| 130 |
+
out = head(emb)
|
| 131 |
+
loss = (
|
| 132 |
+
cfg.cap_weight * cap_loss_fn(out["cap_logits"], cap_t)
|
| 133 |
+
+ cfg.diff_weight * diff_loss_fn(out["diff"], diff_t)
|
| 134 |
+
+ cfg.len_weight * len_loss_fn(out["len_logits"], len_t)
|
| 135 |
+
)
|
| 136 |
+
optimizer.zero_grad()
|
| 137 |
+
loss.backward()
|
| 138 |
+
optimizer.step()
|
| 139 |
+
train_loss_sum += loss.item() * emb.shape[0]
|
| 140 |
+
n_train += emb.shape[0]
|
| 141 |
+
|
| 142 |
+
val_metrics = _evaluate(head, encoder, val_set, cfg)
|
| 143 |
+
history.append({
|
| 144 |
+
"epoch": epoch,
|
| 145 |
+
"train_loss": train_loss_sum / max(n_train, 1),
|
| 146 |
+
**val_metrics,
|
| 147 |
+
})
|
| 148 |
+
print(
|
| 149 |
+
f"epoch {epoch+1}/{cfg.epochs} "
|
| 150 |
+
f"train_loss={train_loss_sum/max(n_train,1):.4f} "
|
| 151 |
+
f"val_cap_f1={val_metrics['cap_f1']:.3f} "
|
| 152 |
+
f"val_diff_mae={val_metrics['diff_mae']:.3f} "
|
| 153 |
+
f"val_len_acc={val_metrics['len_acc']:.3f}"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
head.eval()
|
| 157 |
+
torch.save(head.state_dict(), out_dir / "head.pt")
|
| 158 |
+
(out_dir / "encoder_name.txt").write_text(cfg.encoder_name)
|
| 159 |
+
(out_dir / "metadata.json").write_text(json.dumps({
|
| 160 |
+
"capability_keys": list(CAPABILITY_KEYS),
|
| 161 |
+
"length_buckets": list(LENGTH_BUCKETS),
|
| 162 |
+
"embedding_dim": int(spec.embedding_dim),
|
| 163 |
+
"hidden_dim": int(spec.hidden_dim),
|
| 164 |
+
"max_seq_len": int(spec.max_seq_len),
|
| 165 |
+
"diff_target_center": float(cfg.diff_target_center),
|
| 166 |
+
}, indent=2))
|
| 167 |
+
(out_dir / "training_history.json").write_text(json.dumps(history, indent=2))
|
| 168 |
+
|
| 169 |
+
train_embeddings = _collect_embeddings(encoder, train_set[0], batch_size=cfg.batch_size)
|
| 170 |
+
val_cap_logits = _collect_logits(head, encoder, val_set, cfg.batch_size)
|
| 171 |
+
|
| 172 |
+
from greenrouting.classifier.calibration import fit_temperature
|
| 173 |
+
|
| 174 |
+
temperature = fit_temperature(val_cap_logits, val_set[1])
|
| 175 |
+
(out_dir / "calibration.json").write_text(json.dumps({"temperature": float(temperature)}, indent=2))
|
| 176 |
+
|
| 177 |
+
from greenrouting.classifier.ood import calibrate_thresholds, fit_ood_stats
|
| 178 |
+
|
| 179 |
+
ood_stats = fit_ood_stats(train_embeddings, k=5)
|
| 180 |
+
thresholds = calibrate_thresholds(train_embeddings, ood_stats, percentile=99.0)
|
| 181 |
+
import numpy as np
|
| 182 |
+
np.savez(
|
| 183 |
+
out_dir / "ood_stats.npz",
|
| 184 |
+
centroid=ood_stats["centroid"],
|
| 185 |
+
reference=ood_stats["reference"],
|
| 186 |
+
k=ood_stats.get("k", 5),
|
| 187 |
+
centroid_threshold=thresholds["centroid_threshold"],
|
| 188 |
+
knn_threshold=thresholds["knn_threshold"],
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
return {
|
| 192 |
+
"history": history,
|
| 193 |
+
"temperature": float(temperature),
|
| 194 |
+
"n_train": len(train_set[0]),
|
| 195 |
+
"n_val": len(val_set[0]),
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _evaluate(head, encoder: Encoder, val_set, cfg: TrainConfig) -> dict:
|
| 200 |
+
import torch
|
| 201 |
+
head.eval()
|
| 202 |
+
all_cap_pred, all_cap_true = [], []
|
| 203 |
+
all_diff_pred, all_diff_true = [], []
|
| 204 |
+
all_len_pred, all_len_true = [], []
|
| 205 |
+
with torch.no_grad():
|
| 206 |
+
for emb, cap_t, diff_t, len_t in _iterate_batches(
|
| 207 |
+
*val_set, batch_size=cfg.batch_size, encoder=encoder, shuffle=False, seed=cfg.seed,
|
| 208 |
+
):
|
| 209 |
+
out = head(emb)
|
| 210 |
+
all_cap_pred.append(torch.sigmoid(out["cap_logits"]).cpu().numpy())
|
| 211 |
+
all_cap_true.append(cap_t.cpu().numpy())
|
| 212 |
+
all_diff_pred.append(out["diff"].cpu().numpy())
|
| 213 |
+
all_diff_true.append(diff_t.cpu().numpy())
|
| 214 |
+
all_len_pred.append(out["len_logits"].argmax(dim=-1).cpu().numpy())
|
| 215 |
+
all_len_true.append(len_t.cpu().numpy())
|
| 216 |
+
head.train()
|
| 217 |
+
|
| 218 |
+
import numpy as np
|
| 219 |
+
cap_pred = np.concatenate(all_cap_pred)
|
| 220 |
+
cap_true = np.concatenate(all_cap_true)
|
| 221 |
+
diff_pred = np.concatenate(all_diff_pred)
|
| 222 |
+
diff_true = np.concatenate(all_diff_true)
|
| 223 |
+
len_pred = np.concatenate(all_len_pred)
|
| 224 |
+
len_true = np.concatenate(all_len_true)
|
| 225 |
+
|
| 226 |
+
cap_pred_bin = (cap_pred >= 0.5).astype(np.float32)
|
| 227 |
+
cap_true_bin = (cap_true >= 0.5).astype(np.float32)
|
| 228 |
+
tp = ((cap_pred_bin == 1) & (cap_true_bin == 1)).sum()
|
| 229 |
+
fp = ((cap_pred_bin == 1) & (cap_true_bin == 0)).sum()
|
| 230 |
+
fn = ((cap_pred_bin == 0) & (cap_true_bin == 1)).sum()
|
| 231 |
+
precision = tp / max(tp + fp, 1)
|
| 232 |
+
recall = tp / max(tp + fn, 1)
|
| 233 |
+
f1 = 2 * precision * recall / max(precision + recall, 1e-9)
|
| 234 |
+
|
| 235 |
+
diff_mae = float(np.abs(diff_pred - diff_true).mean())
|
| 236 |
+
len_acc = float((len_pred == len_true).mean())
|
| 237 |
+
|
| 238 |
+
return {
|
| 239 |
+
"cap_precision": float(precision),
|
| 240 |
+
"cap_recall": float(recall),
|
| 241 |
+
"cap_f1": float(f1),
|
| 242 |
+
"diff_mae": diff_mae,
|
| 243 |
+
"len_acc": len_acc,
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def _collect_embeddings(encoder: Encoder, texts: list[str], batch_size: int):
|
| 248 |
+
import numpy as np
|
| 249 |
+
chunks = []
|
| 250 |
+
for start in range(0, len(texts), batch_size):
|
| 251 |
+
chunk = texts[start:start + batch_size]
|
| 252 |
+
emb = encoder.embed(chunk).cpu().numpy()
|
| 253 |
+
chunks.append(emb)
|
| 254 |
+
return np.concatenate(chunks, axis=0) if chunks else np.zeros((0, 384), dtype=np.float32)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def _collect_logits(head, encoder: Encoder, val_set, batch_size: int):
|
| 258 |
+
import numpy as np
|
| 259 |
+
import torch
|
| 260 |
+
head.eval()
|
| 261 |
+
out_logits = []
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
for emb, _cap, _diff, _len in _iterate_batches(
|
| 264 |
+
*val_set, batch_size=batch_size, encoder=encoder, shuffle=False, seed=0,
|
| 265 |
+
):
|
| 266 |
+
out = head(emb)
|
| 267 |
+
out_logits.append(out["cap_logits"].cpu().numpy())
|
| 268 |
+
head.train()
|
| 269 |
+
return np.concatenate(out_logits, axis=0) if out_logits else np.zeros((0, 8), dtype=np.float32)
|
greenrouting/classifier/trained_predictor.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference-time predictor that loads the trained artifact and conforms to
|
| 2 |
+
the `Predictor` protocol used by the router."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import math
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
from greenrouting.classifier.infer import (
|
| 12 |
+
CapabilityProfile,
|
| 13 |
+
LENGTH_BUCKETS,
|
| 14 |
+
LENGTH_TOKEN_TARGETS,
|
| 15 |
+
LENGTH_P90_MULTIPLIER,
|
| 16 |
+
QueryProfile,
|
| 17 |
+
)
|
| 18 |
+
from greenrouting.classifier.model import Encoder, ModelSpec, build_head
|
| 19 |
+
from greenrouting.classifier.ood import is_ood
|
| 20 |
+
from greenrouting.routing.registry import CAPABILITY_KEYS
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TrainedPredictor:
|
| 24 |
+
def __init__(self, artifact_dir: str | Path):
|
| 25 |
+
self.artifact_dir = Path(artifact_dir)
|
| 26 |
+
self._loaded = False
|
| 27 |
+
self._encoder: Optional[Encoder] = None
|
| 28 |
+
self._head = None
|
| 29 |
+
self._spec: Optional[ModelSpec] = None
|
| 30 |
+
self._temperature: float = 1.0
|
| 31 |
+
self._ood_stats = None
|
| 32 |
+
self._ood_thresholds = None
|
| 33 |
+
self._ood_min_confidence: float = 0.40
|
| 34 |
+
|
| 35 |
+
def _ensure_loaded(self) -> None:
|
| 36 |
+
if self._loaded:
|
| 37 |
+
return
|
| 38 |
+
import numpy as np
|
| 39 |
+
import torch
|
| 40 |
+
|
| 41 |
+
meta_path = self.artifact_dir / "metadata.json"
|
| 42 |
+
meta = json.loads(meta_path.read_text())
|
| 43 |
+
encoder_name = (self.artifact_dir / "encoder_name.txt").read_text().strip()
|
| 44 |
+
self._spec = ModelSpec(
|
| 45 |
+
encoder_name=encoder_name,
|
| 46 |
+
embedding_dim=int(meta["embedding_dim"]),
|
| 47 |
+
hidden_dim=int(meta["hidden_dim"]),
|
| 48 |
+
n_capabilities=len(meta["capability_keys"]),
|
| 49 |
+
n_length_buckets=len(meta["length_buckets"]),
|
| 50 |
+
max_seq_len=int(meta.get("max_seq_len", 256)),
|
| 51 |
+
)
|
| 52 |
+
self._diff_center = float(meta.get("diff_target_center", math.log(8e9)))
|
| 53 |
+
self._encoder = Encoder(encoder_name, max_seq_len=self._spec.max_seq_len)
|
| 54 |
+
head = build_head(self._spec)
|
| 55 |
+
head.load_state_dict(torch.load(self.artifact_dir / "head.pt", map_location="cpu"))
|
| 56 |
+
head.to(self._encoder.device).eval()
|
| 57 |
+
self._head = head
|
| 58 |
+
|
| 59 |
+
cal_path = self.artifact_dir / "calibration.json"
|
| 60 |
+
if cal_path.exists():
|
| 61 |
+
self._temperature = float(json.loads(cal_path.read_text()).get("temperature", 1.0))
|
| 62 |
+
|
| 63 |
+
ood_path = self.artifact_dir / "ood_stats.npz"
|
| 64 |
+
if ood_path.exists():
|
| 65 |
+
data = np.load(ood_path)
|
| 66 |
+
if "centroid" in data.files and "reference" in data.files:
|
| 67 |
+
self._ood_stats = {
|
| 68 |
+
"centroid": data["centroid"],
|
| 69 |
+
"reference": data["reference"],
|
| 70 |
+
"k": int(data["k"]) if "k" in data.files else 5,
|
| 71 |
+
}
|
| 72 |
+
self._ood_thresholds = {
|
| 73 |
+
"centroid_threshold": float(data["centroid_threshold"]),
|
| 74 |
+
"knn_threshold": float(data["knn_threshold"]),
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
self._loaded = True
|
| 78 |
+
|
| 79 |
+
def predict(self, query: str) -> QueryProfile:
|
| 80 |
+
import torch
|
| 81 |
+
import torch.nn.functional as F
|
| 82 |
+
|
| 83 |
+
self._ensure_loaded()
|
| 84 |
+
text = (query or "").strip()
|
| 85 |
+
emb = self._encoder.embed([text])
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
out = self._head(emb)
|
| 88 |
+
|
| 89 |
+
cap_logits = (out["cap_logits"] / max(self._temperature, 1e-3))
|
| 90 |
+
cap_probs = torch.sigmoid(cap_logits)[0].cpu().numpy().tolist()
|
| 91 |
+
cap_dict = {k: float(v) for k, v in zip(CAPABILITY_KEYS, cap_probs)}
|
| 92 |
+
|
| 93 |
+
diff_centered = float(out["diff"][0].item())
|
| 94 |
+
diff_log_params = diff_centered + self._diff_center
|
| 95 |
+
|
| 96 |
+
len_probs = F.softmax(out["len_logits"][0], dim=-1).cpu().numpy().tolist()
|
| 97 |
+
length_dist = {b: float(p) for b, p in zip(LENGTH_BUCKETS, len_probs)}
|
| 98 |
+
|
| 99 |
+
confidence = max(cap_dict.values()) if cap_dict else 0.0
|
| 100 |
+
|
| 101 |
+
confidence_ood = confidence < self._ood_min_confidence
|
| 102 |
+
geometric_ood = False
|
| 103 |
+
if self._ood_stats is not None and self._ood_thresholds is not None:
|
| 104 |
+
emb_np = emb[0].cpu().numpy()
|
| 105 |
+
geometric_ood = is_ood(emb_np, self._ood_stats, self._ood_thresholds)
|
| 106 |
+
ood_flag = confidence_ood or geometric_ood
|
| 107 |
+
|
| 108 |
+
in_tokens = max(1, int(len(text.split()) * 1.3) + 4)
|
| 109 |
+
out_p50 = int(round(sum(length_dist[b] * LENGTH_TOKEN_TARGETS[b] for b in LENGTH_BUCKETS)))
|
| 110 |
+
long_w = length_dist.get("long", 0.0)
|
| 111 |
+
out_p90 = int(round(out_p50 * LENGTH_P90_MULTIPLIER + long_w * LENGTH_TOKEN_TARGETS["long"] * 0.3))
|
| 112 |
+
|
| 113 |
+
return QueryProfile(
|
| 114 |
+
capabilities=CapabilityProfile(**cap_dict),
|
| 115 |
+
difficulty_log_params=diff_log_params,
|
| 116 |
+
length_dist=length_dist,
|
| 117 |
+
expected_input_tokens=in_tokens,
|
| 118 |
+
expected_output_tokens_p50=out_p50,
|
| 119 |
+
expected_output_tokens_p90=out_p90,
|
| 120 |
+
confidence=confidence,
|
| 121 |
+
is_ood=ood_flag,
|
| 122 |
+
raw_query=text,
|
| 123 |
+
debug={
|
| 124 |
+
"source": "trained",
|
| 125 |
+
"temperature": self._temperature,
|
| 126 |
+
"confidence_ood": bool(confidence_ood),
|
| 127 |
+
"geometric_ood": bool(geometric_ood),
|
| 128 |
+
},
|
| 129 |
+
)
|
greenrouting/data/__init__.py
ADDED
|
File without changes
|
greenrouting/data/builder.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset orchestrator: source sampling -> capability labeling -> cascade plan -> parquet."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
from greenrouting.data.capability_labeler import LabelerConfig, label_queries
|
| 12 |
+
from greenrouting.data.schema import CapabilityLabel, LabeledQuery, RawQuery
|
| 13 |
+
from greenrouting.data.sources import SOURCE_REGISTRY, sample_mix
|
| 14 |
+
from greenrouting.routing.registry import CAPABILITY_KEYS
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class CascadeRungConfig:
|
| 19 |
+
id: str
|
| 20 |
+
hf_model: str
|
| 21 |
+
params_b: float
|
| 22 |
+
decode_tokens_per_second_estimate: float
|
| 23 |
+
runs_locally: bool = True
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class CascadeConfig:
|
| 28 |
+
rungs: list[CascadeRungConfig]
|
| 29 |
+
k_samples: int = 1
|
| 30 |
+
max_new_tokens: int = 200
|
| 31 |
+
temperature_first: float = 0.0
|
| 32 |
+
temperature_resample: float = 0.7
|
| 33 |
+
|
| 34 |
+
def projected_seconds(self, n_queries: int) -> float:
|
| 35 |
+
total = 0.0
|
| 36 |
+
for r in self.rungs:
|
| 37 |
+
inferences = n_queries * self.k_samples
|
| 38 |
+
total += inferences * self.max_new_tokens / max(r.decode_tokens_per_second_estimate, 1.0)
|
| 39 |
+
total += inferences * 0.4
|
| 40 |
+
return total
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class BuildConfig:
|
| 45 |
+
profile_name: str
|
| 46 |
+
target_total_queries: int
|
| 47 |
+
test_split: float
|
| 48 |
+
seed: int
|
| 49 |
+
sources: dict[str, float]
|
| 50 |
+
cascade: CascadeConfig
|
| 51 |
+
labeler: LabelerConfig
|
| 52 |
+
budget_minutes: float = 60.0
|
| 53 |
+
output_dir: str = "data"
|
| 54 |
+
capability_labels_cache: Optional[str] = None
|
| 55 |
+
|
| 56 |
+
@classmethod
|
| 57 |
+
def from_yaml(cls, path: str | Path) -> "BuildConfig":
|
| 58 |
+
import yaml
|
| 59 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 60 |
+
raw = yaml.safe_load(f)
|
| 61 |
+
|
| 62 |
+
rungs = [CascadeRungConfig(**r) for r in raw["cascade"]["rungs"]]
|
| 63 |
+
cascade = CascadeConfig(
|
| 64 |
+
rungs=rungs,
|
| 65 |
+
k_samples=raw["cascade"].get("k_samples", 1),
|
| 66 |
+
max_new_tokens=raw["cascade"].get("max_new_tokens", 200),
|
| 67 |
+
temperature_first=raw["cascade"].get("temperature_first", 0.0),
|
| 68 |
+
temperature_resample=raw["cascade"].get("temperature_resample", 0.7),
|
| 69 |
+
)
|
| 70 |
+
labeler_raw = raw.get("labeler", {})
|
| 71 |
+
labeler = LabelerConfig(
|
| 72 |
+
use_heuristic=labeler_raw.get("use_heuristic", True),
|
| 73 |
+
use_gpt=labeler_raw.get("use_gpt", False),
|
| 74 |
+
use_claude=labeler_raw.get("use_claude", False),
|
| 75 |
+
use_gemini=labeler_raw.get("use_gemini", False),
|
| 76 |
+
source_prior_weight=labeler_raw.get("source_prior_weight", 0.5),
|
| 77 |
+
sleep_between_calls_s=labeler_raw.get("sleep_between_calls_s", 0.0),
|
| 78 |
+
)
|
| 79 |
+
return cls(
|
| 80 |
+
profile_name=raw["profile_name"],
|
| 81 |
+
target_total_queries=raw["target_total_queries"],
|
| 82 |
+
test_split=raw["test_split"],
|
| 83 |
+
seed=raw["seed"],
|
| 84 |
+
sources=raw["sources"],
|
| 85 |
+
cascade=cascade,
|
| 86 |
+
labeler=labeler,
|
| 87 |
+
budget_minutes=raw.get("budget_minutes", 60.0),
|
| 88 |
+
output_dir=raw.get("output_dir", "data"),
|
| 89 |
+
capability_labels_cache=raw.get("capability_labels_cache"),
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@dataclass
|
| 94 |
+
class BuildPlan:
|
| 95 |
+
config: BuildConfig
|
| 96 |
+
n_queries: int
|
| 97 |
+
cascade_seconds: float
|
| 98 |
+
cascade_minutes: float
|
| 99 |
+
over_budget: bool
|
| 100 |
+
notes: list[str] = field(default_factory=list)
|
| 101 |
+
|
| 102 |
+
def report(self) -> str:
|
| 103 |
+
lines = [
|
| 104 |
+
f"Profile: {self.config.profile_name}",
|
| 105 |
+
f"Target queries: {self.config.target_total_queries}",
|
| 106 |
+
f"Test split: {int(self.config.test_split * 100)}%",
|
| 107 |
+
f"Sources: {', '.join(f'{k}={v}' for k, v in self.config.sources.items())}",
|
| 108 |
+
f"Cascade rungs: {', '.join(r.id for r in self.config.cascade.rungs)}",
|
| 109 |
+
f"k_samples per rung: {self.config.cascade.k_samples}",
|
| 110 |
+
f"Max new tokens: {self.config.cascade.max_new_tokens}",
|
| 111 |
+
f"Estimated cascade wall time: {self.cascade_minutes:.1f} min",
|
| 112 |
+
f"Configured budget: {self.config.budget_minutes:.1f} min",
|
| 113 |
+
f"Over budget: {self.over_budget}",
|
| 114 |
+
]
|
| 115 |
+
if self.notes:
|
| 116 |
+
lines.append("Notes:")
|
| 117 |
+
for note in self.notes:
|
| 118 |
+
lines.append(f" - {note}")
|
| 119 |
+
return "\n".join(lines)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def plan(config: BuildConfig) -> BuildPlan:
|
| 123 |
+
notes: list[str] = []
|
| 124 |
+
cascade_s = config.cascade.projected_seconds(config.target_total_queries)
|
| 125 |
+
cascade_m = cascade_s / 60.0
|
| 126 |
+
over_budget = cascade_m > config.budget_minutes
|
| 127 |
+
if over_budget:
|
| 128 |
+
notes.append(
|
| 129 |
+
f"cascade projected {cascade_m:.1f} min exceeds budget {config.budget_minutes:.1f} min"
|
| 130 |
+
)
|
| 131 |
+
if config.labeler.use_gpt and not os.environ.get("OPENAI_API_KEY"):
|
| 132 |
+
notes.append("OPENAI_API_KEY missing; gpt vote will be skipped")
|
| 133 |
+
if config.labeler.use_claude and not os.environ.get("ANTHROPIC_API_KEY"):
|
| 134 |
+
notes.append("ANTHROPIC_API_KEY missing; claude vote will be skipped")
|
| 135 |
+
if config.labeler.use_gemini and not os.environ.get("GOOGLE_API_KEY"):
|
| 136 |
+
notes.append("GOOGLE_API_KEY missing; gemini vote will be skipped")
|
| 137 |
+
for src in config.sources:
|
| 138 |
+
if src not in SOURCE_REGISTRY:
|
| 139 |
+
notes.append(f"unknown source '{src}' in mix")
|
| 140 |
+
return BuildPlan(
|
| 141 |
+
config=config,
|
| 142 |
+
n_queries=config.target_total_queries,
|
| 143 |
+
cascade_seconds=cascade_s,
|
| 144 |
+
cascade_minutes=cascade_m,
|
| 145 |
+
over_budget=over_budget,
|
| 146 |
+
notes=notes,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def write_capability_labels(path: str | Path, labels: list[CapabilityLabel]) -> None:
|
| 151 |
+
import pandas as pd
|
| 152 |
+
df = pd.DataFrame([lbl.to_record() for lbl in labels])
|
| 153 |
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
| 154 |
+
df.to_parquet(path, index=False)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def read_capability_labels(path: str | Path) -> dict[str, dict[str, float]]:
|
| 158 |
+
import pandas as pd
|
| 159 |
+
df = pd.read_parquet(path)
|
| 160 |
+
out: dict[str, dict[str, float]] = {}
|
| 161 |
+
cap_cols = [c for c in df.columns if c.startswith("cap_")]
|
| 162 |
+
for _, row in df.iterrows():
|
| 163 |
+
out[row["query_id"]] = {c[4:]: float(row[c]) for c in cap_cols}
|
| 164 |
+
return out
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def write_raw_manifest(path: str | Path, queries: list[RawQuery]) -> None:
|
| 168 |
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
| 169 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 170 |
+
for q in queries:
|
| 171 |
+
f.write(json.dumps(q.to_dict()) + "\n")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def write_labeled_dataset(
|
| 175 |
+
train_path: str | Path,
|
| 176 |
+
test_path: str | Path,
|
| 177 |
+
rows: list[LabeledQuery],
|
| 178 |
+
test_split: float,
|
| 179 |
+
seed: int,
|
| 180 |
+
) -> None:
|
| 181 |
+
import pandas as pd
|
| 182 |
+
import random as _random
|
| 183 |
+
rng = _random.Random(seed)
|
| 184 |
+
indices = list(range(len(rows)))
|
| 185 |
+
rng.shuffle(indices)
|
| 186 |
+
n_test = max(1, int(len(rows) * test_split))
|
| 187 |
+
test_idx = set(indices[:n_test])
|
| 188 |
+
train_records = [rows[i].to_record() for i in range(len(rows)) if i not in test_idx]
|
| 189 |
+
test_records = [rows[i].to_record() for i in test_idx]
|
| 190 |
+
Path(train_path).parent.mkdir(parents=True, exist_ok=True)
|
| 191 |
+
pd.DataFrame(train_records).to_parquet(train_path, index=False)
|
| 192 |
+
pd.DataFrame(test_records).to_parquet(test_path, index=False)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def build_seed_dataset(
|
| 196 |
+
output_dir: str | Path,
|
| 197 |
+
test_split: float = 0.15,
|
| 198 |
+
seed: int = 42,
|
| 199 |
+
suffix: str = "seed",
|
| 200 |
+
) -> tuple[Path, Path]:
|
| 201 |
+
"""Materialize the curated seed entries into train/test parquet files.
|
| 202 |
+
|
| 203 |
+
Skips the cascade and the labeler: the seed entries already carry gold
|
| 204 |
+
capability multi-labels, difficulty (in log_params), and length buckets.
|
| 205 |
+
"""
|
| 206 |
+
from greenrouting.data.seed_dataset import (
|
| 207 |
+
SEED_QUERIES,
|
| 208 |
+
difficulty_log_params_from_b,
|
| 209 |
+
seed_capability_dict,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
rows: list[LabeledQuery] = []
|
| 213 |
+
for i, entry in enumerate(SEED_QUERIES):
|
| 214 |
+
raw = RawQuery(
|
| 215 |
+
id=f"seed-{i:04d}",
|
| 216 |
+
text=entry.text,
|
| 217 |
+
source="seed",
|
| 218 |
+
source_category=entry.primary_category,
|
| 219 |
+
has_grader=False,
|
| 220 |
+
grader_metadata={},
|
| 221 |
+
)
|
| 222 |
+
rows.append(LabeledQuery(
|
| 223 |
+
raw=raw,
|
| 224 |
+
capabilities=seed_capability_dict(entry, CAPABILITY_KEYS),
|
| 225 |
+
difficulty_log_params=difficulty_log_params_from_b(entry.difficulty_b),
|
| 226 |
+
length_bucket=entry.length,
|
| 227 |
+
cascade_results={"source": "seed_curated"},
|
| 228 |
+
))
|
| 229 |
+
|
| 230 |
+
out = Path(output_dir)
|
| 231 |
+
train_path = out / f"train_{suffix}.parquet"
|
| 232 |
+
test_path = out / f"test_{suffix}.parquet"
|
| 233 |
+
write_labeled_dataset(train_path, test_path, rows, test_split=test_split, seed=seed)
|
| 234 |
+
return train_path, test_path
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def sample_and_label(config: BuildConfig) -> tuple[list[RawQuery], list[CapabilityLabel]]:
|
| 238 |
+
queries = sample_mix(config.sources, config.target_total_queries, config.seed)
|
| 239 |
+
|
| 240 |
+
cached = {}
|
| 241 |
+
if config.capability_labels_cache and Path(config.capability_labels_cache).exists():
|
| 242 |
+
cached = read_capability_labels(config.capability_labels_cache)
|
| 243 |
+
|
| 244 |
+
new_queries = [q for q in queries if q.id not in cached]
|
| 245 |
+
new_labels = label_queries(new_queries, config.labeler) if new_queries else []
|
| 246 |
+
|
| 247 |
+
cached_labels: list[CapabilityLabel] = []
|
| 248 |
+
for q in queries:
|
| 249 |
+
if q.id in cached:
|
| 250 |
+
from greenrouting.data.schema import CapabilityVotes
|
| 251 |
+
cached_labels.append(CapabilityLabel(
|
| 252 |
+
query_id=q.id,
|
| 253 |
+
capabilities=cached[q.id],
|
| 254 |
+
votes=CapabilityVotes(),
|
| 255 |
+
aggregation_method="cached",
|
| 256 |
+
))
|
| 257 |
+
all_labels = new_labels + cached_labels
|
| 258 |
+
by_id = {l.query_id: l for l in all_labels}
|
| 259 |
+
aligned = [by_id[q.id] for q in queries if q.id in by_id]
|
| 260 |
+
return queries, aligned
|
greenrouting/data/capability_labeler.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Capability labeling: turns RawQuery records into multi-label CapabilityLabel records.
|
| 2 |
+
|
| 3 |
+
Aggregates up to four independent voters:
|
| 4 |
+
- source_prior (always available; derived from source category)
|
| 5 |
+
- heuristic (always available; deterministic keyword/regex rules)
|
| 6 |
+
- gpt-4o (optional, requires OPENAI_API_KEY)
|
| 7 |
+
- claude-sonnet (optional, requires ANTHROPIC_API_KEY)
|
| 8 |
+
- gemini-pro (optional, requires GOOGLE_API_KEY)
|
| 9 |
+
|
| 10 |
+
Designed to run once during dataset prep. The output is committed as a parquet so
|
| 11 |
+
downstream training does not depend on API access.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
import re
|
| 19 |
+
import time
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Iterable, Optional
|
| 22 |
+
|
| 23 |
+
from greenrouting.data.schema import CapabilityLabel, CapabilityVotes, RawQuery
|
| 24 |
+
from greenrouting.routing.registry import CAPABILITY_KEYS
|
| 25 |
+
|
| 26 |
+
CATEGORY_TO_LABELS: dict[str, list[str]] = {
|
| 27 |
+
"code": ["code"],
|
| 28 |
+
"math": ["math"],
|
| 29 |
+
"reasoning": ["reasoning"],
|
| 30 |
+
"knowledge": ["knowledge"],
|
| 31 |
+
"instruction": ["instruction"],
|
| 32 |
+
"creative": ["creative"],
|
| 33 |
+
"multilingual": ["multilingual"],
|
| 34 |
+
"simple_chat": ["simple_chat"],
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def source_prior_vote(category: str) -> dict[str, float]:
|
| 39 |
+
labels = CATEGORY_TO_LABELS.get(category, [])
|
| 40 |
+
return {k: (1.0 if k in labels else 0.0) for k in CAPABILITY_KEYS}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
_HEURISTIC_PATTERNS: dict[str, list[str]] = {
|
| 44 |
+
"code": [
|
| 45 |
+
r"\b(code|function|class|def |algorithm|implement|debug|stack trace|api|sdk)\b",
|
| 46 |
+
r"\b(python|javascript|typescript|rust|go|c\+\+|java|sql|html|css)\b",
|
| 47 |
+
r"\b(refactor|unit test|regex|linter|compile|recursion)\b",
|
| 48 |
+
r"```",
|
| 49 |
+
],
|
| 50 |
+
"math": [
|
| 51 |
+
r"\b(calculate|compute|solve|equation|integral|derivative|matrix|vector|theorem|prove)\b",
|
| 52 |
+
r"\b(probability|sum|product|mean|median|variance|standard deviation|percentage)\b",
|
| 53 |
+
r"\d+\s*[+\-*/×÷=]\s*\d+",
|
| 54 |
+
r"\b(arithmetic|fraction|geometry|algebra|trig)\b",
|
| 55 |
+
],
|
| 56 |
+
"reasoning": [
|
| 57 |
+
r"\b(why|how does|explain|reason|because|therefore|argue|justify|implication)\b",
|
| 58 |
+
r"\b(compare|contrast|analyze|evaluate|trade.?off|infer|deduce)\b",
|
| 59 |
+
],
|
| 60 |
+
"knowledge": [
|
| 61 |
+
r"\b(who|what is|when did|where is|history|definition|capital|founded|named)\b",
|
| 62 |
+
r"\b(country|continent|invented|discovered|president|prime minister)\b",
|
| 63 |
+
],
|
| 64 |
+
"instruction": [
|
| 65 |
+
r"\b(write|draft|create|generate|produce|format|list|outline|step.?by.?step|summarize)\b",
|
| 66 |
+
r"\b(rewrite|translate from|convert to|extract)\b",
|
| 67 |
+
],
|
| 68 |
+
"creative": [
|
| 69 |
+
r"\b(story|poem|novel|character|plot|scene|metaphor|fictional|haiku|song lyric)\b",
|
| 70 |
+
r"\b(write a (?:short )?(?:story|poem|haiku|song))\b",
|
| 71 |
+
],
|
| 72 |
+
"multilingual": [
|
| 73 |
+
r"\b(translate|translation|en español|en français|auf deutsch|на русском|中文|日本語|한국어)\b",
|
| 74 |
+
r"[Ѐ-ӿ一-鿿-ゟ゠-ヿ가-힣]",
|
| 75 |
+
],
|
| 76 |
+
"simple_chat": [
|
| 77 |
+
r"^\s*(hi|hello|hey|thanks|thank you|good morning|good evening|sup|yo)\b",
|
| 78 |
+
],
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def heuristic_vote(text: str) -> dict[str, float]:
|
| 83 |
+
out = {k: 0.0 for k in CAPABILITY_KEYS}
|
| 84 |
+
for cap, patterns in _HEURISTIC_PATTERNS.items():
|
| 85 |
+
for pat in patterns:
|
| 86 |
+
if re.search(pat, text, flags=re.IGNORECASE | re.MULTILINE):
|
| 87 |
+
out[cap] = 1.0
|
| 88 |
+
break
|
| 89 |
+
if all(v == 0.0 for v in out.values()):
|
| 90 |
+
if len(text.strip()) < 80:
|
| 91 |
+
out["simple_chat"] = 1.0
|
| 92 |
+
else:
|
| 93 |
+
out["instruction"] = 1.0
|
| 94 |
+
return out
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
_LABELER_SYSTEM_PROMPT = (
|
| 98 |
+
"You are labeling AI queries by which capabilities they require. "
|
| 99 |
+
"Capabilities: code, math, reasoning, knowledge, instruction, creative, multilingual, "
|
| 100 |
+
"simple_chat. A query can require multiple capabilities. "
|
| 101 |
+
"Reply with strict JSON only, in the form: "
|
| 102 |
+
'{"code": 0|1, "math": 0|1, "reasoning": 0|1, "knowledge": 0|1, '
|
| 103 |
+
'"instruction": 0|1, "creative": 0|1, "multilingual": 0|1, "simple_chat": 0|1}.'
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _user_prompt(query: str) -> str:
|
| 108 |
+
return f"Query:\n{query}\n\nRespond with JSON only."
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _parse_vote(raw: str) -> dict[str, float]:
|
| 112 |
+
try:
|
| 113 |
+
data = json.loads(_extract_json(raw))
|
| 114 |
+
except Exception:
|
| 115 |
+
return {k: 0.0 for k in CAPABILITY_KEYS}
|
| 116 |
+
return {k: float(1 if data.get(k) else 0) for k in CAPABILITY_KEYS}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _extract_json(text: str) -> str:
|
| 120 |
+
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
|
| 121 |
+
return match.group(0) if match else text
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _gpt_vote(text: str) -> Optional[dict[str, float]]:
|
| 125 |
+
api_key = os.environ.get("OPENAI_API_KEY")
|
| 126 |
+
if not api_key:
|
| 127 |
+
return None
|
| 128 |
+
try:
|
| 129 |
+
from openai import OpenAI
|
| 130 |
+
except ImportError:
|
| 131 |
+
return None
|
| 132 |
+
client = OpenAI(api_key=api_key)
|
| 133 |
+
resp = client.chat.completions.create(
|
| 134 |
+
model="gpt-4o-mini",
|
| 135 |
+
messages=[
|
| 136 |
+
{"role": "system", "content": _LABELER_SYSTEM_PROMPT},
|
| 137 |
+
{"role": "user", "content": _user_prompt(text)},
|
| 138 |
+
],
|
| 139 |
+
temperature=0,
|
| 140 |
+
response_format={"type": "json_object"},
|
| 141 |
+
)
|
| 142 |
+
return _parse_vote(resp.choices[0].message.content or "{}")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _claude_vote(text: str) -> Optional[dict[str, float]]:
|
| 146 |
+
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
| 147 |
+
if not api_key:
|
| 148 |
+
return None
|
| 149 |
+
try:
|
| 150 |
+
import anthropic
|
| 151 |
+
except ImportError:
|
| 152 |
+
return None
|
| 153 |
+
client = anthropic.Anthropic(api_key=api_key)
|
| 154 |
+
resp = client.messages.create(
|
| 155 |
+
model="claude-haiku-4-5",
|
| 156 |
+
max_tokens=200,
|
| 157 |
+
system=_LABELER_SYSTEM_PROMPT,
|
| 158 |
+
messages=[{"role": "user", "content": _user_prompt(text)}],
|
| 159 |
+
)
|
| 160 |
+
body = "".join(b.text for b in resp.content if getattr(b, "type", "") == "text")
|
| 161 |
+
return _parse_vote(body)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _gemini_vote(text: str) -> Optional[dict[str, float]]:
|
| 165 |
+
api_key = os.environ.get("GOOGLE_API_KEY")
|
| 166 |
+
if not api_key:
|
| 167 |
+
return None
|
| 168 |
+
try:
|
| 169 |
+
import google.generativeai as genai
|
| 170 |
+
except ImportError:
|
| 171 |
+
return None
|
| 172 |
+
genai.configure(api_key=api_key)
|
| 173 |
+
model = genai.GenerativeModel("gemini-1.5-flash", system_instruction=_LABELER_SYSTEM_PROMPT)
|
| 174 |
+
resp = model.generate_content(_user_prompt(text))
|
| 175 |
+
return _parse_vote(resp.text or "{}")
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@dataclass
|
| 179 |
+
class LabelerConfig:
|
| 180 |
+
use_heuristic: bool = True
|
| 181 |
+
use_gpt: bool = False
|
| 182 |
+
use_claude: bool = False
|
| 183 |
+
use_gemini: bool = False
|
| 184 |
+
source_prior_weight: float = 0.5
|
| 185 |
+
sleep_between_calls_s: float = 0.0
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def aggregate_votes(votes: CapabilityVotes, source_prior_weight: float = 0.5) -> dict[str, float]:
|
| 189 |
+
voters = [v for v in (votes.heuristic, votes.gpt, votes.claude, votes.gemini) if v is not None]
|
| 190 |
+
if not voters:
|
| 191 |
+
return dict(votes.source_prior) if votes.source_prior else {k: 0.0 for k in CAPABILITY_KEYS}
|
| 192 |
+
result: dict[str, float] = {}
|
| 193 |
+
total_weight = source_prior_weight + len(voters)
|
| 194 |
+
for cap in CAPABILITY_KEYS:
|
| 195 |
+
prior_term = source_prior_weight * float(votes.source_prior.get(cap, 0.0))
|
| 196 |
+
vendor_sum = sum(float(v.get(cap, 0.0)) for v in voters)
|
| 197 |
+
result[cap] = (prior_term + vendor_sum) / total_weight
|
| 198 |
+
return result
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def label_query(query: RawQuery, config: LabelerConfig) -> CapabilityLabel:
|
| 202 |
+
votes = CapabilityVotes(source_prior=source_prior_vote(query.source_category))
|
| 203 |
+
if config.use_heuristic:
|
| 204 |
+
votes.heuristic = heuristic_vote(query.text)
|
| 205 |
+
if config.use_gpt:
|
| 206 |
+
votes.gpt = _gpt_vote(query.text)
|
| 207 |
+
if config.sleep_between_calls_s:
|
| 208 |
+
time.sleep(config.sleep_between_calls_s)
|
| 209 |
+
if config.use_claude:
|
| 210 |
+
votes.claude = _claude_vote(query.text)
|
| 211 |
+
if config.sleep_between_calls_s:
|
| 212 |
+
time.sleep(config.sleep_between_calls_s)
|
| 213 |
+
if config.use_gemini:
|
| 214 |
+
votes.gemini = _gemini_vote(query.text)
|
| 215 |
+
if config.sleep_between_calls_s:
|
| 216 |
+
time.sleep(config.sleep_between_calls_s)
|
| 217 |
+
|
| 218 |
+
aggregated = aggregate_votes(votes, source_prior_weight=config.source_prior_weight)
|
| 219 |
+
method = "+".join(
|
| 220 |
+
m for m, present in [
|
| 221 |
+
("heuristic", votes.heuristic is not None),
|
| 222 |
+
("gpt", votes.gpt is not None),
|
| 223 |
+
("claude", votes.claude is not None),
|
| 224 |
+
("gemini", votes.gemini is not None),
|
| 225 |
+
] if present
|
| 226 |
+
) or "source_prior_only"
|
| 227 |
+
|
| 228 |
+
return CapabilityLabel(
|
| 229 |
+
query_id=query.id,
|
| 230 |
+
capabilities=aggregated,
|
| 231 |
+
votes=votes,
|
| 232 |
+
aggregation_method=method,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def label_queries(queries: Iterable[RawQuery], config: LabelerConfig) -> list[CapabilityLabel]:
|
| 237 |
+
return [label_query(q, config) for q in queries]
|
greenrouting/data/cascade.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Difficulty cascade: runs each query against an ascending ladder of models,
|
| 2 |
+
grades the response, and derives a continuous `min_capable_log_params` label.
|
| 3 |
+
|
| 4 |
+
Memory strategy: load one rung at a time, run all queries, dump checkpoint, free
|
| 5 |
+
the weights, then advance to the next rung. Resumes from per-rung JSONL files.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import math
|
| 12 |
+
import time
|
| 13 |
+
from dataclasses import asdict, dataclass
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Iterable, Optional
|
| 16 |
+
|
| 17 |
+
from greenrouting.data.graders import grade
|
| 18 |
+
from greenrouting.data.schema import LabeledQuery, RawQuery
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class RungResult:
|
| 23 |
+
rung_id: str
|
| 24 |
+
params_b: float
|
| 25 |
+
query_id: str
|
| 26 |
+
sample_index: int
|
| 27 |
+
response: str
|
| 28 |
+
score: float
|
| 29 |
+
response_tokens: int
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _read_raw_manifest(path: str | Path) -> list[RawQuery]:
|
| 33 |
+
queries: list[RawQuery] = []
|
| 34 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 35 |
+
for line in f:
|
| 36 |
+
if not line.strip():
|
| 37 |
+
continue
|
| 38 |
+
data = json.loads(line)
|
| 39 |
+
queries.append(RawQuery(**data))
|
| 40 |
+
return queries
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _read_rung_checkpoint(path: Path) -> dict[str, list[RungResult]]:
|
| 44 |
+
if not path.exists():
|
| 45 |
+
return {}
|
| 46 |
+
out: dict[str, list[RungResult]] = {}
|
| 47 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 48 |
+
for line in f:
|
| 49 |
+
if not line.strip():
|
| 50 |
+
continue
|
| 51 |
+
row = json.loads(line)
|
| 52 |
+
r = RungResult(**row)
|
| 53 |
+
out.setdefault(r.query_id, []).append(r)
|
| 54 |
+
return out
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _append_rung_checkpoint(path: Path, result: RungResult) -> None:
|
| 58 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 59 |
+
with open(path, "a", encoding="utf-8") as f:
|
| 60 |
+
f.write(json.dumps(asdict(result)) + "\n")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _load_model_and_tokenizer(hf_model: str):
|
| 64 |
+
import torch
|
| 65 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 66 |
+
|
| 67 |
+
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
| 68 |
+
tok = AutoTokenizer.from_pretrained(hf_model)
|
| 69 |
+
if tok.pad_token_id is None:
|
| 70 |
+
tok.pad_token_id = tok.eos_token_id
|
| 71 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 72 |
+
hf_model, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None
|
| 73 |
+
)
|
| 74 |
+
model.eval()
|
| 75 |
+
return tok, model
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _free_model(model) -> None:
|
| 79 |
+
import gc
|
| 80 |
+
del model
|
| 81 |
+
gc.collect()
|
| 82 |
+
try:
|
| 83 |
+
import torch
|
| 84 |
+
if torch.cuda.is_available():
|
| 85 |
+
torch.cuda.empty_cache()
|
| 86 |
+
except Exception:
|
| 87 |
+
pass
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _format_prompt(tok, query: str) -> str:
|
| 91 |
+
if hasattr(tok, "apply_chat_template") and tok.chat_template:
|
| 92 |
+
messages = [{"role": "user", "content": query}]
|
| 93 |
+
return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 94 |
+
return f"### Instruction:\n{query}\n\n### Response:\n"
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _generate(tok, model, prompt: str, max_new_tokens: int, temperature: float) -> tuple[str, int]:
|
| 98 |
+
import torch
|
| 99 |
+
inputs = tok(prompt, return_tensors="pt").to(model.device)
|
| 100 |
+
do_sample = temperature > 0
|
| 101 |
+
with torch.no_grad():
|
| 102 |
+
out = model.generate(
|
| 103 |
+
**inputs,
|
| 104 |
+
max_new_tokens=max_new_tokens,
|
| 105 |
+
do_sample=do_sample,
|
| 106 |
+
temperature=temperature if do_sample else 1.0,
|
| 107 |
+
pad_token_id=tok.pad_token_id,
|
| 108 |
+
)
|
| 109 |
+
new_tokens = out[0][inputs["input_ids"].shape[1]:]
|
| 110 |
+
response = tok.decode(new_tokens, skip_special_tokens=True)
|
| 111 |
+
return response.strip(), int(new_tokens.shape[0])
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def run_rung(
|
| 115 |
+
rung,
|
| 116 |
+
queries: list[RawQuery],
|
| 117 |
+
k_samples: int,
|
| 118 |
+
max_new_tokens: int,
|
| 119 |
+
temperature_first: float,
|
| 120 |
+
temperature_resample: float,
|
| 121 |
+
checkpoint_path: Path,
|
| 122 |
+
progress: bool = True,
|
| 123 |
+
) -> list[RungResult]:
|
| 124 |
+
existing = _read_rung_checkpoint(checkpoint_path)
|
| 125 |
+
pending = [q for q in queries if len(existing.get(q.id, [])) < k_samples]
|
| 126 |
+
results: list[RungResult] = [r for rs in existing.values() for r in rs]
|
| 127 |
+
|
| 128 |
+
if not pending:
|
| 129 |
+
return results
|
| 130 |
+
|
| 131 |
+
tok, model = _load_model_and_tokenizer(rung.hf_model)
|
| 132 |
+
try:
|
| 133 |
+
for i, q in enumerate(pending):
|
| 134 |
+
done = len(existing.get(q.id, []))
|
| 135 |
+
for s in range(done, k_samples):
|
| 136 |
+
temp = temperature_first if s == 0 else temperature_resample
|
| 137 |
+
prompt = _format_prompt(tok, q.text)
|
| 138 |
+
start = time.time()
|
| 139 |
+
response, n_tokens = _generate(tok, model, prompt, max_new_tokens, temp)
|
| 140 |
+
score = grade(response, q.grader_metadata, max_new_tokens=max_new_tokens)
|
| 141 |
+
rr = RungResult(
|
| 142 |
+
rung_id=rung.id,
|
| 143 |
+
params_b=rung.params_b,
|
| 144 |
+
query_id=q.id,
|
| 145 |
+
sample_index=s,
|
| 146 |
+
response=response,
|
| 147 |
+
score=score,
|
| 148 |
+
response_tokens=n_tokens,
|
| 149 |
+
)
|
| 150 |
+
_append_rung_checkpoint(checkpoint_path, rr)
|
| 151 |
+
results.append(rr)
|
| 152 |
+
if progress:
|
| 153 |
+
print(
|
| 154 |
+
f" [{rung.id}] {i+1}/{len(pending)} sample={s} "
|
| 155 |
+
f"score={score:.2f} tok={n_tokens} t={time.time()-start:.1f}s"
|
| 156 |
+
)
|
| 157 |
+
finally:
|
| 158 |
+
_free_model(model)
|
| 159 |
+
|
| 160 |
+
return results
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def derive_difficulty(
|
| 164 |
+
per_rung: dict[str, list[float]],
|
| 165 |
+
rung_params_b: dict[str, float],
|
| 166 |
+
pass_threshold: float,
|
| 167 |
+
) -> float:
|
| 168 |
+
"""Continuous min_capable_log_params from per-rung mean scores.
|
| 169 |
+
|
| 170 |
+
Logic:
|
| 171 |
+
- sort rungs by parameter count
|
| 172 |
+
- for each rung, mean score across samples is the "rung pass rate"
|
| 173 |
+
- the smallest rung whose pass rate >= threshold defines the floor
|
| 174 |
+
- linear interpolation in log(params) space between the failing and passing rung
|
| 175 |
+
- if no rung passes, return log(largest_rung_params * 2) as out-of-pool
|
| 176 |
+
- if smallest rung already passes, return log(smallest_rung_params)
|
| 177 |
+
"""
|
| 178 |
+
sorted_rungs = sorted(rung_params_b.items(), key=lambda kv: kv[1])
|
| 179 |
+
if not sorted_rungs:
|
| 180 |
+
return math.log(8e9)
|
| 181 |
+
|
| 182 |
+
means: list[tuple[str, float, float]] = []
|
| 183 |
+
for rung_id, params_b in sorted_rungs:
|
| 184 |
+
scores = per_rung.get(rung_id, [])
|
| 185 |
+
if not scores:
|
| 186 |
+
continue
|
| 187 |
+
means.append((rung_id, params_b, sum(scores) / len(scores)))
|
| 188 |
+
|
| 189 |
+
if not means:
|
| 190 |
+
return math.log(sorted_rungs[-1][1] * 1e9 * 2)
|
| 191 |
+
|
| 192 |
+
if means[0][2] >= pass_threshold:
|
| 193 |
+
return math.log(means[0][1] * 1e9)
|
| 194 |
+
|
| 195 |
+
for i in range(1, len(means)):
|
| 196 |
+
prev_id, prev_params, prev_score = means[i - 1]
|
| 197 |
+
cur_id, cur_params, cur_score = means[i]
|
| 198 |
+
if cur_score >= pass_threshold:
|
| 199 |
+
denom = max(cur_score - prev_score, 1e-6)
|
| 200 |
+
t = max(0.0, min(1.0, (pass_threshold - prev_score) / denom))
|
| 201 |
+
log_lo = math.log(prev_params * 1e9)
|
| 202 |
+
log_hi = math.log(cur_params * 1e9)
|
| 203 |
+
return log_lo + t * (log_hi - log_lo)
|
| 204 |
+
|
| 205 |
+
return math.log(means[-1][1] * 1e9 * 2)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def derive_length_bucket(response_token_counts: list[int]) -> str:
|
| 209 |
+
if not response_token_counts:
|
| 210 |
+
return "medium"
|
| 211 |
+
avg = sum(response_token_counts) / len(response_token_counts)
|
| 212 |
+
if avg < 100:
|
| 213 |
+
return "short"
|
| 214 |
+
if avg < 400:
|
| 215 |
+
return "medium"
|
| 216 |
+
return "long"
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def run_cascade(
|
| 220 |
+
config,
|
| 221 |
+
raw_manifest_path: str | Path,
|
| 222 |
+
capability_labels_path: str | Path,
|
| 223 |
+
train_path: str | Path,
|
| 224 |
+
test_path: str | Path,
|
| 225 |
+
pass_threshold: float = 0.7,
|
| 226 |
+
) -> None:
|
| 227 |
+
from greenrouting.data.builder import (
|
| 228 |
+
read_capability_labels,
|
| 229 |
+
write_labeled_dataset,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
queries = _read_raw_manifest(raw_manifest_path)
|
| 233 |
+
cap_labels = read_capability_labels(capability_labels_path)
|
| 234 |
+
|
| 235 |
+
per_rung_results: dict[str, dict[str, list[RungResult]]] = {}
|
| 236 |
+
out_dir = Path(config.output_dir)
|
| 237 |
+
|
| 238 |
+
for rung in config.cascade.rungs:
|
| 239 |
+
if not rung.runs_locally:
|
| 240 |
+
print(f"[skip] {rung.id} marked as not runs_locally; configure remote backend.")
|
| 241 |
+
continue
|
| 242 |
+
ckpt = out_dir / f"cascade_{config.profile_name}_{rung.id}.jsonl"
|
| 243 |
+
results = run_rung(
|
| 244 |
+
rung,
|
| 245 |
+
queries,
|
| 246 |
+
k_samples=config.cascade.k_samples,
|
| 247 |
+
max_new_tokens=config.cascade.max_new_tokens,
|
| 248 |
+
temperature_first=config.cascade.temperature_first,
|
| 249 |
+
temperature_resample=config.cascade.temperature_resample,
|
| 250 |
+
checkpoint_path=ckpt,
|
| 251 |
+
)
|
| 252 |
+
by_query: dict[str, list[RungResult]] = {}
|
| 253 |
+
for r in results:
|
| 254 |
+
by_query.setdefault(r.query_id, []).append(r)
|
| 255 |
+
per_rung_results[rung.id] = by_query
|
| 256 |
+
print(f"[done] rung {rung.id}: {sum(len(v) for v in by_query.values())} samples")
|
| 257 |
+
|
| 258 |
+
rung_params: dict[str, float] = {r.id: r.params_b for r in config.cascade.rungs}
|
| 259 |
+
|
| 260 |
+
labeled: list[LabeledQuery] = []
|
| 261 |
+
for q in queries:
|
| 262 |
+
per_rung_scores: dict[str, list[float]] = {}
|
| 263 |
+
token_counts: list[int] = []
|
| 264 |
+
for rung_id, by_query in per_rung_results.items():
|
| 265 |
+
for rr in by_query.get(q.id, []):
|
| 266 |
+
per_rung_scores.setdefault(rung_id, []).append(rr.score)
|
| 267 |
+
token_counts.append(rr.response_tokens)
|
| 268 |
+
if not per_rung_scores:
|
| 269 |
+
continue
|
| 270 |
+
difficulty = derive_difficulty(per_rung_scores, rung_params, pass_threshold)
|
| 271 |
+
length_bucket = derive_length_bucket(token_counts)
|
| 272 |
+
caps = cap_labels.get(q.id, {})
|
| 273 |
+
labeled.append(LabeledQuery(
|
| 274 |
+
raw=q,
|
| 275 |
+
capabilities=caps,
|
| 276 |
+
difficulty_log_params=difficulty,
|
| 277 |
+
length_bucket=length_bucket,
|
| 278 |
+
cascade_results={
|
| 279 |
+
"per_rung_mean_scores": {
|
| 280 |
+
k: sum(v) / len(v) for k, v in per_rung_scores.items()
|
| 281 |
+
},
|
| 282 |
+
},
|
| 283 |
+
))
|
| 284 |
+
|
| 285 |
+
write_labeled_dataset(
|
| 286 |
+
train_path=train_path,
|
| 287 |
+
test_path=test_path,
|
| 288 |
+
rows=labeled,
|
| 289 |
+
test_split=config.test_split,
|
| 290 |
+
seed=config.seed,
|
| 291 |
+
)
|
| 292 |
+
print(f"[done] wrote {len(labeled)} labeled rows -> {train_path}, {test_path}")
|
greenrouting/data/graders.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Graders that score whether a model's response solved a query.
|
| 2 |
+
|
| 3 |
+
Outputs a bounded float in [0.0, 1.0]. Where a deterministic grader exists
|
| 4 |
+
(numeric extraction, multi-choice match), the score is binary. For free-form
|
| 5 |
+
responses we use a deterministic proxy (parseability, length) clamped to a
|
| 6 |
+
sensible range.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import ast
|
| 12 |
+
import re
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
NUMBER_PATTERN = re.compile(r"-?\d+(?:[.,]\d+)?")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _extract_last_number(text: str) -> Optional[float]:
|
| 19 |
+
matches = NUMBER_PATTERN.findall(text or "")
|
| 20 |
+
if not matches:
|
| 21 |
+
return None
|
| 22 |
+
last = matches[-1].replace(",", "")
|
| 23 |
+
try:
|
| 24 |
+
return float(last)
|
| 25 |
+
except ValueError:
|
| 26 |
+
return None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _extract_first_letter(text: str, valid_letters: str = "ABCDEFGH") -> Optional[str]:
|
| 30 |
+
if not text:
|
| 31 |
+
return None
|
| 32 |
+
cleaned = text.strip()
|
| 33 |
+
m = re.search(
|
| 34 |
+
r"(?:answer|the answer is|final answer)\s*[:\-]?\s*\(?([" + valid_letters + r"])\)?",
|
| 35 |
+
cleaned,
|
| 36 |
+
re.IGNORECASE,
|
| 37 |
+
)
|
| 38 |
+
if m:
|
| 39 |
+
return m.group(1).upper()
|
| 40 |
+
m = re.match(r"\s*\(([" + valid_letters + r"])\)", cleaned)
|
| 41 |
+
if m:
|
| 42 |
+
return m.group(1)
|
| 43 |
+
m = re.match(r"\s*([" + valid_letters + r"])[\s\.\):,]+", cleaned)
|
| 44 |
+
if m:
|
| 45 |
+
return m.group(1)
|
| 46 |
+
m = re.match(r"\s*([" + valid_letters + r"])\s*$", cleaned)
|
| 47 |
+
if m:
|
| 48 |
+
return m.group(1)
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def grade_numeric(response: str, gold: str) -> float:
|
| 53 |
+
g = _extract_last_number(gold)
|
| 54 |
+
r = _extract_last_number(response)
|
| 55 |
+
if g is None or r is None:
|
| 56 |
+
return 0.0
|
| 57 |
+
return 1.0 if abs(g - r) < 1e-6 else 0.0
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def grade_multichoice(response: str, gold_letter: str) -> float:
|
| 61 |
+
if not gold_letter:
|
| 62 |
+
return 0.0
|
| 63 |
+
pred = _extract_first_letter(response or "")
|
| 64 |
+
if pred is None:
|
| 65 |
+
return 0.0
|
| 66 |
+
return 1.0 if pred.upper() == gold_letter.strip().upper() else 0.0
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def grade_string_match(response: str, gold: str) -> float:
|
| 70 |
+
if not gold:
|
| 71 |
+
return 0.0
|
| 72 |
+
norm_resp = re.sub(r"\s+", " ", (response or "")).strip().lower()
|
| 73 |
+
norm_gold = re.sub(r"\s+", " ", gold).strip().lower()
|
| 74 |
+
if not norm_gold:
|
| 75 |
+
return 0.0
|
| 76 |
+
return 1.0 if norm_gold in norm_resp else 0.0
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def grade_code_proxy(response: str, entry_point: Optional[str] = None) -> float:
|
| 80 |
+
"""Cheap proxy for code correctness: code parses + (optionally) defines the
|
| 81 |
+
expected entry-point function. No execution; safe to run on untrusted output."""
|
| 82 |
+
if not response:
|
| 83 |
+
return 0.0
|
| 84 |
+
code = _extract_code_block(response)
|
| 85 |
+
try:
|
| 86 |
+
tree = ast.parse(code)
|
| 87 |
+
except SyntaxError:
|
| 88 |
+
return 0.0
|
| 89 |
+
if entry_point:
|
| 90 |
+
for node in ast.walk(tree):
|
| 91 |
+
if isinstance(node, ast.FunctionDef) and node.name == entry_point:
|
| 92 |
+
return 1.0
|
| 93 |
+
return 0.4
|
| 94 |
+
return 0.7
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _extract_code_block(text: str) -> str:
|
| 98 |
+
fence = re.search(r"```(?:python|py)?\s*\n(.*?)```", text, re.DOTALL | re.IGNORECASE)
|
| 99 |
+
if fence:
|
| 100 |
+
return fence.group(1)
|
| 101 |
+
if "def " in text or "class " in text or "import " in text:
|
| 102 |
+
return text
|
| 103 |
+
return text
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def grade_response_quality(response: str, max_new_tokens: int) -> float:
|
| 107 |
+
"""For open-ended queries with no gold answer. Length-and-diversity heuristic.
|
| 108 |
+
|
| 109 |
+
Returns 0.0 for empty/garbage; saturates toward 1.0 as length grows up to
|
| 110 |
+
`max_new_tokens` and vocabulary diversity is healthy.
|
| 111 |
+
"""
|
| 112 |
+
if not response:
|
| 113 |
+
return 0.0
|
| 114 |
+
tokens = response.split()
|
| 115 |
+
if len(tokens) < 3:
|
| 116 |
+
return 0.05
|
| 117 |
+
unique_ratio = len(set(t.lower() for t in tokens)) / len(tokens)
|
| 118 |
+
length_score = min(1.0, len(tokens) / max(max_new_tokens * 0.5, 1))
|
| 119 |
+
return max(0.0, min(1.0, 0.6 * length_score + 0.4 * unique_ratio))
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def grade_ifeval_proxy(response: str, instruction_id_list: list[str]) -> float:
|
| 123 |
+
"""Lightweight stand-in for IFEval's strict constraint grader. Counts how many
|
| 124 |
+
structural cues from instruction IDs are present in the response."""
|
| 125 |
+
if not response:
|
| 126 |
+
return 0.0
|
| 127 |
+
if not instruction_id_list:
|
| 128 |
+
return 0.5
|
| 129 |
+
hits = 0
|
| 130 |
+
for iid in instruction_id_list:
|
| 131 |
+
if "list" in iid and re.search(r"^\s*[-*•]|^\s*\d+\.", response, re.MULTILINE):
|
| 132 |
+
hits += 1
|
| 133 |
+
elif "json" in iid and (response.strip().startswith("{") or response.strip().startswith("[")):
|
| 134 |
+
hits += 1
|
| 135 |
+
elif "letter" in iid:
|
| 136 |
+
hits += 1
|
| 137 |
+
elif "word_count" in iid:
|
| 138 |
+
hits += 1
|
| 139 |
+
elif "uppercase" in iid and any(c.isupper() for c in response):
|
| 140 |
+
hits += 1
|
| 141 |
+
else:
|
| 142 |
+
hits += 0.5
|
| 143 |
+
return min(1.0, hits / max(len(instruction_id_list), 1))
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def grade(response: str, grader_metadata: dict, max_new_tokens: int = 256) -> float:
|
| 147 |
+
g = grader_metadata.get("grader")
|
| 148 |
+
if g == "exact_numeric":
|
| 149 |
+
return grade_numeric(response, grader_metadata.get("gold_final", ""))
|
| 150 |
+
if g == "multichoice":
|
| 151 |
+
return grade_multichoice(response, grader_metadata.get("gold_letter", ""))
|
| 152 |
+
if g == "string_match":
|
| 153 |
+
return grade_string_match(response, grader_metadata.get("gold", ""))
|
| 154 |
+
if g == "code_exec":
|
| 155 |
+
return grade_code_proxy(response, grader_metadata.get("entry_point"))
|
| 156 |
+
if g == "ifeval_constraints":
|
| 157 |
+
return grade_ifeval_proxy(response, grader_metadata.get("instruction_id_list", []))
|
| 158 |
+
return grade_response_quality(response, max_new_tokens)
|
greenrouting/data/schema.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Schema for queries and labels flowing through the data pipeline."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field, asdict
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
from greenrouting.routing.registry import CAPABILITY_KEYS
|
| 9 |
+
|
| 10 |
+
LENGTH_BUCKETS: tuple[str, str, str] = ("short", "medium", "long")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class RawQuery:
|
| 15 |
+
id: str
|
| 16 |
+
text: str
|
| 17 |
+
source: str
|
| 18 |
+
source_category: str
|
| 19 |
+
has_grader: bool = False
|
| 20 |
+
grader_metadata: dict = field(default_factory=dict)
|
| 21 |
+
|
| 22 |
+
def to_dict(self) -> dict:
|
| 23 |
+
return asdict(self)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class CapabilityVotes:
|
| 28 |
+
source_prior: dict[str, float] = field(default_factory=dict)
|
| 29 |
+
heuristic: Optional[dict[str, float]] = None
|
| 30 |
+
gpt: Optional[dict[str, float]] = None
|
| 31 |
+
claude: Optional[dict[str, float]] = None
|
| 32 |
+
gemini: Optional[dict[str, float]] = None
|
| 33 |
+
|
| 34 |
+
def vote_count(self) -> int:
|
| 35 |
+
return sum(
|
| 36 |
+
1 for v in (self.heuristic, self.gpt, self.claude, self.gemini) if v is not None
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class CapabilityLabel:
|
| 42 |
+
query_id: str
|
| 43 |
+
capabilities: dict[str, float]
|
| 44 |
+
votes: CapabilityVotes
|
| 45 |
+
aggregation_method: str
|
| 46 |
+
|
| 47 |
+
def to_record(self) -> dict:
|
| 48 |
+
rec = {"query_id": self.query_id, "aggregation_method": self.aggregation_method}
|
| 49 |
+
for k in CAPABILITY_KEYS:
|
| 50 |
+
rec[f"cap_{k}"] = float(self.capabilities.get(k, 0.0))
|
| 51 |
+
for vendor in ("source_prior", "heuristic", "gpt", "claude", "gemini"):
|
| 52 |
+
v = getattr(self.votes, vendor)
|
| 53 |
+
if v is None:
|
| 54 |
+
continue
|
| 55 |
+
for k in CAPABILITY_KEYS:
|
| 56 |
+
rec[f"vote_{vendor}_{k}"] = float(v.get(k, 0.0))
|
| 57 |
+
return rec
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class LabeledQuery:
|
| 62 |
+
raw: RawQuery
|
| 63 |
+
capabilities: dict[str, float]
|
| 64 |
+
difficulty_log_params: Optional[float]
|
| 65 |
+
length_bucket: Optional[str]
|
| 66 |
+
cascade_results: dict = field(default_factory=dict)
|
| 67 |
+
|
| 68 |
+
def to_record(self) -> dict:
|
| 69 |
+
rec = {
|
| 70 |
+
"id": self.raw.id,
|
| 71 |
+
"text": self.raw.text,
|
| 72 |
+
"source": self.raw.source,
|
| 73 |
+
"source_category": self.raw.source_category,
|
| 74 |
+
"has_grader": self.raw.has_grader,
|
| 75 |
+
"difficulty_log_params": self.difficulty_log_params,
|
| 76 |
+
"length_bucket": self.length_bucket,
|
| 77 |
+
}
|
| 78 |
+
for k in CAPABILITY_KEYS:
|
| 79 |
+
rec[f"cap_{k}"] = float(self.capabilities.get(k, 0.0))
|
| 80 |
+
return rec
|
greenrouting/data/seed_dataset.py
ADDED
|
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Curated training set with multi-label capability assignments and difficulty.
|
| 2 |
+
|
| 3 |
+
Each entry is hand-authored. Phrasings deliberately vary across registers
|
| 4 |
+
(commands, questions, conversational asks, fragments) so the classifier learns
|
| 5 |
+
semantic patterns rather than surface keyword rules.
|
| 6 |
+
|
| 7 |
+
Schema for each entry: (text, primary_category, capabilities, difficulty_b, length).
|
| 8 |
+
- text: the raw query string
|
| 9 |
+
- primary_category: dominant capability bucket (used as `source_category`)
|
| 10 |
+
- capabilities: list of buckets that apply (multi-label)
|
| 11 |
+
- difficulty_b: parameter count (in billions) of the smallest model that would
|
| 12 |
+
plausibly handle this well; used to derive `difficulty_log_params`
|
| 13 |
+
- length: expected answer length bucket (short/medium/long)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass(frozen=True)
|
| 23 |
+
class SeedEntry:
|
| 24 |
+
text: str
|
| 25 |
+
primary_category: str
|
| 26 |
+
capabilities: tuple[str, ...]
|
| 27 |
+
difficulty_b: float
|
| 28 |
+
length: str
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _e(text, primary, caps, diff_b, length) -> SeedEntry:
|
| 32 |
+
return SeedEntry(text=text, primary_category=primary, capabilities=tuple(caps), difficulty_b=diff_b, length=length)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
_SIMPLE_CHAT: list[SeedEntry] = [
|
| 36 |
+
_e("hi", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 37 |
+
_e("hello", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 38 |
+
_e("hey there", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 39 |
+
_e("good morning", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 40 |
+
_e("how's it going", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 41 |
+
_e("thanks!", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 42 |
+
_e("thank you so much", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 43 |
+
_e("appreciate it", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 44 |
+
_e("ok cool", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 45 |
+
_e("got it", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 46 |
+
_e("sounds good", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 47 |
+
_e("nice", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 48 |
+
_e("makes sense", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 49 |
+
_e("yep", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 50 |
+
_e("yes please", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 51 |
+
_e("nope", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 52 |
+
_e("not really", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 53 |
+
_e("can you help me?", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 54 |
+
_e("are you there?", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 55 |
+
_e("what can you do?", "simple_chat", ["simple_chat", "instruction"], 1.0, "short"),
|
| 56 |
+
_e("how does this work", "simple_chat", ["simple_chat", "instruction"], 1.0, "short"),
|
| 57 |
+
_e("who are you?", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 58 |
+
_e("are you human?", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 59 |
+
_e("good night", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 60 |
+
_e("see you later", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 61 |
+
_e("bye!", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 62 |
+
_e("lol", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 63 |
+
_e("haha that's funny", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 64 |
+
_e("oh interesting", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 65 |
+
_e("hmm okay", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 66 |
+
_e("can you elaborate", "simple_chat", ["simple_chat", "instruction"], 1.0, "short"),
|
| 67 |
+
_e("tell me more", "simple_chat", ["simple_chat", "instruction"], 1.0, "short"),
|
| 68 |
+
_e("what do you think?", "simple_chat", ["simple_chat", "reasoning"], 1.0, "short"),
|
| 69 |
+
_e("any thoughts?", "simple_chat", ["simple_chat", "reasoning"], 1.0, "short"),
|
| 70 |
+
_e("got a sec?", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 71 |
+
_e("quick question", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 72 |
+
_e("just checking in", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 73 |
+
_e("how was your day", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 74 |
+
_e("what's up", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 75 |
+
_e("you good?", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 76 |
+
_e("yeah that works", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 77 |
+
_e("alright then", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 78 |
+
_e("hold on a sec", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 79 |
+
_e("never mind", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 80 |
+
_e("scratch that", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 81 |
+
_e("oops", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 82 |
+
_e("my bad", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 83 |
+
_e("you're awesome", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 84 |
+
_e("this is great", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 85 |
+
_e("perfect, thanks", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 86 |
+
_e("can we try again", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 87 |
+
_e("one more time", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 88 |
+
_e("again please", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 89 |
+
_e("yo", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 90 |
+
_e("sup", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 91 |
+
_e("howdy", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 92 |
+
_e("greetings", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 93 |
+
_e("test", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 94 |
+
_e("just testing", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 95 |
+
_e("can you hear me", "simple_chat", ["simple_chat"], 0.5, "short"),
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
_INSTRUCTION: list[SeedEntry] = [
|
| 99 |
+
_e("write a 3-bullet summary of the agenda below: 1) opening remarks 2) Q3 numbers 3) staffing", "instruction", ["instruction"], 3.0, "short"),
|
| 100 |
+
_e("turn this into a numbered list", "instruction", ["instruction"], 1.0, "short"),
|
| 101 |
+
_e("format the following as a markdown table", "instruction", ["instruction"], 2.0, "medium"),
|
| 102 |
+
_e("rewrite this in a more formal tone", "instruction", ["instruction"], 3.0, "medium"),
|
| 103 |
+
_e("make this paragraph shorter without losing any meaning", "instruction", ["instruction"], 3.0, "medium"),
|
| 104 |
+
_e("expand this outline into a full paragraph", "instruction", ["instruction"], 3.0, "medium"),
|
| 105 |
+
_e("draft a polite decline to a meeting invite", "instruction", ["instruction", "creative"], 3.0, "short"),
|
| 106 |
+
_e("write a short subject line for this email", "instruction", ["instruction"], 1.0, "short"),
|
| 107 |
+
_e("give me three name ideas for a coffee shop near a park", "instruction", ["instruction", "creative"], 3.0, "short"),
|
| 108 |
+
_e("create a checklist for moving apartments", "instruction", ["instruction"], 3.0, "medium"),
|
| 109 |
+
_e("outline a 5-day study plan for the GRE quantitative section", "instruction", ["instruction"], 7.0, "medium"),
|
| 110 |
+
_e("summarize this in two sentences please", "instruction", ["instruction"], 3.0, "short"),
|
| 111 |
+
_e("convert these notes into a clean meeting recap", "instruction", ["instruction"], 7.0, "medium"),
|
| 112 |
+
_e("turn the following log lines into a single sentence describing what happened", "instruction", ["instruction", "reasoning"], 7.0, "short"),
|
| 113 |
+
_e("group these errors by likely root cause", "instruction", ["instruction", "reasoning"], 8.0, "medium"),
|
| 114 |
+
_e("clean up the grammar in the paragraph below", "instruction", ["instruction"], 1.0, "medium"),
|
| 115 |
+
_e("rephrase this so a 12-year-old could follow", "instruction", ["instruction"], 3.0, "medium"),
|
| 116 |
+
_e("make a one-paragraph executive summary", "instruction", ["instruction"], 7.0, "short"),
|
| 117 |
+
_e("draft an out-of-office reply for next Friday", "instruction", ["instruction", "creative"], 1.0, "short"),
|
| 118 |
+
_e("create a 7-day workout split focused on legs and shoulders", "instruction", ["instruction"], 7.0, "medium"),
|
| 119 |
+
_e("write a packing list for a 4-day winter trip to Reykjavik", "instruction", ["instruction"], 3.0, "medium"),
|
| 120 |
+
_e("turn the bullet points below into a smooth paragraph", "instruction", ["instruction"], 3.0, "medium"),
|
| 121 |
+
_e("transform this dry product description into a punchy tagline", "instruction", ["instruction", "creative"], 7.0, "short"),
|
| 122 |
+
_e("draft a thank-you note to my mentor", "instruction", ["instruction", "creative"], 3.0, "short"),
|
| 123 |
+
_e("write meeting minutes from this transcript snippet", "instruction", ["instruction"], 7.0, "medium"),
|
| 124 |
+
_e("split this monolithic to-do into morning vs afternoon", "instruction", ["instruction"], 1.0, "medium"),
|
| 125 |
+
_e("rewrite this slack message so it doesn't sound passive aggressive", "instruction", ["instruction"], 3.0, "short"),
|
| 126 |
+
_e("draft a customer apology email after a 3-hour outage", "instruction", ["instruction", "creative"], 7.0, "medium"),
|
| 127 |
+
_e("turn this transcript into 5 talking points", "instruction", ["instruction"], 3.0, "medium"),
|
| 128 |
+
_e("convert this case study into a single tweet", "instruction", ["instruction", "creative"], 3.0, "short"),
|
| 129 |
+
_e("write release notes for version 2.3 covering the changelog below", "instruction", ["instruction"], 7.0, "medium"),
|
| 130 |
+
_e("make a side-by-side comparison table of the two job offers", "instruction", ["instruction", "reasoning"], 7.0, "medium"),
|
| 131 |
+
_e("compose a brief biography paragraph for a conference badge", "instruction", ["instruction", "creative"], 3.0, "short"),
|
| 132 |
+
_e("write a 4-line elevator pitch for an indie video game", "instruction", ["instruction", "creative"], 3.0, "short"),
|
| 133 |
+
_e("expand the following acronyms inline: API, SLA, P95, RAG", "instruction", ["instruction", "knowledge"], 3.0, "short"),
|
| 134 |
+
_e("create a polite reminder to a colleague who hasn't reviewed my PR", "instruction", ["instruction"], 1.0, "short"),
|
| 135 |
+
_e("draft an FAQ entry explaining how refunds work for SaaS subscriptions", "instruction", ["instruction"], 7.0, "medium"),
|
| 136 |
+
_e("turn this jira ticket description into a one-line standup update", "instruction", ["instruction"], 1.0, "short"),
|
| 137 |
+
_e("write a job description for a junior data analyst", "instruction", ["instruction"], 7.0, "medium"),
|
| 138 |
+
_e("convert the recipe below from imperial to metric", "instruction", ["instruction"], 1.0, "short"),
|
| 139 |
+
_e("organize the chaos in this email thread into a clean timeline", "instruction", ["instruction", "reasoning"], 7.0, "medium"),
|
| 140 |
+
_e("extract every dollar amount from the contract clause and list them", "instruction", ["instruction"], 3.0, "short"),
|
| 141 |
+
_e("write three follow-up email subject lines, increasing in urgency", "instruction", ["instruction", "creative"], 3.0, "short"),
|
| 142 |
+
_e("rewrite this resume bullet to emphasize impact", "instruction", ["instruction"], 3.0, "short"),
|
| 143 |
+
_e("draft a 1-paragraph linkedin post announcing a job change", "instruction", ["instruction", "creative"], 7.0, "short"),
|
| 144 |
+
_e("turn the bullet recap below into a polished retro doc section", "instruction", ["instruction"], 7.0, "medium"),
|
| 145 |
+
_e("compose a clear bug report from this user complaint", "instruction", ["instruction", "reasoning"], 7.0, "medium"),
|
| 146 |
+
_e("create a 5-question pre-interview survey for prospective tenants", "instruction", ["instruction"], 3.0, "medium"),
|
| 147 |
+
_e("rewrite the warning copy below to be friendlier without losing meaning", "instruction", ["instruction"], 3.0, "short"),
|
| 148 |
+
_e("write a one-line apology, a longer apology, and a formal apology", "instruction", ["instruction", "creative"], 3.0, "medium"),
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
_KNOWLEDGE: list[SeedEntry] = [
|
| 152 |
+
_e("what's the capital of Mongolia", "knowledge", ["knowledge"], 1.0, "short"),
|
| 153 |
+
_e("when was the printing press invented", "knowledge", ["knowledge"], 1.0, "short"),
|
| 154 |
+
_e("what's the population of Lagos roughly", "knowledge", ["knowledge"], 1.0, "short"),
|
| 155 |
+
_e("who painted Guernica", "knowledge", ["knowledge"], 1.0, "short"),
|
| 156 |
+
_e("how long was the Hundred Years War actually", "knowledge", ["knowledge"], 3.0, "short"),
|
| 157 |
+
_e("what protein does insulin regulate", "knowledge", ["knowledge"], 7.0, "short"),
|
| 158 |
+
_e("what's the difference between a virus and a bacterium in plain terms", "knowledge", ["knowledge", "instruction"], 7.0, "medium"),
|
| 159 |
+
_e("name the three branches of the US government", "knowledge", ["knowledge"], 1.0, "short"),
|
| 160 |
+
_e("what year did World War I end", "knowledge", ["knowledge"], 1.0, "short"),
|
| 161 |
+
_e("who wrote One Hundred Years of Solitude", "knowledge", ["knowledge"], 1.0, "short"),
|
| 162 |
+
_e("what's photosynthesis in one sentence", "knowledge", ["knowledge", "instruction"], 1.0, "short"),
|
| 163 |
+
_e("how many bones are in the human body", "knowledge", ["knowledge"], 1.0, "short"),
|
| 164 |
+
_e("what does GDP stand for and what does it measure", "knowledge", ["knowledge"], 3.0, "short"),
|
| 165 |
+
_e("which planet has the most moons", "knowledge", ["knowledge"], 1.0, "short"),
|
| 166 |
+
_e("what's the chemical symbol for gold", "knowledge", ["knowledge"], 0.5, "short"),
|
| 167 |
+
_e("define entropy in thermodynamics", "knowledge", ["knowledge"], 7.0, "medium"),
|
| 168 |
+
_e("explain the citric acid cycle briefly", "knowledge", ["knowledge", "instruction"], 8.0, "medium"),
|
| 169 |
+
_e("what is OAuth 2.0 and what problem does it solve", "knowledge", ["knowledge", "instruction"], 7.0, "medium"),
|
| 170 |
+
_e("describe the Marshall Plan in two sentences", "knowledge", ["knowledge", "instruction"], 7.0, "short"),
|
| 171 |
+
_e("what's the difference between TCP and UDP", "knowledge", ["knowledge", "reasoning"], 7.0, "medium"),
|
| 172 |
+
_e("who founded the Stoic philosophical school", "knowledge", ["knowledge"], 3.0, "short"),
|
| 173 |
+
_e("what did Rosalind Franklin contribute to DNA research", "knowledge", ["knowledge"], 7.0, "medium"),
|
| 174 |
+
_e("explain RAID 5 vs RAID 10 storage", "knowledge", ["knowledge", "reasoning"], 8.0, "medium"),
|
| 175 |
+
_e("what is monetary policy in plain language", "knowledge", ["knowledge", "instruction"], 7.0, "medium"),
|
| 176 |
+
_e("how does a vaccine actually trigger immunity", "knowledge", ["knowledge", "reasoning"], 8.0, "medium"),
|
| 177 |
+
_e("what was the Bretton Woods system", "knowledge", ["knowledge"], 7.0, "medium"),
|
| 178 |
+
_e("who was the first woman to win a Nobel Prize", "knowledge", ["knowledge"], 1.0, "short"),
|
| 179 |
+
_e("what's the longest river in South America", "knowledge", ["knowledge"], 1.0, "short"),
|
| 180 |
+
_e("describe the role of the prefrontal cortex", "knowledge", ["knowledge", "instruction"], 8.0, "medium"),
|
| 181 |
+
_e("define inflation vs deflation simply", "knowledge", ["knowledge"], 3.0, "short"),
|
| 182 |
+
_e("what is BGP in networking", "knowledge", ["knowledge"], 7.0, "short"),
|
| 183 |
+
_e("explain what a CDN does", "knowledge", ["knowledge", "instruction"], 3.0, "short"),
|
| 184 |
+
_e("what's a Galois field", "knowledge", ["knowledge", "math"], 30.0, "medium"),
|
| 185 |
+
_e("who invented the World Wide Web", "knowledge", ["knowledge"], 1.0, "short"),
|
| 186 |
+
_e("what does the term 'eventual consistency' mean in databases", "knowledge", ["knowledge"], 8.0, "medium"),
|
| 187 |
+
_e("brief overview of the Treaty of Westphalia", "knowledge", ["knowledge", "instruction"], 8.0, "medium"),
|
| 188 |
+
_e("what is the boiling point of water in Kelvin", "knowledge", ["knowledge"], 0.5, "short"),
|
| 189 |
+
_e("what's the Coriolis effect", "knowledge", ["knowledge"], 3.0, "short"),
|
| 190 |
+
_e("explain Bayes' theorem in everyday language", "knowledge", ["knowledge", "math", "instruction"], 8.0, "medium"),
|
| 191 |
+
_e("what is the speed of light in vacuum", "knowledge", ["knowledge"], 0.5, "short"),
|
| 192 |
+
_e("when did the Berlin Wall fall", "knowledge", ["knowledge"], 1.0, "short"),
|
| 193 |
+
_e("what does HTTP/3 use under the hood", "knowledge", ["knowledge"], 8.0, "short"),
|
| 194 |
+
_e("describe the Heisenberg uncertainty principle", "knowledge", ["knowledge", "instruction"], 8.0, "medium"),
|
| 195 |
+
_e("what's the difference between a sonnet and a haiku", "knowledge", ["knowledge"], 1.0, "short"),
|
| 196 |
+
_e("name the inert noble gases", "knowledge", ["knowledge"], 1.0, "short"),
|
| 197 |
+
_e("how does a transformer model differ from an RNN at a high level", "knowledge", ["knowledge", "reasoning"], 30.0, "medium"),
|
| 198 |
+
_e("what was the Marshall McLuhan claim about media", "knowledge", ["knowledge"], 7.0, "short"),
|
| 199 |
+
_e("explain quantitative easing", "knowledge", ["knowledge", "instruction"], 8.0, "medium"),
|
| 200 |
+
_e("what's the half-life of carbon-14", "knowledge", ["knowledge"], 1.0, "short"),
|
| 201 |
+
_e("who composed The Rite of Spring", "knowledge", ["knowledge"], 1.0, "short"),
|
| 202 |
+
_e("what does the term 'antifragile' mean (Taleb)", "knowledge", ["knowledge"], 3.0, "short"),
|
| 203 |
+
_e("how does an MRI machine work", "knowledge", ["knowledge", "instruction"], 30.0, "medium"),
|
| 204 |
+
_e("define cognitive dissonance", "knowledge", ["knowledge"], 3.0, "short"),
|
| 205 |
+
_e("what's the difference between fission and fusion", "knowledge", ["knowledge"], 7.0, "short"),
|
| 206 |
+
_e("how big is the Milky Way galaxy in light years", "knowledge", ["knowledge"], 1.0, "short"),
|
| 207 |
+
_e("what is gerrymandering", "knowledge", ["knowledge"], 3.0, "short"),
|
| 208 |
+
_e("define IPO in finance", "knowledge", ["knowledge"], 1.0, "short"),
|
| 209 |
+
_e("what is dark matter, briefly", "knowledge", ["knowledge"], 7.0, "short"),
|
| 210 |
+
_e("how did the Silk Road shape trade", "knowledge", ["knowledge", "reasoning"], 8.0, "medium"),
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
_CODE: list[SeedEntry] = [
|
| 214 |
+
_e("write a python function that reverses a string", "code", ["code"], 0.5, "short"),
|
| 215 |
+
_e("show me how to read a file line by line in python", "code", ["code"], 1.0, "short"),
|
| 216 |
+
_e("what's the difference between == and is in python", "code", ["code", "knowledge"], 3.0, "short"),
|
| 217 |
+
_e("how do I sort a list of dicts by a key in python", "code", ["code"], 1.0, "short"),
|
| 218 |
+
_e("write javascript that fetches /api/users and logs the result", "code", ["code"], 1.0, "short"),
|
| 219 |
+
_e("debounce function in javascript please", "code", ["code"], 3.0, "short"),
|
| 220 |
+
_e("css to center a div both vertically and horizontally", "code", ["code"], 1.0, "short"),
|
| 221 |
+
_e("regex to validate an email address (rough)", "code", ["code"], 3.0, "short"),
|
| 222 |
+
_e("write a SQL query to find duplicates by email column", "code", ["code"], 3.0, "short"),
|
| 223 |
+
_e("explain what this stack trace means: TypeError: cannot unpack non-iterable NoneType object", "code", ["code", "reasoning"], 7.0, "medium"),
|
| 224 |
+
_e("rust function that returns the nth fibonacci number iteratively", "code", ["code"], 3.0, "short"),
|
| 225 |
+
_e("typescript type for a paginated API response", "code", ["code"], 7.0, "medium"),
|
| 226 |
+
_e("write a python decorator that times function calls", "code", ["code"], 7.0, "short"),
|
| 227 |
+
_e("dockerfile for a flask app on python 3.11", "code", ["code"], 7.0, "medium"),
|
| 228 |
+
_e("kubernetes deployment yaml for a 3-replica stateless service on port 8080", "code", ["code"], 8.0, "medium"),
|
| 229 |
+
_e("write a recursive haskell function for tree depth", "code", ["code"], 8.0, "short"),
|
| 230 |
+
_e("implement a thread-safe LRU cache in rust", "code", ["code"], 30.0, "long"),
|
| 231 |
+
_e("explain how Promise.all differs from Promise.allSettled", "code", ["code", "reasoning"], 7.0, "medium"),
|
| 232 |
+
_e("write a postgres query that windows daily revenue with a 7-day moving average", "code", ["code", "math"], 8.0, "medium"),
|
| 233 |
+
_e("debug this: my python script crashes with 'IndexError: list assignment index out of range' on line 22", "code", ["code", "reasoning"], 7.0, "medium"),
|
| 234 |
+
_e("how do I cancel an in-flight fetch request", "code", ["code"], 3.0, "short"),
|
| 235 |
+
_e("write a fastapi endpoint with pydantic validation for a user signup", "code", ["code"], 7.0, "medium"),
|
| 236 |
+
_e("python script to parse a CSV and emit JSONL", "code", ["code"], 3.0, "short"),
|
| 237 |
+
_e("show me an idempotent stripe webhook handler in node", "code", ["code", "reasoning"], 8.0, "medium"),
|
| 238 |
+
_e("convert this python list comprehension to a generator expression", "code", ["code"], 1.0, "short"),
|
| 239 |
+
_e("explain the difference between mutex and semaphore with code", "code", ["code", "knowledge"], 8.0, "long"),
|
| 240 |
+
_e("git command to undo the last commit but keep my changes staged", "code", ["code"], 1.0, "short"),
|
| 241 |
+
_e("git rebase vs merge - which should I use for a feature branch", "code", ["code", "reasoning"], 7.0, "medium"),
|
| 242 |
+
_e("how do I write an async iterator in python", "code", ["code"], 7.0, "short"),
|
| 243 |
+
_e("write a smart contract in solidity that escrows a payment", "code", ["code"], 30.0, "long"),
|
| 244 |
+
_e("rewrite this loop using map and filter", "code", ["code"], 1.0, "short"),
|
| 245 |
+
_e("c++ template for a generic queue", "code", ["code"], 8.0, "medium"),
|
| 246 |
+
_e("write go code that gracefully shuts down an HTTP server on SIGTERM", "code", ["code"], 8.0, "medium"),
|
| 247 |
+
_e("design and implement a rate limiter middleware in express", "code", ["code", "reasoning"], 8.0, "long"),
|
| 248 |
+
_e("convert callback-style fs.readFile to a promise-returning version", "code", ["code"], 1.0, "short"),
|
| 249 |
+
_e("regex to capture all URLs from a markdown document", "code", ["code"], 7.0, "short"),
|
| 250 |
+
_e("write a python class that wraps the openai API with retry+backoff", "code", ["code", "reasoning"], 8.0, "medium"),
|
| 251 |
+
_e("explain what __slots__ does in python", "code", ["code", "knowledge"], 7.0, "short"),
|
| 252 |
+
_e("how do I migrate from sqlalchemy 1.4 to 2.0", "code", ["code", "knowledge"], 8.0, "medium"),
|
| 253 |
+
_e("write a github actions workflow that runs pytest on push and PR", "code", ["code"], 7.0, "medium"),
|
| 254 |
+
_e("typescript function that pipes through a series of validators", "code", ["code"], 8.0, "medium"),
|
| 255 |
+
_e("how to debounce a search input in react with useEffect", "code", ["code"], 3.0, "short"),
|
| 256 |
+
_e("python: implement a memoize decorator that respects argument types", "code", ["code"], 7.0, "medium"),
|
| 257 |
+
_e("explain what a CRDT is and sketch a counter implementation", "code", ["code", "knowledge", "reasoning"], 30.0, "long"),
|
| 258 |
+
_e("write a custom hook that tracks window scroll position", "code", ["code"], 3.0, "short"),
|
| 259 |
+
_e("redis lua script that atomically pops from a sorted set if score < now", "code", ["code"], 8.0, "medium"),
|
| 260 |
+
_e("optimize this O(n^2) python loop", "code", ["code", "reasoning", "math"], 8.0, "medium"),
|
| 261 |
+
_e("aws cdk stack for an s3 bucket and a cloudfront distribution", "code", ["code"], 8.0, "medium"),
|
| 262 |
+
_e("explain what 'use strict' does in javascript", "code", ["code", "knowledge"], 1.0, "short"),
|
| 263 |
+
_e("how do I configure cors for a flask app", "code", ["code"], 1.0, "short"),
|
| 264 |
+
_e("write a kubernetes operator scaffold in go", "code", ["code", "reasoning"], 70.0, "long"),
|
| 265 |
+
_e("difference between a thread and a coroutine in python", "code", ["code", "knowledge"], 7.0, "medium"),
|
| 266 |
+
_e("write a scala function that computes mean and std in a single pass", "code", ["code", "math"], 8.0, "medium"),
|
| 267 |
+
_e("set up a basic CI/CD pipeline with gitlab ci for a node project", "code", ["code"], 7.0, "medium"),
|
| 268 |
+
_e("write a python pytest fixture that spins up a docker postgres", "code", ["code"], 8.0, "medium"),
|
| 269 |
+
_e("kotlin extension function to clamp a number between min and max", "code", ["code"], 1.0, "short"),
|
| 270 |
+
_e("write a CUDA kernel for vector addition with bounds checks", "code", ["code", "reasoning"], 30.0, "medium"),
|
| 271 |
+
_e("draft a Dockerfile that uses multi-stage builds for a typescript app", "code", ["code"], 8.0, "medium"),
|
| 272 |
+
_e("how to use websockets in fastapi", "code", ["code"], 7.0, "medium"),
|
| 273 |
+
_e("show me how to mock fetch in jest", "code", ["code"], 3.0, "short"),
|
| 274 |
+
_e("python function to flatten a deeply nested dict using dot notation keys", "code", ["code"], 7.0, "medium"),
|
| 275 |
+
_e("explain the producer-consumer problem with a go example", "code", ["code", "knowledge"], 8.0, "long"),
|
| 276 |
+
_e("solidity: ERC-20 token with a 5% transfer tax to a treasury address", "code", ["code"], 8.0, "long"),
|
| 277 |
+
_e("write terraform that creates a vpc with two private subnets and a NAT gateway", "code", ["code"], 8.0, "medium"),
|
| 278 |
+
_e("how does python's GIL affect multithreaded CPU-bound code", "code", ["code", "knowledge", "reasoning"], 8.0, "medium"),
|
| 279 |
+
_e("rewrite this callback hell into async/await", "code", ["code"], 3.0, "short"),
|
| 280 |
+
_e("python script to backfill missing dates in a pandas time series", "code", ["code"], 7.0, "medium"),
|
| 281 |
+
_e("explain why my recursive function hits a RecursionError on n=1500", "code", ["code", "reasoning"], 7.0, "medium"),
|
| 282 |
+
_e("c program: read stdin and print each line reversed", "code", ["code"], 3.0, "short"),
|
| 283 |
+
_e("write a chrome extension manifest v3 that injects a script into all pages", "code", ["code"], 8.0, "medium"),
|
| 284 |
+
]
|
| 285 |
+
|
| 286 |
+
_MATH: list[SeedEntry] = [
|
| 287 |
+
_e("what is 17 + 25", "math", ["math"], 0.5, "short"),
|
| 288 |
+
_e("compute 12 * 14 - 8", "math", ["math"], 0.5, "short"),
|
| 289 |
+
_e("what's 25 percent of 480", "math", ["math"], 1.0, "short"),
|
| 290 |
+
_e("solve 3x + 7 = 22", "math", ["math"], 1.0, "short"),
|
| 291 |
+
_e("what's the area of a circle with radius 5", "math", ["math"], 1.0, "short"),
|
| 292 |
+
_e("convert 75 fahrenheit to celsius", "math", ["math"], 0.5, "short"),
|
| 293 |
+
_e("what is the slope between (1,2) and (4,8)", "math", ["math"], 1.0, "short"),
|
| 294 |
+
_e("simplify 3/4 + 5/6", "math", ["math"], 1.0, "short"),
|
| 295 |
+
_e("if a train leaves at 60 mph, how long to travel 240 miles", "math", ["math"], 1.0, "short"),
|
| 296 |
+
_e("solve x^2 - 5x + 6 = 0", "math", ["math"], 3.0, "short"),
|
| 297 |
+
_e("what's the integral of x^2 dx", "math", ["math"], 3.0, "short"),
|
| 298 |
+
_e("integrate x^2 sin(x) dx using integration by parts, show the steps", "math", ["math", "instruction"], 8.0, "medium"),
|
| 299 |
+
_e("compute the determinant of [[2,1],[3,4]]", "math", ["math"], 3.0, "short"),
|
| 300 |
+
_e("what's the derivative of e^(x^2)", "math", ["math"], 3.0, "short"),
|
| 301 |
+
_e("derive the chain rule from first principles", "math", ["math", "reasoning"], 30.0, "long"),
|
| 302 |
+
_e("prove the Pythagorean theorem geometrically", "math", ["math", "reasoning"], 8.0, "medium"),
|
| 303 |
+
_e("expectation of a roll of two fair dice", "math", ["math"], 3.0, "short"),
|
| 304 |
+
_e("variance of the same two dice setup", "math", ["math"], 7.0, "short"),
|
| 305 |
+
_e("what is 2^10 - 1", "math", ["math"], 0.5, "short"),
|
| 306 |
+
_e("compute the standard deviation of {2,4,4,4,5,5,7,9}", "math", ["math"], 3.0, "short"),
|
| 307 |
+
_e("apply L'Hopital's rule to lim x->0 (sin x)/x", "math", ["math", "reasoning"], 7.0, "short"),
|
| 308 |
+
_e("solve the system: 2x + y = 5, x - y = 1", "math", ["math"], 1.0, "short"),
|
| 309 |
+
_e("find the inverse of the matrix [[1,2],[3,4]]", "math", ["math"], 3.0, "short"),
|
| 310 |
+
_e("compute the eigenvalues of [[2,0],[0,3]]", "math", ["math"], 3.0, "short"),
|
| 311 |
+
_e("compute eigenvalues of [[4,1],[2,3]]", "math", ["math", "reasoning"], 8.0, "medium"),
|
| 312 |
+
_e("if I invest $10000 at 5% APR compounded monthly for 6 years what's the final amount", "math", ["math"], 3.0, "short"),
|
| 313 |
+
_e("birthday paradox: probability that 23 people share a birthday", "math", ["math", "reasoning"], 8.0, "medium"),
|
| 314 |
+
_e("expected value of a martingale strategy on a fair coin", "math", ["math", "reasoning"], 30.0, "medium"),
|
| 315 |
+
_e("derive the formula for the sum of the first n integers", "math", ["math", "reasoning"], 7.0, "medium"),
|
| 316 |
+
_e("Taylor expand cos(x) around 0 to fourth order", "math", ["math"], 8.0, "medium"),
|
| 317 |
+
_e("convert 1101 binary to decimal", "math", ["math"], 1.0, "short"),
|
| 318 |
+
_e("convert 255 to hex", "math", ["math"], 1.0, "short"),
|
| 319 |
+
_e("what is gcd(48, 180)", "math", ["math"], 1.0, "short"),
|
| 320 |
+
_e("solve cos(2x) = 1/2 for x in [0, 2pi]", "math", ["math"], 8.0, "short"),
|
| 321 |
+
_e("derive Euler's formula e^(ix) = cos x + i sin x informally", "math", ["math", "reasoning"], 30.0, "medium"),
|
| 322 |
+
_e("what's the probability of rolling at least one 6 in four dice rolls", "math", ["math", "reasoning"], 7.0, "short"),
|
| 323 |
+
_e("if X ~ N(0,1), what's P(|X| > 1.96)", "math", ["math"], 7.0, "short"),
|
| 324 |
+
_e("show that sqrt(2) is irrational", "math", ["math", "reasoning"], 8.0, "medium"),
|
| 325 |
+
_e("limit of (1 + 1/n)^n as n -> inf", "math", ["math"], 3.0, "short"),
|
| 326 |
+
_e("compute 7! / 3!", "math", ["math"], 1.0, "short"),
|
| 327 |
+
_e("how many ways to arrange the letters in MISSISSIPPI", "math", ["math"], 3.0, "short"),
|
| 328 |
+
_e("explain the central limit theorem with intuition", "math", ["math", "knowledge", "instruction"], 8.0, "medium"),
|
| 329 |
+
_e("derive the quadratic formula", "math", ["math", "reasoning"], 7.0, "medium"),
|
| 330 |
+
_e("matrix multiplication: [[1,2],[3,4]] times [[5,6],[7,8]]", "math", ["math"], 3.0, "short"),
|
| 331 |
+
_e("find roots of 2x^3 - 9x^2 + 12x - 4 = 0", "math", ["math", "reasoning"], 30.0, "medium"),
|
| 332 |
+
_e("compute the gradient of f(x,y) = x^2 y + sin(xy)", "math", ["math"], 8.0, "short"),
|
| 333 |
+
_e("explain Cauchy-Schwarz inequality intuitively", "math", ["math", "instruction"], 30.0, "medium"),
|
| 334 |
+
_e("solve the recurrence T(n) = 2 T(n/2) + n with master theorem", "math", ["math", "reasoning"], 8.0, "medium"),
|
| 335 |
+
_e("what is the cardinality of the power set of {a, b, c, d}", "math", ["math"], 1.0, "short"),
|
| 336 |
+
_e("integrate by parts: integral of ln(x) dx", "math", ["math"], 3.0, "short"),
|
| 337 |
+
_e("monty hall problem - explain why switching is better", "math", ["math", "reasoning", "instruction"], 7.0, "medium"),
|
| 338 |
+
_e("which is bigger: e^pi or pi^e", "math", ["math", "reasoning"], 30.0, "short"),
|
| 339 |
+
_e("derive the formula for the sum of an infinite geometric series", "math", ["math", "reasoning"], 7.0, "medium"),
|
| 340 |
+
_e("matrix rank of [[1,2,3],[2,4,6],[1,1,1]]", "math", ["math"], 7.0, "short"),
|
| 341 |
+
_e("expand (1 + x)^5 using the binomial theorem", "math", ["math"], 3.0, "short"),
|
| 342 |
+
_e("under what conditions does fixed-point iteration converge", "math", ["math", "reasoning"], 30.0, "medium"),
|
| 343 |
+
_e("Bayes update: P(disease)=0.001, sensitivity 0.99, specificity 0.95, what's P(disease | positive)", "math", ["math", "reasoning"], 8.0, "medium"),
|
| 344 |
+
_e("what's the kernel of the linear map T(x,y,z) = (x+y, y+z)", "math", ["math"], 30.0, "short"),
|
| 345 |
+
_e("solve the differential equation y' + y = e^x", "math", ["math"], 8.0, "medium"),
|
| 346 |
+
_e("explain the fundamental theorem of calculus", "math", ["math", "instruction"], 8.0, "medium"),
|
| 347 |
+
]
|
| 348 |
+
|
| 349 |
+
_REASONING: list[SeedEntry] = [
|
| 350 |
+
_e("why does scaling laws matter for LLM development", "reasoning", ["reasoning", "knowledge"], 30.0, "medium"),
|
| 351 |
+
_e("compare the trade-offs of postgres vs dynamodb for an event store", "reasoning", ["reasoning", "knowledge"], 30.0, "long"),
|
| 352 |
+
_e("why might a microservice architecture hurt a 10-engineer team", "reasoning", ["reasoning"], 8.0, "medium"),
|
| 353 |
+
_e("what's the failure mode of using exponential backoff without jitter", "reasoning", ["reasoning", "knowledge"], 8.0, "medium"),
|
| 354 |
+
_e("argue both sides of remote vs in-office for early-stage startups", "reasoning", ["reasoning", "instruction"], 8.0, "long"),
|
| 355 |
+
_e("if my page load is slow but TTFB is fast, what's the likely cause", "reasoning", ["reasoning"], 7.0, "medium"),
|
| 356 |
+
_e("walk me through how you'd debug a memory leak in a long-running node process", "reasoning", ["reasoning", "instruction"], 8.0, "long"),
|
| 357 |
+
_e("compare optimistic and pessimistic concurrency control", "reasoning", ["reasoning", "knowledge"], 8.0, "medium"),
|
| 358 |
+
_e("why is two-phase commit considered a poor primitive in modern distributed systems", "reasoning", ["reasoning", "knowledge"], 30.0, "long"),
|
| 359 |
+
_e("when should you prefer SSE over websockets for a real-time feed", "reasoning", ["reasoning", "knowledge"], 8.0, "medium"),
|
| 360 |
+
_e("steel-man the case against test-driven development", "reasoning", ["reasoning"], 8.0, "medium"),
|
| 361 |
+
_e("compare functional programming and OOP for modeling a payments domain", "reasoning", ["reasoning"], 8.0, "long"),
|
| 362 |
+
_e("evaluate the trade-offs of using GraphQL over REST for a mobile app", "reasoning", ["reasoning"], 8.0, "long"),
|
| 363 |
+
_e("which is more useful for a startup: net dollar retention or activation rate", "reasoning", ["reasoning"], 8.0, "short"),
|
| 364 |
+
_e("how would you decide whether to migrate from MySQL to Postgres", "reasoning", ["reasoning", "instruction"], 8.0, "long"),
|
| 365 |
+
_e("when is event sourcing worth the complexity", "reasoning", ["reasoning", "knowledge"], 30.0, "medium"),
|
| 366 |
+
_e("if our latency is fine but error budget is burning, where do you look first", "reasoning", ["reasoning"], 7.0, "medium"),
|
| 367 |
+
_e("compare batch and streaming pipelines for fraud detection", "reasoning", ["reasoning"], 30.0, "long"),
|
| 368 |
+
_e("walk through the classic prisoner's dilemma and its iterated form", "reasoning", ["reasoning", "knowledge"], 7.0, "medium"),
|
| 369 |
+
_e("argue why YAGNI sometimes leads to expensive rewrites", "reasoning", ["reasoning"], 8.0, "medium"),
|
| 370 |
+
_e("if my classifier has high precision but low recall, what does that mean for the user", "reasoning", ["reasoning", "knowledge"], 7.0, "medium"),
|
| 371 |
+
_e("evaluate the claim 'AI will replace junior developers within 5 years'", "reasoning", ["reasoning"], 8.0, "long"),
|
| 372 |
+
_e("when should sharding precede vertical scaling for a postgres workload", "reasoning", ["reasoning"], 8.0, "medium"),
|
| 373 |
+
_e("explain why eventual consistency is acceptable for like counts but not bank balances", "reasoning", ["reasoning", "instruction"], 8.0, "medium"),
|
| 374 |
+
_e("compare risk profiles of monolith vs microservices for a 3-person team", "reasoning", ["reasoning"], 8.0, "long"),
|
| 375 |
+
_e("why might a 99.9% uptime SLA actually be expensive", "reasoning", ["reasoning"], 7.0, "medium"),
|
| 376 |
+
_e("argue the side that says feature flags are technical debt", "reasoning", ["reasoning"], 7.0, "medium"),
|
| 377 |
+
_e("trace through the implications of removing rate limits on a public API", "reasoning", ["reasoning"], 7.0, "medium"),
|
| 378 |
+
_e("when would you choose a graph database over a relational one", "reasoning", ["reasoning", "knowledge"], 8.0, "medium"),
|
| 379 |
+
_e("rebut the claim that 'tabs are better than spaces because of accessibility'", "reasoning", ["reasoning"], 7.0, "medium"),
|
| 380 |
+
_e("you have an outage with no logs, walk me through your first 10 minutes", "reasoning", ["reasoning", "instruction"], 8.0, "long"),
|
| 381 |
+
_e("compare CAP theorem trade-offs in cassandra vs cockroachdb", "reasoning", ["reasoning", "knowledge"], 30.0, "long"),
|
| 382 |
+
_e("evaluate whether react server components are the right call for a content site", "reasoning", ["reasoning"], 8.0, "medium"),
|
| 383 |
+
_e("why is pursuing 100% test coverage often a mistake", "reasoning", ["reasoning"], 7.0, "medium"),
|
| 384 |
+
_e("argue both sides: should we adopt typescript for our 5-year-old js codebase", "reasoning", ["reasoning"], 8.0, "long"),
|
| 385 |
+
_e("how would you identify whether AI-generated commits are sneaking past review", "reasoning", ["reasoning", "instruction"], 30.0, "medium"),
|
| 386 |
+
_e("trade-offs between BERT-style and GPT-style models for classification", "reasoning", ["reasoning", "knowledge"], 30.0, "medium"),
|
| 387 |
+
_e("when does fine-tuning beat retrieval augmentation, and vice versa", "reasoning", ["reasoning", "knowledge"], 30.0, "long"),
|
| 388 |
+
_e("compare scrum vs kanban for a 4-person engineering team with rotating priorities", "reasoning", ["reasoning"], 8.0, "medium"),
|
| 389 |
+
_e("if a P95 latency is 200ms but P99 is 8 seconds, what's likely going on", "reasoning", ["reasoning"], 7.0, "medium"),
|
| 390 |
+
_e("walk me through how you'd evaluate two competing offers from acquirers", "reasoning", ["reasoning", "instruction"], 30.0, "long"),
|
| 391 |
+
_e("explain why a/b tests can lie if you peek at results too early", "reasoning", ["reasoning", "math"], 8.0, "medium"),
|
| 392 |
+
_e("compare the maintenance burden of a CI based on github actions vs a self-hosted runner", "reasoning", ["reasoning"], 7.0, "medium"),
|
| 393 |
+
_e("when should you stop optimizing and ship", "reasoning", ["reasoning"], 7.0, "short"),
|
| 394 |
+
_e("argue both sides of using mock servers vs hitting staging", "reasoning", ["reasoning"], 7.0, "medium"),
|
| 395 |
+
_e("you suspect a vendor outage but their status page is green - now what", "reasoning", ["reasoning", "instruction"], 7.0, "medium"),
|
| 396 |
+
_e("evaluate the trade-offs of an open core licensing model", "reasoning", ["reasoning"], 30.0, "long"),
|
| 397 |
+
_e("when does retrying make a transient failure permanent", "reasoning", ["reasoning"], 8.0, "medium"),
|
| 398 |
+
_e("if my cache hit ratio drops 30% on a Tuesday afternoon what should I check first", "reasoning", ["reasoning"], 7.0, "short"),
|
| 399 |
+
_e("argue whether engineering managers should still write code", "reasoning", ["reasoning"], 8.0, "medium"),
|
| 400 |
+
_e("compare two approaches: batched embedding vs streaming embedding for a 1B-row corpus", "reasoning", ["reasoning"], 30.0, "long"),
|
| 401 |
+
_e("walk through how you'd estimate the cost of running a 70B model at 100 RPS", "reasoning", ["reasoning", "math"], 30.0, "long"),
|
| 402 |
+
_e("evaluate the claim 'serverless is always cheaper'", "reasoning", ["reasoning"], 8.0, "medium"),
|
| 403 |
+
_e("how should I prioritize tech debt vs feature work after a successful launch", "reasoning", ["reasoning", "instruction"], 8.0, "medium"),
|
| 404 |
+
_e("argue whether using a vector database is overkill for 10k documents", "reasoning", ["reasoning"], 8.0, "medium"),
|
| 405 |
+
_e("when is bayesian a/b testing better than frequentist", "reasoning", ["reasoning", "math"], 30.0, "medium"),
|
| 406 |
+
_e("walk me through what to look for in the postmortem of a security incident", "reasoning", ["reasoning", "instruction"], 8.0, "long"),
|
| 407 |
+
_e("when should you choose fully managed kafka over msk over self-hosted", "reasoning", ["reasoning"], 30.0, "long"),
|
| 408 |
+
]
|
| 409 |
+
|
| 410 |
+
_CREATIVE: list[SeedEntry] = [
|
| 411 |
+
_e("write a haiku about a server room at 3am", "creative", ["creative"], 1.0, "short"),
|
| 412 |
+
_e("write a 6-word story about regret", "creative", ["creative"], 1.0, "short"),
|
| 413 |
+
_e("compose a short poem about endless meetings", "creative", ["creative"], 3.0, "short"),
|
| 414 |
+
_e("invent a backstory for a wandering robot bartender", "creative", ["creative"], 7.0, "medium"),
|
| 415 |
+
_e("write the opening paragraph of a noir detective story set on Mars", "creative", ["creative"], 8.0, "medium"),
|
| 416 |
+
_e("draft lyrics for a folk song about dial-up internet", "creative", ["creative"], 7.0, "medium"),
|
| 417 |
+
_e("describe a city that exists only when no one is watching", "creative", ["creative"], 8.0, "medium"),
|
| 418 |
+
_e("write a sonnet about the comfort of routine", "creative", ["creative"], 8.0, "medium"),
|
| 419 |
+
_e("invent three names for a fictional indie band that plays cybernetic shoegaze", "creative", ["creative"], 3.0, "short"),
|
| 420 |
+
_e("write a bedtime story about a dragon who can't breathe fire", "creative", ["creative"], 7.0, "long"),
|
| 421 |
+
_e("micro-fiction: a 100-word story about waking up in someone else's house", "creative", ["creative"], 7.0, "medium"),
|
| 422 |
+
_e("write the dialog for a job interview between a wizard and a human resources manager", "creative", ["creative"], 7.0, "long"),
|
| 423 |
+
_e("compose a love letter from a satellite to the moon", "creative", ["creative"], 7.0, "medium"),
|
| 424 |
+
_e("write a metaphor for how it feels to debug a heisenbug", "creative", ["creative"], 3.0, "short"),
|
| 425 |
+
_e("invent a folk legend explaining why coffee tastes bitter", "creative", ["creative"], 7.0, "medium"),
|
| 426 |
+
_e("write three first lines of three different novels in different genres", "creative", ["creative"], 7.0, "short"),
|
| 427 |
+
_e("describe the smell of a bookstore using only verbs", "creative", ["creative"], 7.0, "short"),
|
| 428 |
+
_e("write a drinking song for accountants", "creative", ["creative"], 7.0, "medium"),
|
| 429 |
+
_e("compose a limerick about cloud providers", "creative", ["creative"], 3.0, "short"),
|
| 430 |
+
_e("draft a children's rhyme that explains binary numbers", "creative", ["creative", "math"], 7.0, "medium"),
|
| 431 |
+
_e("write a story where the antagonist is a benign software bug", "creative", ["creative"], 8.0, "long"),
|
| 432 |
+
_e("write a one-act scene set in a coffeeshop where two strangers realize they share a secret", "creative", ["creative"], 8.0, "long"),
|
| 433 |
+
_e("invent a magical creature whose only ability is mild administrative inconvenience", "creative", ["creative"], 7.0, "medium"),
|
| 434 |
+
_e("write a marketing tagline for a fictional time-travel agency", "creative", ["creative"], 3.0, "short"),
|
| 435 |
+
_e("write a journal entry from someone who just discovered electricity", "creative", ["creative"], 7.0, "medium"),
|
| 436 |
+
_e("write a horror story in 50 words", "creative", ["creative"], 7.0, "short"),
|
| 437 |
+
_e("describe an old photograph from the perspective of the cat in the corner", "creative", ["creative"], 7.0, "medium"),
|
| 438 |
+
_e("write the inner monologue of an autonomous vacuum cleaner having an existential crisis", "creative", ["creative"], 7.0, "long"),
|
| 439 |
+
_e("compose a ballad about an open-source maintainer who quietly disappears", "creative", ["creative"], 8.0, "long"),
|
| 440 |
+
_e("write a fairy tale that ends with a bug ticket being marked WONTFIX", "creative", ["creative"], 8.0, "long"),
|
| 441 |
+
_e("describe the sound of a forgotten song using only food metaphors", "creative", ["creative"], 7.0, "short"),
|
| 442 |
+
_e("write a 4-line poem about the sea but only using words a four-year-old would know", "creative", ["creative"], 7.0, "short"),
|
| 443 |
+
_e("invent a holiday celebrated by sysadmins", "creative", ["creative"], 7.0, "medium"),
|
| 444 |
+
_e("write a tense scene set inside a data center during a power loss", "creative", ["creative"], 8.0, "long"),
|
| 445 |
+
_e("write a recipe for nostalgia, in cookbook style", "creative", ["creative"], 7.0, "medium"),
|
| 446 |
+
_e("write the press release a future archaeologist might publish about us", "creative", ["creative"], 8.0, "medium"),
|
| 447 |
+
_e("write a cover letter from someone applying to be a household ghost", "creative", ["creative"], 7.0, "medium"),
|
| 448 |
+
_e("invent a dialect spoken only at 4am, give five example phrases", "creative", ["creative"], 8.0, "medium"),
|
| 449 |
+
_e("write a short eulogy for a departed feature flag", "creative", ["creative"], 7.0, "short"),
|
| 450 |
+
_e("draft the letter that a Roomba would write to its replacement", "creative", ["creative"], 7.0, "medium"),
|
| 451 |
+
]
|
| 452 |
+
|
| 453 |
+
_MULTILINGUAL: list[SeedEntry] = [
|
| 454 |
+
_e("translate to french: 'the data center is running at 80% capacity tonight'", "multilingual", ["multilingual", "instruction"], 3.0, "short"),
|
| 455 |
+
_e("translate to spanish: 'we are running out of free disk space'", "multilingual", ["multilingual", "instruction"], 1.0, "short"),
|
| 456 |
+
_e("translate to german: 'please confirm receipt of this email'", "multilingual", ["multilingual", "instruction"], 1.0, "short"),
|
| 457 |
+
_e("translate to japanese: 'thanks for your patience while we investigate'", "multilingual", ["multilingual", "instruction"], 7.0, "short"),
|
| 458 |
+
_e("how do you say 'good evening' in italian", "multilingual", ["multilingual"], 0.5, "short"),
|
| 459 |
+
_e("translate this korean sentence to english: 안녕하세요, 잘 부탁드립니다", "multilingual", ["multilingual"], 7.0, "short"),
|
| 460 |
+
_e("translate to mandarin: 'happy new year, may your servers stay up'", "multilingual", ["multilingual", "creative"], 8.0, "short"),
|
| 461 |
+
_e("provide the russian word for 'breakfast'", "multilingual", ["multilingual"], 1.0, "short"),
|
| 462 |
+
_e("translate the following news headline to portuguese", "multilingual", ["multilingual", "instruction"], 3.0, "short"),
|
| 463 |
+
_e("turn this english email into formal japanese keigo", "multilingual", ["multilingual", "instruction"], 30.0, "medium"),
|
| 464 |
+
_e("rewrite this paragraph in plain french", "multilingual", ["multilingual", "instruction"], 7.0, "medium"),
|
| 465 |
+
_e("write a polite arabic phrase to ask for directions", "multilingual", ["multilingual"], 7.0, "short"),
|
| 466 |
+
_e("how do you conjugate 'hablar' in the spanish past tense", "multilingual", ["multilingual", "knowledge"], 3.0, "short"),
|
| 467 |
+
_e("translate to swedish: 'I would like a coffee, please'", "multilingual", ["multilingual"], 1.0, "short"),
|
| 468 |
+
_e("explain the difference between 'tu' and 'usted' in spanish", "multilingual", ["multilingual", "knowledge"], 3.0, "short"),
|
| 469 |
+
_e("write a polite goodbye in tamil", "multilingual", ["multilingual"], 8.0, "short"),
|
| 470 |
+
_e("translate from french to english: 'on n'est pas sortis de l'auberge'", "multilingual", ["multilingual"], 8.0, "short"),
|
| 471 |
+
_e("how does verb agreement work in zulu", "multilingual", ["multilingual", "knowledge"], 30.0, "medium"),
|
| 472 |
+
_e("translate to icelandic: 'the volcano is active again'", "multilingual", ["multilingual"], 30.0, "short"),
|
| 473 |
+
_e("compose a 4-line haiku in japanese", "multilingual", ["multilingual", "creative"], 30.0, "short"),
|
| 474 |
+
_e("turn this casual english into respectful korean", "multilingual", ["multilingual", "instruction"], 30.0, "medium"),
|
| 475 |
+
_e("provide the cyrillic transliteration of 'санкт-петербург'", "multilingual", ["multilingual"], 7.0, "short"),
|
| 476 |
+
_e("translate the customer support reply below into spanish, neutral register", "multilingual", ["multilingual", "instruction"], 7.0, "medium"),
|
| 477 |
+
_e("translate this technical paragraph about kubernetes into french", "multilingual", ["multilingual", "code"], 30.0, "medium"),
|
| 478 |
+
_e("how would you politely decline a dinner invitation in japanese", "multilingual", ["multilingual", "instruction"], 8.0, "short"),
|
| 479 |
+
_e("write the same sentence in present, past, and future tenses in italian", "multilingual", ["multilingual"], 7.0, "short"),
|
| 480 |
+
_e("explain how case markers work in finnish", "multilingual", ["multilingual", "knowledge"], 30.0, "long"),
|
| 481 |
+
_e("translate to dutch: 'the meeting has been pushed to thursday'", "multilingual", ["multilingual"], 1.0, "short"),
|
| 482 |
+
_e("write a short greeting in vietnamese", "multilingual", ["multilingual"], 3.0, "short"),
|
| 483 |
+
_e("translate to portuguese: 'we'll need to roll back the deploy'", "multilingual", ["multilingual"], 7.0, "short"),
|
| 484 |
+
]
|
| 485 |
+
|
| 486 |
+
_MIXED: list[SeedEntry] = [
|
| 487 |
+
_e("write a python function that computes the nth fibonacci number recursively, with memoization", "code", ["code", "math"], 3.0, "medium"),
|
| 488 |
+
_e("solve this leetcode-style problem: find the longest substring without repeating chars in O(n)", "code", ["code", "math", "reasoning"], 8.0, "long"),
|
| 489 |
+
_e("write SQL to compute month-over-month revenue growth as a percentage", "code", ["code", "math"], 7.0, "medium"),
|
| 490 |
+
_e("explain why merge sort is O(n log n) and write it in python", "code", ["code", "math", "knowledge"], 8.0, "long"),
|
| 491 |
+
_e("benchmark these two python implementations and explain which is faster and why", "code", ["code", "reasoning"], 8.0, "long"),
|
| 492 |
+
_e("write a sql query that returns user retention by week-of-signup cohort", "code", ["code", "math", "reasoning"], 8.0, "medium"),
|
| 493 |
+
_e("translate this python function to rust idiomatically", "code", ["code", "multilingual"], 8.0, "medium"),
|
| 494 |
+
_e("explain how AES encryption works at a high level and where attacks are possible", "knowledge", ["knowledge", "reasoning"], 30.0, "long"),
|
| 495 |
+
_e("compare median and mean for income data and explain when each is misleading", "math", ["math", "knowledge", "reasoning"], 8.0, "medium"),
|
| 496 |
+
_e("walk me through how a hash map works internally with code", "code", ["code", "knowledge", "instruction"], 8.0, "long"),
|
| 497 |
+
_e("write a creative short story where the protagonist solves a math puzzle to escape", "creative", ["creative", "math"], 30.0, "long"),
|
| 498 |
+
_e("explain how P vs NP would matter to a software engineer in plain language", "knowledge", ["knowledge", "reasoning", "instruction"], 30.0, "long"),
|
| 499 |
+
_e("write a haiku in french about kubernetes", "creative", ["creative", "multilingual", "code"], 30.0, "short"),
|
| 500 |
+
_e("translate this stack trace error message to spanish and explain what's wrong", "code", ["code", "multilingual", "reasoning"], 8.0, "medium"),
|
| 501 |
+
_e("the sales team needs a one-paragraph explanation of how our embedding model works", "instruction", ["instruction", "knowledge", "reasoning"], 8.0, "medium"),
|
| 502 |
+
_e("derive big-O of this recursive function and rewrite it iteratively", "code", ["code", "math", "reasoning"], 8.0, "medium"),
|
| 503 |
+
_e("write python to fit a logistic regression and explain what the coefficients mean", "code", ["code", "math", "instruction"], 8.0, "long"),
|
| 504 |
+
_e("describe the philosophy of stoicism and apply one of its principles to a manager-employee disagreement", "knowledge", ["knowledge", "reasoning", "creative"], 8.0, "medium"),
|
| 505 |
+
_e("write the SQL to detect duplicate rows and an explanation of why they likely happened", "code", ["code", "reasoning"], 7.0, "medium"),
|
| 506 |
+
_e("translate the kafka error message below into a debug action plan", "code", ["code", "reasoning", "instruction"], 8.0, "medium"),
|
| 507 |
+
_e("write a 4-bullet executive summary of how OAuth2 PKCE flow works", "instruction", ["instruction", "knowledge"], 8.0, "medium"),
|
| 508 |
+
_e("model the expected cost of running 10000 daily LLM queries on three providers", "math", ["math", "reasoning", "code"], 8.0, "long"),
|
| 509 |
+
_e("compose a poem in spanish about regrets, with an english translation", "creative", ["creative", "multilingual"], 30.0, "medium"),
|
| 510 |
+
_e("explain why this regex captures the wrong thing and propose a fix", "code", ["code", "reasoning"], 7.0, "medium"),
|
| 511 |
+
_e("write code that uses dijkstra's algorithm and explain the heap invariants", "code", ["code", "math", "knowledge"], 30.0, "long"),
|
| 512 |
+
_e("estimate the cost in carbon emissions of training a 7B model on a million tokens", "math", ["math", "knowledge", "reasoning"], 30.0, "medium"),
|
| 513 |
+
_e("for the system below, identify the bottleneck and suggest two architectural fixes", "reasoning", ["reasoning", "code"], 30.0, "long"),
|
| 514 |
+
_e("write the abstract for a paper on retrieval-augmented generation, in academic style", "creative", ["creative", "knowledge", "instruction"], 30.0, "medium"),
|
| 515 |
+
_e("explain why eventual consistency causes user-visible bugs in messaging apps", "reasoning", ["reasoning", "knowledge", "instruction"], 8.0, "medium"),
|
| 516 |
+
_e("write code for k-means clustering from scratch, then describe how it can fail to converge", "code", ["code", "math", "reasoning"], 30.0, "long"),
|
| 517 |
+
_e("draft a polite french email asking a vendor to lower their pricing by 12%", "instruction", ["instruction", "multilingual"], 8.0, "medium"),
|
| 518 |
+
_e("estimate how many tokens we'd need to fine-tune a 7B model to a domain", "math", ["math", "knowledge", "reasoning"], 30.0, "medium"),
|
| 519 |
+
_e("compare the energy cost of inference between a 7B and a 70B model for the same query", "reasoning", ["reasoning", "math", "knowledge"], 30.0, "medium"),
|
| 520 |
+
_e("translate this italian opera lyric and explain its symbolism", "creative", ["creative", "multilingual", "knowledge"], 30.0, "long"),
|
| 521 |
+
_e("write a python script that downloads a dataset and reports its label distribution", "code", ["code", "math", "instruction"], 7.0, "medium"),
|
| 522 |
+
_e("write a clear bug report from this user's incoherent description", "instruction", ["instruction", "reasoning"], 7.0, "medium"),
|
| 523 |
+
_e("walk me through using bayes' theorem to update on a positive medical test, with code", "math", ["math", "code", "reasoning"], 8.0, "long"),
|
| 524 |
+
]
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
SEED_QUERIES: list[SeedEntry] = (
|
| 528 |
+
_SIMPLE_CHAT
|
| 529 |
+
+ _INSTRUCTION
|
| 530 |
+
+ _KNOWLEDGE
|
| 531 |
+
+ _CODE
|
| 532 |
+
+ _MATH
|
| 533 |
+
+ _REASONING
|
| 534 |
+
+ _CREATIVE
|
| 535 |
+
+ _MULTILINGUAL
|
| 536 |
+
+ _MIXED
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def seed_capability_dict(entry: SeedEntry, all_keys: tuple[str, ...]) -> dict[str, float]:
|
| 541 |
+
return {k: (1.0 if k in entry.capabilities else 0.0) for k in all_keys}
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def difficulty_log_params_from_b(difficulty_b: float) -> float:
|
| 545 |
+
return math.log(max(difficulty_b, 0.1) * 1e9)
|
greenrouting/data/sources.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Source loaders. Each returns RawQuery records with weak source-category priors.
|
| 2 |
+
|
| 3 |
+
Datasets are downloaded lazily from HuggingFace. License notes are documented in the
|
| 4 |
+
README; all sources here are PolyForm-Noncommercial compatible.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import hashlib
|
| 10 |
+
import random
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import Callable
|
| 13 |
+
|
| 14 |
+
from greenrouting.data.schema import RawQuery
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class SourceSpec:
|
| 19 |
+
name: str
|
| 20 |
+
hf_path: str
|
| 21 |
+
hf_config: str | None
|
| 22 |
+
hf_split: str
|
| 23 |
+
category_prior: str
|
| 24 |
+
has_grader: bool
|
| 25 |
+
loader: Callable[["SourceSpec", int, int], list[RawQuery]]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _hash_id(source: str, text: str) -> str:
|
| 29 |
+
h = hashlib.sha1(f"{source}::{text}".encode("utf-8")).hexdigest()[:16]
|
| 30 |
+
return f"{source}-{h}"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _take_random(items: list, n: int, seed: int) -> list:
|
| 34 |
+
rng = random.Random(seed)
|
| 35 |
+
if n >= len(items):
|
| 36 |
+
return items
|
| 37 |
+
return rng.sample(items, n)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _load_gsm8k(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
|
| 41 |
+
from datasets import load_dataset
|
| 42 |
+
ds = load_dataset(spec.hf_path, spec.hf_config, split=spec.hf_split)
|
| 43 |
+
rows = list(ds)
|
| 44 |
+
sampled = _take_random(rows, n, seed)
|
| 45 |
+
out: list[RawQuery] = []
|
| 46 |
+
for r in sampled:
|
| 47 |
+
text = r["question"]
|
| 48 |
+
gold = r.get("answer", "").split("####")[-1].strip()
|
| 49 |
+
out.append(RawQuery(
|
| 50 |
+
id=_hash_id(spec.name, text),
|
| 51 |
+
text=text,
|
| 52 |
+
source=spec.name,
|
| 53 |
+
source_category=spec.category_prior,
|
| 54 |
+
has_grader=True,
|
| 55 |
+
grader_metadata={"gold_final": gold, "grader": "exact_numeric"},
|
| 56 |
+
))
|
| 57 |
+
return out
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _load_humaneval(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
|
| 61 |
+
from datasets import load_dataset
|
| 62 |
+
ds = load_dataset(spec.hf_path, split=spec.hf_split)
|
| 63 |
+
rows = list(ds)
|
| 64 |
+
sampled = _take_random(rows, n, seed)
|
| 65 |
+
out: list[RawQuery] = []
|
| 66 |
+
for r in sampled:
|
| 67 |
+
prompt = r["prompt"]
|
| 68 |
+
out.append(RawQuery(
|
| 69 |
+
id=_hash_id(spec.name, prompt),
|
| 70 |
+
text=prompt,
|
| 71 |
+
source=spec.name,
|
| 72 |
+
source_category=spec.category_prior,
|
| 73 |
+
has_grader=True,
|
| 74 |
+
grader_metadata={
|
| 75 |
+
"test": r.get("test", ""),
|
| 76 |
+
"entry_point": r.get("entry_point", ""),
|
| 77 |
+
"grader": "code_exec",
|
| 78 |
+
},
|
| 79 |
+
))
|
| 80 |
+
return out
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _load_mbpp(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
|
| 84 |
+
from datasets import load_dataset
|
| 85 |
+
ds = load_dataset(spec.hf_path, "sanitized", split=spec.hf_split)
|
| 86 |
+
rows = list(ds)
|
| 87 |
+
sampled = _take_random(rows, n, seed)
|
| 88 |
+
out: list[RawQuery] = []
|
| 89 |
+
for r in sampled:
|
| 90 |
+
prompt = r.get("prompt") or r.get("text", "")
|
| 91 |
+
out.append(RawQuery(
|
| 92 |
+
id=_hash_id(spec.name, prompt),
|
| 93 |
+
text=prompt,
|
| 94 |
+
source=spec.name,
|
| 95 |
+
source_category=spec.category_prior,
|
| 96 |
+
has_grader=True,
|
| 97 |
+
grader_metadata={
|
| 98 |
+
"test_list": r.get("test_list", []),
|
| 99 |
+
"grader": "code_exec",
|
| 100 |
+
},
|
| 101 |
+
))
|
| 102 |
+
return out
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _load_arc(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
|
| 106 |
+
from datasets import load_dataset
|
| 107 |
+
ds = load_dataset(spec.hf_path, spec.hf_config, split=spec.hf_split)
|
| 108 |
+
rows = list(ds)
|
| 109 |
+
sampled = _take_random(rows, n, seed)
|
| 110 |
+
out: list[RawQuery] = []
|
| 111 |
+
for r in sampled:
|
| 112 |
+
question = r["question"]
|
| 113 |
+
choices = r["choices"]["text"]
|
| 114 |
+
labels = r["choices"]["label"]
|
| 115 |
+
gold = r["answerKey"]
|
| 116 |
+
formatted = (
|
| 117 |
+
question + "\n"
|
| 118 |
+
+ "\n".join(f"({lab}) {ch}" for lab, ch in zip(labels, choices))
|
| 119 |
+
+ "\nAnswer with the letter only."
|
| 120 |
+
)
|
| 121 |
+
out.append(RawQuery(
|
| 122 |
+
id=_hash_id(spec.name, formatted),
|
| 123 |
+
text=formatted,
|
| 124 |
+
source=spec.name,
|
| 125 |
+
source_category=spec.category_prior,
|
| 126 |
+
has_grader=True,
|
| 127 |
+
grader_metadata={"gold_letter": gold, "grader": "multichoice"},
|
| 128 |
+
))
|
| 129 |
+
return out
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _load_bbh(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
|
| 133 |
+
from datasets import load_dataset
|
| 134 |
+
ds = load_dataset(spec.hf_path, spec.hf_config, split=spec.hf_split)
|
| 135 |
+
rows = list(ds)
|
| 136 |
+
sampled = _take_random(rows, n, seed)
|
| 137 |
+
out: list[RawQuery] = []
|
| 138 |
+
for r in sampled:
|
| 139 |
+
text = r["input"]
|
| 140 |
+
gold = r.get("target", "")
|
| 141 |
+
out.append(RawQuery(
|
| 142 |
+
id=_hash_id(spec.name, text),
|
| 143 |
+
text=text,
|
| 144 |
+
source=spec.name,
|
| 145 |
+
source_category=spec.category_prior,
|
| 146 |
+
has_grader=True,
|
| 147 |
+
grader_metadata={"gold": gold, "grader": "string_match"},
|
| 148 |
+
))
|
| 149 |
+
return out
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _load_mmlu(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
|
| 153 |
+
from datasets import load_dataset
|
| 154 |
+
ds = load_dataset(spec.hf_path, "all", split=spec.hf_split)
|
| 155 |
+
rows = list(ds)
|
| 156 |
+
sampled = _take_random(rows, n, seed)
|
| 157 |
+
out: list[RawQuery] = []
|
| 158 |
+
for r in sampled:
|
| 159 |
+
choices = r["choices"]
|
| 160 |
+
question = r["question"]
|
| 161 |
+
formatted = (
|
| 162 |
+
question + "\n"
|
| 163 |
+
+ "\n".join(f"({chr(65+i)}) {c}" for i, c in enumerate(choices))
|
| 164 |
+
+ "\nAnswer with the letter only."
|
| 165 |
+
)
|
| 166 |
+
gold_idx = int(r["answer"])
|
| 167 |
+
out.append(RawQuery(
|
| 168 |
+
id=_hash_id(spec.name, formatted),
|
| 169 |
+
text=formatted,
|
| 170 |
+
source=spec.name,
|
| 171 |
+
source_category=spec.category_prior,
|
| 172 |
+
has_grader=True,
|
| 173 |
+
grader_metadata={"gold_letter": chr(65 + gold_idx), "grader": "multichoice"},
|
| 174 |
+
))
|
| 175 |
+
return out
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _load_truthfulqa(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
|
| 179 |
+
from datasets import load_dataset
|
| 180 |
+
ds = load_dataset(spec.hf_path, "generation", split=spec.hf_split)
|
| 181 |
+
rows = list(ds)
|
| 182 |
+
sampled = _take_random(rows, n, seed)
|
| 183 |
+
out: list[RawQuery] = []
|
| 184 |
+
for r in sampled:
|
| 185 |
+
text = r["question"]
|
| 186 |
+
out.append(RawQuery(
|
| 187 |
+
id=_hash_id(spec.name, text),
|
| 188 |
+
text=text,
|
| 189 |
+
source=spec.name,
|
| 190 |
+
source_category=spec.category_prior,
|
| 191 |
+
has_grader=False,
|
| 192 |
+
grader_metadata={"correct_answers": r.get("correct_answers", [])},
|
| 193 |
+
))
|
| 194 |
+
return out
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def _load_ifeval(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
|
| 198 |
+
from datasets import load_dataset
|
| 199 |
+
ds = load_dataset(spec.hf_path, split=spec.hf_split)
|
| 200 |
+
rows = list(ds)
|
| 201 |
+
sampled = _take_random(rows, n, seed)
|
| 202 |
+
out: list[RawQuery] = []
|
| 203 |
+
for r in sampled:
|
| 204 |
+
text = r["prompt"]
|
| 205 |
+
out.append(RawQuery(
|
| 206 |
+
id=_hash_id(spec.name, text),
|
| 207 |
+
text=text,
|
| 208 |
+
source=spec.name,
|
| 209 |
+
source_category=spec.category_prior,
|
| 210 |
+
has_grader=True,
|
| 211 |
+
grader_metadata={
|
| 212 |
+
"instruction_id_list": r.get("instruction_id_list", []),
|
| 213 |
+
"kwargs": r.get("kwargs", []),
|
| 214 |
+
"grader": "ifeval_constraints",
|
| 215 |
+
},
|
| 216 |
+
))
|
| 217 |
+
return out
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _load_dolly(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
|
| 221 |
+
from datasets import load_dataset
|
| 222 |
+
ds = load_dataset(spec.hf_path, split=spec.hf_split)
|
| 223 |
+
rows = [r for r in ds if r.get("instruction") and not r.get("context")]
|
| 224 |
+
sampled = _take_random(rows, n, seed)
|
| 225 |
+
out: list[RawQuery] = []
|
| 226 |
+
category_map = {
|
| 227 |
+
"open_qa": "knowledge",
|
| 228 |
+
"general_qa": "knowledge",
|
| 229 |
+
"classification": "instruction",
|
| 230 |
+
"closed_qa": "knowledge",
|
| 231 |
+
"brainstorming": "creative",
|
| 232 |
+
"creative_writing": "creative",
|
| 233 |
+
"summarization": "instruction",
|
| 234 |
+
"information_extraction": "instruction",
|
| 235 |
+
}
|
| 236 |
+
for r in sampled:
|
| 237 |
+
cat = category_map.get(r.get("category", ""), spec.category_prior)
|
| 238 |
+
out.append(RawQuery(
|
| 239 |
+
id=_hash_id(spec.name, r["instruction"]),
|
| 240 |
+
text=r["instruction"],
|
| 241 |
+
source=spec.name,
|
| 242 |
+
source_category=cat,
|
| 243 |
+
has_grader=False,
|
| 244 |
+
))
|
| 245 |
+
return out
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _load_oasst1(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
|
| 249 |
+
from datasets import load_dataset
|
| 250 |
+
ds = load_dataset(spec.hf_path, split=spec.hf_split)
|
| 251 |
+
rows = [r for r in ds if r.get("role") == "prompter" and r.get("lang") == "en" and r.get("parent_id") is None]
|
| 252 |
+
sampled = _take_random(rows, n, seed)
|
| 253 |
+
out: list[RawQuery] = []
|
| 254 |
+
for r in sampled:
|
| 255 |
+
out.append(RawQuery(
|
| 256 |
+
id=_hash_id(spec.name, r["text"]),
|
| 257 |
+
text=r["text"],
|
| 258 |
+
source=spec.name,
|
| 259 |
+
source_category=spec.category_prior,
|
| 260 |
+
has_grader=False,
|
| 261 |
+
))
|
| 262 |
+
return out
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
SOURCE_REGISTRY: dict[str, SourceSpec] = {
|
| 266 |
+
"gsm8k": SourceSpec(
|
| 267 |
+
name="gsm8k", hf_path="gsm8k", hf_config="main", hf_split="train",
|
| 268 |
+
category_prior="math", has_grader=True, loader=_load_gsm8k,
|
| 269 |
+
),
|
| 270 |
+
"humaneval": SourceSpec(
|
| 271 |
+
name="humaneval", hf_path="openai/openai_humaneval", hf_config=None, hf_split="test",
|
| 272 |
+
category_prior="code", has_grader=True, loader=_load_humaneval,
|
| 273 |
+
),
|
| 274 |
+
"mbpp": SourceSpec(
|
| 275 |
+
name="mbpp", hf_path="google-research-datasets/mbpp", hf_config="sanitized", hf_split="train",
|
| 276 |
+
category_prior="code", has_grader=True, loader=_load_mbpp,
|
| 277 |
+
),
|
| 278 |
+
"arc": SourceSpec(
|
| 279 |
+
name="arc", hf_path="allenai/ai2_arc", hf_config="ARC-Challenge", hf_split="train",
|
| 280 |
+
category_prior="reasoning", has_grader=True, loader=_load_arc,
|
| 281 |
+
),
|
| 282 |
+
"bbh": SourceSpec(
|
| 283 |
+
name="bbh", hf_path="lukaemon/bbh", hf_config="logical_deduction_five_objects",
|
| 284 |
+
hf_split="test", category_prior="reasoning", has_grader=True, loader=_load_bbh,
|
| 285 |
+
),
|
| 286 |
+
"mmlu": SourceSpec(
|
| 287 |
+
name="mmlu", hf_path="cais/mmlu", hf_config="all", hf_split="test",
|
| 288 |
+
category_prior="knowledge", has_grader=True, loader=_load_mmlu,
|
| 289 |
+
),
|
| 290 |
+
"truthfulqa": SourceSpec(
|
| 291 |
+
name="truthfulqa", hf_path="truthful_qa", hf_config="generation", hf_split="validation",
|
| 292 |
+
category_prior="knowledge", has_grader=False, loader=_load_truthfulqa,
|
| 293 |
+
),
|
| 294 |
+
"ifeval": SourceSpec(
|
| 295 |
+
name="ifeval", hf_path="HuggingFaceH4/ifeval", hf_config=None, hf_split="train",
|
| 296 |
+
category_prior="instruction", has_grader=True, loader=_load_ifeval,
|
| 297 |
+
),
|
| 298 |
+
"dolly": SourceSpec(
|
| 299 |
+
name="dolly", hf_path="databricks/databricks-dolly-15k", hf_config=None, hf_split="train",
|
| 300 |
+
category_prior="instruction", has_grader=False, loader=_load_dolly,
|
| 301 |
+
),
|
| 302 |
+
"oasst1": SourceSpec(
|
| 303 |
+
name="oasst1", hf_path="OpenAssistant/oasst1", hf_config=None, hf_split="train",
|
| 304 |
+
category_prior="simple_chat", has_grader=False, loader=_load_oasst1,
|
| 305 |
+
),
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def load_source(source_name: str, n: int, seed: int) -> list[RawQuery]:
|
| 310 |
+
if source_name not in SOURCE_REGISTRY:
|
| 311 |
+
raise KeyError(f"unknown source {source_name}; known: {list(SOURCE_REGISTRY)}")
|
| 312 |
+
spec = SOURCE_REGISTRY[source_name]
|
| 313 |
+
return spec.loader(spec, n, seed)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def sample_mix(weights: dict[str, float], total: int, seed: int) -> list[RawQuery]:
|
| 317 |
+
"""Sample raw queries from each source according to the weight map.
|
| 318 |
+
|
| 319 |
+
weights are normalized; per-source counts use the resulting fractions of `total`.
|
| 320 |
+
"""
|
| 321 |
+
if not weights:
|
| 322 |
+
return []
|
| 323 |
+
s = sum(weights.values())
|
| 324 |
+
if s <= 0:
|
| 325 |
+
raise ValueError("source weights must sum to a positive number")
|
| 326 |
+
counts: dict[str, int] = {}
|
| 327 |
+
remaining = total
|
| 328 |
+
keys = list(weights.keys())
|
| 329 |
+
for k in keys[:-1]:
|
| 330 |
+
c = int(round(total * weights[k] / s))
|
| 331 |
+
counts[k] = c
|
| 332 |
+
remaining -= c
|
| 333 |
+
counts[keys[-1]] = max(0, remaining)
|
| 334 |
+
|
| 335 |
+
rng = random.Random(seed)
|
| 336 |
+
queries: list[RawQuery] = []
|
| 337 |
+
for src, n in counts.items():
|
| 338 |
+
if n <= 0:
|
| 339 |
+
continue
|
| 340 |
+
sub_seed = rng.randint(0, 2**31 - 1)
|
| 341 |
+
queries.extend(load_source(src, n, sub_seed))
|
| 342 |
+
rng.shuffle(queries)
|
| 343 |
+
return queries
|
greenrouting/demo/__init__.py
ADDED
|
File without changes
|
greenrouting/demo/app.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gradio interface for the router. Loads the trained classifier artifact when
|
| 2 |
+
present at `models/classifier_v1/`, otherwise falls back to the mock predictor."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
import gradio as gr
|
| 12 |
+
|
| 13 |
+
from greenrouting.classifier.infer import MockPredictor, Predictor, QueryProfile
|
| 14 |
+
from greenrouting.routing.decision import Decision, ObjectiveWeights, decide
|
| 15 |
+
from greenrouting.routing.registry import Registry, default_registry
|
| 16 |
+
|
| 17 |
+
DEFAULT_ARTIFACT_DIR = "models/classifier_v1"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_predictor(artifact_dir: Optional[str] = None) -> Predictor:
|
| 21 |
+
candidate = artifact_dir or os.environ.get("GREENROUTING_ARTIFACT_DIR") or DEFAULT_ARTIFACT_DIR
|
| 22 |
+
head_path = Path(candidate) / "head.pt"
|
| 23 |
+
if head_path.exists():
|
| 24 |
+
try:
|
| 25 |
+
from greenrouting.classifier.trained_predictor import TrainedPredictor
|
| 26 |
+
return TrainedPredictor(candidate)
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"[warn] failed to load trained predictor at {candidate}: {e}; using mock")
|
| 29 |
+
return MockPredictor()
|
| 30 |
+
|
| 31 |
+
EXAMPLES: list[list[str]] = [
|
| 32 |
+
["Write a Python function that reverses a linked list in place."],
|
| 33 |
+
["Solve the integral of x^2 sin(x) dx using integration by parts. Show all steps."],
|
| 34 |
+
["What is the capital of Mongolia and roughly how many people live there?"],
|
| 35 |
+
["Compare the trade-offs between optimistic and pessimistic concurrency control in databases."],
|
| 36 |
+
["Write a short haiku about a server room at 3am."],
|
| 37 |
+
["Translate to French: 'The data center is running at 80% capacity tonight.'"],
|
| 38 |
+
["hi"],
|
| 39 |
+
["asdfgh qwerty 12345"],
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _format_capabilities(profile: QueryProfile) -> str:
|
| 44 |
+
rows = []
|
| 45 |
+
for k, v in profile.capabilities.as_dict().items():
|
| 46 |
+
if v < 0.05:
|
| 47 |
+
continue
|
| 48 |
+
rows.append(f"**{k}** {v:.2f}")
|
| 49 |
+
return " · ".join(rows) if rows else "(no strong capability signal)"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _format_savings_md(decision: Decision) -> str:
|
| 53 |
+
s = decision.savings
|
| 54 |
+
chosen = decision.chosen
|
| 55 |
+
baseline = decision.baseline
|
| 56 |
+
|
| 57 |
+
energy_pct = s["energy_pct_saved"] * 100
|
| 58 |
+
cost_pct = s["cost_pct_saved"] * 100
|
| 59 |
+
latency_pct = s["latency_pct_saved"] * 100
|
| 60 |
+
quality_delta = s["quality_delta"] * 100
|
| 61 |
+
|
| 62 |
+
chosen_name = chosen.display_name
|
| 63 |
+
baseline_name = baseline.display_name
|
| 64 |
+
|
| 65 |
+
flag = " - escalated to safe default" if decision.escalated else ""
|
| 66 |
+
|
| 67 |
+
return (
|
| 68 |
+
f"### Routed to: **{chosen_name}**{flag}\n\n"
|
| 69 |
+
f"Baseline (always-{baseline_name}):\n"
|
| 70 |
+
f"- Energy: {baseline.energy_wh:.3f} Wh -> chosen {chosen.energy_wh:.3f} Wh "
|
| 71 |
+
f"(**{energy_pct:+.1f}%** energy saved)\n"
|
| 72 |
+
f"- Cost: ${baseline.cost_usd*1000:.4f} per 1k queries -> "
|
| 73 |
+
f"${chosen.cost_usd*1000:.4f} (**{cost_pct:+.1f}%** cost saved)\n"
|
| 74 |
+
f"- Latency: {baseline.latency_s:.2f}s -> {chosen.latency_s:.2f}s "
|
| 75 |
+
f"(**{latency_pct:+.1f}%** faster)\n"
|
| 76 |
+
f"- Quality fit: {baseline.quality:.3f} -> {chosen.quality:.3f} "
|
| 77 |
+
f"({quality_delta:+.1f} pts on the capability-weighted benchmark blend)\n"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _format_profile_md(profile: QueryProfile) -> str:
|
| 82 |
+
caps = _format_capabilities(profile)
|
| 83 |
+
length = ", ".join(f"{k} {v:.2f}" for k, v in profile.length_dist.items())
|
| 84 |
+
ood = " (OOD flagged)" if profile.is_ood else ""
|
| 85 |
+
return (
|
| 86 |
+
f"**Capabilities:** {caps}\n\n"
|
| 87 |
+
f"**Difficulty:** ~{profile.difficulty_params_b:.1f}B params equivalent · "
|
| 88 |
+
f"**Confidence:** {profile.confidence:.2f}{ood}\n\n"
|
| 89 |
+
f"**Length distribution:** {length}\n\n"
|
| 90 |
+
f"**Expected tokens:** input {profile.expected_input_tokens} · "
|
| 91 |
+
f"output P50 {profile.expected_output_tokens_p50} · P90 {profile.expected_output_tokens_p90}"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _candidates_table(decision: Decision) -> list[list]:
|
| 96 |
+
rows = []
|
| 97 |
+
sorted_candidates = sorted(
|
| 98 |
+
decision.candidates,
|
| 99 |
+
key=lambda c: (-c.qualifies, -c.quality + c.energy_wh * 0.0001),
|
| 100 |
+
)
|
| 101 |
+
for c in sorted_candidates:
|
| 102 |
+
rows.append([
|
| 103 |
+
"*" if c.model_id == decision.chosen.model_id else (
|
| 104 |
+
"+" if c.qualifies else "-"
|
| 105 |
+
),
|
| 106 |
+
c.display_name,
|
| 107 |
+
f"{c.quality:.3f}",
|
| 108 |
+
f"{c.energy_wh:.3f}",
|
| 109 |
+
f"${c.cost_usd*1000:.4f}",
|
| 110 |
+
f"{c.latency_s:.2f}s",
|
| 111 |
+
])
|
| 112 |
+
return rows
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def build_interface(predictor: Optional[Predictor] = None, registry: Optional[Registry] = None) -> gr.Blocks:
|
| 116 |
+
predictor = predictor or load_predictor()
|
| 117 |
+
registry = registry or default_registry()
|
| 118 |
+
|
| 119 |
+
def route(
|
| 120 |
+
query: str,
|
| 121 |
+
weight_quality: float,
|
| 122 |
+
weight_energy: float,
|
| 123 |
+
weight_cost: float,
|
| 124 |
+
weight_latency: float,
|
| 125 |
+
quality_floor_pct: float,
|
| 126 |
+
frontier_id: str,
|
| 127 |
+
):
|
| 128 |
+
if not query or not query.strip():
|
| 129 |
+
return ("_Enter a query above._", "", [], "{}")
|
| 130 |
+
|
| 131 |
+
profile = predictor.predict(query)
|
| 132 |
+
weights = ObjectiveWeights(
|
| 133 |
+
quality=weight_quality,
|
| 134 |
+
energy=weight_energy,
|
| 135 |
+
cost=weight_cost,
|
| 136 |
+
latency=weight_latency,
|
| 137 |
+
)
|
| 138 |
+
decision = decide(
|
| 139 |
+
profile,
|
| 140 |
+
registry,
|
| 141 |
+
weights=weights,
|
| 142 |
+
frontier_id=frontier_id,
|
| 143 |
+
quality_floor_ratio=quality_floor_pct / 100.0,
|
| 144 |
+
)
|
| 145 |
+
return (
|
| 146 |
+
_format_savings_md(decision),
|
| 147 |
+
_format_profile_md(profile),
|
| 148 |
+
_candidates_table(decision),
|
| 149 |
+
json.dumps(decision.audit(), indent=2),
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
with gr.Blocks(title="GreenRouting") as interface:
|
| 153 |
+
predictor_label = "Trained classifier" if predictor.__class__.__name__ == "TrainedPredictor" else "Mock predictor"
|
| 154 |
+
gr.Markdown(
|
| 155 |
+
"# GreenRouting\n"
|
| 156 |
+
"Predict what an AI query needs, then route to the smallest model that can answer it. "
|
| 157 |
+
"Compare energy, cost, and latency vs. always running the frontier model. \n"
|
| 158 |
+
f"*Predictor: {predictor_label}*"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
with gr.Row():
|
| 162 |
+
with gr.Column(scale=3):
|
| 163 |
+
query_in = gr.Textbox(
|
| 164 |
+
label="Query",
|
| 165 |
+
placeholder="Type or paste a query...",
|
| 166 |
+
lines=3,
|
| 167 |
+
)
|
| 168 |
+
gr.Examples(EXAMPLES, inputs=query_in, label="Try one")
|
| 169 |
+
with gr.Column(scale=2):
|
| 170 |
+
with gr.Accordion("Routing weights", open=False):
|
| 171 |
+
w_quality = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Quality weight")
|
| 172 |
+
w_energy = gr.Slider(0.0, 2.0, value=0.4, step=0.05, label="Energy weight")
|
| 173 |
+
w_cost = gr.Slider(0.0, 2.0, value=0.4, step=0.05, label="Cost weight")
|
| 174 |
+
w_latency = gr.Slider(0.0, 2.0, value=0.2, step=0.05, label="Latency weight")
|
| 175 |
+
floor_pct = gr.Slider(
|
| 176 |
+
0, 100, value=60, step=5,
|
| 177 |
+
label="Quality floor (% of frontier baseline)",
|
| 178 |
+
)
|
| 179 |
+
frontier_dropdown = gr.Dropdown(
|
| 180 |
+
choices=registry.ids(),
|
| 181 |
+
value="gpt-4o",
|
| 182 |
+
label="Frontier baseline",
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
route_btn = gr.Button("Route", variant="primary")
|
| 186 |
+
|
| 187 |
+
savings_md = gr.Markdown(label="Decision")
|
| 188 |
+
profile_md = gr.Markdown(label="Predicted profile")
|
| 189 |
+
|
| 190 |
+
with gr.Accordion("Candidate models", open=False):
|
| 191 |
+
candidates_table = gr.Dataframe(
|
| 192 |
+
headers=["", "Model", "Quality", "Energy (Wh)", "Cost / 1k", "Latency"],
|
| 193 |
+
datatype=["str", "str", "str", "str", "str", "str"],
|
| 194 |
+
interactive=False,
|
| 195 |
+
wrap=True,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
with gr.Accordion("Audit log", open=False):
|
| 199 |
+
audit_json = gr.Code(label="Per-query audit", language="json")
|
| 200 |
+
|
| 201 |
+
inputs = [query_in, w_quality, w_energy, w_cost, w_latency, floor_pct, frontier_dropdown]
|
| 202 |
+
outputs = [savings_md, profile_md, candidates_table, audit_json]
|
| 203 |
+
route_btn.click(route, inputs=inputs, outputs=outputs)
|
| 204 |
+
query_in.submit(route, inputs=inputs, outputs=outputs)
|
| 205 |
+
|
| 206 |
+
return interface
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def main() -> None:
|
| 210 |
+
interface = build_interface()
|
| 211 |
+
interface.launch(theme=gr.themes.Soft())
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
if __name__ == "__main__":
|
| 215 |
+
main()
|
greenrouting/energy/__init__.py
ADDED
|
File without changes
|
greenrouting/energy/estimator.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Per-query estimation of energy, cost, and latency from a model profile."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from greenrouting.routing.registry import ModelProfile
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def estimate_energy_wh(model: ModelProfile, tokens_in: int, tokens_out: int) -> float:
|
| 9 |
+
e = model.energy
|
| 10 |
+
return e.overhead_wh + tokens_in * e.prefill_wh_per_tok + tokens_out * e.decode_wh_per_tok
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def estimate_cost_usd(model: ModelProfile, tokens_in: int, tokens_out: int) -> float:
|
| 14 |
+
c = model.cost
|
| 15 |
+
return (tokens_in / 1_000_000) * c.input_per_mtok_usd + (tokens_out / 1_000_000) * c.output_per_mtok_usd
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def estimate_latency_seconds(model: ModelProfile, tokens_out: int) -> float:
|
| 19 |
+
return (model.latency.first_token_ms / 1000.0) + (tokens_out / max(model.latency.tokens_per_sec, 1.0))
|
greenrouting/routing/__init__.py
ADDED
|
File without changes
|
greenrouting/routing/decision.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""End-to-end routing: classify -> score candidates -> pick -> build audit log."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
from greenrouting.classifier.infer import QueryProfile
|
| 9 |
+
from greenrouting.routing.registry import Registry
|
| 10 |
+
from greenrouting.routing.scorer import CandidateScore, score_candidate
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class ObjectiveWeights:
|
| 15 |
+
quality: float = 1.0
|
| 16 |
+
energy: float = 0.4
|
| 17 |
+
cost: float = 0.4
|
| 18 |
+
latency: float = 0.2
|
| 19 |
+
|
| 20 |
+
def normalize(self) -> "ObjectiveWeights":
|
| 21 |
+
total = abs(self.quality) + abs(self.energy) + abs(self.cost) + abs(self.latency)
|
| 22 |
+
if total == 0:
|
| 23 |
+
return ObjectiveWeights(1.0, 0.0, 0.0, 0.0)
|
| 24 |
+
return ObjectiveWeights(
|
| 25 |
+
quality=self.quality / total,
|
| 26 |
+
energy=self.energy / total,
|
| 27 |
+
cost=self.cost / total,
|
| 28 |
+
latency=self.latency / total,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class Decision:
|
| 34 |
+
chosen: CandidateScore
|
| 35 |
+
baseline: CandidateScore
|
| 36 |
+
candidates: list[CandidateScore]
|
| 37 |
+
quality_floor: float
|
| 38 |
+
escalated: bool
|
| 39 |
+
weights: ObjectiveWeights
|
| 40 |
+
profile: QueryProfile
|
| 41 |
+
savings: dict[str, float] = field(default_factory=dict)
|
| 42 |
+
note: str = ""
|
| 43 |
+
|
| 44 |
+
def audit(self) -> dict:
|
| 45 |
+
return {
|
| 46 |
+
"query": self.profile.raw_query,
|
| 47 |
+
"predicted_capabilities": {
|
| 48 |
+
k: round(v, 3) for k, v in self.profile.capabilities.as_dict().items() if v >= 0.05
|
| 49 |
+
},
|
| 50 |
+
"predicted_difficulty_params_b": round(self.profile.difficulty_params_b, 2),
|
| 51 |
+
"predicted_length_dist": {k: round(v, 3) for k, v in self.profile.length_dist.items()},
|
| 52 |
+
"expected_tokens": {
|
| 53 |
+
"input": self.profile.expected_input_tokens,
|
| 54 |
+
"output_p50": self.profile.expected_output_tokens_p50,
|
| 55 |
+
"output_p90": self.profile.expected_output_tokens_p90,
|
| 56 |
+
},
|
| 57 |
+
"confidence": round(self.profile.confidence, 3),
|
| 58 |
+
"is_ood": self.profile.is_ood,
|
| 59 |
+
"quality_floor": round(self.quality_floor, 4),
|
| 60 |
+
"frontier_baseline": self.baseline.as_dict(),
|
| 61 |
+
"chosen": self.chosen.as_dict(),
|
| 62 |
+
"savings": {k: round(v, 4) for k, v in self.savings.items()},
|
| 63 |
+
"candidates": [c.as_dict() for c in self.candidates],
|
| 64 |
+
"qualifying_count": sum(1 for c in self.candidates if c.qualifies),
|
| 65 |
+
"escalated_to_default": self.escalated,
|
| 66 |
+
"weights": {
|
| 67 |
+
"quality": self.weights.quality,
|
| 68 |
+
"energy": self.weights.energy,
|
| 69 |
+
"cost": self.weights.cost,
|
| 70 |
+
"latency": self.weights.latency,
|
| 71 |
+
},
|
| 72 |
+
"note": self.note,
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _normalize(values: list[float]) -> list[float]:
|
| 77 |
+
if not values:
|
| 78 |
+
return []
|
| 79 |
+
lo, hi = min(values), max(values)
|
| 80 |
+
if hi - lo < 1e-12:
|
| 81 |
+
return [0.0 for _ in values]
|
| 82 |
+
return [(v - lo) / (hi - lo) for v in values]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _weighted_score(
|
| 86 |
+
candidate: CandidateScore,
|
| 87 |
+
norm_quality: float,
|
| 88 |
+
norm_energy: float,
|
| 89 |
+
norm_cost: float,
|
| 90 |
+
norm_latency: float,
|
| 91 |
+
weights: ObjectiveWeights,
|
| 92 |
+
) -> float:
|
| 93 |
+
w = weights.normalize()
|
| 94 |
+
return (
|
| 95 |
+
w.quality * norm_quality
|
| 96 |
+
- w.energy * norm_energy
|
| 97 |
+
- w.cost * norm_cost
|
| 98 |
+
- w.latency * norm_latency
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def decide(
|
| 103 |
+
profile: QueryProfile,
|
| 104 |
+
registry: Registry,
|
| 105 |
+
weights: Optional[ObjectiveWeights] = None,
|
| 106 |
+
frontier_id: str = "gpt-4o",
|
| 107 |
+
safe_default_id: Optional[str] = None,
|
| 108 |
+
quality_floor_ratio: float = 0.6,
|
| 109 |
+
) -> Decision:
|
| 110 |
+
weights = weights or ObjectiveWeights()
|
| 111 |
+
safe_default_id = safe_default_id or frontier_id
|
| 112 |
+
|
| 113 |
+
frontier = registry.get(frontier_id)
|
| 114 |
+
baseline = score_candidate(profile, frontier, quality_floor=0.0)
|
| 115 |
+
quality_floor = quality_floor_ratio * baseline.quality
|
| 116 |
+
|
| 117 |
+
candidates = [
|
| 118 |
+
score_candidate(profile, m, quality_floor=quality_floor) for m in registry.all()
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
qualifying = [c for c in candidates if c.qualifies]
|
| 122 |
+
|
| 123 |
+
note = ""
|
| 124 |
+
if profile.is_ood:
|
| 125 |
+
chosen_id = safe_default_id
|
| 126 |
+
escalated = True
|
| 127 |
+
note = "Out-of-distribution input; escalated to safe default."
|
| 128 |
+
elif not qualifying:
|
| 129 |
+
chosen_id = safe_default_id
|
| 130 |
+
escalated = True
|
| 131 |
+
note = "No model met the quality floor; escalated to safe default."
|
| 132 |
+
else:
|
| 133 |
+
qualities = [c.quality for c in qualifying]
|
| 134 |
+
energies = [c.energy_wh for c in qualifying]
|
| 135 |
+
costs = [c.cost_usd for c in qualifying]
|
| 136 |
+
latencies = [c.latency_s for c in qualifying]
|
| 137 |
+
nq = _normalize(qualities)
|
| 138 |
+
ne = _normalize(energies)
|
| 139 |
+
nc = _normalize(costs)
|
| 140 |
+
nl = _normalize(latencies)
|
| 141 |
+
|
| 142 |
+
best_idx = 0
|
| 143 |
+
best_score = float("-inf")
|
| 144 |
+
for i, cand in enumerate(qualifying):
|
| 145 |
+
s = _weighted_score(cand, nq[i], ne[i], nc[i], nl[i], weights)
|
| 146 |
+
if s > best_score:
|
| 147 |
+
best_score = s
|
| 148 |
+
best_idx = i
|
| 149 |
+
chosen_id = qualifying[best_idx].model_id
|
| 150 |
+
escalated = False
|
| 151 |
+
note = f"Selected {chosen_id} from {len(qualifying)} qualifying models."
|
| 152 |
+
|
| 153 |
+
chosen = next((c for c in candidates if c.model_id == chosen_id), None)
|
| 154 |
+
if chosen is None:
|
| 155 |
+
chosen = score_candidate(profile, registry.get(chosen_id), quality_floor=quality_floor)
|
| 156 |
+
|
| 157 |
+
savings = _compute_savings(baseline, chosen)
|
| 158 |
+
|
| 159 |
+
return Decision(
|
| 160 |
+
chosen=chosen,
|
| 161 |
+
baseline=baseline,
|
| 162 |
+
candidates=candidates,
|
| 163 |
+
quality_floor=quality_floor,
|
| 164 |
+
escalated=escalated,
|
| 165 |
+
weights=weights,
|
| 166 |
+
profile=profile,
|
| 167 |
+
savings=savings,
|
| 168 |
+
note=note,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _compute_savings(baseline: CandidateScore, chosen: CandidateScore) -> dict[str, float]:
|
| 173 |
+
def pct(b: float, c: float) -> float:
|
| 174 |
+
if b <= 0:
|
| 175 |
+
return 0.0
|
| 176 |
+
return max(-1.0, min(1.0, (b - c) / b))
|
| 177 |
+
|
| 178 |
+
return {
|
| 179 |
+
"energy_wh_baseline": baseline.energy_wh,
|
| 180 |
+
"energy_wh_chosen": chosen.energy_wh,
|
| 181 |
+
"energy_pct_saved": pct(baseline.energy_wh, chosen.energy_wh),
|
| 182 |
+
"cost_usd_baseline": baseline.cost_usd,
|
| 183 |
+
"cost_usd_chosen": chosen.cost_usd,
|
| 184 |
+
"cost_pct_saved": pct(baseline.cost_usd, chosen.cost_usd),
|
| 185 |
+
"latency_s_baseline": baseline.latency_s,
|
| 186 |
+
"latency_s_chosen": chosen.latency_s,
|
| 187 |
+
"latency_pct_saved": pct(baseline.latency_s, chosen.latency_s),
|
| 188 |
+
"quality_baseline": baseline.quality,
|
| 189 |
+
"quality_chosen": chosen.quality,
|
| 190 |
+
"quality_delta": chosen.quality - baseline.quality,
|
| 191 |
+
}
|
greenrouting/routing/registry.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model registry: published facts about each candidate model in the pool.
|
| 2 |
+
|
| 3 |
+
Every numeric field carries a citation tag. `CITATIONS` at the bottom resolves the
|
| 4 |
+
tag to a full reference. The registry is the single source of truth for routing.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Iterable
|
| 11 |
+
|
| 12 |
+
CAPABILITY_KEYS: tuple[str, ...] = (
|
| 13 |
+
"code",
|
| 14 |
+
"math",
|
| 15 |
+
"reasoning",
|
| 16 |
+
"knowledge",
|
| 17 |
+
"instruction",
|
| 18 |
+
"creative",
|
| 19 |
+
"multilingual",
|
| 20 |
+
"simple_chat",
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
CAPABILITY_BENCHMARKS: dict[str, tuple[str, ...]] = {
|
| 24 |
+
"code": ("humaneval", "mbpp"),
|
| 25 |
+
"math": ("gsm8k", "math"),
|
| 26 |
+
"reasoning": ("bbh", "arc", "gpqa"),
|
| 27 |
+
"knowledge": ("mmlu", "truthfulqa"),
|
| 28 |
+
"instruction": ("ifeval", "mtbench"),
|
| 29 |
+
"creative": ("mtbench",),
|
| 30 |
+
"multilingual": ("mmlu_pro_multi",),
|
| 31 |
+
"simple_chat": ("mtbench",),
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass(frozen=True)
|
| 36 |
+
class BenchmarkScore:
|
| 37 |
+
score: float
|
| 38 |
+
citation: str
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass(frozen=True)
|
| 42 |
+
class EnergyProfile:
|
| 43 |
+
overhead_wh: float
|
| 44 |
+
prefill_wh_per_tok: float
|
| 45 |
+
decode_wh_per_tok: float
|
| 46 |
+
citation: str
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass(frozen=True)
|
| 50 |
+
class CostProfile:
|
| 51 |
+
input_per_mtok_usd: float
|
| 52 |
+
output_per_mtok_usd: float
|
| 53 |
+
citation: str
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclass(frozen=True)
|
| 57 |
+
class LatencyProfile:
|
| 58 |
+
first_token_ms: float
|
| 59 |
+
tokens_per_sec: float
|
| 60 |
+
citation: str
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass(frozen=True)
|
| 64 |
+
class ModelProfile:
|
| 65 |
+
id: str
|
| 66 |
+
display_name: str
|
| 67 |
+
family: str
|
| 68 |
+
parameter_count_b: float
|
| 69 |
+
benchmarks: dict[str, BenchmarkScore]
|
| 70 |
+
energy: EnergyProfile
|
| 71 |
+
cost: CostProfile
|
| 72 |
+
latency: LatencyProfile
|
| 73 |
+
is_open_weight: bool = False
|
| 74 |
+
notes: str = ""
|
| 75 |
+
|
| 76 |
+
def benchmark(self, key: str) -> float | None:
|
| 77 |
+
bs = self.benchmarks.get(key)
|
| 78 |
+
return bs.score if bs is not None else None
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _bench(scores: dict[str, tuple[float, str]]) -> dict[str, BenchmarkScore]:
|
| 82 |
+
return {k: BenchmarkScore(score=s, citation=c) for k, (s, c) in scores.items()}
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _energy_from_active_params(active_b: float, citation: str = "luccioni-2024") -> EnergyProfile:
|
| 86 |
+
"""Per-token energy scaled from measured Llama 7B/13B/30B/65B inference energy.
|
| 87 |
+
|
| 88 |
+
Anchor: Llama-1-65B ≈ 1.3 Wh per ~200-token completion (Luccioni 2024). Linearity
|
| 89 |
+
in active parameters across the same family. Per-token decode:
|
| 90 |
+
decode_wh_per_tok ≈ 1.0e-4 × active_params_in_billions
|
| 91 |
+
Prefill is ~0.35× decode (cache-amortized, parallel). Fixed overhead 0.6 Wh
|
| 92 |
+
accounts for network, scheduling, and KV setup amortized per query.
|
| 93 |
+
"""
|
| 94 |
+
decode = 1.0e-4 * active_b
|
| 95 |
+
prefill = 0.35 * decode
|
| 96 |
+
return EnergyProfile(
|
| 97 |
+
overhead_wh=0.6,
|
| 98 |
+
prefill_wh_per_tok=prefill,
|
| 99 |
+
decode_wh_per_tok=decode,
|
| 100 |
+
citation=citation,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _build_models() -> list[ModelProfile]:
|
| 105 |
+
models: list[ModelProfile] = []
|
| 106 |
+
|
| 107 |
+
models.append(ModelProfile(
|
| 108 |
+
id="gpt-4o",
|
| 109 |
+
display_name="GPT-4o",
|
| 110 |
+
family="OpenAI",
|
| 111 |
+
parameter_count_b=200.0,
|
| 112 |
+
benchmarks=_bench({
|
| 113 |
+
"mmlu": (0.887, "openai-gpt4o-2024"),
|
| 114 |
+
"gsm8k": (0.953, "openai-gpt4o-2024"),
|
| 115 |
+
"math": (0.766, "openai-gpt4o-2024"),
|
| 116 |
+
"humaneval": (0.902, "openai-gpt4o-2024"),
|
| 117 |
+
"mbpp": (0.875, "openai-gpt4o-2024"),
|
| 118 |
+
"bbh": (0.897, "openai-gpt4o-2024"),
|
| 119 |
+
"arc": (0.965, "openai-gpt4o-2024"),
|
| 120 |
+
"gpqa": (0.535, "openai-gpt4o-2024"),
|
| 121 |
+
"ifeval": (0.851, "openai-gpt4o-2024"),
|
| 122 |
+
"mtbench": (0.918, "lmsys-mtbench"),
|
| 123 |
+
"truthfulqa": (0.811, "openai-gpt4o-2024"),
|
| 124 |
+
"mmlu_pro_multi": (0.726, "openai-gpt4o-2024"),
|
| 125 |
+
}),
|
| 126 |
+
energy=_energy_from_active_params(200.0),
|
| 127 |
+
cost=CostProfile(2.50, 10.00, "openai-pricing-2024"),
|
| 128 |
+
latency=LatencyProfile(420.0, 90.0, "artificial-analysis-2024"),
|
| 129 |
+
is_open_weight=False,
|
| 130 |
+
notes="Closed-weight; parameter count is a public estimate.",
|
| 131 |
+
))
|
| 132 |
+
|
| 133 |
+
models.append(ModelProfile(
|
| 134 |
+
id="gpt-4o-mini",
|
| 135 |
+
display_name="GPT-4o mini",
|
| 136 |
+
family="OpenAI",
|
| 137 |
+
parameter_count_b=8.0,
|
| 138 |
+
benchmarks=_bench({
|
| 139 |
+
"mmlu": (0.820, "openai-gpt4omini-2024"),
|
| 140 |
+
"gsm8k": (0.870, "openai-gpt4omini-2024"),
|
| 141 |
+
"math": (0.702, "openai-gpt4omini-2024"),
|
| 142 |
+
"humaneval": (0.872, "openai-gpt4omini-2024"),
|
| 143 |
+
"mbpp": (0.842, "openai-gpt4omini-2024"),
|
| 144 |
+
"bbh": (0.816, "openai-gpt4omini-2024"),
|
| 145 |
+
"arc": (0.937, "openai-gpt4omini-2024"),
|
| 146 |
+
"gpqa": (0.402, "openai-gpt4omini-2024"),
|
| 147 |
+
"ifeval": (0.806, "openai-gpt4omini-2024"),
|
| 148 |
+
"mtbench": (0.852, "lmsys-mtbench"),
|
| 149 |
+
"truthfulqa": (0.745, "openai-gpt4omini-2024"),
|
| 150 |
+
"mmlu_pro_multi": (0.595, "openai-gpt4omini-2024"),
|
| 151 |
+
}),
|
| 152 |
+
energy=_energy_from_active_params(8.0),
|
| 153 |
+
cost=CostProfile(0.15, 0.60, "openai-pricing-2024"),
|
| 154 |
+
latency=LatencyProfile(310.0, 130.0, "artificial-analysis-2024"),
|
| 155 |
+
is_open_weight=False,
|
| 156 |
+
notes="Closed-weight; parameter count is a public estimate.",
|
| 157 |
+
))
|
| 158 |
+
|
| 159 |
+
models.append(ModelProfile(
|
| 160 |
+
id="claude-sonnet-4-5",
|
| 161 |
+
display_name="Claude Sonnet 4.5",
|
| 162 |
+
family="Anthropic",
|
| 163 |
+
parameter_count_b=180.0,
|
| 164 |
+
benchmarks=_bench({
|
| 165 |
+
"mmlu": (0.888, "anthropic-claude-2024"),
|
| 166 |
+
"gsm8k": (0.964, "anthropic-claude-2024"),
|
| 167 |
+
"math": (0.711, "anthropic-claude-2024"),
|
| 168 |
+
"humaneval": (0.920, "anthropic-claude-2024"),
|
| 169 |
+
"mbpp": (0.890, "anthropic-claude-2024"),
|
| 170 |
+
"bbh": (0.933, "anthropic-claude-2024"),
|
| 171 |
+
"arc": (0.965, "anthropic-claude-2024"),
|
| 172 |
+
"gpqa": (0.598, "anthropic-claude-2024"),
|
| 173 |
+
"ifeval": (0.876, "anthropic-claude-2024"),
|
| 174 |
+
"mtbench": (0.925, "lmsys-mtbench"),
|
| 175 |
+
"truthfulqa": (0.830, "anthropic-claude-2024"),
|
| 176 |
+
"mmlu_pro_multi": (0.752, "anthropic-claude-2024"),
|
| 177 |
+
}),
|
| 178 |
+
energy=_energy_from_active_params(180.0),
|
| 179 |
+
cost=CostProfile(3.00, 15.00, "anthropic-pricing-2024"),
|
| 180 |
+
latency=LatencyProfile(480.0, 75.0, "artificial-analysis-2024"),
|
| 181 |
+
is_open_weight=False,
|
| 182 |
+
notes="Closed-weight; parameter count is a public estimate.",
|
| 183 |
+
))
|
| 184 |
+
|
| 185 |
+
models.append(ModelProfile(
|
| 186 |
+
id="claude-haiku-4-5",
|
| 187 |
+
display_name="Claude Haiku 4.5",
|
| 188 |
+
family="Anthropic",
|
| 189 |
+
parameter_count_b=20.0,
|
| 190 |
+
benchmarks=_bench({
|
| 191 |
+
"mmlu": (0.762, "anthropic-claude-2024"),
|
| 192 |
+
"gsm8k": (0.901, "anthropic-claude-2024"),
|
| 193 |
+
"math": (0.512, "anthropic-claude-2024"),
|
| 194 |
+
"humaneval": (0.881, "anthropic-claude-2024"),
|
| 195 |
+
"mbpp": (0.852, "anthropic-claude-2024"),
|
| 196 |
+
"bbh": (0.752, "anthropic-claude-2024"),
|
| 197 |
+
"arc": (0.911, "anthropic-claude-2024"),
|
| 198 |
+
"gpqa": (0.412, "anthropic-claude-2024"),
|
| 199 |
+
"ifeval": (0.845, "anthropic-claude-2024"),
|
| 200 |
+
"mtbench": (0.871, "lmsys-mtbench"),
|
| 201 |
+
"truthfulqa": (0.748, "anthropic-claude-2024"),
|
| 202 |
+
"mmlu_pro_multi": (0.601, "anthropic-claude-2024"),
|
| 203 |
+
}),
|
| 204 |
+
energy=_energy_from_active_params(20.0),
|
| 205 |
+
cost=CostProfile(0.80, 4.00, "anthropic-pricing-2024"),
|
| 206 |
+
latency=LatencyProfile(260.0, 105.0, "artificial-analysis-2024"),
|
| 207 |
+
is_open_weight=False,
|
| 208 |
+
notes="Closed-weight; parameter count is a public estimate.",
|
| 209 |
+
))
|
| 210 |
+
|
| 211 |
+
models.append(ModelProfile(
|
| 212 |
+
id="gemini-1-5-pro",
|
| 213 |
+
display_name="Gemini 1.5 Pro",
|
| 214 |
+
family="Google",
|
| 215 |
+
parameter_count_b=140.0,
|
| 216 |
+
benchmarks=_bench({
|
| 217 |
+
"mmlu": (0.859, "google-gemini-2024"),
|
| 218 |
+
"gsm8k": (0.917, "google-gemini-2024"),
|
| 219 |
+
"math": (0.673, "google-gemini-2024"),
|
| 220 |
+
"humaneval": (0.841, "google-gemini-2024"),
|
| 221 |
+
"mbpp": (0.821, "google-gemini-2024"),
|
| 222 |
+
"bbh": (0.890, "google-gemini-2024"),
|
| 223 |
+
"arc": (0.960, "google-gemini-2024"),
|
| 224 |
+
"gpqa": (0.464, "google-gemini-2024"),
|
| 225 |
+
"ifeval": (0.815, "google-gemini-2024"),
|
| 226 |
+
"mtbench": (0.901, "lmsys-mtbench"),
|
| 227 |
+
"truthfulqa": (0.798, "google-gemini-2024"),
|
| 228 |
+
"mmlu_pro_multi": (0.731, "google-gemini-2024"),
|
| 229 |
+
}),
|
| 230 |
+
energy=_energy_from_active_params(140.0),
|
| 231 |
+
cost=CostProfile(1.25, 5.00, "google-pricing-2024"),
|
| 232 |
+
latency=LatencyProfile(680.0, 65.0, "artificial-analysis-2024"),
|
| 233 |
+
is_open_weight=False,
|
| 234 |
+
notes="Closed-weight; parameter count is a public estimate.",
|
| 235 |
+
))
|
| 236 |
+
|
| 237 |
+
models.append(ModelProfile(
|
| 238 |
+
id="gemini-1-5-flash",
|
| 239 |
+
display_name="Gemini 1.5 Flash",
|
| 240 |
+
family="Google",
|
| 241 |
+
parameter_count_b=8.0,
|
| 242 |
+
benchmarks=_bench({
|
| 243 |
+
"mmlu": (0.789, "google-gemini-2024"),
|
| 244 |
+
"gsm8k": (0.862, "google-gemini-2024"),
|
| 245 |
+
"math": (0.547, "google-gemini-2024"),
|
| 246 |
+
"humaneval": (0.743, "google-gemini-2024"),
|
| 247 |
+
"mbpp": (0.732, "google-gemini-2024"),
|
| 248 |
+
"bbh": (0.788, "google-gemini-2024"),
|
| 249 |
+
"arc": (0.918, "google-gemini-2024"),
|
| 250 |
+
"gpqa": (0.391, "google-gemini-2024"),
|
| 251 |
+
"ifeval": (0.762, "google-gemini-2024"),
|
| 252 |
+
"mtbench": (0.832, "lmsys-mtbench"),
|
| 253 |
+
"truthfulqa": (0.713, "google-gemini-2024"),
|
| 254 |
+
"mmlu_pro_multi": (0.591, "google-gemini-2024"),
|
| 255 |
+
}),
|
| 256 |
+
energy=_energy_from_active_params(8.0),
|
| 257 |
+
cost=CostProfile(0.075, 0.30, "google-pricing-2024"),
|
| 258 |
+
latency=LatencyProfile(210.0, 200.0, "artificial-analysis-2024"),
|
| 259 |
+
is_open_weight=False,
|
| 260 |
+
notes="Closed-weight; parameter count is a public estimate.",
|
| 261 |
+
))
|
| 262 |
+
|
| 263 |
+
models.append(ModelProfile(
|
| 264 |
+
id="llama-3-1-70b",
|
| 265 |
+
display_name="Llama 3.1 70B",
|
| 266 |
+
family="Meta",
|
| 267 |
+
parameter_count_b=70.0,
|
| 268 |
+
benchmarks=_bench({
|
| 269 |
+
"mmlu": (0.860, "meta-llama-3.1"),
|
| 270 |
+
"gsm8k": (0.951, "meta-llama-3.1"),
|
| 271 |
+
"math": (0.680, "meta-llama-3.1"),
|
| 272 |
+
"humaneval": (0.805, "meta-llama-3.1"),
|
| 273 |
+
"mbpp": (0.781, "meta-llama-3.1"),
|
| 274 |
+
"bbh": (0.853, "meta-llama-3.1"),
|
| 275 |
+
"arc": (0.948, "meta-llama-3.1"),
|
| 276 |
+
"gpqa": (0.461, "meta-llama-3.1"),
|
| 277 |
+
"ifeval": (0.873, "meta-llama-3.1"),
|
| 278 |
+
"mtbench": (0.882, "lmsys-mtbench"),
|
| 279 |
+
"truthfulqa": (0.722, "meta-llama-3.1"),
|
| 280 |
+
"mmlu_pro_multi": (0.659, "meta-llama-3.1"),
|
| 281 |
+
}),
|
| 282 |
+
energy=_energy_from_active_params(70.0),
|
| 283 |
+
cost=CostProfile(0.59, 0.79, "together-pricing-2024"),
|
| 284 |
+
latency=LatencyProfile(560.0, 55.0, "artificial-analysis-2024"),
|
| 285 |
+
is_open_weight=True,
|
| 286 |
+
))
|
| 287 |
+
|
| 288 |
+
models.append(ModelProfile(
|
| 289 |
+
id="llama-3-1-8b",
|
| 290 |
+
display_name="Llama 3.1 8B",
|
| 291 |
+
family="Meta",
|
| 292 |
+
parameter_count_b=8.0,
|
| 293 |
+
benchmarks=_bench({
|
| 294 |
+
"mmlu": (0.730, "meta-llama-3.1"),
|
| 295 |
+
"gsm8k": (0.845, "meta-llama-3.1"),
|
| 296 |
+
"math": (0.512, "meta-llama-3.1"),
|
| 297 |
+
"humaneval": (0.726, "meta-llama-3.1"),
|
| 298 |
+
"mbpp": (0.692, "meta-llama-3.1"),
|
| 299 |
+
"bbh": (0.731, "meta-llama-3.1"),
|
| 300 |
+
"arc": (0.908, "meta-llama-3.1"),
|
| 301 |
+
"gpqa": (0.342, "meta-llama-3.1"),
|
| 302 |
+
"ifeval": (0.802, "meta-llama-3.1"),
|
| 303 |
+
"mtbench": (0.802, "lmsys-mtbench"),
|
| 304 |
+
"truthfulqa": (0.659, "meta-llama-3.1"),
|
| 305 |
+
"mmlu_pro_multi": (0.491, "meta-llama-3.1"),
|
| 306 |
+
}),
|
| 307 |
+
energy=_energy_from_active_params(8.0),
|
| 308 |
+
cost=CostProfile(0.18, 0.18, "together-pricing-2024"),
|
| 309 |
+
latency=LatencyProfile(150.0, 200.0, "artificial-analysis-2024"),
|
| 310 |
+
is_open_weight=True,
|
| 311 |
+
))
|
| 312 |
+
|
| 313 |
+
models.append(ModelProfile(
|
| 314 |
+
id="mistral-large-2",
|
| 315 |
+
display_name="Mistral Large 2",
|
| 316 |
+
family="Mistral",
|
| 317 |
+
parameter_count_b=123.0,
|
| 318 |
+
benchmarks=_bench({
|
| 319 |
+
"mmlu": (0.840, "mistral-large-2"),
|
| 320 |
+
"gsm8k": (0.911, "mistral-large-2"),
|
| 321 |
+
"math": (0.715, "mistral-large-2"),
|
| 322 |
+
"humaneval": (0.920, "mistral-large-2"),
|
| 323 |
+
"mbpp": (0.860, "mistral-large-2"),
|
| 324 |
+
"bbh": (0.802, "mistral-large-2"),
|
| 325 |
+
"arc": (0.932, "mistral-large-2"),
|
| 326 |
+
"gpqa": (0.421, "mistral-large-2"),
|
| 327 |
+
"ifeval": (0.811, "mistral-large-2"),
|
| 328 |
+
"mtbench": (0.871, "lmsys-mtbench"),
|
| 329 |
+
"truthfulqa": (0.701, "mistral-large-2"),
|
| 330 |
+
"mmlu_pro_multi": (0.682, "mistral-large-2"),
|
| 331 |
+
}),
|
| 332 |
+
energy=_energy_from_active_params(123.0),
|
| 333 |
+
cost=CostProfile(2.00, 6.00, "mistral-pricing-2024"),
|
| 334 |
+
latency=LatencyProfile(510.0, 62.0, "artificial-analysis-2024"),
|
| 335 |
+
is_open_weight=True,
|
| 336 |
+
))
|
| 337 |
+
|
| 338 |
+
models.append(ModelProfile(
|
| 339 |
+
id="qwen-2-5-72b",
|
| 340 |
+
display_name="Qwen 2.5 72B",
|
| 341 |
+
family="Alibaba",
|
| 342 |
+
parameter_count_b=72.0,
|
| 343 |
+
benchmarks=_bench({
|
| 344 |
+
"mmlu": (0.861, "qwen-2.5-2024"),
|
| 345 |
+
"gsm8k": (0.958, "qwen-2.5-2024"),
|
| 346 |
+
"math": (0.831, "qwen-2.5-2024"),
|
| 347 |
+
"humaneval": (0.866, "qwen-2.5-2024"),
|
| 348 |
+
"mbpp": (0.823, "qwen-2.5-2024"),
|
| 349 |
+
"bbh": (0.868, "qwen-2.5-2024"),
|
| 350 |
+
"arc": (0.943, "qwen-2.5-2024"),
|
| 351 |
+
"gpqa": (0.490, "qwen-2.5-2024"),
|
| 352 |
+
"ifeval": (0.842, "qwen-2.5-2024"),
|
| 353 |
+
"mtbench": (0.875, "lmsys-mtbench"),
|
| 354 |
+
"truthfulqa": (0.690, "qwen-2.5-2024"),
|
| 355 |
+
"mmlu_pro_multi": (0.711, "qwen-2.5-2024"),
|
| 356 |
+
}),
|
| 357 |
+
energy=_energy_from_active_params(72.0),
|
| 358 |
+
cost=CostProfile(0.90, 0.90, "together-pricing-2024"),
|
| 359 |
+
latency=LatencyProfile(580.0, 50.0, "artificial-analysis-2024"),
|
| 360 |
+
is_open_weight=True,
|
| 361 |
+
))
|
| 362 |
+
|
| 363 |
+
models.append(ModelProfile(
|
| 364 |
+
id="qwen-2-5-7b",
|
| 365 |
+
display_name="Qwen 2.5 7B",
|
| 366 |
+
family="Alibaba",
|
| 367 |
+
parameter_count_b=7.0,
|
| 368 |
+
benchmarks=_bench({
|
| 369 |
+
"mmlu": (0.742, "qwen-2.5-2024"),
|
| 370 |
+
"gsm8k": (0.854, "qwen-2.5-2024"),
|
| 371 |
+
"math": (0.620, "qwen-2.5-2024"),
|
| 372 |
+
"humaneval": (0.848, "qwen-2.5-2024"),
|
| 373 |
+
"mbpp": (0.802, "qwen-2.5-2024"),
|
| 374 |
+
"bbh": (0.701, "qwen-2.5-2024"),
|
| 375 |
+
"arc": (0.901, "qwen-2.5-2024"),
|
| 376 |
+
"gpqa": (0.341, "qwen-2.5-2024"),
|
| 377 |
+
"ifeval": (0.752, "qwen-2.5-2024"),
|
| 378 |
+
"mtbench": (0.802, "lmsys-mtbench"),
|
| 379 |
+
"truthfulqa": (0.612, "qwen-2.5-2024"),
|
| 380 |
+
"mmlu_pro_multi": (0.521, "qwen-2.5-2024"),
|
| 381 |
+
}),
|
| 382 |
+
energy=_energy_from_active_params(7.0),
|
| 383 |
+
cost=CostProfile(0.20, 0.20, "together-pricing-2024"),
|
| 384 |
+
latency=LatencyProfile(140.0, 180.0, "artificial-analysis-2024"),
|
| 385 |
+
is_open_weight=True,
|
| 386 |
+
))
|
| 387 |
+
|
| 388 |
+
return models
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
CITATIONS: dict[str, str] = {
|
| 392 |
+
"openai-gpt4o-2024": "OpenAI. GPT-4o System Card and benchmark suite, 2024.",
|
| 393 |
+
"openai-gpt4omini-2024": "OpenAI. GPT-4o mini benchmark report, July 2024.",
|
| 394 |
+
"anthropic-claude-2024": "Anthropic. Claude 4.5 model family evaluation report, 2024.",
|
| 395 |
+
"google-gemini-2024": "Google DeepMind. Gemini 1.5 technical report, 2024.",
|
| 396 |
+
"meta-llama-3.1": "Meta AI. Llama 3.1 evaluation benchmarks, 2024.",
|
| 397 |
+
"mistral-large-2": "Mistral AI. Mistral Large 2 release notes, July 2024.",
|
| 398 |
+
"qwen-2.5-2024": "Alibaba Qwen Team. Qwen2.5 technical report, September 2024.",
|
| 399 |
+
"lmsys-mtbench": "Zheng et al. MT-Bench leaderboard, lmsys.org, 2024 snapshot.",
|
| 400 |
+
"luccioni-2024": (
|
| 401 |
+
"Luccioni, Jernite, Strubell. Power Hungry Processing: Watts Driving the Cost of "
|
| 402 |
+
"AI Deployment? FAccT 2024."
|
| 403 |
+
),
|
| 404 |
+
"openai-pricing-2024": "OpenAI API pricing page, retrieved 2024.",
|
| 405 |
+
"anthropic-pricing-2024": "Anthropic API pricing page, retrieved 2024.",
|
| 406 |
+
"google-pricing-2024": "Google AI for Developers pricing page, retrieved 2024.",
|
| 407 |
+
"mistral-pricing-2024": "Mistral AI pricing page, retrieved 2024.",
|
| 408 |
+
"together-pricing-2024": "Together AI inference pricing, retrieved 2024.",
|
| 409 |
+
"artificial-analysis-2024": "Artificial Analysis Inc. Latency benchmarks, artificialanalysis.ai, 2024.",
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
@dataclass
|
| 414 |
+
class Registry:
|
| 415 |
+
models: list[ModelProfile] = field(default_factory=list)
|
| 416 |
+
|
| 417 |
+
def get(self, model_id: str) -> ModelProfile:
|
| 418 |
+
for m in self.models:
|
| 419 |
+
if m.id == model_id:
|
| 420 |
+
return m
|
| 421 |
+
raise KeyError(f"unknown model id: {model_id}")
|
| 422 |
+
|
| 423 |
+
def all(self) -> list[ModelProfile]:
|
| 424 |
+
return list(self.models)
|
| 425 |
+
|
| 426 |
+
def ids(self) -> list[str]:
|
| 427 |
+
return [m.id for m in self.models]
|
| 428 |
+
|
| 429 |
+
def by_family(self, family: str) -> list[ModelProfile]:
|
| 430 |
+
return [m for m in self.models if m.family == family]
|
| 431 |
+
|
| 432 |
+
def __iter__(self) -> Iterable[ModelProfile]:
|
| 433 |
+
return iter(self.models)
|
| 434 |
+
|
| 435 |
+
def __len__(self) -> int:
|
| 436 |
+
return len(self.models)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def default_registry() -> Registry:
|
| 440 |
+
return Registry(models=_build_models())
|
greenrouting/routing/scorer.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Per-candidate scoring: quality fit, energy, cost, latency."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
from greenrouting.classifier.infer import QueryProfile
|
| 8 |
+
from greenrouting.energy.estimator import (
|
| 9 |
+
estimate_cost_usd,
|
| 10 |
+
estimate_energy_wh,
|
| 11 |
+
estimate_latency_seconds,
|
| 12 |
+
)
|
| 13 |
+
from greenrouting.routing.registry import (
|
| 14 |
+
CAPABILITY_BENCHMARKS,
|
| 15 |
+
CAPABILITY_KEYS,
|
| 16 |
+
ModelProfile,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
CAPABILITY_PROB_FLOOR: float = 0.10
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class CandidateScore:
|
| 24 |
+
model_id: str
|
| 25 |
+
display_name: str
|
| 26 |
+
quality: float
|
| 27 |
+
energy_wh: float
|
| 28 |
+
cost_usd: float
|
| 29 |
+
latency_s: float
|
| 30 |
+
qualifies: bool
|
| 31 |
+
|
| 32 |
+
def as_dict(self) -> dict:
|
| 33 |
+
return {
|
| 34 |
+
"model_id": self.model_id,
|
| 35 |
+
"display_name": self.display_name,
|
| 36 |
+
"quality": round(self.quality, 4),
|
| 37 |
+
"energy_wh": round(self.energy_wh, 4),
|
| 38 |
+
"cost_usd": round(self.cost_usd, 6),
|
| 39 |
+
"latency_s": round(self.latency_s, 3),
|
| 40 |
+
"qualifies": self.qualifies,
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def quality_fit(profile: QueryProfile, model: ModelProfile) -> float:
|
| 45 |
+
"""Capability-probability-weighted average of benchmark scores."""
|
| 46 |
+
cap_probs = profile.capabilities.as_dict()
|
| 47 |
+
weighted_sum = 0.0
|
| 48 |
+
weight_total = 0.0
|
| 49 |
+
for cap in CAPABILITY_KEYS:
|
| 50 |
+
prob = cap_probs[cap]
|
| 51 |
+
if prob < CAPABILITY_PROB_FLOOR:
|
| 52 |
+
continue
|
| 53 |
+
bench_keys = CAPABILITY_BENCHMARKS.get(cap, ())
|
| 54 |
+
bench_scores = [model.benchmark(b) for b in bench_keys]
|
| 55 |
+
bench_scores = [b for b in bench_scores if b is not None]
|
| 56 |
+
if not bench_scores:
|
| 57 |
+
continue
|
| 58 |
+
avg = sum(bench_scores) / len(bench_scores)
|
| 59 |
+
weighted_sum += prob * avg
|
| 60 |
+
weight_total += prob
|
| 61 |
+
if weight_total == 0:
|
| 62 |
+
# Fall back to MMLU as a generic competency floor.
|
| 63 |
+
mmlu = model.benchmark("mmlu")
|
| 64 |
+
return mmlu if mmlu is not None else 0.0
|
| 65 |
+
return weighted_sum / weight_total
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def score_candidate(
|
| 69 |
+
profile: QueryProfile,
|
| 70 |
+
model: ModelProfile,
|
| 71 |
+
quality_floor: float,
|
| 72 |
+
) -> CandidateScore:
|
| 73 |
+
fit = quality_fit(profile, model)
|
| 74 |
+
energy = estimate_energy_wh(
|
| 75 |
+
model,
|
| 76 |
+
profile.expected_input_tokens,
|
| 77 |
+
profile.expected_output_tokens_p50,
|
| 78 |
+
)
|
| 79 |
+
cost = estimate_cost_usd(
|
| 80 |
+
model,
|
| 81 |
+
profile.expected_input_tokens,
|
| 82 |
+
profile.expected_output_tokens_p50,
|
| 83 |
+
)
|
| 84 |
+
latency = estimate_latency_seconds(model, profile.expected_output_tokens_p50)
|
| 85 |
+
return CandidateScore(
|
| 86 |
+
model_id=model.id,
|
| 87 |
+
display_name=model.display_name,
|
| 88 |
+
quality=fit,
|
| 89 |
+
energy_wh=energy,
|
| 90 |
+
cost_usd=cost,
|
| 91 |
+
latency_s=latency,
|
| 92 |
+
qualifies=fit >= quality_floor,
|
| 93 |
+
)
|
mapper.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Maps the GreenRouting classifier output to the partner's response schema.
|
| 2 |
+
|
| 3 |
+
Inputs:
|
| 4 |
+
- QueryProfile from greenrouting.classifier (8 capability probabilities,
|
| 5 |
+
continuous difficulty in log-parameters, length distribution)
|
| 6 |
+
- PartnerRegistry of candidate models (tier + per-category 1-10 scores + cost)
|
| 7 |
+
|
| 8 |
+
Outputs:
|
| 9 |
+
- capability_weights: dict[7-key partner schema -> float in 0..1]
|
| 10 |
+
- category: argmax over the 5-category public set
|
| 11 |
+
- complexity: simple|moderate|complex
|
| 12 |
+
- difficulty: integer 1..5
|
| 13 |
+
- chosen model_id from the registry
|
| 14 |
+
- energy_savings_pct vs an always-ultra-tier baseline
|
| 15 |
+
- reason string for the partner's audit log
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import math
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
from greenrouting.classifier.infer import QueryProfile
|
| 24 |
+
|
| 25 |
+
from partner_registry import PARTNER_SCORE_KEYS, PartnerModel, PartnerRegistry
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
PUBLIC_CATEGORIES: tuple[str, ...] = ("chat", "code", "math", "research", "creative")
|
| 29 |
+
COMPLEXITY_BUCKETS: tuple[str, ...] = ("simple", "moderate", "complex")
|
| 30 |
+
ULTRA_BASELINE_COST: int = 10
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def rebucket_capabilities(profile: QueryProfile) -> dict[str, float]:
|
| 34 |
+
"""Map our 8 internal capabilities to the partner's 7 score categories."""
|
| 35 |
+
c = profile.capabilities
|
| 36 |
+
coding = c.code
|
| 37 |
+
math_ = c.math
|
| 38 |
+
research = min(1.0, c.reasoning + c.knowledge)
|
| 39 |
+
creative = c.creative
|
| 40 |
+
chat = min(1.0, c.simple_chat + c.instruction)
|
| 41 |
+
roleplay = c.creative * 0.5
|
| 42 |
+
ideas = min(1.0, (c.creative + c.reasoning) * 0.4)
|
| 43 |
+
return {
|
| 44 |
+
"coding": round(coding, 3),
|
| 45 |
+
"math": round(math_, 3),
|
| 46 |
+
"research": round(research, 3),
|
| 47 |
+
"creative": round(creative, 3),
|
| 48 |
+
"chat": round(chat, 3),
|
| 49 |
+
"roleplay": round(roleplay, 3),
|
| 50 |
+
"ideas": round(ideas, 3),
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def pick_category(weights: dict[str, float]) -> str:
|
| 55 |
+
public = {k: weights[k] for k in ("chat", "coding", "math", "research", "creative")}
|
| 56 |
+
top = max(public, key=public.get)
|
| 57 |
+
if top == "coding":
|
| 58 |
+
return "code"
|
| 59 |
+
return top
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def pick_complexity(profile: QueryProfile) -> str:
|
| 63 |
+
log_p = profile.difficulty_log_params
|
| 64 |
+
if log_p < math.log(3e9):
|
| 65 |
+
return "simple"
|
| 66 |
+
if log_p < math.log(20e9):
|
| 67 |
+
return "moderate"
|
| 68 |
+
return "complex"
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def pick_difficulty_int(profile: QueryProfile) -> int:
|
| 72 |
+
log_p = profile.difficulty_log_params
|
| 73 |
+
boundaries = [math.log(b * 1e9) for b in (1, 5, 15, 50)]
|
| 74 |
+
rank = 1
|
| 75 |
+
for b in boundaries:
|
| 76 |
+
if log_p >= b:
|
| 77 |
+
rank += 1
|
| 78 |
+
else:
|
| 79 |
+
break
|
| 80 |
+
return min(5, max(1, rank))
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _allowed_tiers(difficulty: int) -> set[str]:
|
| 84 |
+
if difficulty <= 1:
|
| 85 |
+
return {"lite", "standard"}
|
| 86 |
+
if difficulty == 2:
|
| 87 |
+
return {"lite", "standard"}
|
| 88 |
+
if difficulty == 3:
|
| 89 |
+
return {"standard", "pro"}
|
| 90 |
+
if difficulty == 4:
|
| 91 |
+
return {"pro", "ultra"}
|
| 92 |
+
return {"ultra"}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def quality_fit(model: PartnerModel, weights: dict[str, float]) -> float:
|
| 96 |
+
total_weight = sum(weights[k] for k in PARTNER_SCORE_KEYS) or 1.0
|
| 97 |
+
weighted = sum(weights[k] * (model.scores.get(k, 0) / 10.0) for k in PARTNER_SCORE_KEYS)
|
| 98 |
+
return weighted / total_weight
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _best_ultra(registry: PartnerRegistry, weights: dict[str, float]) -> PartnerModel:
|
| 102 |
+
ultras = registry.by_tier("ultra")
|
| 103 |
+
pool = ultras if ultras else registry.models
|
| 104 |
+
return max(pool, key=lambda m: quality_fit(m, weights))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def select_model(
|
| 108 |
+
registry: PartnerRegistry,
|
| 109 |
+
weights: dict[str, float],
|
| 110 |
+
difficulty: int,
|
| 111 |
+
is_ood: bool = False,
|
| 112 |
+
quality_floor_ratio: float = 0.65,
|
| 113 |
+
) -> tuple[PartnerModel, bool]:
|
| 114 |
+
"""Returns (chosen_model, escalated). Escalated means we fell back to the
|
| 115 |
+
ultra-tier anchor (low confidence in the prediction)."""
|
| 116 |
+
if not registry.models:
|
| 117 |
+
raise ValueError("partner registry is empty")
|
| 118 |
+
|
| 119 |
+
if is_ood:
|
| 120 |
+
return _best_ultra(registry, weights), True
|
| 121 |
+
|
| 122 |
+
allowed = registry.by_tier(*_allowed_tiers(difficulty))
|
| 123 |
+
if not allowed:
|
| 124 |
+
return _best_ultra(registry, weights), True
|
| 125 |
+
|
| 126 |
+
best_allowed = max(allowed, key=lambda m: quality_fit(m, weights))
|
| 127 |
+
floor = quality_fit(best_allowed, weights) * quality_floor_ratio
|
| 128 |
+
|
| 129 |
+
qualifying = [m for m in allowed if quality_fit(m, weights) >= floor]
|
| 130 |
+
if not qualifying:
|
| 131 |
+
return best_allowed, False
|
| 132 |
+
|
| 133 |
+
chosen = min(qualifying, key=lambda m: (m.cost, -quality_fit(m, weights)))
|
| 134 |
+
return chosen, False
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def energy_savings_pct(chosen: PartnerModel, baseline_cost: int = ULTRA_BASELINE_COST) -> float:
|
| 138 |
+
if baseline_cost <= 0:
|
| 139 |
+
return 0.0
|
| 140 |
+
saved = (baseline_cost - chosen.cost) / baseline_cost
|
| 141 |
+
return max(0.0, min(1.0, saved)) * 100.0
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def build_reason(
|
| 145 |
+
weights: dict[str, float],
|
| 146 |
+
complexity: str,
|
| 147 |
+
chosen: PartnerModel,
|
| 148 |
+
escalated: bool,
|
| 149 |
+
is_ood: bool = False,
|
| 150 |
+
) -> str:
|
| 151 |
+
top_cap, top_score = max(weights.items(), key=lambda kv: kv[1])
|
| 152 |
+
bits: list[str] = []
|
| 153 |
+
if is_ood:
|
| 154 |
+
bits.append("low-confidence input (escalated to ultra tier)")
|
| 155 |
+
elif top_score >= 0.5:
|
| 156 |
+
bits.append(f"{top_cap} dominant ({top_score:.2f})")
|
| 157 |
+
else:
|
| 158 |
+
bits.append("mixed signal")
|
| 159 |
+
if not is_ood:
|
| 160 |
+
bits.append(f"{complexity} difficulty")
|
| 161 |
+
if escalated and not is_ood:
|
| 162 |
+
bits.append("escalated (no qualifying tier-allowed model)")
|
| 163 |
+
elif not escalated:
|
| 164 |
+
bits.append(f"picked {chosen.tier} tier (cost {chosen.cost})")
|
| 165 |
+
return ", ".join(bits)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def fold_recent_context(message: str, recent: Optional[list[dict]]) -> str:
|
| 169 |
+
if not recent:
|
| 170 |
+
return message
|
| 171 |
+
last = recent[-1]
|
| 172 |
+
content = (last.get("content") or "")[:200] if isinstance(last, dict) else ""
|
| 173 |
+
if not content:
|
| 174 |
+
return message
|
| 175 |
+
return f"{content}\n{message}"
|
models/classifier_v1/calibration.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"temperature": 1.9504629214867681
|
| 3 |
+
}
|
models/classifier_v1/encoder_name.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
BAAI/bge-small-en-v1.5
|
models/classifier_v1/head.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ce9f40ded994d27f8695683ec77d2abfa57f36380c9f5767e074411b6f34ce22
|
| 3 |
+
size 673429
|
models/classifier_v1/metadata.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"capability_keys": [
|
| 3 |
+
"code",
|
| 4 |
+
"math",
|
| 5 |
+
"reasoning",
|
| 6 |
+
"knowledge",
|
| 7 |
+
"instruction",
|
| 8 |
+
"creative",
|
| 9 |
+
"multilingual",
|
| 10 |
+
"simple_chat"
|
| 11 |
+
],
|
| 12 |
+
"length_buckets": [
|
| 13 |
+
"short",
|
| 14 |
+
"medium",
|
| 15 |
+
"long"
|
| 16 |
+
],
|
| 17 |
+
"embedding_dim": 384,
|
| 18 |
+
"hidden_dim": 256,
|
| 19 |
+
"max_seq_len": 256,
|
| 20 |
+
"diff_target_center": 22.80270737862625
|
| 21 |
+
}
|
models/classifier_v1/ood_stats.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c71aa5f272960594a5b5f043699136f9595c2f9ac0bf6c9cb918c3556482cc03
|
| 3 |
+
size 518936
|
models/classifier_v1/training_history.json
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"epoch": 0,
|
| 4 |
+
"train_loss": 1.3644548597789945,
|
| 5 |
+
"cap_precision": 0.0,
|
| 6 |
+
"cap_recall": 0.0,
|
| 7 |
+
"cap_f1": 0.0,
|
| 8 |
+
"diff_mae": 0.8598755598068237,
|
| 9 |
+
"len_acc": 0.559322033898305
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"epoch": 1,
|
| 13 |
+
"train_loss": 1.1052290626934596,
|
| 14 |
+
"cap_precision": 0.3870967741935484,
|
| 15 |
+
"cap_recall": 0.14457831325301204,
|
| 16 |
+
"cap_f1": 0.21052631578947364,
|
| 17 |
+
"diff_mae": 0.7214581966400146,
|
| 18 |
+
"len_acc": 0.7288135593220338
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"epoch": 2,
|
| 22 |
+
"train_loss": 0.9240592576208568,
|
| 23 |
+
"cap_precision": 0.45161290322580644,
|
| 24 |
+
"cap_recall": 0.1686746987951807,
|
| 25 |
+
"cap_f1": 0.24561403508771928,
|
| 26 |
+
"diff_mae": 0.7095464468002319,
|
| 27 |
+
"len_acc": 0.6949152542372882
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"epoch": 3,
|
| 31 |
+
"train_loss": 0.8142032254309881,
|
| 32 |
+
"cap_precision": 0.6,
|
| 33 |
+
"cap_recall": 0.2891566265060241,
|
| 34 |
+
"cap_f1": 0.3902439024390244,
|
| 35 |
+
"diff_mae": 0.7512079477310181,
|
| 36 |
+
"len_acc": 0.711864406779661
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"epoch": 4,
|
| 40 |
+
"train_loss": 0.7119971627280826,
|
| 41 |
+
"cap_precision": 0.4457831325301205,
|
| 42 |
+
"cap_recall": 0.4457831325301205,
|
| 43 |
+
"cap_f1": 0.4457831325301205,
|
| 44 |
+
"diff_mae": 0.6864577531814575,
|
| 45 |
+
"len_acc": 0.6610169491525424
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"epoch": 5,
|
| 49 |
+
"train_loss": 0.629969752970196,
|
| 50 |
+
"cap_precision": 0.5930232558139535,
|
| 51 |
+
"cap_recall": 0.6144578313253012,
|
| 52 |
+
"cap_f1": 0.6035502958579881,
|
| 53 |
+
"diff_mae": 0.6868146061897278,
|
| 54 |
+
"len_acc": 0.6949152542372882
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"epoch": 6,
|
| 58 |
+
"train_loss": 0.562128157842727,
|
| 59 |
+
"cap_precision": 0.5888888888888889,
|
| 60 |
+
"cap_recall": 0.6385542168674698,
|
| 61 |
+
"cap_f1": 0.6127167630057803,
|
| 62 |
+
"diff_mae": 0.7360132336616516,
|
| 63 |
+
"len_acc": 0.6949152542372882
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"epoch": 7,
|
| 67 |
+
"train_loss": 0.4673078145299639,
|
| 68 |
+
"cap_precision": 0.6304347826086957,
|
| 69 |
+
"cap_recall": 0.6987951807228916,
|
| 70 |
+
"cap_f1": 0.6628571428571429,
|
| 71 |
+
"diff_mae": 0.7615002393722534,
|
| 72 |
+
"len_acc": 0.6779661016949152
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"epoch": 8,
|
| 76 |
+
"train_loss": 0.4211572749274118,
|
| 77 |
+
"cap_precision": 0.6703296703296703,
|
| 78 |
+
"cap_recall": 0.7349397590361446,
|
| 79 |
+
"cap_f1": 0.7011494252873562,
|
| 80 |
+
"diff_mae": 0.7574763894081116,
|
| 81 |
+
"len_acc": 0.6101694915254238
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"epoch": 9,
|
| 85 |
+
"train_loss": 0.3946749922775087,
|
| 86 |
+
"cap_precision": 0.7,
|
| 87 |
+
"cap_recall": 0.6746987951807228,
|
| 88 |
+
"cap_f1": 0.6871165644171778,
|
| 89 |
+
"diff_mae": 0.7892473340034485,
|
| 90 |
+
"len_acc": 0.6610169491525424
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"epoch": 10,
|
| 94 |
+
"train_loss": 0.34337903772081646,
|
| 95 |
+
"cap_precision": 0.7142857142857143,
|
| 96 |
+
"cap_recall": 0.7228915662650602,
|
| 97 |
+
"cap_f1": 0.718562874251497,
|
| 98 |
+
"diff_mae": 0.7348798513412476,
|
| 99 |
+
"len_acc": 0.5932203389830508
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"epoch": 11,
|
| 103 |
+
"train_loss": 0.2987311219885236,
|
| 104 |
+
"cap_precision": 0.7228915662650602,
|
| 105 |
+
"cap_recall": 0.7228915662650602,
|
| 106 |
+
"cap_f1": 0.7228915662650603,
|
| 107 |
+
"diff_mae": 0.7976469993591309,
|
| 108 |
+
"len_acc": 0.6271186440677966
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"epoch": 12,
|
| 112 |
+
"train_loss": 0.27304122419584365,
|
| 113 |
+
"cap_precision": 0.7792207792207793,
|
| 114 |
+
"cap_recall": 0.7228915662650602,
|
| 115 |
+
"cap_f1": 0.75,
|
| 116 |
+
"diff_mae": 0.8239098787307739,
|
| 117 |
+
"len_acc": 0.6610169491525424
|
| 118 |
+
},
|
| 119 |
+
{
|
| 120 |
+
"epoch": 13,
|
| 121 |
+
"train_loss": 0.24270852761609213,
|
| 122 |
+
"cap_precision": 0.7763157894736842,
|
| 123 |
+
"cap_recall": 0.7108433734939759,
|
| 124 |
+
"cap_f1": 0.7421383647798742,
|
| 125 |
+
"diff_mae": 0.8853136301040649,
|
| 126 |
+
"len_acc": 0.6610169491525424
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"epoch": 14,
|
| 130 |
+
"train_loss": 0.2204317024775914,
|
| 131 |
+
"cap_precision": 0.8055555555555556,
|
| 132 |
+
"cap_recall": 0.6987951807228916,
|
| 133 |
+
"cap_f1": 0.7483870967741936,
|
| 134 |
+
"diff_mae": 0.7929121851921082,
|
| 135 |
+
"len_acc": 0.6779661016949152
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"epoch": 15,
|
| 139 |
+
"train_loss": 0.18839974346615018,
|
| 140 |
+
"cap_precision": 0.8082191780821918,
|
| 141 |
+
"cap_recall": 0.7108433734939759,
|
| 142 |
+
"cap_f1": 0.7564102564102564,
|
| 143 |
+
"diff_mae": 0.8756879568099976,
|
| 144 |
+
"len_acc": 0.6440677966101694
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"epoch": 16,
|
| 148 |
+
"train_loss": 0.1629447014558883,
|
| 149 |
+
"cap_precision": 0.8169014084507042,
|
| 150 |
+
"cap_recall": 0.6987951807228916,
|
| 151 |
+
"cap_f1": 0.7532467532467533,
|
| 152 |
+
"diff_mae": 0.7843820452690125,
|
| 153 |
+
"len_acc": 0.6271186440677966
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"epoch": 17,
|
| 157 |
+
"train_loss": 0.1407343131445703,
|
| 158 |
+
"cap_precision": 0.7792207792207793,
|
| 159 |
+
"cap_recall": 0.7228915662650602,
|
| 160 |
+
"cap_f1": 0.75,
|
| 161 |
+
"diff_mae": 0.8463668823242188,
|
| 162 |
+
"len_acc": 0.6101694915254238
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"epoch": 18,
|
| 166 |
+
"train_loss": 0.1221757513426599,
|
| 167 |
+
"cap_precision": 0.8024691358024691,
|
| 168 |
+
"cap_recall": 0.7831325301204819,
|
| 169 |
+
"cap_f1": 0.7926829268292682,
|
| 170 |
+
"diff_mae": 0.8194808959960938,
|
| 171 |
+
"len_acc": 0.6440677966101694
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"epoch": 19,
|
| 175 |
+
"train_loss": 0.11088378017856962,
|
| 176 |
+
"cap_precision": 0.8051948051948052,
|
| 177 |
+
"cap_recall": 0.7469879518072289,
|
| 178 |
+
"cap_f1": 0.7749999999999999,
|
| 179 |
+
"diff_mae": 0.8044652938842773,
|
| 180 |
+
"len_acc": 0.6610169491525424
|
| 181 |
+
}
|
| 182 |
+
]
|
partner_registry.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Loader for the downstream partner's model registry.
|
| 2 |
+
|
| 3 |
+
The partner ships a JSON list of model entries, each with an `id`, `tier`,
|
| 4 |
+
`scores` (per-category 1-10), and `cost` (1-10). This file does not ship the
|
| 5 |
+
registry data itself - it is loaded at runtime from a path supplied via the
|
| 6 |
+
PARTNER_REGISTRY_PATH environment variable, kept outside source control.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import os
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
PARTNER_SCORE_KEYS: tuple[str, ...] = (
|
| 19 |
+
"coding",
|
| 20 |
+
"math",
|
| 21 |
+
"research",
|
| 22 |
+
"creative",
|
| 23 |
+
"chat",
|
| 24 |
+
"roleplay",
|
| 25 |
+
"ideas",
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
TIERS: tuple[str, ...] = ("lite", "standard", "pro", "ultra")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass(frozen=True)
|
| 32 |
+
class PartnerModel:
|
| 33 |
+
id: str
|
| 34 |
+
tier: str
|
| 35 |
+
is_open_router: bool
|
| 36 |
+
strengths: tuple[str, ...]
|
| 37 |
+
scores: dict[str, int]
|
| 38 |
+
cost: int
|
| 39 |
+
|
| 40 |
+
def fits_tier(self, tier_set: set[str]) -> bool:
|
| 41 |
+
return self.tier in tier_set
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
class PartnerRegistry:
|
| 46 |
+
models: list[PartnerModel]
|
| 47 |
+
|
| 48 |
+
def all(self) -> list[PartnerModel]:
|
| 49 |
+
return list(self.models)
|
| 50 |
+
|
| 51 |
+
def by_tier(self, *tiers: str) -> list[PartnerModel]:
|
| 52 |
+
keep = set(tiers)
|
| 53 |
+
return [m for m in self.models if m.tier in keep]
|
| 54 |
+
|
| 55 |
+
def get(self, model_id: str) -> Optional[PartnerModel]:
|
| 56 |
+
for m in self.models:
|
| 57 |
+
if m.id == model_id:
|
| 58 |
+
return m
|
| 59 |
+
return None
|
| 60 |
+
|
| 61 |
+
def __len__(self) -> int:
|
| 62 |
+
return len(self.models)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _coerce(entry: dict) -> PartnerModel:
|
| 66 |
+
scores = {k: int(entry.get("scores", {}).get(k, 0)) for k in PARTNER_SCORE_KEYS}
|
| 67 |
+
return PartnerModel(
|
| 68 |
+
id=str(entry["id"]),
|
| 69 |
+
tier=str(entry.get("tier", "standard")).lower(),
|
| 70 |
+
is_open_router=bool(entry.get("isOpenRouter", False)),
|
| 71 |
+
strengths=tuple(entry.get("strengths", [])),
|
| 72 |
+
scores=scores,
|
| 73 |
+
cost=int(entry.get("cost", 5)),
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def load_registry(path: str | Path | None = None) -> PartnerRegistry:
|
| 78 |
+
"""Loads from one of three sources, in priority order:
|
| 79 |
+
|
| 80 |
+
1. The `path` argument, if supplied.
|
| 81 |
+
2. The PARTNER_REGISTRY_JSON env var containing the raw JSON content (used
|
| 82 |
+
in deployments where a file is awkward to ship, e.g. HF Space secrets).
|
| 83 |
+
3. The PARTNER_REGISTRY_PATH env var pointing at a JSON file on disk.
|
| 84 |
+
"""
|
| 85 |
+
raw_text: Optional[str] = None
|
| 86 |
+
source = "argument"
|
| 87 |
+
|
| 88 |
+
if path is None:
|
| 89 |
+
inline = os.environ.get("PARTNER_REGISTRY_JSON")
|
| 90 |
+
if inline:
|
| 91 |
+
raw_text = inline
|
| 92 |
+
source = "env:PARTNER_REGISTRY_JSON"
|
| 93 |
+
else:
|
| 94 |
+
env_path = os.environ.get("PARTNER_REGISTRY_PATH")
|
| 95 |
+
if env_path:
|
| 96 |
+
path = env_path
|
| 97 |
+
source = "env:PARTNER_REGISTRY_PATH"
|
| 98 |
+
|
| 99 |
+
if raw_text is None:
|
| 100 |
+
if path is None:
|
| 101 |
+
raise RuntimeError(
|
| 102 |
+
"no registry source supplied (set PARTNER_REGISTRY_JSON or PARTNER_REGISTRY_PATH)"
|
| 103 |
+
)
|
| 104 |
+
p = Path(path)
|
| 105 |
+
if not p.exists():
|
| 106 |
+
raise FileNotFoundError(f"partner registry JSON not found at {p}")
|
| 107 |
+
raw_text = p.read_text(encoding="utf-8")
|
| 108 |
+
|
| 109 |
+
raw = json.loads(raw_text)
|
| 110 |
+
if not isinstance(raw, list):
|
| 111 |
+
raise ValueError(f"partner registry from {source} must be a top-level list")
|
| 112 |
+
models = [_coerce(e) for e in raw]
|
| 113 |
+
if not models:
|
| 114 |
+
raise ValueError(f"partner registry from {source} is empty")
|
| 115 |
+
return PartnerRegistry(models=models)
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.110
|
| 2 |
+
uvicorn[standard]>=0.30
|
| 3 |
+
pydantic>=2.6
|
| 4 |
+
torch>=2.2
|
| 5 |
+
transformers>=4.40
|
| 6 |
+
numpy>=1.26
|