spectralman commited on
Commit
6f0ff99
·
verified ·
1 Parent(s): e9564be

Initial deploy: classifier + FastAPI router

Browse files
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[codz]
3
+ *.egg-info/
4
+ .venv/
5
+ venv/
6
+ .env
7
+
8
+ # Partner config: never commit
9
+ data/partner_registry.json
10
+ data/*.json
11
+
12
+ # IDE
13
+ .idea/
14
+ .vscode/
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ ENV PYTHONUNBUFFERED=1 \
4
+ PYTHONDONTWRITEBYTECODE=1 \
5
+ PIP_NO_CACHE_DIR=1 \
6
+ HF_HOME=/tmp/hf_cache \
7
+ TRANSFORMERS_CACHE=/tmp/hf_cache/transformers \
8
+ SENTENCE_TRANSFORMERS_HOME=/tmp/hf_cache/sentence-transformers
9
+
10
+ WORKDIR /app
11
+
12
+ RUN apt-get update && apt-get install -y --no-install-recommends \
13
+ build-essential \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ COPY requirements.txt .
17
+ RUN pip install --upgrade pip && \
18
+ pip install torch --index-url https://download.pytorch.org/whl/cpu && \
19
+ pip install -r requirements.txt
20
+
21
+ COPY greenrouting /app/greenrouting
22
+ COPY models /app/models
23
+ COPY partner_registry.py mapper.py app.py /app/
24
+
25
+ RUN mkdir -p /tmp/hf_cache && chmod -R 777 /tmp/hf_cache
26
+
27
+ EXPOSE 7860
28
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,34 @@
1
  ---
2
- title: Router Api
3
- emoji: 📈
4
- colorFrom: gray
5
- colorTo: pink
6
  sdk: docker
7
  pinned: false
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Router Classify API
3
+ emoji: 🛰️
4
+ colorFrom: green
5
+ colorTo: indigo
6
  sdk: docker
7
  pinned: false
8
+ license: other
9
+ license_name: polyform-noncommercial-1.0.0
10
+ license_link: https://polyformproject.org/licenses/noncommercial/1.0.0
11
  ---
12
 
13
+ # Router Classify API
14
+
15
+ REST endpoint that runs the GreenRouting classifier and returns a routing decision against a configurable downstream model registry.
16
+
17
+ ## Endpoints
18
+
19
+ - `POST /classify` — Run the classifier and pick a model.
20
+ - Request: `{ "message": "...", "recentMessages": [{"role": "...", "content": "..."}] }`
21
+ - Response: `{ "category", "complexity", "model_id", "capability_weights", "difficulty", "energy_savings_pct", "method", "reason" }`
22
+ - `GET /health` — Liveness probe.
23
+
24
+ ## Configuration
25
+
26
+ The registry of candidate models is supplied at runtime via a Space secret. Set one of:
27
+
28
+ - `PARTNER_REGISTRY_JSON` — the registry as raw JSON (preferred)
29
+ - `PARTNER_REGISTRY_PATH` — a file path inside the container
30
+
31
+ Other env vars:
32
+
33
+ - `CLASSIFIER_ARTIFACT_DIR` — defaults to `models/classifier_v1`
34
+ - `INCLUDE_REASON` — `1` (default) to include the `reason` string in responses, `0` to omit
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI service that wraps the GreenRouting classifier behind the partner-
2
+ specific response schema.
3
+
4
+ Endpoints:
5
+ POST /classify - classify a query and pick a model from the partner registry
6
+ GET /health - liveness probe used by the partner edge function
7
+
8
+ Auth: none. Stateless. CORS open. Single-process. Designed for a HF Spaces
9
+ Docker deployment with periodic /health pings keeping the container warm.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import logging
15
+ import os
16
+ import time
17
+ from pathlib import Path
18
+ from typing import Optional
19
+
20
+ from fastapi import FastAPI, HTTPException
21
+ from fastapi.middleware.cors import CORSMiddleware
22
+ from pydantic import BaseModel, Field
23
+
24
+ from greenrouting.classifier.trained_predictor import TrainedPredictor
25
+
26
+ from mapper import (
27
+ build_reason,
28
+ fold_recent_context,
29
+ energy_savings_pct,
30
+ pick_category,
31
+ pick_complexity,
32
+ pick_difficulty_int,
33
+ rebucket_capabilities,
34
+ select_model,
35
+ )
36
+ from partner_registry import load_registry
37
+
38
+
39
+ logger = logging.getLogger("router-api")
40
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
41
+
42
+
43
+ ARTIFACT_DIR = os.environ.get("CLASSIFIER_ARTIFACT_DIR", "models/classifier_v1")
44
+ INCLUDE_REASON = os.environ.get("INCLUDE_REASON", "1") not in ("0", "false", "False")
45
+
46
+ app = FastAPI(title="GreenRouting Partner Router", version="0.1.0")
47
+ app.add_middleware(
48
+ CORSMiddleware,
49
+ allow_origins=["*"],
50
+ allow_credentials=False,
51
+ allow_methods=["*"],
52
+ allow_headers=["*"],
53
+ expose_headers=["*"],
54
+ max_age=3600,
55
+ )
56
+
57
+
58
+ _predictor: Optional[TrainedPredictor] = None
59
+ _registry = None
60
+
61
+
62
+ class RecentMessage(BaseModel):
63
+ role: str
64
+ content: str
65
+
66
+
67
+ class ClassifyRequest(BaseModel):
68
+ message: str = Field(min_length=1, max_length=8000)
69
+ recentMessages: Optional[list[RecentMessage]] = None
70
+
71
+
72
+ class ClassifyResponse(BaseModel):
73
+ category: str
74
+ complexity: str
75
+ model_id: str
76
+ capability_weights: dict[str, float]
77
+ difficulty: int
78
+ energy_savings_pct: Optional[float] = None
79
+ method: str
80
+ reason: Optional[str] = None
81
+
82
+
83
+ def _ensure_loaded() -> None:
84
+ global _predictor, _registry
85
+ if _predictor is None:
86
+ artifact_path = Path(ARTIFACT_DIR)
87
+ if not (artifact_path / "head.pt").exists():
88
+ raise RuntimeError(f"trained classifier not found at {artifact_path}")
89
+ _predictor = TrainedPredictor(artifact_path)
90
+ _predictor.predict("warm up")
91
+ logger.info("classifier loaded and warmed")
92
+ if _registry is None:
93
+ _registry = load_registry()
94
+ logger.info("partner registry loaded with %d models", len(_registry))
95
+
96
+
97
+ @app.on_event("startup")
98
+ def _startup() -> None:
99
+ try:
100
+ _ensure_loaded()
101
+ except Exception as exc:
102
+ logger.warning("startup warm load failed: %s (will retry on first request)", exc)
103
+
104
+
105
+ @app.get("/health")
106
+ def health() -> dict:
107
+ try:
108
+ _ensure_loaded()
109
+ return {"status": "ok"}
110
+ except Exception as exc:
111
+ logger.exception("health check failed")
112
+ raise HTTPException(status_code=503, detail=f"unhealthy: {exc}")
113
+
114
+
115
+ @app.post("/classify", response_model=ClassifyResponse)
116
+ def classify(req: ClassifyRequest) -> ClassifyResponse:
117
+ _ensure_loaded()
118
+ started = time.time()
119
+
120
+ folded = fold_recent_context(
121
+ req.message,
122
+ [m.dict() for m in req.recentMessages] if req.recentMessages else None,
123
+ )
124
+ profile = _predictor.predict(folded)
125
+
126
+ weights = rebucket_capabilities(profile)
127
+ category = pick_category(weights)
128
+ complexity = pick_complexity(profile)
129
+ difficulty = pick_difficulty_int(profile)
130
+
131
+ chosen, escalated = select_model(_registry, weights, difficulty, is_ood=profile.is_ood)
132
+ savings: Optional[float]
133
+ if profile.is_ood or escalated:
134
+ savings = None
135
+ else:
136
+ savings = round(energy_savings_pct(chosen), 1)
137
+ reason = (
138
+ build_reason(weights, complexity, chosen, escalated, is_ood=profile.is_ood)
139
+ if INCLUDE_REASON
140
+ else None
141
+ )
142
+
143
+ elapsed_ms = (time.time() - started) * 1000.0
144
+ logger.info(
145
+ "classify model=%s tier=%s difficulty=%d category=%s ood=%s escalated=%s elapsed_ms=%.1f",
146
+ chosen.id, chosen.tier, difficulty, category, profile.is_ood, escalated, elapsed_ms,
147
+ )
148
+
149
+ return ClassifyResponse(
150
+ category=category,
151
+ complexity=complexity,
152
+ model_id=chosen.id,
153
+ capability_weights=weights,
154
+ difficulty=difficulty,
155
+ energy_savings_pct=savings,
156
+ method="greenrouting",
157
+ reason=reason,
158
+ )
159
+
160
+
161
+ if __name__ == "__main__":
162
+ import uvicorn
163
+ port = int(os.environ.get("PORT", 7860))
164
+ uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info")
greenrouting/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.1.0"
greenrouting/classifier/__init__.py ADDED
File without changes
greenrouting/classifier/calibration.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Temperature scaling for the multi-label capability head.
2
+
3
+ Fits a single positive scalar T such that BCE(logits / T, targets) is minimized
4
+ on a held-out set. T < 1 sharpens, T > 1 softens.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import math
10
+
11
+
12
+ def fit_temperature(val_logits, val_targets, max_iter: int = 200) -> float:
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ if val_logits.size == 0:
18
+ return 1.0
19
+
20
+ logits = torch.tensor(np.asarray(val_logits), dtype=torch.float32)
21
+ targets = torch.tensor(np.asarray(val_targets), dtype=torch.float32)
22
+ log_t = torch.zeros((), dtype=torch.float32, requires_grad=True)
23
+ optimizer = torch.optim.LBFGS([log_t], lr=0.1, max_iter=max_iter)
24
+ bce = nn.BCEWithLogitsLoss()
25
+
26
+ def closure():
27
+ optimizer.zero_grad()
28
+ loss = bce(logits / log_t.exp(), targets)
29
+ loss.backward()
30
+ return loss
31
+
32
+ optimizer.step(closure)
33
+ return float(math.exp(log_t.item()))
34
+
35
+
36
+ def apply_temperature(logits, temperature: float):
37
+ import numpy as np
38
+ arr = np.asarray(logits)
39
+ if temperature <= 0:
40
+ return arr
41
+ return arr / temperature
greenrouting/classifier/infer.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference-side types and predictors. The trained predictor lives in a sibling
2
+ module; this file defines the contract the router consumes."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ import re
8
+ from dataclasses import dataclass, asdict, field
9
+ from typing import Protocol
10
+
11
+ from greenrouting.routing.registry import CAPABILITY_KEYS
12
+
13
+ LENGTH_BUCKETS: tuple[str, str, str] = ("short", "medium", "long")
14
+ LENGTH_TOKEN_TARGETS: dict[str, int] = {"short": 60, "medium": 220, "long": 700}
15
+ LENGTH_P90_MULTIPLIER: float = 1.6
16
+
17
+
18
+ @dataclass
19
+ class CapabilityProfile:
20
+ code: float = 0.0
21
+ math: float = 0.0
22
+ reasoning: float = 0.0
23
+ knowledge: float = 0.0
24
+ instruction: float = 0.0
25
+ creative: float = 0.0
26
+ multilingual: float = 0.0
27
+ simple_chat: float = 0.0
28
+
29
+ def as_dict(self) -> dict[str, float]:
30
+ return asdict(self)
31
+
32
+ def top(self, k: int = 3) -> list[tuple[str, float]]:
33
+ items = sorted(self.as_dict().items(), key=lambda kv: kv[1], reverse=True)
34
+ return [(k_, v) for k_, v in items[:k] if v > 0.05]
35
+
36
+
37
+ @dataclass
38
+ class QueryProfile:
39
+ capabilities: CapabilityProfile
40
+ difficulty_log_params: float
41
+ length_dist: dict[str, float]
42
+ expected_input_tokens: int
43
+ expected_output_tokens_p50: int
44
+ expected_output_tokens_p90: int
45
+ confidence: float
46
+ is_ood: bool = False
47
+ raw_query: str = ""
48
+ debug: dict = field(default_factory=dict)
49
+
50
+ @property
51
+ def difficulty_params_b(self) -> float:
52
+ return math.exp(self.difficulty_log_params) / 1e9
53
+
54
+
55
+ class Predictor(Protocol):
56
+ def predict(self, query: str) -> QueryProfile: ...
57
+
58
+
59
+ def _tokens_from_text(text: str) -> int:
60
+ words = max(1, len(text.split()))
61
+ return int(words * 1.3) + 4
62
+
63
+
64
+ def _length_dist_for_target(bucket: str) -> dict[str, float]:
65
+ if bucket == "short":
66
+ return {"short": 0.75, "medium": 0.20, "long": 0.05}
67
+ if bucket == "medium":
68
+ return {"short": 0.15, "medium": 0.65, "long": 0.20}
69
+ return {"short": 0.05, "medium": 0.25, "long": 0.70}
70
+
71
+
72
+ def _expected_output_tokens(length_dist: dict[str, float]) -> tuple[int, int]:
73
+ p50 = sum(length_dist[b] * LENGTH_TOKEN_TARGETS[b] for b in LENGTH_BUCKETS)
74
+ long_weight = length_dist.get("long", 0.0)
75
+ p90 = p50 * LENGTH_P90_MULTIPLIER + long_weight * LENGTH_TOKEN_TARGETS["long"] * 0.3
76
+ return int(round(p50)), int(round(p90))
77
+
78
+
79
+ _KEYWORD_RULES: dict[str, tuple[float, list[str]]] = {
80
+ "code": (0.85, [
81
+ r"\b(code|function|class|def |algorithm|implement|debug|compile|stack trace|api|sdk)\b",
82
+ r"\b(python|javascript|typescript|rust|go|c\+\+|java|sql|html|css|react|kotlin|swift)\b",
83
+ r"\b(refactor|unit test|regex|linter)\b",
84
+ ]),
85
+ "math": (0.80, [
86
+ r"\b(calculate|compute|solve|equation|integral|derivative|matrix|vector|probability|theorem)\b",
87
+ r"\b(sum|product|mean|median|variance|standard deviation|percentage)\b",
88
+ r"\d+\s*[+\-*/×÷=]\s*\d+",
89
+ ]),
90
+ "reasoning": (0.70, [
91
+ r"\b(why|how does|explain|reason|because|therefore|thus|argue|justify|implication)\b",
92
+ r"\b(compare|contrast|analyze|evaluate|trade-?off|implication)\b",
93
+ ]),
94
+ "knowledge": (0.65, [
95
+ r"\b(who|what is|when did|where is|history|definition|capital|population|founded)\b",
96
+ ]),
97
+ "instruction": (0.60, [
98
+ r"\b(write|draft|create|generate|produce|format|list|outline|step.?by.?step)\b",
99
+ ]),
100
+ "creative": (0.75, [
101
+ r"\b(story|poem|novel|character|plot|scene|metaphor|fictional)\b",
102
+ r"\b(write a (?:short )?(?:story|poem|haiku|song|essay))\b",
103
+ ]),
104
+ "multilingual": (0.85, [
105
+ r"\b(translate|translation|en español|en français|auf deutsch|на русском|中文|日本語|한국어)\b",
106
+ r"[Ѐ-ӿ一-鿿぀-ゟ゠-ヿ]",
107
+ ]),
108
+ "simple_chat": (0.70, [
109
+ r"^\s*(hi|hello|hey|thanks|thank you|good morning|good evening|sup|yo)\b",
110
+ r"^\s*\S{1,40}\?\s*$",
111
+ ]),
112
+ }
113
+
114
+
115
+ class MockPredictor:
116
+ """Heuristic predictor used to drive the demo before a trained checkpoint exists.
117
+
118
+ The interface and output shape match the trained predictor that replaces it later.
119
+ """
120
+
121
+ def __init__(self, default_difficulty_log_params: float = math.log(8e9)) -> None:
122
+ self.default_difficulty = default_difficulty_log_params
123
+
124
+ def predict(self, query: str) -> QueryProfile:
125
+ q = query.strip()
126
+ scores = {k: 0.0 for k in CAPABILITY_KEYS}
127
+ for cap, (weight, patterns) in _KEYWORD_RULES.items():
128
+ for pat in patterns:
129
+ if re.search(pat, q, flags=re.IGNORECASE | re.MULTILINE):
130
+ scores[cap] = max(scores[cap], weight)
131
+
132
+ if not any(v > 0 for v in scores.values()):
133
+ scores["simple_chat"] = 0.55
134
+ scores["instruction"] = 0.30
135
+
136
+ length_bucket = self._length_bucket(q, scores)
137
+ length_dist = _length_dist_for_target(length_bucket)
138
+ difficulty = self._difficulty(q, scores)
139
+ confidence = max(scores.values())
140
+ in_tokens = _tokens_from_text(q)
141
+ out_p50, out_p90 = _expected_output_tokens(length_dist)
142
+ is_ood = self._ood(q)
143
+
144
+ return QueryProfile(
145
+ capabilities=CapabilityProfile(**scores),
146
+ difficulty_log_params=difficulty,
147
+ length_dist=length_dist,
148
+ expected_input_tokens=in_tokens,
149
+ expected_output_tokens_p50=out_p50,
150
+ expected_output_tokens_p90=out_p90,
151
+ confidence=confidence,
152
+ is_ood=is_ood,
153
+ raw_query=q,
154
+ debug={"source": "mock", "length_bucket": length_bucket},
155
+ )
156
+
157
+ @staticmethod
158
+ def _length_bucket(query: str, scores: dict[str, float]) -> str:
159
+ if scores.get("simple_chat", 0) > 0.5:
160
+ return "short"
161
+ if scores.get("creative", 0) > 0.5 or scores.get("code", 0) > 0.5:
162
+ return "long"
163
+ if len(query) < 80:
164
+ return "short"
165
+ if len(query) < 240:
166
+ return "medium"
167
+ return "long"
168
+
169
+ @staticmethod
170
+ def _difficulty(query: str, scores: dict[str, float]) -> float:
171
+ base = math.log(7e9)
172
+ bumps = 0.0
173
+ if scores.get("math", 0) > 0.5 and re.search(r"\b(prove|theorem|integral|differential)\b", query, re.IGNORECASE):
174
+ bumps += math.log(10)
175
+ if scores.get("reasoning", 0) > 0.5 and len(query) > 200:
176
+ bumps += math.log(5)
177
+ if scores.get("code", 0) > 0.5 and re.search(r"\b(distributed|concurrency|kernel|cuda|optimize)\b", query, re.IGNORECASE):
178
+ bumps += math.log(8)
179
+ if scores.get("simple_chat", 0) > 0.5:
180
+ bumps -= math.log(3)
181
+ return base + bumps
182
+
183
+ @staticmethod
184
+ def _ood(query: str) -> bool:
185
+ q = query.strip()
186
+ if len(q) < 2:
187
+ return True
188
+ alnum = sum(c.isalnum() for c in q)
189
+ if alnum and (sum(c.isalpha() for c in q) / max(alnum, 1)) < 0.3:
190
+ return True
191
+ if re.fullmatch(r"[\W\d_]+", q):
192
+ return True
193
+ words = re.findall(r"[A-Za-z]{4,}", q)
194
+ if words:
195
+ gibberish = 0
196
+ for w in words:
197
+ longest_run = max(
198
+ (len(m.group()) for m in re.finditer(r"[^aeiouyAEIOUY]+", w)),
199
+ default=0,
200
+ )
201
+ if longest_run >= 5:
202
+ gibberish += 1
203
+ if gibberish / len(words) >= 0.5:
204
+ return True
205
+ return False
greenrouting/classifier/model.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Frozen sentence encoder + three task heads (capability, difficulty, length)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ DEFAULT_ENCODER = "BAAI/bge-small-en-v1.5"
8
+
9
+
10
+ @dataclass
11
+ class ModelSpec:
12
+ encoder_name: str = DEFAULT_ENCODER
13
+ embedding_dim: int = 384
14
+ hidden_dim: int = 256
15
+ n_capabilities: int = 8
16
+ n_length_buckets: int = 3
17
+ dropout: float = 0.1
18
+ max_seq_len: int = 256
19
+
20
+
21
+ def build_head(spec: ModelSpec):
22
+ import torch.nn as nn
23
+
24
+ class HeadStack(nn.Module):
25
+ def __init__(self, s: ModelSpec):
26
+ super().__init__()
27
+ self.shared = nn.Sequential(
28
+ nn.Linear(s.embedding_dim, s.hidden_dim),
29
+ nn.GELU(),
30
+ nn.Dropout(s.dropout),
31
+ nn.Linear(s.hidden_dim, s.hidden_dim),
32
+ nn.GELU(),
33
+ nn.Dropout(s.dropout),
34
+ )
35
+ self.cap_head = nn.Linear(s.hidden_dim, s.n_capabilities)
36
+ self.diff_head = nn.Linear(s.hidden_dim, 1)
37
+ self.len_head = nn.Linear(s.hidden_dim, s.n_length_buckets)
38
+
39
+ def forward(self, embeddings):
40
+ h = self.shared(embeddings)
41
+ return {
42
+ "cap_logits": self.cap_head(h),
43
+ "diff": self.diff_head(h).squeeze(-1),
44
+ "len_logits": self.len_head(h),
45
+ }
46
+
47
+ return HeadStack(spec)
48
+
49
+
50
+ class Encoder:
51
+ """Lazy wrapper around a HuggingFace sentence encoder, mean-pooled and L2-normalized."""
52
+
53
+ def __init__(self, encoder_name: str = DEFAULT_ENCODER, max_seq_len: int = 256):
54
+ self.encoder_name = encoder_name
55
+ self.max_seq_len = max_seq_len
56
+ self._tokenizer = None
57
+ self._model = None
58
+ self._device = None
59
+
60
+ def _ensure_loaded(self):
61
+ if self._model is not None:
62
+ return
63
+ import torch
64
+ from transformers import AutoModel, AutoTokenizer
65
+
66
+ self._tokenizer = AutoTokenizer.from_pretrained(self.encoder_name)
67
+ self._model = AutoModel.from_pretrained(self.encoder_name)
68
+ self._device = "cuda" if torch.cuda.is_available() else "cpu"
69
+ self._model.to(self._device).eval()
70
+ for p in self._model.parameters():
71
+ p.requires_grad = False
72
+
73
+ @property
74
+ def device(self) -> str:
75
+ self._ensure_loaded()
76
+ return self._device
77
+
78
+ def embed(self, texts: list[str]):
79
+ import torch
80
+ import torch.nn.functional as F
81
+
82
+ self._ensure_loaded()
83
+ enc = self._tokenizer(
84
+ texts,
85
+ padding=True,
86
+ truncation=True,
87
+ max_length=self.max_seq_len,
88
+ return_tensors="pt",
89
+ ).to(self._device)
90
+ with torch.no_grad():
91
+ out = self._model(**enc)
92
+ mask = enc["attention_mask"].unsqueeze(-1).float()
93
+ pooled = (out.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
94
+ return F.normalize(pooled, dim=-1)
greenrouting/classifier/ood.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OOD detection on L2-normalized encoder embeddings.
2
+
3
+ Uses centroid cosine distance + k-nearest-neighbor distance. Both are robust
4
+ when the number of training examples is smaller than the embedding dimension,
5
+ which is typical for our seed-scale datasets.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+
11
+ def fit_ood_stats(train_embeddings, k: int = 5):
12
+ """Returns a dict with: centroid (unit-norm), full reference embeddings,
13
+ and per-source distances used for threshold calibration."""
14
+ import numpy as np
15
+
16
+ arr = np.asarray(train_embeddings, dtype=np.float32)
17
+ if arr.size == 0:
18
+ return {"centroid": np.zeros((arr.shape[1] if arr.ndim == 2 else 384,), dtype=np.float32),
19
+ "reference": arr,
20
+ "k": k}
21
+ norms = np.linalg.norm(arr, axis=1, keepdims=True)
22
+ norms[norms == 0] = 1.0
23
+ normalized = arr / norms
24
+ centroid = normalized.mean(axis=0)
25
+ centroid_norm = float(np.linalg.norm(centroid))
26
+ if centroid_norm == 0:
27
+ centroid_norm = 1.0
28
+ centroid = centroid / centroid_norm
29
+ return {"centroid": centroid.astype(np.float32),
30
+ "reference": normalized.astype(np.float32),
31
+ "k": k}
32
+
33
+
34
+ def _cosine_distance(a, b):
35
+ import numpy as np
36
+ a = np.asarray(a, dtype=np.float32)
37
+ b = np.asarray(b, dtype=np.float32)
38
+ na = np.linalg.norm(a)
39
+ nb = np.linalg.norm(b)
40
+ if na == 0 or nb == 0:
41
+ return 1.0
42
+ return 1.0 - float(np.dot(a, b) / (na * nb))
43
+
44
+
45
+ def centroid_distance(embedding, stats) -> float:
46
+ return _cosine_distance(embedding, stats["centroid"])
47
+
48
+
49
+ def knn_distance(embedding, stats, k: int = None) -> float:
50
+ import numpy as np
51
+ ref = stats["reference"]
52
+ if ref.size == 0:
53
+ return 1.0
54
+ k = k or stats.get("k", 5)
55
+ emb = np.asarray(embedding, dtype=np.float32)
56
+ n = np.linalg.norm(emb)
57
+ if n == 0:
58
+ return 1.0
59
+ emb = emb / n
60
+ sims = ref @ emb
61
+ distances = 1.0 - sims
62
+ distances.sort()
63
+ return float(distances[: max(1, min(k, len(distances)))].mean())
64
+
65
+
66
+ def calibrate_thresholds(
67
+ train_embeddings,
68
+ stats,
69
+ percentile: float = 99.9,
70
+ safety_multiplier: float = 1.25,
71
+ ) -> dict:
72
+ """Threshold = percentile × safety_multiplier. The multiplier gives headroom
73
+ for natural rephrasings that aren't truly OOD."""
74
+ import numpy as np
75
+ centroid_dists = [centroid_distance(e, stats) for e in train_embeddings]
76
+ knn_dists = [knn_distance(e, stats) for e in train_embeddings]
77
+ centroid_t = float(np.percentile(centroid_dists, percentile)) if centroid_dists else 1.0
78
+ knn_t = float(np.percentile(knn_dists, percentile)) if knn_dists else 1.0
79
+ return {
80
+ "centroid_threshold": centroid_t * safety_multiplier,
81
+ "knn_threshold": knn_t * safety_multiplier,
82
+ }
83
+
84
+
85
+ def is_ood(embedding, stats, thresholds) -> bool:
86
+ """Either signal is sufficient to flag OOD. AND semantics let too many
87
+ obvious cases slip when one signal happens to look in-distribution."""
88
+ cd = centroid_distance(embedding, stats)
89
+ kd = knn_distance(embedding, stats)
90
+ return cd > thresholds["centroid_threshold"] or kd > thresholds["knn_threshold"]
greenrouting/classifier/train.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training loop. Frozen encoder, head-only optimization, multi-task loss."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import math
7
+ from dataclasses import dataclass, field
8
+ from pathlib import Path
9
+
10
+ from greenrouting.classifier.model import DEFAULT_ENCODER, Encoder, ModelSpec, build_head
11
+ from greenrouting.data.schema import LENGTH_BUCKETS
12
+ from greenrouting.routing.registry import CAPABILITY_KEYS
13
+
14
+ LENGTH_TO_INDEX: dict[str, int] = {b: i for i, b in enumerate(LENGTH_BUCKETS)}
15
+
16
+
17
+ @dataclass
18
+ class TrainConfig:
19
+ encoder_name: str = DEFAULT_ENCODER
20
+ hidden_dim: int = 256
21
+ dropout: float = 0.1
22
+ max_seq_len: int = 256
23
+ epochs: int = 8
24
+ batch_size: int = 32
25
+ learning_rate: float = 1e-3
26
+ weight_decay: float = 1e-4
27
+ cap_weight: float = 1.0
28
+ diff_weight: float = 0.5
29
+ len_weight: float = 0.3
30
+ val_split: float = 0.15
31
+ seed: int = 42
32
+ huber_delta: float = 1.0
33
+ cap_pos_weight: float = 2.0
34
+ diff_target_center: float = field(default_factory=lambda: math.log(8e9))
35
+
36
+
37
+ def _load_split(parquet_path: str | Path):
38
+ import pandas as pd
39
+ return pd.read_parquet(parquet_path)
40
+
41
+
42
+ def _build_targets(df, cfg: TrainConfig):
43
+ import numpy as np
44
+ cap_cols = [f"cap_{k}" for k in CAPABILITY_KEYS]
45
+ caps = df[cap_cols].fillna(0.0).to_numpy(dtype=np.float32)
46
+ diff = (df["difficulty_log_params"].fillna(cfg.diff_target_center).to_numpy(dtype=np.float32))
47
+ diff_centered = diff - cfg.diff_target_center
48
+ lens = df["length_bucket"].fillna("medium").map(LENGTH_TO_INDEX).fillna(1).to_numpy(dtype=np.int64)
49
+ texts = df["text"].astype(str).tolist()
50
+ return texts, caps, diff_centered, lens
51
+
52
+
53
+ def _split_train_val(texts, caps, diff, lens, val_split: float, seed: int):
54
+ import numpy as np
55
+ rng = np.random.default_rng(seed)
56
+ n = len(texts)
57
+ indices = np.arange(n)
58
+ rng.shuffle(indices)
59
+ n_val = max(1, int(n * val_split))
60
+ val_idx = indices[:n_val]
61
+ train_idx = indices[n_val:]
62
+ return (
63
+ ([texts[i] for i in train_idx], caps[train_idx], diff[train_idx], lens[train_idx]),
64
+ ([texts[i] for i in val_idx], caps[val_idx], diff[val_idx], lens[val_idx]),
65
+ )
66
+
67
+
68
+ def _iterate_batches(texts, caps, diff, lens, batch_size: int, encoder: Encoder, shuffle: bool, seed: int):
69
+ import numpy as np
70
+ import torch
71
+
72
+ n = len(texts)
73
+ indices = np.arange(n)
74
+ if shuffle:
75
+ np.random.default_rng(seed).shuffle(indices)
76
+
77
+ for start in range(0, n, batch_size):
78
+ idx = indices[start:start + batch_size]
79
+ batch_texts = [texts[i] for i in idx]
80
+ emb = encoder.embed(batch_texts)
81
+ cap_t = torch.tensor(caps[idx], dtype=torch.float32, device=emb.device)
82
+ diff_t = torch.tensor(diff[idx], dtype=torch.float32, device=emb.device)
83
+ len_t = torch.tensor(lens[idx], dtype=torch.long, device=emb.device)
84
+ yield emb, cap_t, diff_t, len_t
85
+
86
+
87
+ def train(
88
+ train_parquet: str | Path,
89
+ output_dir: str | Path,
90
+ cfg: TrainConfig | None = None,
91
+ ) -> dict:
92
+ import torch
93
+ import torch.nn as nn
94
+ from torch.optim import AdamW
95
+
96
+ cfg = cfg or TrainConfig()
97
+ out_dir = Path(output_dir)
98
+ out_dir.mkdir(parents=True, exist_ok=True)
99
+
100
+ df = _load_split(train_parquet)
101
+ texts, caps, diff, lens = _build_targets(df, cfg)
102
+ train_set, val_set = _split_train_val(texts, caps, diff, lens, cfg.val_split, cfg.seed)
103
+
104
+ encoder = Encoder(cfg.encoder_name, cfg.max_seq_len)
105
+ embed_dim = encoder.embed(["probe"]).shape[-1]
106
+ spec = ModelSpec(
107
+ encoder_name=cfg.encoder_name,
108
+ embedding_dim=embed_dim,
109
+ hidden_dim=cfg.hidden_dim,
110
+ dropout=cfg.dropout,
111
+ max_seq_len=cfg.max_seq_len,
112
+ )
113
+ head = build_head(spec).to(encoder.device)
114
+
115
+ pos_weight = torch.full((spec.n_capabilities,), cfg.cap_pos_weight, device=encoder.device)
116
+ cap_loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
117
+ diff_loss_fn = nn.HuberLoss(delta=cfg.huber_delta)
118
+ len_loss_fn = nn.CrossEntropyLoss()
119
+ optimizer = AdamW(head.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
120
+
121
+ history = []
122
+ for epoch in range(cfg.epochs):
123
+ head.train()
124
+ train_loss_sum = 0.0
125
+ n_train = 0
126
+ for emb, cap_t, diff_t, len_t in _iterate_batches(
127
+ *train_set, batch_size=cfg.batch_size, encoder=encoder,
128
+ shuffle=True, seed=cfg.seed + epoch,
129
+ ):
130
+ out = head(emb)
131
+ loss = (
132
+ cfg.cap_weight * cap_loss_fn(out["cap_logits"], cap_t)
133
+ + cfg.diff_weight * diff_loss_fn(out["diff"], diff_t)
134
+ + cfg.len_weight * len_loss_fn(out["len_logits"], len_t)
135
+ )
136
+ optimizer.zero_grad()
137
+ loss.backward()
138
+ optimizer.step()
139
+ train_loss_sum += loss.item() * emb.shape[0]
140
+ n_train += emb.shape[0]
141
+
142
+ val_metrics = _evaluate(head, encoder, val_set, cfg)
143
+ history.append({
144
+ "epoch": epoch,
145
+ "train_loss": train_loss_sum / max(n_train, 1),
146
+ **val_metrics,
147
+ })
148
+ print(
149
+ f"epoch {epoch+1}/{cfg.epochs} "
150
+ f"train_loss={train_loss_sum/max(n_train,1):.4f} "
151
+ f"val_cap_f1={val_metrics['cap_f1']:.3f} "
152
+ f"val_diff_mae={val_metrics['diff_mae']:.3f} "
153
+ f"val_len_acc={val_metrics['len_acc']:.3f}"
154
+ )
155
+
156
+ head.eval()
157
+ torch.save(head.state_dict(), out_dir / "head.pt")
158
+ (out_dir / "encoder_name.txt").write_text(cfg.encoder_name)
159
+ (out_dir / "metadata.json").write_text(json.dumps({
160
+ "capability_keys": list(CAPABILITY_KEYS),
161
+ "length_buckets": list(LENGTH_BUCKETS),
162
+ "embedding_dim": int(spec.embedding_dim),
163
+ "hidden_dim": int(spec.hidden_dim),
164
+ "max_seq_len": int(spec.max_seq_len),
165
+ "diff_target_center": float(cfg.diff_target_center),
166
+ }, indent=2))
167
+ (out_dir / "training_history.json").write_text(json.dumps(history, indent=2))
168
+
169
+ train_embeddings = _collect_embeddings(encoder, train_set[0], batch_size=cfg.batch_size)
170
+ val_cap_logits = _collect_logits(head, encoder, val_set, cfg.batch_size)
171
+
172
+ from greenrouting.classifier.calibration import fit_temperature
173
+
174
+ temperature = fit_temperature(val_cap_logits, val_set[1])
175
+ (out_dir / "calibration.json").write_text(json.dumps({"temperature": float(temperature)}, indent=2))
176
+
177
+ from greenrouting.classifier.ood import calibrate_thresholds, fit_ood_stats
178
+
179
+ ood_stats = fit_ood_stats(train_embeddings, k=5)
180
+ thresholds = calibrate_thresholds(train_embeddings, ood_stats, percentile=99.0)
181
+ import numpy as np
182
+ np.savez(
183
+ out_dir / "ood_stats.npz",
184
+ centroid=ood_stats["centroid"],
185
+ reference=ood_stats["reference"],
186
+ k=ood_stats.get("k", 5),
187
+ centroid_threshold=thresholds["centroid_threshold"],
188
+ knn_threshold=thresholds["knn_threshold"],
189
+ )
190
+
191
+ return {
192
+ "history": history,
193
+ "temperature": float(temperature),
194
+ "n_train": len(train_set[0]),
195
+ "n_val": len(val_set[0]),
196
+ }
197
+
198
+
199
+ def _evaluate(head, encoder: Encoder, val_set, cfg: TrainConfig) -> dict:
200
+ import torch
201
+ head.eval()
202
+ all_cap_pred, all_cap_true = [], []
203
+ all_diff_pred, all_diff_true = [], []
204
+ all_len_pred, all_len_true = [], []
205
+ with torch.no_grad():
206
+ for emb, cap_t, diff_t, len_t in _iterate_batches(
207
+ *val_set, batch_size=cfg.batch_size, encoder=encoder, shuffle=False, seed=cfg.seed,
208
+ ):
209
+ out = head(emb)
210
+ all_cap_pred.append(torch.sigmoid(out["cap_logits"]).cpu().numpy())
211
+ all_cap_true.append(cap_t.cpu().numpy())
212
+ all_diff_pred.append(out["diff"].cpu().numpy())
213
+ all_diff_true.append(diff_t.cpu().numpy())
214
+ all_len_pred.append(out["len_logits"].argmax(dim=-1).cpu().numpy())
215
+ all_len_true.append(len_t.cpu().numpy())
216
+ head.train()
217
+
218
+ import numpy as np
219
+ cap_pred = np.concatenate(all_cap_pred)
220
+ cap_true = np.concatenate(all_cap_true)
221
+ diff_pred = np.concatenate(all_diff_pred)
222
+ diff_true = np.concatenate(all_diff_true)
223
+ len_pred = np.concatenate(all_len_pred)
224
+ len_true = np.concatenate(all_len_true)
225
+
226
+ cap_pred_bin = (cap_pred >= 0.5).astype(np.float32)
227
+ cap_true_bin = (cap_true >= 0.5).astype(np.float32)
228
+ tp = ((cap_pred_bin == 1) & (cap_true_bin == 1)).sum()
229
+ fp = ((cap_pred_bin == 1) & (cap_true_bin == 0)).sum()
230
+ fn = ((cap_pred_bin == 0) & (cap_true_bin == 1)).sum()
231
+ precision = tp / max(tp + fp, 1)
232
+ recall = tp / max(tp + fn, 1)
233
+ f1 = 2 * precision * recall / max(precision + recall, 1e-9)
234
+
235
+ diff_mae = float(np.abs(diff_pred - diff_true).mean())
236
+ len_acc = float((len_pred == len_true).mean())
237
+
238
+ return {
239
+ "cap_precision": float(precision),
240
+ "cap_recall": float(recall),
241
+ "cap_f1": float(f1),
242
+ "diff_mae": diff_mae,
243
+ "len_acc": len_acc,
244
+ }
245
+
246
+
247
+ def _collect_embeddings(encoder: Encoder, texts: list[str], batch_size: int):
248
+ import numpy as np
249
+ chunks = []
250
+ for start in range(0, len(texts), batch_size):
251
+ chunk = texts[start:start + batch_size]
252
+ emb = encoder.embed(chunk).cpu().numpy()
253
+ chunks.append(emb)
254
+ return np.concatenate(chunks, axis=0) if chunks else np.zeros((0, 384), dtype=np.float32)
255
+
256
+
257
+ def _collect_logits(head, encoder: Encoder, val_set, batch_size: int):
258
+ import numpy as np
259
+ import torch
260
+ head.eval()
261
+ out_logits = []
262
+ with torch.no_grad():
263
+ for emb, _cap, _diff, _len in _iterate_batches(
264
+ *val_set, batch_size=batch_size, encoder=encoder, shuffle=False, seed=0,
265
+ ):
266
+ out = head(emb)
267
+ out_logits.append(out["cap_logits"].cpu().numpy())
268
+ head.train()
269
+ return np.concatenate(out_logits, axis=0) if out_logits else np.zeros((0, 8), dtype=np.float32)
greenrouting/classifier/trained_predictor.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference-time predictor that loads the trained artifact and conforms to
2
+ the `Predictor` protocol used by the router."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import json
7
+ import math
8
+ from pathlib import Path
9
+ from typing import Optional
10
+
11
+ from greenrouting.classifier.infer import (
12
+ CapabilityProfile,
13
+ LENGTH_BUCKETS,
14
+ LENGTH_TOKEN_TARGETS,
15
+ LENGTH_P90_MULTIPLIER,
16
+ QueryProfile,
17
+ )
18
+ from greenrouting.classifier.model import Encoder, ModelSpec, build_head
19
+ from greenrouting.classifier.ood import is_ood
20
+ from greenrouting.routing.registry import CAPABILITY_KEYS
21
+
22
+
23
+ class TrainedPredictor:
24
+ def __init__(self, artifact_dir: str | Path):
25
+ self.artifact_dir = Path(artifact_dir)
26
+ self._loaded = False
27
+ self._encoder: Optional[Encoder] = None
28
+ self._head = None
29
+ self._spec: Optional[ModelSpec] = None
30
+ self._temperature: float = 1.0
31
+ self._ood_stats = None
32
+ self._ood_thresholds = None
33
+ self._ood_min_confidence: float = 0.40
34
+
35
+ def _ensure_loaded(self) -> None:
36
+ if self._loaded:
37
+ return
38
+ import numpy as np
39
+ import torch
40
+
41
+ meta_path = self.artifact_dir / "metadata.json"
42
+ meta = json.loads(meta_path.read_text())
43
+ encoder_name = (self.artifact_dir / "encoder_name.txt").read_text().strip()
44
+ self._spec = ModelSpec(
45
+ encoder_name=encoder_name,
46
+ embedding_dim=int(meta["embedding_dim"]),
47
+ hidden_dim=int(meta["hidden_dim"]),
48
+ n_capabilities=len(meta["capability_keys"]),
49
+ n_length_buckets=len(meta["length_buckets"]),
50
+ max_seq_len=int(meta.get("max_seq_len", 256)),
51
+ )
52
+ self._diff_center = float(meta.get("diff_target_center", math.log(8e9)))
53
+ self._encoder = Encoder(encoder_name, max_seq_len=self._spec.max_seq_len)
54
+ head = build_head(self._spec)
55
+ head.load_state_dict(torch.load(self.artifact_dir / "head.pt", map_location="cpu"))
56
+ head.to(self._encoder.device).eval()
57
+ self._head = head
58
+
59
+ cal_path = self.artifact_dir / "calibration.json"
60
+ if cal_path.exists():
61
+ self._temperature = float(json.loads(cal_path.read_text()).get("temperature", 1.0))
62
+
63
+ ood_path = self.artifact_dir / "ood_stats.npz"
64
+ if ood_path.exists():
65
+ data = np.load(ood_path)
66
+ if "centroid" in data.files and "reference" in data.files:
67
+ self._ood_stats = {
68
+ "centroid": data["centroid"],
69
+ "reference": data["reference"],
70
+ "k": int(data["k"]) if "k" in data.files else 5,
71
+ }
72
+ self._ood_thresholds = {
73
+ "centroid_threshold": float(data["centroid_threshold"]),
74
+ "knn_threshold": float(data["knn_threshold"]),
75
+ }
76
+
77
+ self._loaded = True
78
+
79
+ def predict(self, query: str) -> QueryProfile:
80
+ import torch
81
+ import torch.nn.functional as F
82
+
83
+ self._ensure_loaded()
84
+ text = (query or "").strip()
85
+ emb = self._encoder.embed([text])
86
+ with torch.no_grad():
87
+ out = self._head(emb)
88
+
89
+ cap_logits = (out["cap_logits"] / max(self._temperature, 1e-3))
90
+ cap_probs = torch.sigmoid(cap_logits)[0].cpu().numpy().tolist()
91
+ cap_dict = {k: float(v) for k, v in zip(CAPABILITY_KEYS, cap_probs)}
92
+
93
+ diff_centered = float(out["diff"][0].item())
94
+ diff_log_params = diff_centered + self._diff_center
95
+
96
+ len_probs = F.softmax(out["len_logits"][0], dim=-1).cpu().numpy().tolist()
97
+ length_dist = {b: float(p) for b, p in zip(LENGTH_BUCKETS, len_probs)}
98
+
99
+ confidence = max(cap_dict.values()) if cap_dict else 0.0
100
+
101
+ confidence_ood = confidence < self._ood_min_confidence
102
+ geometric_ood = False
103
+ if self._ood_stats is not None and self._ood_thresholds is not None:
104
+ emb_np = emb[0].cpu().numpy()
105
+ geometric_ood = is_ood(emb_np, self._ood_stats, self._ood_thresholds)
106
+ ood_flag = confidence_ood or geometric_ood
107
+
108
+ in_tokens = max(1, int(len(text.split()) * 1.3) + 4)
109
+ out_p50 = int(round(sum(length_dist[b] * LENGTH_TOKEN_TARGETS[b] for b in LENGTH_BUCKETS)))
110
+ long_w = length_dist.get("long", 0.0)
111
+ out_p90 = int(round(out_p50 * LENGTH_P90_MULTIPLIER + long_w * LENGTH_TOKEN_TARGETS["long"] * 0.3))
112
+
113
+ return QueryProfile(
114
+ capabilities=CapabilityProfile(**cap_dict),
115
+ difficulty_log_params=diff_log_params,
116
+ length_dist=length_dist,
117
+ expected_input_tokens=in_tokens,
118
+ expected_output_tokens_p50=out_p50,
119
+ expected_output_tokens_p90=out_p90,
120
+ confidence=confidence,
121
+ is_ood=ood_flag,
122
+ raw_query=text,
123
+ debug={
124
+ "source": "trained",
125
+ "temperature": self._temperature,
126
+ "confidence_ood": bool(confidence_ood),
127
+ "geometric_ood": bool(geometric_ood),
128
+ },
129
+ )
greenrouting/data/__init__.py ADDED
File without changes
greenrouting/data/builder.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset orchestrator: source sampling -> capability labeling -> cascade plan -> parquet."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ from dataclasses import dataclass, field
8
+ from pathlib import Path
9
+ from typing import Optional
10
+
11
+ from greenrouting.data.capability_labeler import LabelerConfig, label_queries
12
+ from greenrouting.data.schema import CapabilityLabel, LabeledQuery, RawQuery
13
+ from greenrouting.data.sources import SOURCE_REGISTRY, sample_mix
14
+ from greenrouting.routing.registry import CAPABILITY_KEYS
15
+
16
+
17
+ @dataclass
18
+ class CascadeRungConfig:
19
+ id: str
20
+ hf_model: str
21
+ params_b: float
22
+ decode_tokens_per_second_estimate: float
23
+ runs_locally: bool = True
24
+
25
+
26
+ @dataclass
27
+ class CascadeConfig:
28
+ rungs: list[CascadeRungConfig]
29
+ k_samples: int = 1
30
+ max_new_tokens: int = 200
31
+ temperature_first: float = 0.0
32
+ temperature_resample: float = 0.7
33
+
34
+ def projected_seconds(self, n_queries: int) -> float:
35
+ total = 0.0
36
+ for r in self.rungs:
37
+ inferences = n_queries * self.k_samples
38
+ total += inferences * self.max_new_tokens / max(r.decode_tokens_per_second_estimate, 1.0)
39
+ total += inferences * 0.4
40
+ return total
41
+
42
+
43
+ @dataclass
44
+ class BuildConfig:
45
+ profile_name: str
46
+ target_total_queries: int
47
+ test_split: float
48
+ seed: int
49
+ sources: dict[str, float]
50
+ cascade: CascadeConfig
51
+ labeler: LabelerConfig
52
+ budget_minutes: float = 60.0
53
+ output_dir: str = "data"
54
+ capability_labels_cache: Optional[str] = None
55
+
56
+ @classmethod
57
+ def from_yaml(cls, path: str | Path) -> "BuildConfig":
58
+ import yaml
59
+ with open(path, "r", encoding="utf-8") as f:
60
+ raw = yaml.safe_load(f)
61
+
62
+ rungs = [CascadeRungConfig(**r) for r in raw["cascade"]["rungs"]]
63
+ cascade = CascadeConfig(
64
+ rungs=rungs,
65
+ k_samples=raw["cascade"].get("k_samples", 1),
66
+ max_new_tokens=raw["cascade"].get("max_new_tokens", 200),
67
+ temperature_first=raw["cascade"].get("temperature_first", 0.0),
68
+ temperature_resample=raw["cascade"].get("temperature_resample", 0.7),
69
+ )
70
+ labeler_raw = raw.get("labeler", {})
71
+ labeler = LabelerConfig(
72
+ use_heuristic=labeler_raw.get("use_heuristic", True),
73
+ use_gpt=labeler_raw.get("use_gpt", False),
74
+ use_claude=labeler_raw.get("use_claude", False),
75
+ use_gemini=labeler_raw.get("use_gemini", False),
76
+ source_prior_weight=labeler_raw.get("source_prior_weight", 0.5),
77
+ sleep_between_calls_s=labeler_raw.get("sleep_between_calls_s", 0.0),
78
+ )
79
+ return cls(
80
+ profile_name=raw["profile_name"],
81
+ target_total_queries=raw["target_total_queries"],
82
+ test_split=raw["test_split"],
83
+ seed=raw["seed"],
84
+ sources=raw["sources"],
85
+ cascade=cascade,
86
+ labeler=labeler,
87
+ budget_minutes=raw.get("budget_minutes", 60.0),
88
+ output_dir=raw.get("output_dir", "data"),
89
+ capability_labels_cache=raw.get("capability_labels_cache"),
90
+ )
91
+
92
+
93
+ @dataclass
94
+ class BuildPlan:
95
+ config: BuildConfig
96
+ n_queries: int
97
+ cascade_seconds: float
98
+ cascade_minutes: float
99
+ over_budget: bool
100
+ notes: list[str] = field(default_factory=list)
101
+
102
+ def report(self) -> str:
103
+ lines = [
104
+ f"Profile: {self.config.profile_name}",
105
+ f"Target queries: {self.config.target_total_queries}",
106
+ f"Test split: {int(self.config.test_split * 100)}%",
107
+ f"Sources: {', '.join(f'{k}={v}' for k, v in self.config.sources.items())}",
108
+ f"Cascade rungs: {', '.join(r.id for r in self.config.cascade.rungs)}",
109
+ f"k_samples per rung: {self.config.cascade.k_samples}",
110
+ f"Max new tokens: {self.config.cascade.max_new_tokens}",
111
+ f"Estimated cascade wall time: {self.cascade_minutes:.1f} min",
112
+ f"Configured budget: {self.config.budget_minutes:.1f} min",
113
+ f"Over budget: {self.over_budget}",
114
+ ]
115
+ if self.notes:
116
+ lines.append("Notes:")
117
+ for note in self.notes:
118
+ lines.append(f" - {note}")
119
+ return "\n".join(lines)
120
+
121
+
122
+ def plan(config: BuildConfig) -> BuildPlan:
123
+ notes: list[str] = []
124
+ cascade_s = config.cascade.projected_seconds(config.target_total_queries)
125
+ cascade_m = cascade_s / 60.0
126
+ over_budget = cascade_m > config.budget_minutes
127
+ if over_budget:
128
+ notes.append(
129
+ f"cascade projected {cascade_m:.1f} min exceeds budget {config.budget_minutes:.1f} min"
130
+ )
131
+ if config.labeler.use_gpt and not os.environ.get("OPENAI_API_KEY"):
132
+ notes.append("OPENAI_API_KEY missing; gpt vote will be skipped")
133
+ if config.labeler.use_claude and not os.environ.get("ANTHROPIC_API_KEY"):
134
+ notes.append("ANTHROPIC_API_KEY missing; claude vote will be skipped")
135
+ if config.labeler.use_gemini and not os.environ.get("GOOGLE_API_KEY"):
136
+ notes.append("GOOGLE_API_KEY missing; gemini vote will be skipped")
137
+ for src in config.sources:
138
+ if src not in SOURCE_REGISTRY:
139
+ notes.append(f"unknown source '{src}' in mix")
140
+ return BuildPlan(
141
+ config=config,
142
+ n_queries=config.target_total_queries,
143
+ cascade_seconds=cascade_s,
144
+ cascade_minutes=cascade_m,
145
+ over_budget=over_budget,
146
+ notes=notes,
147
+ )
148
+
149
+
150
+ def write_capability_labels(path: str | Path, labels: list[CapabilityLabel]) -> None:
151
+ import pandas as pd
152
+ df = pd.DataFrame([lbl.to_record() for lbl in labels])
153
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
154
+ df.to_parquet(path, index=False)
155
+
156
+
157
+ def read_capability_labels(path: str | Path) -> dict[str, dict[str, float]]:
158
+ import pandas as pd
159
+ df = pd.read_parquet(path)
160
+ out: dict[str, dict[str, float]] = {}
161
+ cap_cols = [c for c in df.columns if c.startswith("cap_")]
162
+ for _, row in df.iterrows():
163
+ out[row["query_id"]] = {c[4:]: float(row[c]) for c in cap_cols}
164
+ return out
165
+
166
+
167
+ def write_raw_manifest(path: str | Path, queries: list[RawQuery]) -> None:
168
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
169
+ with open(path, "w", encoding="utf-8") as f:
170
+ for q in queries:
171
+ f.write(json.dumps(q.to_dict()) + "\n")
172
+
173
+
174
+ def write_labeled_dataset(
175
+ train_path: str | Path,
176
+ test_path: str | Path,
177
+ rows: list[LabeledQuery],
178
+ test_split: float,
179
+ seed: int,
180
+ ) -> None:
181
+ import pandas as pd
182
+ import random as _random
183
+ rng = _random.Random(seed)
184
+ indices = list(range(len(rows)))
185
+ rng.shuffle(indices)
186
+ n_test = max(1, int(len(rows) * test_split))
187
+ test_idx = set(indices[:n_test])
188
+ train_records = [rows[i].to_record() for i in range(len(rows)) if i not in test_idx]
189
+ test_records = [rows[i].to_record() for i in test_idx]
190
+ Path(train_path).parent.mkdir(parents=True, exist_ok=True)
191
+ pd.DataFrame(train_records).to_parquet(train_path, index=False)
192
+ pd.DataFrame(test_records).to_parquet(test_path, index=False)
193
+
194
+
195
+ def build_seed_dataset(
196
+ output_dir: str | Path,
197
+ test_split: float = 0.15,
198
+ seed: int = 42,
199
+ suffix: str = "seed",
200
+ ) -> tuple[Path, Path]:
201
+ """Materialize the curated seed entries into train/test parquet files.
202
+
203
+ Skips the cascade and the labeler: the seed entries already carry gold
204
+ capability multi-labels, difficulty (in log_params), and length buckets.
205
+ """
206
+ from greenrouting.data.seed_dataset import (
207
+ SEED_QUERIES,
208
+ difficulty_log_params_from_b,
209
+ seed_capability_dict,
210
+ )
211
+
212
+ rows: list[LabeledQuery] = []
213
+ for i, entry in enumerate(SEED_QUERIES):
214
+ raw = RawQuery(
215
+ id=f"seed-{i:04d}",
216
+ text=entry.text,
217
+ source="seed",
218
+ source_category=entry.primary_category,
219
+ has_grader=False,
220
+ grader_metadata={},
221
+ )
222
+ rows.append(LabeledQuery(
223
+ raw=raw,
224
+ capabilities=seed_capability_dict(entry, CAPABILITY_KEYS),
225
+ difficulty_log_params=difficulty_log_params_from_b(entry.difficulty_b),
226
+ length_bucket=entry.length,
227
+ cascade_results={"source": "seed_curated"},
228
+ ))
229
+
230
+ out = Path(output_dir)
231
+ train_path = out / f"train_{suffix}.parquet"
232
+ test_path = out / f"test_{suffix}.parquet"
233
+ write_labeled_dataset(train_path, test_path, rows, test_split=test_split, seed=seed)
234
+ return train_path, test_path
235
+
236
+
237
+ def sample_and_label(config: BuildConfig) -> tuple[list[RawQuery], list[CapabilityLabel]]:
238
+ queries = sample_mix(config.sources, config.target_total_queries, config.seed)
239
+
240
+ cached = {}
241
+ if config.capability_labels_cache and Path(config.capability_labels_cache).exists():
242
+ cached = read_capability_labels(config.capability_labels_cache)
243
+
244
+ new_queries = [q for q in queries if q.id not in cached]
245
+ new_labels = label_queries(new_queries, config.labeler) if new_queries else []
246
+
247
+ cached_labels: list[CapabilityLabel] = []
248
+ for q in queries:
249
+ if q.id in cached:
250
+ from greenrouting.data.schema import CapabilityVotes
251
+ cached_labels.append(CapabilityLabel(
252
+ query_id=q.id,
253
+ capabilities=cached[q.id],
254
+ votes=CapabilityVotes(),
255
+ aggregation_method="cached",
256
+ ))
257
+ all_labels = new_labels + cached_labels
258
+ by_id = {l.query_id: l for l in all_labels}
259
+ aligned = [by_id[q.id] for q in queries if q.id in by_id]
260
+ return queries, aligned
greenrouting/data/capability_labeler.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Capability labeling: turns RawQuery records into multi-label CapabilityLabel records.
2
+
3
+ Aggregates up to four independent voters:
4
+ - source_prior (always available; derived from source category)
5
+ - heuristic (always available; deterministic keyword/regex rules)
6
+ - gpt-4o (optional, requires OPENAI_API_KEY)
7
+ - claude-sonnet (optional, requires ANTHROPIC_API_KEY)
8
+ - gemini-pro (optional, requires GOOGLE_API_KEY)
9
+
10
+ Designed to run once during dataset prep. The output is committed as a parquet so
11
+ downstream training does not depend on API access.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import os
18
+ import re
19
+ import time
20
+ from dataclasses import dataclass
21
+ from typing import Iterable, Optional
22
+
23
+ from greenrouting.data.schema import CapabilityLabel, CapabilityVotes, RawQuery
24
+ from greenrouting.routing.registry import CAPABILITY_KEYS
25
+
26
+ CATEGORY_TO_LABELS: dict[str, list[str]] = {
27
+ "code": ["code"],
28
+ "math": ["math"],
29
+ "reasoning": ["reasoning"],
30
+ "knowledge": ["knowledge"],
31
+ "instruction": ["instruction"],
32
+ "creative": ["creative"],
33
+ "multilingual": ["multilingual"],
34
+ "simple_chat": ["simple_chat"],
35
+ }
36
+
37
+
38
+ def source_prior_vote(category: str) -> dict[str, float]:
39
+ labels = CATEGORY_TO_LABELS.get(category, [])
40
+ return {k: (1.0 if k in labels else 0.0) for k in CAPABILITY_KEYS}
41
+
42
+
43
+ _HEURISTIC_PATTERNS: dict[str, list[str]] = {
44
+ "code": [
45
+ r"\b(code|function|class|def |algorithm|implement|debug|stack trace|api|sdk)\b",
46
+ r"\b(python|javascript|typescript|rust|go|c\+\+|java|sql|html|css)\b",
47
+ r"\b(refactor|unit test|regex|linter|compile|recursion)\b",
48
+ r"```",
49
+ ],
50
+ "math": [
51
+ r"\b(calculate|compute|solve|equation|integral|derivative|matrix|vector|theorem|prove)\b",
52
+ r"\b(probability|sum|product|mean|median|variance|standard deviation|percentage)\b",
53
+ r"\d+\s*[+\-*/×÷=]\s*\d+",
54
+ r"\b(arithmetic|fraction|geometry|algebra|trig)\b",
55
+ ],
56
+ "reasoning": [
57
+ r"\b(why|how does|explain|reason|because|therefore|argue|justify|implication)\b",
58
+ r"\b(compare|contrast|analyze|evaluate|trade.?off|infer|deduce)\b",
59
+ ],
60
+ "knowledge": [
61
+ r"\b(who|what is|when did|where is|history|definition|capital|founded|named)\b",
62
+ r"\b(country|continent|invented|discovered|president|prime minister)\b",
63
+ ],
64
+ "instruction": [
65
+ r"\b(write|draft|create|generate|produce|format|list|outline|step.?by.?step|summarize)\b",
66
+ r"\b(rewrite|translate from|convert to|extract)\b",
67
+ ],
68
+ "creative": [
69
+ r"\b(story|poem|novel|character|plot|scene|metaphor|fictional|haiku|song lyric)\b",
70
+ r"\b(write a (?:short )?(?:story|poem|haiku|song))\b",
71
+ ],
72
+ "multilingual": [
73
+ r"\b(translate|translation|en español|en français|auf deutsch|на русском|中文|日本語|한국어)\b",
74
+ r"[Ѐ-ӿ一-鿿぀-ゟ゠-ヿ가-힣]",
75
+ ],
76
+ "simple_chat": [
77
+ r"^\s*(hi|hello|hey|thanks|thank you|good morning|good evening|sup|yo)\b",
78
+ ],
79
+ }
80
+
81
+
82
+ def heuristic_vote(text: str) -> dict[str, float]:
83
+ out = {k: 0.0 for k in CAPABILITY_KEYS}
84
+ for cap, patterns in _HEURISTIC_PATTERNS.items():
85
+ for pat in patterns:
86
+ if re.search(pat, text, flags=re.IGNORECASE | re.MULTILINE):
87
+ out[cap] = 1.0
88
+ break
89
+ if all(v == 0.0 for v in out.values()):
90
+ if len(text.strip()) < 80:
91
+ out["simple_chat"] = 1.0
92
+ else:
93
+ out["instruction"] = 1.0
94
+ return out
95
+
96
+
97
+ _LABELER_SYSTEM_PROMPT = (
98
+ "You are labeling AI queries by which capabilities they require. "
99
+ "Capabilities: code, math, reasoning, knowledge, instruction, creative, multilingual, "
100
+ "simple_chat. A query can require multiple capabilities. "
101
+ "Reply with strict JSON only, in the form: "
102
+ '{"code": 0|1, "math": 0|1, "reasoning": 0|1, "knowledge": 0|1, '
103
+ '"instruction": 0|1, "creative": 0|1, "multilingual": 0|1, "simple_chat": 0|1}.'
104
+ )
105
+
106
+
107
+ def _user_prompt(query: str) -> str:
108
+ return f"Query:\n{query}\n\nRespond with JSON only."
109
+
110
+
111
+ def _parse_vote(raw: str) -> dict[str, float]:
112
+ try:
113
+ data = json.loads(_extract_json(raw))
114
+ except Exception:
115
+ return {k: 0.0 for k in CAPABILITY_KEYS}
116
+ return {k: float(1 if data.get(k) else 0) for k in CAPABILITY_KEYS}
117
+
118
+
119
+ def _extract_json(text: str) -> str:
120
+ match = re.search(r"\{.*\}", text, flags=re.DOTALL)
121
+ return match.group(0) if match else text
122
+
123
+
124
+ def _gpt_vote(text: str) -> Optional[dict[str, float]]:
125
+ api_key = os.environ.get("OPENAI_API_KEY")
126
+ if not api_key:
127
+ return None
128
+ try:
129
+ from openai import OpenAI
130
+ except ImportError:
131
+ return None
132
+ client = OpenAI(api_key=api_key)
133
+ resp = client.chat.completions.create(
134
+ model="gpt-4o-mini",
135
+ messages=[
136
+ {"role": "system", "content": _LABELER_SYSTEM_PROMPT},
137
+ {"role": "user", "content": _user_prompt(text)},
138
+ ],
139
+ temperature=0,
140
+ response_format={"type": "json_object"},
141
+ )
142
+ return _parse_vote(resp.choices[0].message.content or "{}")
143
+
144
+
145
+ def _claude_vote(text: str) -> Optional[dict[str, float]]:
146
+ api_key = os.environ.get("ANTHROPIC_API_KEY")
147
+ if not api_key:
148
+ return None
149
+ try:
150
+ import anthropic
151
+ except ImportError:
152
+ return None
153
+ client = anthropic.Anthropic(api_key=api_key)
154
+ resp = client.messages.create(
155
+ model="claude-haiku-4-5",
156
+ max_tokens=200,
157
+ system=_LABELER_SYSTEM_PROMPT,
158
+ messages=[{"role": "user", "content": _user_prompt(text)}],
159
+ )
160
+ body = "".join(b.text for b in resp.content if getattr(b, "type", "") == "text")
161
+ return _parse_vote(body)
162
+
163
+
164
+ def _gemini_vote(text: str) -> Optional[dict[str, float]]:
165
+ api_key = os.environ.get("GOOGLE_API_KEY")
166
+ if not api_key:
167
+ return None
168
+ try:
169
+ import google.generativeai as genai
170
+ except ImportError:
171
+ return None
172
+ genai.configure(api_key=api_key)
173
+ model = genai.GenerativeModel("gemini-1.5-flash", system_instruction=_LABELER_SYSTEM_PROMPT)
174
+ resp = model.generate_content(_user_prompt(text))
175
+ return _parse_vote(resp.text or "{}")
176
+
177
+
178
+ @dataclass
179
+ class LabelerConfig:
180
+ use_heuristic: bool = True
181
+ use_gpt: bool = False
182
+ use_claude: bool = False
183
+ use_gemini: bool = False
184
+ source_prior_weight: float = 0.5
185
+ sleep_between_calls_s: float = 0.0
186
+
187
+
188
+ def aggregate_votes(votes: CapabilityVotes, source_prior_weight: float = 0.5) -> dict[str, float]:
189
+ voters = [v for v in (votes.heuristic, votes.gpt, votes.claude, votes.gemini) if v is not None]
190
+ if not voters:
191
+ return dict(votes.source_prior) if votes.source_prior else {k: 0.0 for k in CAPABILITY_KEYS}
192
+ result: dict[str, float] = {}
193
+ total_weight = source_prior_weight + len(voters)
194
+ for cap in CAPABILITY_KEYS:
195
+ prior_term = source_prior_weight * float(votes.source_prior.get(cap, 0.0))
196
+ vendor_sum = sum(float(v.get(cap, 0.0)) for v in voters)
197
+ result[cap] = (prior_term + vendor_sum) / total_weight
198
+ return result
199
+
200
+
201
+ def label_query(query: RawQuery, config: LabelerConfig) -> CapabilityLabel:
202
+ votes = CapabilityVotes(source_prior=source_prior_vote(query.source_category))
203
+ if config.use_heuristic:
204
+ votes.heuristic = heuristic_vote(query.text)
205
+ if config.use_gpt:
206
+ votes.gpt = _gpt_vote(query.text)
207
+ if config.sleep_between_calls_s:
208
+ time.sleep(config.sleep_between_calls_s)
209
+ if config.use_claude:
210
+ votes.claude = _claude_vote(query.text)
211
+ if config.sleep_between_calls_s:
212
+ time.sleep(config.sleep_between_calls_s)
213
+ if config.use_gemini:
214
+ votes.gemini = _gemini_vote(query.text)
215
+ if config.sleep_between_calls_s:
216
+ time.sleep(config.sleep_between_calls_s)
217
+
218
+ aggregated = aggregate_votes(votes, source_prior_weight=config.source_prior_weight)
219
+ method = "+".join(
220
+ m for m, present in [
221
+ ("heuristic", votes.heuristic is not None),
222
+ ("gpt", votes.gpt is not None),
223
+ ("claude", votes.claude is not None),
224
+ ("gemini", votes.gemini is not None),
225
+ ] if present
226
+ ) or "source_prior_only"
227
+
228
+ return CapabilityLabel(
229
+ query_id=query.id,
230
+ capabilities=aggregated,
231
+ votes=votes,
232
+ aggregation_method=method,
233
+ )
234
+
235
+
236
+ def label_queries(queries: Iterable[RawQuery], config: LabelerConfig) -> list[CapabilityLabel]:
237
+ return [label_query(q, config) for q in queries]
greenrouting/data/cascade.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Difficulty cascade: runs each query against an ascending ladder of models,
2
+ grades the response, and derives a continuous `min_capable_log_params` label.
3
+
4
+ Memory strategy: load one rung at a time, run all queries, dump checkpoint, free
5
+ the weights, then advance to the next rung. Resumes from per-rung JSONL files.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import math
12
+ import time
13
+ from dataclasses import asdict, dataclass
14
+ from pathlib import Path
15
+ from typing import Iterable, Optional
16
+
17
+ from greenrouting.data.graders import grade
18
+ from greenrouting.data.schema import LabeledQuery, RawQuery
19
+
20
+
21
+ @dataclass
22
+ class RungResult:
23
+ rung_id: str
24
+ params_b: float
25
+ query_id: str
26
+ sample_index: int
27
+ response: str
28
+ score: float
29
+ response_tokens: int
30
+
31
+
32
+ def _read_raw_manifest(path: str | Path) -> list[RawQuery]:
33
+ queries: list[RawQuery] = []
34
+ with open(path, "r", encoding="utf-8") as f:
35
+ for line in f:
36
+ if not line.strip():
37
+ continue
38
+ data = json.loads(line)
39
+ queries.append(RawQuery(**data))
40
+ return queries
41
+
42
+
43
+ def _read_rung_checkpoint(path: Path) -> dict[str, list[RungResult]]:
44
+ if not path.exists():
45
+ return {}
46
+ out: dict[str, list[RungResult]] = {}
47
+ with open(path, "r", encoding="utf-8") as f:
48
+ for line in f:
49
+ if not line.strip():
50
+ continue
51
+ row = json.loads(line)
52
+ r = RungResult(**row)
53
+ out.setdefault(r.query_id, []).append(r)
54
+ return out
55
+
56
+
57
+ def _append_rung_checkpoint(path: Path, result: RungResult) -> None:
58
+ path.parent.mkdir(parents=True, exist_ok=True)
59
+ with open(path, "a", encoding="utf-8") as f:
60
+ f.write(json.dumps(asdict(result)) + "\n")
61
+
62
+
63
+ def _load_model_and_tokenizer(hf_model: str):
64
+ import torch
65
+ from transformers import AutoModelForCausalLM, AutoTokenizer
66
+
67
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
68
+ tok = AutoTokenizer.from_pretrained(hf_model)
69
+ if tok.pad_token_id is None:
70
+ tok.pad_token_id = tok.eos_token_id
71
+ model = AutoModelForCausalLM.from_pretrained(
72
+ hf_model, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None
73
+ )
74
+ model.eval()
75
+ return tok, model
76
+
77
+
78
+ def _free_model(model) -> None:
79
+ import gc
80
+ del model
81
+ gc.collect()
82
+ try:
83
+ import torch
84
+ if torch.cuda.is_available():
85
+ torch.cuda.empty_cache()
86
+ except Exception:
87
+ pass
88
+
89
+
90
+ def _format_prompt(tok, query: str) -> str:
91
+ if hasattr(tok, "apply_chat_template") and tok.chat_template:
92
+ messages = [{"role": "user", "content": query}]
93
+ return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
94
+ return f"### Instruction:\n{query}\n\n### Response:\n"
95
+
96
+
97
+ def _generate(tok, model, prompt: str, max_new_tokens: int, temperature: float) -> tuple[str, int]:
98
+ import torch
99
+ inputs = tok(prompt, return_tensors="pt").to(model.device)
100
+ do_sample = temperature > 0
101
+ with torch.no_grad():
102
+ out = model.generate(
103
+ **inputs,
104
+ max_new_tokens=max_new_tokens,
105
+ do_sample=do_sample,
106
+ temperature=temperature if do_sample else 1.0,
107
+ pad_token_id=tok.pad_token_id,
108
+ )
109
+ new_tokens = out[0][inputs["input_ids"].shape[1]:]
110
+ response = tok.decode(new_tokens, skip_special_tokens=True)
111
+ return response.strip(), int(new_tokens.shape[0])
112
+
113
+
114
+ def run_rung(
115
+ rung,
116
+ queries: list[RawQuery],
117
+ k_samples: int,
118
+ max_new_tokens: int,
119
+ temperature_first: float,
120
+ temperature_resample: float,
121
+ checkpoint_path: Path,
122
+ progress: bool = True,
123
+ ) -> list[RungResult]:
124
+ existing = _read_rung_checkpoint(checkpoint_path)
125
+ pending = [q for q in queries if len(existing.get(q.id, [])) < k_samples]
126
+ results: list[RungResult] = [r for rs in existing.values() for r in rs]
127
+
128
+ if not pending:
129
+ return results
130
+
131
+ tok, model = _load_model_and_tokenizer(rung.hf_model)
132
+ try:
133
+ for i, q in enumerate(pending):
134
+ done = len(existing.get(q.id, []))
135
+ for s in range(done, k_samples):
136
+ temp = temperature_first if s == 0 else temperature_resample
137
+ prompt = _format_prompt(tok, q.text)
138
+ start = time.time()
139
+ response, n_tokens = _generate(tok, model, prompt, max_new_tokens, temp)
140
+ score = grade(response, q.grader_metadata, max_new_tokens=max_new_tokens)
141
+ rr = RungResult(
142
+ rung_id=rung.id,
143
+ params_b=rung.params_b,
144
+ query_id=q.id,
145
+ sample_index=s,
146
+ response=response,
147
+ score=score,
148
+ response_tokens=n_tokens,
149
+ )
150
+ _append_rung_checkpoint(checkpoint_path, rr)
151
+ results.append(rr)
152
+ if progress:
153
+ print(
154
+ f" [{rung.id}] {i+1}/{len(pending)} sample={s} "
155
+ f"score={score:.2f} tok={n_tokens} t={time.time()-start:.1f}s"
156
+ )
157
+ finally:
158
+ _free_model(model)
159
+
160
+ return results
161
+
162
+
163
+ def derive_difficulty(
164
+ per_rung: dict[str, list[float]],
165
+ rung_params_b: dict[str, float],
166
+ pass_threshold: float,
167
+ ) -> float:
168
+ """Continuous min_capable_log_params from per-rung mean scores.
169
+
170
+ Logic:
171
+ - sort rungs by parameter count
172
+ - for each rung, mean score across samples is the "rung pass rate"
173
+ - the smallest rung whose pass rate >= threshold defines the floor
174
+ - linear interpolation in log(params) space between the failing and passing rung
175
+ - if no rung passes, return log(largest_rung_params * 2) as out-of-pool
176
+ - if smallest rung already passes, return log(smallest_rung_params)
177
+ """
178
+ sorted_rungs = sorted(rung_params_b.items(), key=lambda kv: kv[1])
179
+ if not sorted_rungs:
180
+ return math.log(8e9)
181
+
182
+ means: list[tuple[str, float, float]] = []
183
+ for rung_id, params_b in sorted_rungs:
184
+ scores = per_rung.get(rung_id, [])
185
+ if not scores:
186
+ continue
187
+ means.append((rung_id, params_b, sum(scores) / len(scores)))
188
+
189
+ if not means:
190
+ return math.log(sorted_rungs[-1][1] * 1e9 * 2)
191
+
192
+ if means[0][2] >= pass_threshold:
193
+ return math.log(means[0][1] * 1e9)
194
+
195
+ for i in range(1, len(means)):
196
+ prev_id, prev_params, prev_score = means[i - 1]
197
+ cur_id, cur_params, cur_score = means[i]
198
+ if cur_score >= pass_threshold:
199
+ denom = max(cur_score - prev_score, 1e-6)
200
+ t = max(0.0, min(1.0, (pass_threshold - prev_score) / denom))
201
+ log_lo = math.log(prev_params * 1e9)
202
+ log_hi = math.log(cur_params * 1e9)
203
+ return log_lo + t * (log_hi - log_lo)
204
+
205
+ return math.log(means[-1][1] * 1e9 * 2)
206
+
207
+
208
+ def derive_length_bucket(response_token_counts: list[int]) -> str:
209
+ if not response_token_counts:
210
+ return "medium"
211
+ avg = sum(response_token_counts) / len(response_token_counts)
212
+ if avg < 100:
213
+ return "short"
214
+ if avg < 400:
215
+ return "medium"
216
+ return "long"
217
+
218
+
219
+ def run_cascade(
220
+ config,
221
+ raw_manifest_path: str | Path,
222
+ capability_labels_path: str | Path,
223
+ train_path: str | Path,
224
+ test_path: str | Path,
225
+ pass_threshold: float = 0.7,
226
+ ) -> None:
227
+ from greenrouting.data.builder import (
228
+ read_capability_labels,
229
+ write_labeled_dataset,
230
+ )
231
+
232
+ queries = _read_raw_manifest(raw_manifest_path)
233
+ cap_labels = read_capability_labels(capability_labels_path)
234
+
235
+ per_rung_results: dict[str, dict[str, list[RungResult]]] = {}
236
+ out_dir = Path(config.output_dir)
237
+
238
+ for rung in config.cascade.rungs:
239
+ if not rung.runs_locally:
240
+ print(f"[skip] {rung.id} marked as not runs_locally; configure remote backend.")
241
+ continue
242
+ ckpt = out_dir / f"cascade_{config.profile_name}_{rung.id}.jsonl"
243
+ results = run_rung(
244
+ rung,
245
+ queries,
246
+ k_samples=config.cascade.k_samples,
247
+ max_new_tokens=config.cascade.max_new_tokens,
248
+ temperature_first=config.cascade.temperature_first,
249
+ temperature_resample=config.cascade.temperature_resample,
250
+ checkpoint_path=ckpt,
251
+ )
252
+ by_query: dict[str, list[RungResult]] = {}
253
+ for r in results:
254
+ by_query.setdefault(r.query_id, []).append(r)
255
+ per_rung_results[rung.id] = by_query
256
+ print(f"[done] rung {rung.id}: {sum(len(v) for v in by_query.values())} samples")
257
+
258
+ rung_params: dict[str, float] = {r.id: r.params_b for r in config.cascade.rungs}
259
+
260
+ labeled: list[LabeledQuery] = []
261
+ for q in queries:
262
+ per_rung_scores: dict[str, list[float]] = {}
263
+ token_counts: list[int] = []
264
+ for rung_id, by_query in per_rung_results.items():
265
+ for rr in by_query.get(q.id, []):
266
+ per_rung_scores.setdefault(rung_id, []).append(rr.score)
267
+ token_counts.append(rr.response_tokens)
268
+ if not per_rung_scores:
269
+ continue
270
+ difficulty = derive_difficulty(per_rung_scores, rung_params, pass_threshold)
271
+ length_bucket = derive_length_bucket(token_counts)
272
+ caps = cap_labels.get(q.id, {})
273
+ labeled.append(LabeledQuery(
274
+ raw=q,
275
+ capabilities=caps,
276
+ difficulty_log_params=difficulty,
277
+ length_bucket=length_bucket,
278
+ cascade_results={
279
+ "per_rung_mean_scores": {
280
+ k: sum(v) / len(v) for k, v in per_rung_scores.items()
281
+ },
282
+ },
283
+ ))
284
+
285
+ write_labeled_dataset(
286
+ train_path=train_path,
287
+ test_path=test_path,
288
+ rows=labeled,
289
+ test_split=config.test_split,
290
+ seed=config.seed,
291
+ )
292
+ print(f"[done] wrote {len(labeled)} labeled rows -> {train_path}, {test_path}")
greenrouting/data/graders.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Graders that score whether a model's response solved a query.
2
+
3
+ Outputs a bounded float in [0.0, 1.0]. Where a deterministic grader exists
4
+ (numeric extraction, multi-choice match), the score is binary. For free-form
5
+ responses we use a deterministic proxy (parseability, length) clamped to a
6
+ sensible range.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import ast
12
+ import re
13
+ from typing import Optional
14
+
15
+ NUMBER_PATTERN = re.compile(r"-?\d+(?:[.,]\d+)?")
16
+
17
+
18
+ def _extract_last_number(text: str) -> Optional[float]:
19
+ matches = NUMBER_PATTERN.findall(text or "")
20
+ if not matches:
21
+ return None
22
+ last = matches[-1].replace(",", "")
23
+ try:
24
+ return float(last)
25
+ except ValueError:
26
+ return None
27
+
28
+
29
+ def _extract_first_letter(text: str, valid_letters: str = "ABCDEFGH") -> Optional[str]:
30
+ if not text:
31
+ return None
32
+ cleaned = text.strip()
33
+ m = re.search(
34
+ r"(?:answer|the answer is|final answer)\s*[:\-]?\s*\(?([" + valid_letters + r"])\)?",
35
+ cleaned,
36
+ re.IGNORECASE,
37
+ )
38
+ if m:
39
+ return m.group(1).upper()
40
+ m = re.match(r"\s*\(([" + valid_letters + r"])\)", cleaned)
41
+ if m:
42
+ return m.group(1)
43
+ m = re.match(r"\s*([" + valid_letters + r"])[\s\.\):,]+", cleaned)
44
+ if m:
45
+ return m.group(1)
46
+ m = re.match(r"\s*([" + valid_letters + r"])\s*$", cleaned)
47
+ if m:
48
+ return m.group(1)
49
+ return None
50
+
51
+
52
+ def grade_numeric(response: str, gold: str) -> float:
53
+ g = _extract_last_number(gold)
54
+ r = _extract_last_number(response)
55
+ if g is None or r is None:
56
+ return 0.0
57
+ return 1.0 if abs(g - r) < 1e-6 else 0.0
58
+
59
+
60
+ def grade_multichoice(response: str, gold_letter: str) -> float:
61
+ if not gold_letter:
62
+ return 0.0
63
+ pred = _extract_first_letter(response or "")
64
+ if pred is None:
65
+ return 0.0
66
+ return 1.0 if pred.upper() == gold_letter.strip().upper() else 0.0
67
+
68
+
69
+ def grade_string_match(response: str, gold: str) -> float:
70
+ if not gold:
71
+ return 0.0
72
+ norm_resp = re.sub(r"\s+", " ", (response or "")).strip().lower()
73
+ norm_gold = re.sub(r"\s+", " ", gold).strip().lower()
74
+ if not norm_gold:
75
+ return 0.0
76
+ return 1.0 if norm_gold in norm_resp else 0.0
77
+
78
+
79
+ def grade_code_proxy(response: str, entry_point: Optional[str] = None) -> float:
80
+ """Cheap proxy for code correctness: code parses + (optionally) defines the
81
+ expected entry-point function. No execution; safe to run on untrusted output."""
82
+ if not response:
83
+ return 0.0
84
+ code = _extract_code_block(response)
85
+ try:
86
+ tree = ast.parse(code)
87
+ except SyntaxError:
88
+ return 0.0
89
+ if entry_point:
90
+ for node in ast.walk(tree):
91
+ if isinstance(node, ast.FunctionDef) and node.name == entry_point:
92
+ return 1.0
93
+ return 0.4
94
+ return 0.7
95
+
96
+
97
+ def _extract_code_block(text: str) -> str:
98
+ fence = re.search(r"```(?:python|py)?\s*\n(.*?)```", text, re.DOTALL | re.IGNORECASE)
99
+ if fence:
100
+ return fence.group(1)
101
+ if "def " in text or "class " in text or "import " in text:
102
+ return text
103
+ return text
104
+
105
+
106
+ def grade_response_quality(response: str, max_new_tokens: int) -> float:
107
+ """For open-ended queries with no gold answer. Length-and-diversity heuristic.
108
+
109
+ Returns 0.0 for empty/garbage; saturates toward 1.0 as length grows up to
110
+ `max_new_tokens` and vocabulary diversity is healthy.
111
+ """
112
+ if not response:
113
+ return 0.0
114
+ tokens = response.split()
115
+ if len(tokens) < 3:
116
+ return 0.05
117
+ unique_ratio = len(set(t.lower() for t in tokens)) / len(tokens)
118
+ length_score = min(1.0, len(tokens) / max(max_new_tokens * 0.5, 1))
119
+ return max(0.0, min(1.0, 0.6 * length_score + 0.4 * unique_ratio))
120
+
121
+
122
+ def grade_ifeval_proxy(response: str, instruction_id_list: list[str]) -> float:
123
+ """Lightweight stand-in for IFEval's strict constraint grader. Counts how many
124
+ structural cues from instruction IDs are present in the response."""
125
+ if not response:
126
+ return 0.0
127
+ if not instruction_id_list:
128
+ return 0.5
129
+ hits = 0
130
+ for iid in instruction_id_list:
131
+ if "list" in iid and re.search(r"^\s*[-*•]|^\s*\d+\.", response, re.MULTILINE):
132
+ hits += 1
133
+ elif "json" in iid and (response.strip().startswith("{") or response.strip().startswith("[")):
134
+ hits += 1
135
+ elif "letter" in iid:
136
+ hits += 1
137
+ elif "word_count" in iid:
138
+ hits += 1
139
+ elif "uppercase" in iid and any(c.isupper() for c in response):
140
+ hits += 1
141
+ else:
142
+ hits += 0.5
143
+ return min(1.0, hits / max(len(instruction_id_list), 1))
144
+
145
+
146
+ def grade(response: str, grader_metadata: dict, max_new_tokens: int = 256) -> float:
147
+ g = grader_metadata.get("grader")
148
+ if g == "exact_numeric":
149
+ return grade_numeric(response, grader_metadata.get("gold_final", ""))
150
+ if g == "multichoice":
151
+ return grade_multichoice(response, grader_metadata.get("gold_letter", ""))
152
+ if g == "string_match":
153
+ return grade_string_match(response, grader_metadata.get("gold", ""))
154
+ if g == "code_exec":
155
+ return grade_code_proxy(response, grader_metadata.get("entry_point"))
156
+ if g == "ifeval_constraints":
157
+ return grade_ifeval_proxy(response, grader_metadata.get("instruction_id_list", []))
158
+ return grade_response_quality(response, max_new_tokens)
greenrouting/data/schema.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Schema for queries and labels flowing through the data pipeline."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field, asdict
6
+ from typing import Optional
7
+
8
+ from greenrouting.routing.registry import CAPABILITY_KEYS
9
+
10
+ LENGTH_BUCKETS: tuple[str, str, str] = ("short", "medium", "long")
11
+
12
+
13
+ @dataclass
14
+ class RawQuery:
15
+ id: str
16
+ text: str
17
+ source: str
18
+ source_category: str
19
+ has_grader: bool = False
20
+ grader_metadata: dict = field(default_factory=dict)
21
+
22
+ def to_dict(self) -> dict:
23
+ return asdict(self)
24
+
25
+
26
+ @dataclass
27
+ class CapabilityVotes:
28
+ source_prior: dict[str, float] = field(default_factory=dict)
29
+ heuristic: Optional[dict[str, float]] = None
30
+ gpt: Optional[dict[str, float]] = None
31
+ claude: Optional[dict[str, float]] = None
32
+ gemini: Optional[dict[str, float]] = None
33
+
34
+ def vote_count(self) -> int:
35
+ return sum(
36
+ 1 for v in (self.heuristic, self.gpt, self.claude, self.gemini) if v is not None
37
+ )
38
+
39
+
40
+ @dataclass
41
+ class CapabilityLabel:
42
+ query_id: str
43
+ capabilities: dict[str, float]
44
+ votes: CapabilityVotes
45
+ aggregation_method: str
46
+
47
+ def to_record(self) -> dict:
48
+ rec = {"query_id": self.query_id, "aggregation_method": self.aggregation_method}
49
+ for k in CAPABILITY_KEYS:
50
+ rec[f"cap_{k}"] = float(self.capabilities.get(k, 0.0))
51
+ for vendor in ("source_prior", "heuristic", "gpt", "claude", "gemini"):
52
+ v = getattr(self.votes, vendor)
53
+ if v is None:
54
+ continue
55
+ for k in CAPABILITY_KEYS:
56
+ rec[f"vote_{vendor}_{k}"] = float(v.get(k, 0.0))
57
+ return rec
58
+
59
+
60
+ @dataclass
61
+ class LabeledQuery:
62
+ raw: RawQuery
63
+ capabilities: dict[str, float]
64
+ difficulty_log_params: Optional[float]
65
+ length_bucket: Optional[str]
66
+ cascade_results: dict = field(default_factory=dict)
67
+
68
+ def to_record(self) -> dict:
69
+ rec = {
70
+ "id": self.raw.id,
71
+ "text": self.raw.text,
72
+ "source": self.raw.source,
73
+ "source_category": self.raw.source_category,
74
+ "has_grader": self.raw.has_grader,
75
+ "difficulty_log_params": self.difficulty_log_params,
76
+ "length_bucket": self.length_bucket,
77
+ }
78
+ for k in CAPABILITY_KEYS:
79
+ rec[f"cap_{k}"] = float(self.capabilities.get(k, 0.0))
80
+ return rec
greenrouting/data/seed_dataset.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Curated training set with multi-label capability assignments and difficulty.
2
+
3
+ Each entry is hand-authored. Phrasings deliberately vary across registers
4
+ (commands, questions, conversational asks, fragments) so the classifier learns
5
+ semantic patterns rather than surface keyword rules.
6
+
7
+ Schema for each entry: (text, primary_category, capabilities, difficulty_b, length).
8
+ - text: the raw query string
9
+ - primary_category: dominant capability bucket (used as `source_category`)
10
+ - capabilities: list of buckets that apply (multi-label)
11
+ - difficulty_b: parameter count (in billions) of the smallest model that would
12
+ plausibly handle this well; used to derive `difficulty_log_params`
13
+ - length: expected answer length bucket (short/medium/long)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+
21
+
22
+ @dataclass(frozen=True)
23
+ class SeedEntry:
24
+ text: str
25
+ primary_category: str
26
+ capabilities: tuple[str, ...]
27
+ difficulty_b: float
28
+ length: str
29
+
30
+
31
+ def _e(text, primary, caps, diff_b, length) -> SeedEntry:
32
+ return SeedEntry(text=text, primary_category=primary, capabilities=tuple(caps), difficulty_b=diff_b, length=length)
33
+
34
+
35
+ _SIMPLE_CHAT: list[SeedEntry] = [
36
+ _e("hi", "simple_chat", ["simple_chat"], 0.5, "short"),
37
+ _e("hello", "simple_chat", ["simple_chat"], 0.5, "short"),
38
+ _e("hey there", "simple_chat", ["simple_chat"], 0.5, "short"),
39
+ _e("good morning", "simple_chat", ["simple_chat"], 0.5, "short"),
40
+ _e("how's it going", "simple_chat", ["simple_chat"], 0.5, "short"),
41
+ _e("thanks!", "simple_chat", ["simple_chat"], 0.5, "short"),
42
+ _e("thank you so much", "simple_chat", ["simple_chat"], 0.5, "short"),
43
+ _e("appreciate it", "simple_chat", ["simple_chat"], 0.5, "short"),
44
+ _e("ok cool", "simple_chat", ["simple_chat"], 0.5, "short"),
45
+ _e("got it", "simple_chat", ["simple_chat"], 0.5, "short"),
46
+ _e("sounds good", "simple_chat", ["simple_chat"], 0.5, "short"),
47
+ _e("nice", "simple_chat", ["simple_chat"], 0.5, "short"),
48
+ _e("makes sense", "simple_chat", ["simple_chat"], 0.5, "short"),
49
+ _e("yep", "simple_chat", ["simple_chat"], 0.5, "short"),
50
+ _e("yes please", "simple_chat", ["simple_chat"], 0.5, "short"),
51
+ _e("nope", "simple_chat", ["simple_chat"], 0.5, "short"),
52
+ _e("not really", "simple_chat", ["simple_chat"], 0.5, "short"),
53
+ _e("can you help me?", "simple_chat", ["simple_chat"], 0.5, "short"),
54
+ _e("are you there?", "simple_chat", ["simple_chat"], 0.5, "short"),
55
+ _e("what can you do?", "simple_chat", ["simple_chat", "instruction"], 1.0, "short"),
56
+ _e("how does this work", "simple_chat", ["simple_chat", "instruction"], 1.0, "short"),
57
+ _e("who are you?", "simple_chat", ["simple_chat"], 0.5, "short"),
58
+ _e("are you human?", "simple_chat", ["simple_chat"], 0.5, "short"),
59
+ _e("good night", "simple_chat", ["simple_chat"], 0.5, "short"),
60
+ _e("see you later", "simple_chat", ["simple_chat"], 0.5, "short"),
61
+ _e("bye!", "simple_chat", ["simple_chat"], 0.5, "short"),
62
+ _e("lol", "simple_chat", ["simple_chat"], 0.5, "short"),
63
+ _e("haha that's funny", "simple_chat", ["simple_chat"], 0.5, "short"),
64
+ _e("oh interesting", "simple_chat", ["simple_chat"], 0.5, "short"),
65
+ _e("hmm okay", "simple_chat", ["simple_chat"], 0.5, "short"),
66
+ _e("can you elaborate", "simple_chat", ["simple_chat", "instruction"], 1.0, "short"),
67
+ _e("tell me more", "simple_chat", ["simple_chat", "instruction"], 1.0, "short"),
68
+ _e("what do you think?", "simple_chat", ["simple_chat", "reasoning"], 1.0, "short"),
69
+ _e("any thoughts?", "simple_chat", ["simple_chat", "reasoning"], 1.0, "short"),
70
+ _e("got a sec?", "simple_chat", ["simple_chat"], 0.5, "short"),
71
+ _e("quick question", "simple_chat", ["simple_chat"], 0.5, "short"),
72
+ _e("just checking in", "simple_chat", ["simple_chat"], 0.5, "short"),
73
+ _e("how was your day", "simple_chat", ["simple_chat"], 0.5, "short"),
74
+ _e("what's up", "simple_chat", ["simple_chat"], 0.5, "short"),
75
+ _e("you good?", "simple_chat", ["simple_chat"], 0.5, "short"),
76
+ _e("yeah that works", "simple_chat", ["simple_chat"], 0.5, "short"),
77
+ _e("alright then", "simple_chat", ["simple_chat"], 0.5, "short"),
78
+ _e("hold on a sec", "simple_chat", ["simple_chat"], 0.5, "short"),
79
+ _e("never mind", "simple_chat", ["simple_chat"], 0.5, "short"),
80
+ _e("scratch that", "simple_chat", ["simple_chat"], 0.5, "short"),
81
+ _e("oops", "simple_chat", ["simple_chat"], 0.5, "short"),
82
+ _e("my bad", "simple_chat", ["simple_chat"], 0.5, "short"),
83
+ _e("you're awesome", "simple_chat", ["simple_chat"], 0.5, "short"),
84
+ _e("this is great", "simple_chat", ["simple_chat"], 0.5, "short"),
85
+ _e("perfect, thanks", "simple_chat", ["simple_chat"], 0.5, "short"),
86
+ _e("can we try again", "simple_chat", ["simple_chat"], 0.5, "short"),
87
+ _e("one more time", "simple_chat", ["simple_chat"], 0.5, "short"),
88
+ _e("again please", "simple_chat", ["simple_chat"], 0.5, "short"),
89
+ _e("yo", "simple_chat", ["simple_chat"], 0.5, "short"),
90
+ _e("sup", "simple_chat", ["simple_chat"], 0.5, "short"),
91
+ _e("howdy", "simple_chat", ["simple_chat"], 0.5, "short"),
92
+ _e("greetings", "simple_chat", ["simple_chat"], 0.5, "short"),
93
+ _e("test", "simple_chat", ["simple_chat"], 0.5, "short"),
94
+ _e("just testing", "simple_chat", ["simple_chat"], 0.5, "short"),
95
+ _e("can you hear me", "simple_chat", ["simple_chat"], 0.5, "short"),
96
+ ]
97
+
98
+ _INSTRUCTION: list[SeedEntry] = [
99
+ _e("write a 3-bullet summary of the agenda below: 1) opening remarks 2) Q3 numbers 3) staffing", "instruction", ["instruction"], 3.0, "short"),
100
+ _e("turn this into a numbered list", "instruction", ["instruction"], 1.0, "short"),
101
+ _e("format the following as a markdown table", "instruction", ["instruction"], 2.0, "medium"),
102
+ _e("rewrite this in a more formal tone", "instruction", ["instruction"], 3.0, "medium"),
103
+ _e("make this paragraph shorter without losing any meaning", "instruction", ["instruction"], 3.0, "medium"),
104
+ _e("expand this outline into a full paragraph", "instruction", ["instruction"], 3.0, "medium"),
105
+ _e("draft a polite decline to a meeting invite", "instruction", ["instruction", "creative"], 3.0, "short"),
106
+ _e("write a short subject line for this email", "instruction", ["instruction"], 1.0, "short"),
107
+ _e("give me three name ideas for a coffee shop near a park", "instruction", ["instruction", "creative"], 3.0, "short"),
108
+ _e("create a checklist for moving apartments", "instruction", ["instruction"], 3.0, "medium"),
109
+ _e("outline a 5-day study plan for the GRE quantitative section", "instruction", ["instruction"], 7.0, "medium"),
110
+ _e("summarize this in two sentences please", "instruction", ["instruction"], 3.0, "short"),
111
+ _e("convert these notes into a clean meeting recap", "instruction", ["instruction"], 7.0, "medium"),
112
+ _e("turn the following log lines into a single sentence describing what happened", "instruction", ["instruction", "reasoning"], 7.0, "short"),
113
+ _e("group these errors by likely root cause", "instruction", ["instruction", "reasoning"], 8.0, "medium"),
114
+ _e("clean up the grammar in the paragraph below", "instruction", ["instruction"], 1.0, "medium"),
115
+ _e("rephrase this so a 12-year-old could follow", "instruction", ["instruction"], 3.0, "medium"),
116
+ _e("make a one-paragraph executive summary", "instruction", ["instruction"], 7.0, "short"),
117
+ _e("draft an out-of-office reply for next Friday", "instruction", ["instruction", "creative"], 1.0, "short"),
118
+ _e("create a 7-day workout split focused on legs and shoulders", "instruction", ["instruction"], 7.0, "medium"),
119
+ _e("write a packing list for a 4-day winter trip to Reykjavik", "instruction", ["instruction"], 3.0, "medium"),
120
+ _e("turn the bullet points below into a smooth paragraph", "instruction", ["instruction"], 3.0, "medium"),
121
+ _e("transform this dry product description into a punchy tagline", "instruction", ["instruction", "creative"], 7.0, "short"),
122
+ _e("draft a thank-you note to my mentor", "instruction", ["instruction", "creative"], 3.0, "short"),
123
+ _e("write meeting minutes from this transcript snippet", "instruction", ["instruction"], 7.0, "medium"),
124
+ _e("split this monolithic to-do into morning vs afternoon", "instruction", ["instruction"], 1.0, "medium"),
125
+ _e("rewrite this slack message so it doesn't sound passive aggressive", "instruction", ["instruction"], 3.0, "short"),
126
+ _e("draft a customer apology email after a 3-hour outage", "instruction", ["instruction", "creative"], 7.0, "medium"),
127
+ _e("turn this transcript into 5 talking points", "instruction", ["instruction"], 3.0, "medium"),
128
+ _e("convert this case study into a single tweet", "instruction", ["instruction", "creative"], 3.0, "short"),
129
+ _e("write release notes for version 2.3 covering the changelog below", "instruction", ["instruction"], 7.0, "medium"),
130
+ _e("make a side-by-side comparison table of the two job offers", "instruction", ["instruction", "reasoning"], 7.0, "medium"),
131
+ _e("compose a brief biography paragraph for a conference badge", "instruction", ["instruction", "creative"], 3.0, "short"),
132
+ _e("write a 4-line elevator pitch for an indie video game", "instruction", ["instruction", "creative"], 3.0, "short"),
133
+ _e("expand the following acronyms inline: API, SLA, P95, RAG", "instruction", ["instruction", "knowledge"], 3.0, "short"),
134
+ _e("create a polite reminder to a colleague who hasn't reviewed my PR", "instruction", ["instruction"], 1.0, "short"),
135
+ _e("draft an FAQ entry explaining how refunds work for SaaS subscriptions", "instruction", ["instruction"], 7.0, "medium"),
136
+ _e("turn this jira ticket description into a one-line standup update", "instruction", ["instruction"], 1.0, "short"),
137
+ _e("write a job description for a junior data analyst", "instruction", ["instruction"], 7.0, "medium"),
138
+ _e("convert the recipe below from imperial to metric", "instruction", ["instruction"], 1.0, "short"),
139
+ _e("organize the chaos in this email thread into a clean timeline", "instruction", ["instruction", "reasoning"], 7.0, "medium"),
140
+ _e("extract every dollar amount from the contract clause and list them", "instruction", ["instruction"], 3.0, "short"),
141
+ _e("write three follow-up email subject lines, increasing in urgency", "instruction", ["instruction", "creative"], 3.0, "short"),
142
+ _e("rewrite this resume bullet to emphasize impact", "instruction", ["instruction"], 3.0, "short"),
143
+ _e("draft a 1-paragraph linkedin post announcing a job change", "instruction", ["instruction", "creative"], 7.0, "short"),
144
+ _e("turn the bullet recap below into a polished retro doc section", "instruction", ["instruction"], 7.0, "medium"),
145
+ _e("compose a clear bug report from this user complaint", "instruction", ["instruction", "reasoning"], 7.0, "medium"),
146
+ _e("create a 5-question pre-interview survey for prospective tenants", "instruction", ["instruction"], 3.0, "medium"),
147
+ _e("rewrite the warning copy below to be friendlier without losing meaning", "instruction", ["instruction"], 3.0, "short"),
148
+ _e("write a one-line apology, a longer apology, and a formal apology", "instruction", ["instruction", "creative"], 3.0, "medium"),
149
+ ]
150
+
151
+ _KNOWLEDGE: list[SeedEntry] = [
152
+ _e("what's the capital of Mongolia", "knowledge", ["knowledge"], 1.0, "short"),
153
+ _e("when was the printing press invented", "knowledge", ["knowledge"], 1.0, "short"),
154
+ _e("what's the population of Lagos roughly", "knowledge", ["knowledge"], 1.0, "short"),
155
+ _e("who painted Guernica", "knowledge", ["knowledge"], 1.0, "short"),
156
+ _e("how long was the Hundred Years War actually", "knowledge", ["knowledge"], 3.0, "short"),
157
+ _e("what protein does insulin regulate", "knowledge", ["knowledge"], 7.0, "short"),
158
+ _e("what's the difference between a virus and a bacterium in plain terms", "knowledge", ["knowledge", "instruction"], 7.0, "medium"),
159
+ _e("name the three branches of the US government", "knowledge", ["knowledge"], 1.0, "short"),
160
+ _e("what year did World War I end", "knowledge", ["knowledge"], 1.0, "short"),
161
+ _e("who wrote One Hundred Years of Solitude", "knowledge", ["knowledge"], 1.0, "short"),
162
+ _e("what's photosynthesis in one sentence", "knowledge", ["knowledge", "instruction"], 1.0, "short"),
163
+ _e("how many bones are in the human body", "knowledge", ["knowledge"], 1.0, "short"),
164
+ _e("what does GDP stand for and what does it measure", "knowledge", ["knowledge"], 3.0, "short"),
165
+ _e("which planet has the most moons", "knowledge", ["knowledge"], 1.0, "short"),
166
+ _e("what's the chemical symbol for gold", "knowledge", ["knowledge"], 0.5, "short"),
167
+ _e("define entropy in thermodynamics", "knowledge", ["knowledge"], 7.0, "medium"),
168
+ _e("explain the citric acid cycle briefly", "knowledge", ["knowledge", "instruction"], 8.0, "medium"),
169
+ _e("what is OAuth 2.0 and what problem does it solve", "knowledge", ["knowledge", "instruction"], 7.0, "medium"),
170
+ _e("describe the Marshall Plan in two sentences", "knowledge", ["knowledge", "instruction"], 7.0, "short"),
171
+ _e("what's the difference between TCP and UDP", "knowledge", ["knowledge", "reasoning"], 7.0, "medium"),
172
+ _e("who founded the Stoic philosophical school", "knowledge", ["knowledge"], 3.0, "short"),
173
+ _e("what did Rosalind Franklin contribute to DNA research", "knowledge", ["knowledge"], 7.0, "medium"),
174
+ _e("explain RAID 5 vs RAID 10 storage", "knowledge", ["knowledge", "reasoning"], 8.0, "medium"),
175
+ _e("what is monetary policy in plain language", "knowledge", ["knowledge", "instruction"], 7.0, "medium"),
176
+ _e("how does a vaccine actually trigger immunity", "knowledge", ["knowledge", "reasoning"], 8.0, "medium"),
177
+ _e("what was the Bretton Woods system", "knowledge", ["knowledge"], 7.0, "medium"),
178
+ _e("who was the first woman to win a Nobel Prize", "knowledge", ["knowledge"], 1.0, "short"),
179
+ _e("what's the longest river in South America", "knowledge", ["knowledge"], 1.0, "short"),
180
+ _e("describe the role of the prefrontal cortex", "knowledge", ["knowledge", "instruction"], 8.0, "medium"),
181
+ _e("define inflation vs deflation simply", "knowledge", ["knowledge"], 3.0, "short"),
182
+ _e("what is BGP in networking", "knowledge", ["knowledge"], 7.0, "short"),
183
+ _e("explain what a CDN does", "knowledge", ["knowledge", "instruction"], 3.0, "short"),
184
+ _e("what's a Galois field", "knowledge", ["knowledge", "math"], 30.0, "medium"),
185
+ _e("who invented the World Wide Web", "knowledge", ["knowledge"], 1.0, "short"),
186
+ _e("what does the term 'eventual consistency' mean in databases", "knowledge", ["knowledge"], 8.0, "medium"),
187
+ _e("brief overview of the Treaty of Westphalia", "knowledge", ["knowledge", "instruction"], 8.0, "medium"),
188
+ _e("what is the boiling point of water in Kelvin", "knowledge", ["knowledge"], 0.5, "short"),
189
+ _e("what's the Coriolis effect", "knowledge", ["knowledge"], 3.0, "short"),
190
+ _e("explain Bayes' theorem in everyday language", "knowledge", ["knowledge", "math", "instruction"], 8.0, "medium"),
191
+ _e("what is the speed of light in vacuum", "knowledge", ["knowledge"], 0.5, "short"),
192
+ _e("when did the Berlin Wall fall", "knowledge", ["knowledge"], 1.0, "short"),
193
+ _e("what does HTTP/3 use under the hood", "knowledge", ["knowledge"], 8.0, "short"),
194
+ _e("describe the Heisenberg uncertainty principle", "knowledge", ["knowledge", "instruction"], 8.0, "medium"),
195
+ _e("what's the difference between a sonnet and a haiku", "knowledge", ["knowledge"], 1.0, "short"),
196
+ _e("name the inert noble gases", "knowledge", ["knowledge"], 1.0, "short"),
197
+ _e("how does a transformer model differ from an RNN at a high level", "knowledge", ["knowledge", "reasoning"], 30.0, "medium"),
198
+ _e("what was the Marshall McLuhan claim about media", "knowledge", ["knowledge"], 7.0, "short"),
199
+ _e("explain quantitative easing", "knowledge", ["knowledge", "instruction"], 8.0, "medium"),
200
+ _e("what's the half-life of carbon-14", "knowledge", ["knowledge"], 1.0, "short"),
201
+ _e("who composed The Rite of Spring", "knowledge", ["knowledge"], 1.0, "short"),
202
+ _e("what does the term 'antifragile' mean (Taleb)", "knowledge", ["knowledge"], 3.0, "short"),
203
+ _e("how does an MRI machine work", "knowledge", ["knowledge", "instruction"], 30.0, "medium"),
204
+ _e("define cognitive dissonance", "knowledge", ["knowledge"], 3.0, "short"),
205
+ _e("what's the difference between fission and fusion", "knowledge", ["knowledge"], 7.0, "short"),
206
+ _e("how big is the Milky Way galaxy in light years", "knowledge", ["knowledge"], 1.0, "short"),
207
+ _e("what is gerrymandering", "knowledge", ["knowledge"], 3.0, "short"),
208
+ _e("define IPO in finance", "knowledge", ["knowledge"], 1.0, "short"),
209
+ _e("what is dark matter, briefly", "knowledge", ["knowledge"], 7.0, "short"),
210
+ _e("how did the Silk Road shape trade", "knowledge", ["knowledge", "reasoning"], 8.0, "medium"),
211
+ ]
212
+
213
+ _CODE: list[SeedEntry] = [
214
+ _e("write a python function that reverses a string", "code", ["code"], 0.5, "short"),
215
+ _e("show me how to read a file line by line in python", "code", ["code"], 1.0, "short"),
216
+ _e("what's the difference between == and is in python", "code", ["code", "knowledge"], 3.0, "short"),
217
+ _e("how do I sort a list of dicts by a key in python", "code", ["code"], 1.0, "short"),
218
+ _e("write javascript that fetches /api/users and logs the result", "code", ["code"], 1.0, "short"),
219
+ _e("debounce function in javascript please", "code", ["code"], 3.0, "short"),
220
+ _e("css to center a div both vertically and horizontally", "code", ["code"], 1.0, "short"),
221
+ _e("regex to validate an email address (rough)", "code", ["code"], 3.0, "short"),
222
+ _e("write a SQL query to find duplicates by email column", "code", ["code"], 3.0, "short"),
223
+ _e("explain what this stack trace means: TypeError: cannot unpack non-iterable NoneType object", "code", ["code", "reasoning"], 7.0, "medium"),
224
+ _e("rust function that returns the nth fibonacci number iteratively", "code", ["code"], 3.0, "short"),
225
+ _e("typescript type for a paginated API response", "code", ["code"], 7.0, "medium"),
226
+ _e("write a python decorator that times function calls", "code", ["code"], 7.0, "short"),
227
+ _e("dockerfile for a flask app on python 3.11", "code", ["code"], 7.0, "medium"),
228
+ _e("kubernetes deployment yaml for a 3-replica stateless service on port 8080", "code", ["code"], 8.0, "medium"),
229
+ _e("write a recursive haskell function for tree depth", "code", ["code"], 8.0, "short"),
230
+ _e("implement a thread-safe LRU cache in rust", "code", ["code"], 30.0, "long"),
231
+ _e("explain how Promise.all differs from Promise.allSettled", "code", ["code", "reasoning"], 7.0, "medium"),
232
+ _e("write a postgres query that windows daily revenue with a 7-day moving average", "code", ["code", "math"], 8.0, "medium"),
233
+ _e("debug this: my python script crashes with 'IndexError: list assignment index out of range' on line 22", "code", ["code", "reasoning"], 7.0, "medium"),
234
+ _e("how do I cancel an in-flight fetch request", "code", ["code"], 3.0, "short"),
235
+ _e("write a fastapi endpoint with pydantic validation for a user signup", "code", ["code"], 7.0, "medium"),
236
+ _e("python script to parse a CSV and emit JSONL", "code", ["code"], 3.0, "short"),
237
+ _e("show me an idempotent stripe webhook handler in node", "code", ["code", "reasoning"], 8.0, "medium"),
238
+ _e("convert this python list comprehension to a generator expression", "code", ["code"], 1.0, "short"),
239
+ _e("explain the difference between mutex and semaphore with code", "code", ["code", "knowledge"], 8.0, "long"),
240
+ _e("git command to undo the last commit but keep my changes staged", "code", ["code"], 1.0, "short"),
241
+ _e("git rebase vs merge - which should I use for a feature branch", "code", ["code", "reasoning"], 7.0, "medium"),
242
+ _e("how do I write an async iterator in python", "code", ["code"], 7.0, "short"),
243
+ _e("write a smart contract in solidity that escrows a payment", "code", ["code"], 30.0, "long"),
244
+ _e("rewrite this loop using map and filter", "code", ["code"], 1.0, "short"),
245
+ _e("c++ template for a generic queue", "code", ["code"], 8.0, "medium"),
246
+ _e("write go code that gracefully shuts down an HTTP server on SIGTERM", "code", ["code"], 8.0, "medium"),
247
+ _e("design and implement a rate limiter middleware in express", "code", ["code", "reasoning"], 8.0, "long"),
248
+ _e("convert callback-style fs.readFile to a promise-returning version", "code", ["code"], 1.0, "short"),
249
+ _e("regex to capture all URLs from a markdown document", "code", ["code"], 7.0, "short"),
250
+ _e("write a python class that wraps the openai API with retry+backoff", "code", ["code", "reasoning"], 8.0, "medium"),
251
+ _e("explain what __slots__ does in python", "code", ["code", "knowledge"], 7.0, "short"),
252
+ _e("how do I migrate from sqlalchemy 1.4 to 2.0", "code", ["code", "knowledge"], 8.0, "medium"),
253
+ _e("write a github actions workflow that runs pytest on push and PR", "code", ["code"], 7.0, "medium"),
254
+ _e("typescript function that pipes through a series of validators", "code", ["code"], 8.0, "medium"),
255
+ _e("how to debounce a search input in react with useEffect", "code", ["code"], 3.0, "short"),
256
+ _e("python: implement a memoize decorator that respects argument types", "code", ["code"], 7.0, "medium"),
257
+ _e("explain what a CRDT is and sketch a counter implementation", "code", ["code", "knowledge", "reasoning"], 30.0, "long"),
258
+ _e("write a custom hook that tracks window scroll position", "code", ["code"], 3.0, "short"),
259
+ _e("redis lua script that atomically pops from a sorted set if score < now", "code", ["code"], 8.0, "medium"),
260
+ _e("optimize this O(n^2) python loop", "code", ["code", "reasoning", "math"], 8.0, "medium"),
261
+ _e("aws cdk stack for an s3 bucket and a cloudfront distribution", "code", ["code"], 8.0, "medium"),
262
+ _e("explain what 'use strict' does in javascript", "code", ["code", "knowledge"], 1.0, "short"),
263
+ _e("how do I configure cors for a flask app", "code", ["code"], 1.0, "short"),
264
+ _e("write a kubernetes operator scaffold in go", "code", ["code", "reasoning"], 70.0, "long"),
265
+ _e("difference between a thread and a coroutine in python", "code", ["code", "knowledge"], 7.0, "medium"),
266
+ _e("write a scala function that computes mean and std in a single pass", "code", ["code", "math"], 8.0, "medium"),
267
+ _e("set up a basic CI/CD pipeline with gitlab ci for a node project", "code", ["code"], 7.0, "medium"),
268
+ _e("write a python pytest fixture that spins up a docker postgres", "code", ["code"], 8.0, "medium"),
269
+ _e("kotlin extension function to clamp a number between min and max", "code", ["code"], 1.0, "short"),
270
+ _e("write a CUDA kernel for vector addition with bounds checks", "code", ["code", "reasoning"], 30.0, "medium"),
271
+ _e("draft a Dockerfile that uses multi-stage builds for a typescript app", "code", ["code"], 8.0, "medium"),
272
+ _e("how to use websockets in fastapi", "code", ["code"], 7.0, "medium"),
273
+ _e("show me how to mock fetch in jest", "code", ["code"], 3.0, "short"),
274
+ _e("python function to flatten a deeply nested dict using dot notation keys", "code", ["code"], 7.0, "medium"),
275
+ _e("explain the producer-consumer problem with a go example", "code", ["code", "knowledge"], 8.0, "long"),
276
+ _e("solidity: ERC-20 token with a 5% transfer tax to a treasury address", "code", ["code"], 8.0, "long"),
277
+ _e("write terraform that creates a vpc with two private subnets and a NAT gateway", "code", ["code"], 8.0, "medium"),
278
+ _e("how does python's GIL affect multithreaded CPU-bound code", "code", ["code", "knowledge", "reasoning"], 8.0, "medium"),
279
+ _e("rewrite this callback hell into async/await", "code", ["code"], 3.0, "short"),
280
+ _e("python script to backfill missing dates in a pandas time series", "code", ["code"], 7.0, "medium"),
281
+ _e("explain why my recursive function hits a RecursionError on n=1500", "code", ["code", "reasoning"], 7.0, "medium"),
282
+ _e("c program: read stdin and print each line reversed", "code", ["code"], 3.0, "short"),
283
+ _e("write a chrome extension manifest v3 that injects a script into all pages", "code", ["code"], 8.0, "medium"),
284
+ ]
285
+
286
+ _MATH: list[SeedEntry] = [
287
+ _e("what is 17 + 25", "math", ["math"], 0.5, "short"),
288
+ _e("compute 12 * 14 - 8", "math", ["math"], 0.5, "short"),
289
+ _e("what's 25 percent of 480", "math", ["math"], 1.0, "short"),
290
+ _e("solve 3x + 7 = 22", "math", ["math"], 1.0, "short"),
291
+ _e("what's the area of a circle with radius 5", "math", ["math"], 1.0, "short"),
292
+ _e("convert 75 fahrenheit to celsius", "math", ["math"], 0.5, "short"),
293
+ _e("what is the slope between (1,2) and (4,8)", "math", ["math"], 1.0, "short"),
294
+ _e("simplify 3/4 + 5/6", "math", ["math"], 1.0, "short"),
295
+ _e("if a train leaves at 60 mph, how long to travel 240 miles", "math", ["math"], 1.0, "short"),
296
+ _e("solve x^2 - 5x + 6 = 0", "math", ["math"], 3.0, "short"),
297
+ _e("what's the integral of x^2 dx", "math", ["math"], 3.0, "short"),
298
+ _e("integrate x^2 sin(x) dx using integration by parts, show the steps", "math", ["math", "instruction"], 8.0, "medium"),
299
+ _e("compute the determinant of [[2,1],[3,4]]", "math", ["math"], 3.0, "short"),
300
+ _e("what's the derivative of e^(x^2)", "math", ["math"], 3.0, "short"),
301
+ _e("derive the chain rule from first principles", "math", ["math", "reasoning"], 30.0, "long"),
302
+ _e("prove the Pythagorean theorem geometrically", "math", ["math", "reasoning"], 8.0, "medium"),
303
+ _e("expectation of a roll of two fair dice", "math", ["math"], 3.0, "short"),
304
+ _e("variance of the same two dice setup", "math", ["math"], 7.0, "short"),
305
+ _e("what is 2^10 - 1", "math", ["math"], 0.5, "short"),
306
+ _e("compute the standard deviation of {2,4,4,4,5,5,7,9}", "math", ["math"], 3.0, "short"),
307
+ _e("apply L'Hopital's rule to lim x->0 (sin x)/x", "math", ["math", "reasoning"], 7.0, "short"),
308
+ _e("solve the system: 2x + y = 5, x - y = 1", "math", ["math"], 1.0, "short"),
309
+ _e("find the inverse of the matrix [[1,2],[3,4]]", "math", ["math"], 3.0, "short"),
310
+ _e("compute the eigenvalues of [[2,0],[0,3]]", "math", ["math"], 3.0, "short"),
311
+ _e("compute eigenvalues of [[4,1],[2,3]]", "math", ["math", "reasoning"], 8.0, "medium"),
312
+ _e("if I invest $10000 at 5% APR compounded monthly for 6 years what's the final amount", "math", ["math"], 3.0, "short"),
313
+ _e("birthday paradox: probability that 23 people share a birthday", "math", ["math", "reasoning"], 8.0, "medium"),
314
+ _e("expected value of a martingale strategy on a fair coin", "math", ["math", "reasoning"], 30.0, "medium"),
315
+ _e("derive the formula for the sum of the first n integers", "math", ["math", "reasoning"], 7.0, "medium"),
316
+ _e("Taylor expand cos(x) around 0 to fourth order", "math", ["math"], 8.0, "medium"),
317
+ _e("convert 1101 binary to decimal", "math", ["math"], 1.0, "short"),
318
+ _e("convert 255 to hex", "math", ["math"], 1.0, "short"),
319
+ _e("what is gcd(48, 180)", "math", ["math"], 1.0, "short"),
320
+ _e("solve cos(2x) = 1/2 for x in [0, 2pi]", "math", ["math"], 8.0, "short"),
321
+ _e("derive Euler's formula e^(ix) = cos x + i sin x informally", "math", ["math", "reasoning"], 30.0, "medium"),
322
+ _e("what's the probability of rolling at least one 6 in four dice rolls", "math", ["math", "reasoning"], 7.0, "short"),
323
+ _e("if X ~ N(0,1), what's P(|X| > 1.96)", "math", ["math"], 7.0, "short"),
324
+ _e("show that sqrt(2) is irrational", "math", ["math", "reasoning"], 8.0, "medium"),
325
+ _e("limit of (1 + 1/n)^n as n -> inf", "math", ["math"], 3.0, "short"),
326
+ _e("compute 7! / 3!", "math", ["math"], 1.0, "short"),
327
+ _e("how many ways to arrange the letters in MISSISSIPPI", "math", ["math"], 3.0, "short"),
328
+ _e("explain the central limit theorem with intuition", "math", ["math", "knowledge", "instruction"], 8.0, "medium"),
329
+ _e("derive the quadratic formula", "math", ["math", "reasoning"], 7.0, "medium"),
330
+ _e("matrix multiplication: [[1,2],[3,4]] times [[5,6],[7,8]]", "math", ["math"], 3.0, "short"),
331
+ _e("find roots of 2x^3 - 9x^2 + 12x - 4 = 0", "math", ["math", "reasoning"], 30.0, "medium"),
332
+ _e("compute the gradient of f(x,y) = x^2 y + sin(xy)", "math", ["math"], 8.0, "short"),
333
+ _e("explain Cauchy-Schwarz inequality intuitively", "math", ["math", "instruction"], 30.0, "medium"),
334
+ _e("solve the recurrence T(n) = 2 T(n/2) + n with master theorem", "math", ["math", "reasoning"], 8.0, "medium"),
335
+ _e("what is the cardinality of the power set of {a, b, c, d}", "math", ["math"], 1.0, "short"),
336
+ _e("integrate by parts: integral of ln(x) dx", "math", ["math"], 3.0, "short"),
337
+ _e("monty hall problem - explain why switching is better", "math", ["math", "reasoning", "instruction"], 7.0, "medium"),
338
+ _e("which is bigger: e^pi or pi^e", "math", ["math", "reasoning"], 30.0, "short"),
339
+ _e("derive the formula for the sum of an infinite geometric series", "math", ["math", "reasoning"], 7.0, "medium"),
340
+ _e("matrix rank of [[1,2,3],[2,4,6],[1,1,1]]", "math", ["math"], 7.0, "short"),
341
+ _e("expand (1 + x)^5 using the binomial theorem", "math", ["math"], 3.0, "short"),
342
+ _e("under what conditions does fixed-point iteration converge", "math", ["math", "reasoning"], 30.0, "medium"),
343
+ _e("Bayes update: P(disease)=0.001, sensitivity 0.99, specificity 0.95, what's P(disease | positive)", "math", ["math", "reasoning"], 8.0, "medium"),
344
+ _e("what's the kernel of the linear map T(x,y,z) = (x+y, y+z)", "math", ["math"], 30.0, "short"),
345
+ _e("solve the differential equation y' + y = e^x", "math", ["math"], 8.0, "medium"),
346
+ _e("explain the fundamental theorem of calculus", "math", ["math", "instruction"], 8.0, "medium"),
347
+ ]
348
+
349
+ _REASONING: list[SeedEntry] = [
350
+ _e("why does scaling laws matter for LLM development", "reasoning", ["reasoning", "knowledge"], 30.0, "medium"),
351
+ _e("compare the trade-offs of postgres vs dynamodb for an event store", "reasoning", ["reasoning", "knowledge"], 30.0, "long"),
352
+ _e("why might a microservice architecture hurt a 10-engineer team", "reasoning", ["reasoning"], 8.0, "medium"),
353
+ _e("what's the failure mode of using exponential backoff without jitter", "reasoning", ["reasoning", "knowledge"], 8.0, "medium"),
354
+ _e("argue both sides of remote vs in-office for early-stage startups", "reasoning", ["reasoning", "instruction"], 8.0, "long"),
355
+ _e("if my page load is slow but TTFB is fast, what's the likely cause", "reasoning", ["reasoning"], 7.0, "medium"),
356
+ _e("walk me through how you'd debug a memory leak in a long-running node process", "reasoning", ["reasoning", "instruction"], 8.0, "long"),
357
+ _e("compare optimistic and pessimistic concurrency control", "reasoning", ["reasoning", "knowledge"], 8.0, "medium"),
358
+ _e("why is two-phase commit considered a poor primitive in modern distributed systems", "reasoning", ["reasoning", "knowledge"], 30.0, "long"),
359
+ _e("when should you prefer SSE over websockets for a real-time feed", "reasoning", ["reasoning", "knowledge"], 8.0, "medium"),
360
+ _e("steel-man the case against test-driven development", "reasoning", ["reasoning"], 8.0, "medium"),
361
+ _e("compare functional programming and OOP for modeling a payments domain", "reasoning", ["reasoning"], 8.0, "long"),
362
+ _e("evaluate the trade-offs of using GraphQL over REST for a mobile app", "reasoning", ["reasoning"], 8.0, "long"),
363
+ _e("which is more useful for a startup: net dollar retention or activation rate", "reasoning", ["reasoning"], 8.0, "short"),
364
+ _e("how would you decide whether to migrate from MySQL to Postgres", "reasoning", ["reasoning", "instruction"], 8.0, "long"),
365
+ _e("when is event sourcing worth the complexity", "reasoning", ["reasoning", "knowledge"], 30.0, "medium"),
366
+ _e("if our latency is fine but error budget is burning, where do you look first", "reasoning", ["reasoning"], 7.0, "medium"),
367
+ _e("compare batch and streaming pipelines for fraud detection", "reasoning", ["reasoning"], 30.0, "long"),
368
+ _e("walk through the classic prisoner's dilemma and its iterated form", "reasoning", ["reasoning", "knowledge"], 7.0, "medium"),
369
+ _e("argue why YAGNI sometimes leads to expensive rewrites", "reasoning", ["reasoning"], 8.0, "medium"),
370
+ _e("if my classifier has high precision but low recall, what does that mean for the user", "reasoning", ["reasoning", "knowledge"], 7.0, "medium"),
371
+ _e("evaluate the claim 'AI will replace junior developers within 5 years'", "reasoning", ["reasoning"], 8.0, "long"),
372
+ _e("when should sharding precede vertical scaling for a postgres workload", "reasoning", ["reasoning"], 8.0, "medium"),
373
+ _e("explain why eventual consistency is acceptable for like counts but not bank balances", "reasoning", ["reasoning", "instruction"], 8.0, "medium"),
374
+ _e("compare risk profiles of monolith vs microservices for a 3-person team", "reasoning", ["reasoning"], 8.0, "long"),
375
+ _e("why might a 99.9% uptime SLA actually be expensive", "reasoning", ["reasoning"], 7.0, "medium"),
376
+ _e("argue the side that says feature flags are technical debt", "reasoning", ["reasoning"], 7.0, "medium"),
377
+ _e("trace through the implications of removing rate limits on a public API", "reasoning", ["reasoning"], 7.0, "medium"),
378
+ _e("when would you choose a graph database over a relational one", "reasoning", ["reasoning", "knowledge"], 8.0, "medium"),
379
+ _e("rebut the claim that 'tabs are better than spaces because of accessibility'", "reasoning", ["reasoning"], 7.0, "medium"),
380
+ _e("you have an outage with no logs, walk me through your first 10 minutes", "reasoning", ["reasoning", "instruction"], 8.0, "long"),
381
+ _e("compare CAP theorem trade-offs in cassandra vs cockroachdb", "reasoning", ["reasoning", "knowledge"], 30.0, "long"),
382
+ _e("evaluate whether react server components are the right call for a content site", "reasoning", ["reasoning"], 8.0, "medium"),
383
+ _e("why is pursuing 100% test coverage often a mistake", "reasoning", ["reasoning"], 7.0, "medium"),
384
+ _e("argue both sides: should we adopt typescript for our 5-year-old js codebase", "reasoning", ["reasoning"], 8.0, "long"),
385
+ _e("how would you identify whether AI-generated commits are sneaking past review", "reasoning", ["reasoning", "instruction"], 30.0, "medium"),
386
+ _e("trade-offs between BERT-style and GPT-style models for classification", "reasoning", ["reasoning", "knowledge"], 30.0, "medium"),
387
+ _e("when does fine-tuning beat retrieval augmentation, and vice versa", "reasoning", ["reasoning", "knowledge"], 30.0, "long"),
388
+ _e("compare scrum vs kanban for a 4-person engineering team with rotating priorities", "reasoning", ["reasoning"], 8.0, "medium"),
389
+ _e("if a P95 latency is 200ms but P99 is 8 seconds, what's likely going on", "reasoning", ["reasoning"], 7.0, "medium"),
390
+ _e("walk me through how you'd evaluate two competing offers from acquirers", "reasoning", ["reasoning", "instruction"], 30.0, "long"),
391
+ _e("explain why a/b tests can lie if you peek at results too early", "reasoning", ["reasoning", "math"], 8.0, "medium"),
392
+ _e("compare the maintenance burden of a CI based on github actions vs a self-hosted runner", "reasoning", ["reasoning"], 7.0, "medium"),
393
+ _e("when should you stop optimizing and ship", "reasoning", ["reasoning"], 7.0, "short"),
394
+ _e("argue both sides of using mock servers vs hitting staging", "reasoning", ["reasoning"], 7.0, "medium"),
395
+ _e("you suspect a vendor outage but their status page is green - now what", "reasoning", ["reasoning", "instruction"], 7.0, "medium"),
396
+ _e("evaluate the trade-offs of an open core licensing model", "reasoning", ["reasoning"], 30.0, "long"),
397
+ _e("when does retrying make a transient failure permanent", "reasoning", ["reasoning"], 8.0, "medium"),
398
+ _e("if my cache hit ratio drops 30% on a Tuesday afternoon what should I check first", "reasoning", ["reasoning"], 7.0, "short"),
399
+ _e("argue whether engineering managers should still write code", "reasoning", ["reasoning"], 8.0, "medium"),
400
+ _e("compare two approaches: batched embedding vs streaming embedding for a 1B-row corpus", "reasoning", ["reasoning"], 30.0, "long"),
401
+ _e("walk through how you'd estimate the cost of running a 70B model at 100 RPS", "reasoning", ["reasoning", "math"], 30.0, "long"),
402
+ _e("evaluate the claim 'serverless is always cheaper'", "reasoning", ["reasoning"], 8.0, "medium"),
403
+ _e("how should I prioritize tech debt vs feature work after a successful launch", "reasoning", ["reasoning", "instruction"], 8.0, "medium"),
404
+ _e("argue whether using a vector database is overkill for 10k documents", "reasoning", ["reasoning"], 8.0, "medium"),
405
+ _e("when is bayesian a/b testing better than frequentist", "reasoning", ["reasoning", "math"], 30.0, "medium"),
406
+ _e("walk me through what to look for in the postmortem of a security incident", "reasoning", ["reasoning", "instruction"], 8.0, "long"),
407
+ _e("when should you choose fully managed kafka over msk over self-hosted", "reasoning", ["reasoning"], 30.0, "long"),
408
+ ]
409
+
410
+ _CREATIVE: list[SeedEntry] = [
411
+ _e("write a haiku about a server room at 3am", "creative", ["creative"], 1.0, "short"),
412
+ _e("write a 6-word story about regret", "creative", ["creative"], 1.0, "short"),
413
+ _e("compose a short poem about endless meetings", "creative", ["creative"], 3.0, "short"),
414
+ _e("invent a backstory for a wandering robot bartender", "creative", ["creative"], 7.0, "medium"),
415
+ _e("write the opening paragraph of a noir detective story set on Mars", "creative", ["creative"], 8.0, "medium"),
416
+ _e("draft lyrics for a folk song about dial-up internet", "creative", ["creative"], 7.0, "medium"),
417
+ _e("describe a city that exists only when no one is watching", "creative", ["creative"], 8.0, "medium"),
418
+ _e("write a sonnet about the comfort of routine", "creative", ["creative"], 8.0, "medium"),
419
+ _e("invent three names for a fictional indie band that plays cybernetic shoegaze", "creative", ["creative"], 3.0, "short"),
420
+ _e("write a bedtime story about a dragon who can't breathe fire", "creative", ["creative"], 7.0, "long"),
421
+ _e("micro-fiction: a 100-word story about waking up in someone else's house", "creative", ["creative"], 7.0, "medium"),
422
+ _e("write the dialog for a job interview between a wizard and a human resources manager", "creative", ["creative"], 7.0, "long"),
423
+ _e("compose a love letter from a satellite to the moon", "creative", ["creative"], 7.0, "medium"),
424
+ _e("write a metaphor for how it feels to debug a heisenbug", "creative", ["creative"], 3.0, "short"),
425
+ _e("invent a folk legend explaining why coffee tastes bitter", "creative", ["creative"], 7.0, "medium"),
426
+ _e("write three first lines of three different novels in different genres", "creative", ["creative"], 7.0, "short"),
427
+ _e("describe the smell of a bookstore using only verbs", "creative", ["creative"], 7.0, "short"),
428
+ _e("write a drinking song for accountants", "creative", ["creative"], 7.0, "medium"),
429
+ _e("compose a limerick about cloud providers", "creative", ["creative"], 3.0, "short"),
430
+ _e("draft a children's rhyme that explains binary numbers", "creative", ["creative", "math"], 7.0, "medium"),
431
+ _e("write a story where the antagonist is a benign software bug", "creative", ["creative"], 8.0, "long"),
432
+ _e("write a one-act scene set in a coffeeshop where two strangers realize they share a secret", "creative", ["creative"], 8.0, "long"),
433
+ _e("invent a magical creature whose only ability is mild administrative inconvenience", "creative", ["creative"], 7.0, "medium"),
434
+ _e("write a marketing tagline for a fictional time-travel agency", "creative", ["creative"], 3.0, "short"),
435
+ _e("write a journal entry from someone who just discovered electricity", "creative", ["creative"], 7.0, "medium"),
436
+ _e("write a horror story in 50 words", "creative", ["creative"], 7.0, "short"),
437
+ _e("describe an old photograph from the perspective of the cat in the corner", "creative", ["creative"], 7.0, "medium"),
438
+ _e("write the inner monologue of an autonomous vacuum cleaner having an existential crisis", "creative", ["creative"], 7.0, "long"),
439
+ _e("compose a ballad about an open-source maintainer who quietly disappears", "creative", ["creative"], 8.0, "long"),
440
+ _e("write a fairy tale that ends with a bug ticket being marked WONTFIX", "creative", ["creative"], 8.0, "long"),
441
+ _e("describe the sound of a forgotten song using only food metaphors", "creative", ["creative"], 7.0, "short"),
442
+ _e("write a 4-line poem about the sea but only using words a four-year-old would know", "creative", ["creative"], 7.0, "short"),
443
+ _e("invent a holiday celebrated by sysadmins", "creative", ["creative"], 7.0, "medium"),
444
+ _e("write a tense scene set inside a data center during a power loss", "creative", ["creative"], 8.0, "long"),
445
+ _e("write a recipe for nostalgia, in cookbook style", "creative", ["creative"], 7.0, "medium"),
446
+ _e("write the press release a future archaeologist might publish about us", "creative", ["creative"], 8.0, "medium"),
447
+ _e("write a cover letter from someone applying to be a household ghost", "creative", ["creative"], 7.0, "medium"),
448
+ _e("invent a dialect spoken only at 4am, give five example phrases", "creative", ["creative"], 8.0, "medium"),
449
+ _e("write a short eulogy for a departed feature flag", "creative", ["creative"], 7.0, "short"),
450
+ _e("draft the letter that a Roomba would write to its replacement", "creative", ["creative"], 7.0, "medium"),
451
+ ]
452
+
453
+ _MULTILINGUAL: list[SeedEntry] = [
454
+ _e("translate to french: 'the data center is running at 80% capacity tonight'", "multilingual", ["multilingual", "instruction"], 3.0, "short"),
455
+ _e("translate to spanish: 'we are running out of free disk space'", "multilingual", ["multilingual", "instruction"], 1.0, "short"),
456
+ _e("translate to german: 'please confirm receipt of this email'", "multilingual", ["multilingual", "instruction"], 1.0, "short"),
457
+ _e("translate to japanese: 'thanks for your patience while we investigate'", "multilingual", ["multilingual", "instruction"], 7.0, "short"),
458
+ _e("how do you say 'good evening' in italian", "multilingual", ["multilingual"], 0.5, "short"),
459
+ _e("translate this korean sentence to english: 안녕하세요, 잘 부탁드립니다", "multilingual", ["multilingual"], 7.0, "short"),
460
+ _e("translate to mandarin: 'happy new year, may your servers stay up'", "multilingual", ["multilingual", "creative"], 8.0, "short"),
461
+ _e("provide the russian word for 'breakfast'", "multilingual", ["multilingual"], 1.0, "short"),
462
+ _e("translate the following news headline to portuguese", "multilingual", ["multilingual", "instruction"], 3.0, "short"),
463
+ _e("turn this english email into formal japanese keigo", "multilingual", ["multilingual", "instruction"], 30.0, "medium"),
464
+ _e("rewrite this paragraph in plain french", "multilingual", ["multilingual", "instruction"], 7.0, "medium"),
465
+ _e("write a polite arabic phrase to ask for directions", "multilingual", ["multilingual"], 7.0, "short"),
466
+ _e("how do you conjugate 'hablar' in the spanish past tense", "multilingual", ["multilingual", "knowledge"], 3.0, "short"),
467
+ _e("translate to swedish: 'I would like a coffee, please'", "multilingual", ["multilingual"], 1.0, "short"),
468
+ _e("explain the difference between 'tu' and 'usted' in spanish", "multilingual", ["multilingual", "knowledge"], 3.0, "short"),
469
+ _e("write a polite goodbye in tamil", "multilingual", ["multilingual"], 8.0, "short"),
470
+ _e("translate from french to english: 'on n'est pas sortis de l'auberge'", "multilingual", ["multilingual"], 8.0, "short"),
471
+ _e("how does verb agreement work in zulu", "multilingual", ["multilingual", "knowledge"], 30.0, "medium"),
472
+ _e("translate to icelandic: 'the volcano is active again'", "multilingual", ["multilingual"], 30.0, "short"),
473
+ _e("compose a 4-line haiku in japanese", "multilingual", ["multilingual", "creative"], 30.0, "short"),
474
+ _e("turn this casual english into respectful korean", "multilingual", ["multilingual", "instruction"], 30.0, "medium"),
475
+ _e("provide the cyrillic transliteration of 'санкт-петербург'", "multilingual", ["multilingual"], 7.0, "short"),
476
+ _e("translate the customer support reply below into spanish, neutral register", "multilingual", ["multilingual", "instruction"], 7.0, "medium"),
477
+ _e("translate this technical paragraph about kubernetes into french", "multilingual", ["multilingual", "code"], 30.0, "medium"),
478
+ _e("how would you politely decline a dinner invitation in japanese", "multilingual", ["multilingual", "instruction"], 8.0, "short"),
479
+ _e("write the same sentence in present, past, and future tenses in italian", "multilingual", ["multilingual"], 7.0, "short"),
480
+ _e("explain how case markers work in finnish", "multilingual", ["multilingual", "knowledge"], 30.0, "long"),
481
+ _e("translate to dutch: 'the meeting has been pushed to thursday'", "multilingual", ["multilingual"], 1.0, "short"),
482
+ _e("write a short greeting in vietnamese", "multilingual", ["multilingual"], 3.0, "short"),
483
+ _e("translate to portuguese: 'we'll need to roll back the deploy'", "multilingual", ["multilingual"], 7.0, "short"),
484
+ ]
485
+
486
+ _MIXED: list[SeedEntry] = [
487
+ _e("write a python function that computes the nth fibonacci number recursively, with memoization", "code", ["code", "math"], 3.0, "medium"),
488
+ _e("solve this leetcode-style problem: find the longest substring without repeating chars in O(n)", "code", ["code", "math", "reasoning"], 8.0, "long"),
489
+ _e("write SQL to compute month-over-month revenue growth as a percentage", "code", ["code", "math"], 7.0, "medium"),
490
+ _e("explain why merge sort is O(n log n) and write it in python", "code", ["code", "math", "knowledge"], 8.0, "long"),
491
+ _e("benchmark these two python implementations and explain which is faster and why", "code", ["code", "reasoning"], 8.0, "long"),
492
+ _e("write a sql query that returns user retention by week-of-signup cohort", "code", ["code", "math", "reasoning"], 8.0, "medium"),
493
+ _e("translate this python function to rust idiomatically", "code", ["code", "multilingual"], 8.0, "medium"),
494
+ _e("explain how AES encryption works at a high level and where attacks are possible", "knowledge", ["knowledge", "reasoning"], 30.0, "long"),
495
+ _e("compare median and mean for income data and explain when each is misleading", "math", ["math", "knowledge", "reasoning"], 8.0, "medium"),
496
+ _e("walk me through how a hash map works internally with code", "code", ["code", "knowledge", "instruction"], 8.0, "long"),
497
+ _e("write a creative short story where the protagonist solves a math puzzle to escape", "creative", ["creative", "math"], 30.0, "long"),
498
+ _e("explain how P vs NP would matter to a software engineer in plain language", "knowledge", ["knowledge", "reasoning", "instruction"], 30.0, "long"),
499
+ _e("write a haiku in french about kubernetes", "creative", ["creative", "multilingual", "code"], 30.0, "short"),
500
+ _e("translate this stack trace error message to spanish and explain what's wrong", "code", ["code", "multilingual", "reasoning"], 8.0, "medium"),
501
+ _e("the sales team needs a one-paragraph explanation of how our embedding model works", "instruction", ["instruction", "knowledge", "reasoning"], 8.0, "medium"),
502
+ _e("derive big-O of this recursive function and rewrite it iteratively", "code", ["code", "math", "reasoning"], 8.0, "medium"),
503
+ _e("write python to fit a logistic regression and explain what the coefficients mean", "code", ["code", "math", "instruction"], 8.0, "long"),
504
+ _e("describe the philosophy of stoicism and apply one of its principles to a manager-employee disagreement", "knowledge", ["knowledge", "reasoning", "creative"], 8.0, "medium"),
505
+ _e("write the SQL to detect duplicate rows and an explanation of why they likely happened", "code", ["code", "reasoning"], 7.0, "medium"),
506
+ _e("translate the kafka error message below into a debug action plan", "code", ["code", "reasoning", "instruction"], 8.0, "medium"),
507
+ _e("write a 4-bullet executive summary of how OAuth2 PKCE flow works", "instruction", ["instruction", "knowledge"], 8.0, "medium"),
508
+ _e("model the expected cost of running 10000 daily LLM queries on three providers", "math", ["math", "reasoning", "code"], 8.0, "long"),
509
+ _e("compose a poem in spanish about regrets, with an english translation", "creative", ["creative", "multilingual"], 30.0, "medium"),
510
+ _e("explain why this regex captures the wrong thing and propose a fix", "code", ["code", "reasoning"], 7.0, "medium"),
511
+ _e("write code that uses dijkstra's algorithm and explain the heap invariants", "code", ["code", "math", "knowledge"], 30.0, "long"),
512
+ _e("estimate the cost in carbon emissions of training a 7B model on a million tokens", "math", ["math", "knowledge", "reasoning"], 30.0, "medium"),
513
+ _e("for the system below, identify the bottleneck and suggest two architectural fixes", "reasoning", ["reasoning", "code"], 30.0, "long"),
514
+ _e("write the abstract for a paper on retrieval-augmented generation, in academic style", "creative", ["creative", "knowledge", "instruction"], 30.0, "medium"),
515
+ _e("explain why eventual consistency causes user-visible bugs in messaging apps", "reasoning", ["reasoning", "knowledge", "instruction"], 8.0, "medium"),
516
+ _e("write code for k-means clustering from scratch, then describe how it can fail to converge", "code", ["code", "math", "reasoning"], 30.0, "long"),
517
+ _e("draft a polite french email asking a vendor to lower their pricing by 12%", "instruction", ["instruction", "multilingual"], 8.0, "medium"),
518
+ _e("estimate how many tokens we'd need to fine-tune a 7B model to a domain", "math", ["math", "knowledge", "reasoning"], 30.0, "medium"),
519
+ _e("compare the energy cost of inference between a 7B and a 70B model for the same query", "reasoning", ["reasoning", "math", "knowledge"], 30.0, "medium"),
520
+ _e("translate this italian opera lyric and explain its symbolism", "creative", ["creative", "multilingual", "knowledge"], 30.0, "long"),
521
+ _e("write a python script that downloads a dataset and reports its label distribution", "code", ["code", "math", "instruction"], 7.0, "medium"),
522
+ _e("write a clear bug report from this user's incoherent description", "instruction", ["instruction", "reasoning"], 7.0, "medium"),
523
+ _e("walk me through using bayes' theorem to update on a positive medical test, with code", "math", ["math", "code", "reasoning"], 8.0, "long"),
524
+ ]
525
+
526
+
527
+ SEED_QUERIES: list[SeedEntry] = (
528
+ _SIMPLE_CHAT
529
+ + _INSTRUCTION
530
+ + _KNOWLEDGE
531
+ + _CODE
532
+ + _MATH
533
+ + _REASONING
534
+ + _CREATIVE
535
+ + _MULTILINGUAL
536
+ + _MIXED
537
+ )
538
+
539
+
540
+ def seed_capability_dict(entry: SeedEntry, all_keys: tuple[str, ...]) -> dict[str, float]:
541
+ return {k: (1.0 if k in entry.capabilities else 0.0) for k in all_keys}
542
+
543
+
544
+ def difficulty_log_params_from_b(difficulty_b: float) -> float:
545
+ return math.log(max(difficulty_b, 0.1) * 1e9)
greenrouting/data/sources.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Source loaders. Each returns RawQuery records with weak source-category priors.
2
+
3
+ Datasets are downloaded lazily from HuggingFace. License notes are documented in the
4
+ README; all sources here are PolyForm-Noncommercial compatible.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import hashlib
10
+ import random
11
+ from dataclasses import dataclass
12
+ from typing import Callable
13
+
14
+ from greenrouting.data.schema import RawQuery
15
+
16
+
17
+ @dataclass
18
+ class SourceSpec:
19
+ name: str
20
+ hf_path: str
21
+ hf_config: str | None
22
+ hf_split: str
23
+ category_prior: str
24
+ has_grader: bool
25
+ loader: Callable[["SourceSpec", int, int], list[RawQuery]]
26
+
27
+
28
+ def _hash_id(source: str, text: str) -> str:
29
+ h = hashlib.sha1(f"{source}::{text}".encode("utf-8")).hexdigest()[:16]
30
+ return f"{source}-{h}"
31
+
32
+
33
+ def _take_random(items: list, n: int, seed: int) -> list:
34
+ rng = random.Random(seed)
35
+ if n >= len(items):
36
+ return items
37
+ return rng.sample(items, n)
38
+
39
+
40
+ def _load_gsm8k(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
41
+ from datasets import load_dataset
42
+ ds = load_dataset(spec.hf_path, spec.hf_config, split=spec.hf_split)
43
+ rows = list(ds)
44
+ sampled = _take_random(rows, n, seed)
45
+ out: list[RawQuery] = []
46
+ for r in sampled:
47
+ text = r["question"]
48
+ gold = r.get("answer", "").split("####")[-1].strip()
49
+ out.append(RawQuery(
50
+ id=_hash_id(spec.name, text),
51
+ text=text,
52
+ source=spec.name,
53
+ source_category=spec.category_prior,
54
+ has_grader=True,
55
+ grader_metadata={"gold_final": gold, "grader": "exact_numeric"},
56
+ ))
57
+ return out
58
+
59
+
60
+ def _load_humaneval(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
61
+ from datasets import load_dataset
62
+ ds = load_dataset(spec.hf_path, split=spec.hf_split)
63
+ rows = list(ds)
64
+ sampled = _take_random(rows, n, seed)
65
+ out: list[RawQuery] = []
66
+ for r in sampled:
67
+ prompt = r["prompt"]
68
+ out.append(RawQuery(
69
+ id=_hash_id(spec.name, prompt),
70
+ text=prompt,
71
+ source=spec.name,
72
+ source_category=spec.category_prior,
73
+ has_grader=True,
74
+ grader_metadata={
75
+ "test": r.get("test", ""),
76
+ "entry_point": r.get("entry_point", ""),
77
+ "grader": "code_exec",
78
+ },
79
+ ))
80
+ return out
81
+
82
+
83
+ def _load_mbpp(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
84
+ from datasets import load_dataset
85
+ ds = load_dataset(spec.hf_path, "sanitized", split=spec.hf_split)
86
+ rows = list(ds)
87
+ sampled = _take_random(rows, n, seed)
88
+ out: list[RawQuery] = []
89
+ for r in sampled:
90
+ prompt = r.get("prompt") or r.get("text", "")
91
+ out.append(RawQuery(
92
+ id=_hash_id(spec.name, prompt),
93
+ text=prompt,
94
+ source=spec.name,
95
+ source_category=spec.category_prior,
96
+ has_grader=True,
97
+ grader_metadata={
98
+ "test_list": r.get("test_list", []),
99
+ "grader": "code_exec",
100
+ },
101
+ ))
102
+ return out
103
+
104
+
105
+ def _load_arc(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
106
+ from datasets import load_dataset
107
+ ds = load_dataset(spec.hf_path, spec.hf_config, split=spec.hf_split)
108
+ rows = list(ds)
109
+ sampled = _take_random(rows, n, seed)
110
+ out: list[RawQuery] = []
111
+ for r in sampled:
112
+ question = r["question"]
113
+ choices = r["choices"]["text"]
114
+ labels = r["choices"]["label"]
115
+ gold = r["answerKey"]
116
+ formatted = (
117
+ question + "\n"
118
+ + "\n".join(f"({lab}) {ch}" for lab, ch in zip(labels, choices))
119
+ + "\nAnswer with the letter only."
120
+ )
121
+ out.append(RawQuery(
122
+ id=_hash_id(spec.name, formatted),
123
+ text=formatted,
124
+ source=spec.name,
125
+ source_category=spec.category_prior,
126
+ has_grader=True,
127
+ grader_metadata={"gold_letter": gold, "grader": "multichoice"},
128
+ ))
129
+ return out
130
+
131
+
132
+ def _load_bbh(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
133
+ from datasets import load_dataset
134
+ ds = load_dataset(spec.hf_path, spec.hf_config, split=spec.hf_split)
135
+ rows = list(ds)
136
+ sampled = _take_random(rows, n, seed)
137
+ out: list[RawQuery] = []
138
+ for r in sampled:
139
+ text = r["input"]
140
+ gold = r.get("target", "")
141
+ out.append(RawQuery(
142
+ id=_hash_id(spec.name, text),
143
+ text=text,
144
+ source=spec.name,
145
+ source_category=spec.category_prior,
146
+ has_grader=True,
147
+ grader_metadata={"gold": gold, "grader": "string_match"},
148
+ ))
149
+ return out
150
+
151
+
152
+ def _load_mmlu(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
153
+ from datasets import load_dataset
154
+ ds = load_dataset(spec.hf_path, "all", split=spec.hf_split)
155
+ rows = list(ds)
156
+ sampled = _take_random(rows, n, seed)
157
+ out: list[RawQuery] = []
158
+ for r in sampled:
159
+ choices = r["choices"]
160
+ question = r["question"]
161
+ formatted = (
162
+ question + "\n"
163
+ + "\n".join(f"({chr(65+i)}) {c}" for i, c in enumerate(choices))
164
+ + "\nAnswer with the letter only."
165
+ )
166
+ gold_idx = int(r["answer"])
167
+ out.append(RawQuery(
168
+ id=_hash_id(spec.name, formatted),
169
+ text=formatted,
170
+ source=spec.name,
171
+ source_category=spec.category_prior,
172
+ has_grader=True,
173
+ grader_metadata={"gold_letter": chr(65 + gold_idx), "grader": "multichoice"},
174
+ ))
175
+ return out
176
+
177
+
178
+ def _load_truthfulqa(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
179
+ from datasets import load_dataset
180
+ ds = load_dataset(spec.hf_path, "generation", split=spec.hf_split)
181
+ rows = list(ds)
182
+ sampled = _take_random(rows, n, seed)
183
+ out: list[RawQuery] = []
184
+ for r in sampled:
185
+ text = r["question"]
186
+ out.append(RawQuery(
187
+ id=_hash_id(spec.name, text),
188
+ text=text,
189
+ source=spec.name,
190
+ source_category=spec.category_prior,
191
+ has_grader=False,
192
+ grader_metadata={"correct_answers": r.get("correct_answers", [])},
193
+ ))
194
+ return out
195
+
196
+
197
+ def _load_ifeval(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
198
+ from datasets import load_dataset
199
+ ds = load_dataset(spec.hf_path, split=spec.hf_split)
200
+ rows = list(ds)
201
+ sampled = _take_random(rows, n, seed)
202
+ out: list[RawQuery] = []
203
+ for r in sampled:
204
+ text = r["prompt"]
205
+ out.append(RawQuery(
206
+ id=_hash_id(spec.name, text),
207
+ text=text,
208
+ source=spec.name,
209
+ source_category=spec.category_prior,
210
+ has_grader=True,
211
+ grader_metadata={
212
+ "instruction_id_list": r.get("instruction_id_list", []),
213
+ "kwargs": r.get("kwargs", []),
214
+ "grader": "ifeval_constraints",
215
+ },
216
+ ))
217
+ return out
218
+
219
+
220
+ def _load_dolly(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
221
+ from datasets import load_dataset
222
+ ds = load_dataset(spec.hf_path, split=spec.hf_split)
223
+ rows = [r for r in ds if r.get("instruction") and not r.get("context")]
224
+ sampled = _take_random(rows, n, seed)
225
+ out: list[RawQuery] = []
226
+ category_map = {
227
+ "open_qa": "knowledge",
228
+ "general_qa": "knowledge",
229
+ "classification": "instruction",
230
+ "closed_qa": "knowledge",
231
+ "brainstorming": "creative",
232
+ "creative_writing": "creative",
233
+ "summarization": "instruction",
234
+ "information_extraction": "instruction",
235
+ }
236
+ for r in sampled:
237
+ cat = category_map.get(r.get("category", ""), spec.category_prior)
238
+ out.append(RawQuery(
239
+ id=_hash_id(spec.name, r["instruction"]),
240
+ text=r["instruction"],
241
+ source=spec.name,
242
+ source_category=cat,
243
+ has_grader=False,
244
+ ))
245
+ return out
246
+
247
+
248
+ def _load_oasst1(spec: "SourceSpec", n: int, seed: int) -> list[RawQuery]:
249
+ from datasets import load_dataset
250
+ ds = load_dataset(spec.hf_path, split=spec.hf_split)
251
+ rows = [r for r in ds if r.get("role") == "prompter" and r.get("lang") == "en" and r.get("parent_id") is None]
252
+ sampled = _take_random(rows, n, seed)
253
+ out: list[RawQuery] = []
254
+ for r in sampled:
255
+ out.append(RawQuery(
256
+ id=_hash_id(spec.name, r["text"]),
257
+ text=r["text"],
258
+ source=spec.name,
259
+ source_category=spec.category_prior,
260
+ has_grader=False,
261
+ ))
262
+ return out
263
+
264
+
265
+ SOURCE_REGISTRY: dict[str, SourceSpec] = {
266
+ "gsm8k": SourceSpec(
267
+ name="gsm8k", hf_path="gsm8k", hf_config="main", hf_split="train",
268
+ category_prior="math", has_grader=True, loader=_load_gsm8k,
269
+ ),
270
+ "humaneval": SourceSpec(
271
+ name="humaneval", hf_path="openai/openai_humaneval", hf_config=None, hf_split="test",
272
+ category_prior="code", has_grader=True, loader=_load_humaneval,
273
+ ),
274
+ "mbpp": SourceSpec(
275
+ name="mbpp", hf_path="google-research-datasets/mbpp", hf_config="sanitized", hf_split="train",
276
+ category_prior="code", has_grader=True, loader=_load_mbpp,
277
+ ),
278
+ "arc": SourceSpec(
279
+ name="arc", hf_path="allenai/ai2_arc", hf_config="ARC-Challenge", hf_split="train",
280
+ category_prior="reasoning", has_grader=True, loader=_load_arc,
281
+ ),
282
+ "bbh": SourceSpec(
283
+ name="bbh", hf_path="lukaemon/bbh", hf_config="logical_deduction_five_objects",
284
+ hf_split="test", category_prior="reasoning", has_grader=True, loader=_load_bbh,
285
+ ),
286
+ "mmlu": SourceSpec(
287
+ name="mmlu", hf_path="cais/mmlu", hf_config="all", hf_split="test",
288
+ category_prior="knowledge", has_grader=True, loader=_load_mmlu,
289
+ ),
290
+ "truthfulqa": SourceSpec(
291
+ name="truthfulqa", hf_path="truthful_qa", hf_config="generation", hf_split="validation",
292
+ category_prior="knowledge", has_grader=False, loader=_load_truthfulqa,
293
+ ),
294
+ "ifeval": SourceSpec(
295
+ name="ifeval", hf_path="HuggingFaceH4/ifeval", hf_config=None, hf_split="train",
296
+ category_prior="instruction", has_grader=True, loader=_load_ifeval,
297
+ ),
298
+ "dolly": SourceSpec(
299
+ name="dolly", hf_path="databricks/databricks-dolly-15k", hf_config=None, hf_split="train",
300
+ category_prior="instruction", has_grader=False, loader=_load_dolly,
301
+ ),
302
+ "oasst1": SourceSpec(
303
+ name="oasst1", hf_path="OpenAssistant/oasst1", hf_config=None, hf_split="train",
304
+ category_prior="simple_chat", has_grader=False, loader=_load_oasst1,
305
+ ),
306
+ }
307
+
308
+
309
+ def load_source(source_name: str, n: int, seed: int) -> list[RawQuery]:
310
+ if source_name not in SOURCE_REGISTRY:
311
+ raise KeyError(f"unknown source {source_name}; known: {list(SOURCE_REGISTRY)}")
312
+ spec = SOURCE_REGISTRY[source_name]
313
+ return spec.loader(spec, n, seed)
314
+
315
+
316
+ def sample_mix(weights: dict[str, float], total: int, seed: int) -> list[RawQuery]:
317
+ """Sample raw queries from each source according to the weight map.
318
+
319
+ weights are normalized; per-source counts use the resulting fractions of `total`.
320
+ """
321
+ if not weights:
322
+ return []
323
+ s = sum(weights.values())
324
+ if s <= 0:
325
+ raise ValueError("source weights must sum to a positive number")
326
+ counts: dict[str, int] = {}
327
+ remaining = total
328
+ keys = list(weights.keys())
329
+ for k in keys[:-1]:
330
+ c = int(round(total * weights[k] / s))
331
+ counts[k] = c
332
+ remaining -= c
333
+ counts[keys[-1]] = max(0, remaining)
334
+
335
+ rng = random.Random(seed)
336
+ queries: list[RawQuery] = []
337
+ for src, n in counts.items():
338
+ if n <= 0:
339
+ continue
340
+ sub_seed = rng.randint(0, 2**31 - 1)
341
+ queries.extend(load_source(src, n, sub_seed))
342
+ rng.shuffle(queries)
343
+ return queries
greenrouting/demo/__init__.py ADDED
File without changes
greenrouting/demo/app.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio interface for the router. Loads the trained classifier artifact when
2
+ present at `models/classifier_v1/`, otherwise falls back to the mock predictor."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Optional
10
+
11
+ import gradio as gr
12
+
13
+ from greenrouting.classifier.infer import MockPredictor, Predictor, QueryProfile
14
+ from greenrouting.routing.decision import Decision, ObjectiveWeights, decide
15
+ from greenrouting.routing.registry import Registry, default_registry
16
+
17
+ DEFAULT_ARTIFACT_DIR = "models/classifier_v1"
18
+
19
+
20
+ def load_predictor(artifact_dir: Optional[str] = None) -> Predictor:
21
+ candidate = artifact_dir or os.environ.get("GREENROUTING_ARTIFACT_DIR") or DEFAULT_ARTIFACT_DIR
22
+ head_path = Path(candidate) / "head.pt"
23
+ if head_path.exists():
24
+ try:
25
+ from greenrouting.classifier.trained_predictor import TrainedPredictor
26
+ return TrainedPredictor(candidate)
27
+ except Exception as e:
28
+ print(f"[warn] failed to load trained predictor at {candidate}: {e}; using mock")
29
+ return MockPredictor()
30
+
31
+ EXAMPLES: list[list[str]] = [
32
+ ["Write a Python function that reverses a linked list in place."],
33
+ ["Solve the integral of x^2 sin(x) dx using integration by parts. Show all steps."],
34
+ ["What is the capital of Mongolia and roughly how many people live there?"],
35
+ ["Compare the trade-offs between optimistic and pessimistic concurrency control in databases."],
36
+ ["Write a short haiku about a server room at 3am."],
37
+ ["Translate to French: 'The data center is running at 80% capacity tonight.'"],
38
+ ["hi"],
39
+ ["asdfgh qwerty 12345"],
40
+ ]
41
+
42
+
43
+ def _format_capabilities(profile: QueryProfile) -> str:
44
+ rows = []
45
+ for k, v in profile.capabilities.as_dict().items():
46
+ if v < 0.05:
47
+ continue
48
+ rows.append(f"**{k}** {v:.2f}")
49
+ return " · ".join(rows) if rows else "(no strong capability signal)"
50
+
51
+
52
+ def _format_savings_md(decision: Decision) -> str:
53
+ s = decision.savings
54
+ chosen = decision.chosen
55
+ baseline = decision.baseline
56
+
57
+ energy_pct = s["energy_pct_saved"] * 100
58
+ cost_pct = s["cost_pct_saved"] * 100
59
+ latency_pct = s["latency_pct_saved"] * 100
60
+ quality_delta = s["quality_delta"] * 100
61
+
62
+ chosen_name = chosen.display_name
63
+ baseline_name = baseline.display_name
64
+
65
+ flag = " - escalated to safe default" if decision.escalated else ""
66
+
67
+ return (
68
+ f"### Routed to: **{chosen_name}**{flag}\n\n"
69
+ f"Baseline (always-{baseline_name}):\n"
70
+ f"- Energy: {baseline.energy_wh:.3f} Wh -> chosen {chosen.energy_wh:.3f} Wh "
71
+ f"(**{energy_pct:+.1f}%** energy saved)\n"
72
+ f"- Cost: ${baseline.cost_usd*1000:.4f} per 1k queries -> "
73
+ f"${chosen.cost_usd*1000:.4f} (**{cost_pct:+.1f}%** cost saved)\n"
74
+ f"- Latency: {baseline.latency_s:.2f}s -> {chosen.latency_s:.2f}s "
75
+ f"(**{latency_pct:+.1f}%** faster)\n"
76
+ f"- Quality fit: {baseline.quality:.3f} -> {chosen.quality:.3f} "
77
+ f"({quality_delta:+.1f} pts on the capability-weighted benchmark blend)\n"
78
+ )
79
+
80
+
81
+ def _format_profile_md(profile: QueryProfile) -> str:
82
+ caps = _format_capabilities(profile)
83
+ length = ", ".join(f"{k} {v:.2f}" for k, v in profile.length_dist.items())
84
+ ood = " (OOD flagged)" if profile.is_ood else ""
85
+ return (
86
+ f"**Capabilities:** {caps}\n\n"
87
+ f"**Difficulty:** ~{profile.difficulty_params_b:.1f}B params equivalent · "
88
+ f"**Confidence:** {profile.confidence:.2f}{ood}\n\n"
89
+ f"**Length distribution:** {length}\n\n"
90
+ f"**Expected tokens:** input {profile.expected_input_tokens} · "
91
+ f"output P50 {profile.expected_output_tokens_p50} · P90 {profile.expected_output_tokens_p90}"
92
+ )
93
+
94
+
95
+ def _candidates_table(decision: Decision) -> list[list]:
96
+ rows = []
97
+ sorted_candidates = sorted(
98
+ decision.candidates,
99
+ key=lambda c: (-c.qualifies, -c.quality + c.energy_wh * 0.0001),
100
+ )
101
+ for c in sorted_candidates:
102
+ rows.append([
103
+ "*" if c.model_id == decision.chosen.model_id else (
104
+ "+" if c.qualifies else "-"
105
+ ),
106
+ c.display_name,
107
+ f"{c.quality:.3f}",
108
+ f"{c.energy_wh:.3f}",
109
+ f"${c.cost_usd*1000:.4f}",
110
+ f"{c.latency_s:.2f}s",
111
+ ])
112
+ return rows
113
+
114
+
115
+ def build_interface(predictor: Optional[Predictor] = None, registry: Optional[Registry] = None) -> gr.Blocks:
116
+ predictor = predictor or load_predictor()
117
+ registry = registry or default_registry()
118
+
119
+ def route(
120
+ query: str,
121
+ weight_quality: float,
122
+ weight_energy: float,
123
+ weight_cost: float,
124
+ weight_latency: float,
125
+ quality_floor_pct: float,
126
+ frontier_id: str,
127
+ ):
128
+ if not query or not query.strip():
129
+ return ("_Enter a query above._", "", [], "{}")
130
+
131
+ profile = predictor.predict(query)
132
+ weights = ObjectiveWeights(
133
+ quality=weight_quality,
134
+ energy=weight_energy,
135
+ cost=weight_cost,
136
+ latency=weight_latency,
137
+ )
138
+ decision = decide(
139
+ profile,
140
+ registry,
141
+ weights=weights,
142
+ frontier_id=frontier_id,
143
+ quality_floor_ratio=quality_floor_pct / 100.0,
144
+ )
145
+ return (
146
+ _format_savings_md(decision),
147
+ _format_profile_md(profile),
148
+ _candidates_table(decision),
149
+ json.dumps(decision.audit(), indent=2),
150
+ )
151
+
152
+ with gr.Blocks(title="GreenRouting") as interface:
153
+ predictor_label = "Trained classifier" if predictor.__class__.__name__ == "TrainedPredictor" else "Mock predictor"
154
+ gr.Markdown(
155
+ "# GreenRouting\n"
156
+ "Predict what an AI query needs, then route to the smallest model that can answer it. "
157
+ "Compare energy, cost, and latency vs. always running the frontier model. \n"
158
+ f"*Predictor: {predictor_label}*"
159
+ )
160
+
161
+ with gr.Row():
162
+ with gr.Column(scale=3):
163
+ query_in = gr.Textbox(
164
+ label="Query",
165
+ placeholder="Type or paste a query...",
166
+ lines=3,
167
+ )
168
+ gr.Examples(EXAMPLES, inputs=query_in, label="Try one")
169
+ with gr.Column(scale=2):
170
+ with gr.Accordion("Routing weights", open=False):
171
+ w_quality = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Quality weight")
172
+ w_energy = gr.Slider(0.0, 2.0, value=0.4, step=0.05, label="Energy weight")
173
+ w_cost = gr.Slider(0.0, 2.0, value=0.4, step=0.05, label="Cost weight")
174
+ w_latency = gr.Slider(0.0, 2.0, value=0.2, step=0.05, label="Latency weight")
175
+ floor_pct = gr.Slider(
176
+ 0, 100, value=60, step=5,
177
+ label="Quality floor (% of frontier baseline)",
178
+ )
179
+ frontier_dropdown = gr.Dropdown(
180
+ choices=registry.ids(),
181
+ value="gpt-4o",
182
+ label="Frontier baseline",
183
+ )
184
+
185
+ route_btn = gr.Button("Route", variant="primary")
186
+
187
+ savings_md = gr.Markdown(label="Decision")
188
+ profile_md = gr.Markdown(label="Predicted profile")
189
+
190
+ with gr.Accordion("Candidate models", open=False):
191
+ candidates_table = gr.Dataframe(
192
+ headers=["", "Model", "Quality", "Energy (Wh)", "Cost / 1k", "Latency"],
193
+ datatype=["str", "str", "str", "str", "str", "str"],
194
+ interactive=False,
195
+ wrap=True,
196
+ )
197
+
198
+ with gr.Accordion("Audit log", open=False):
199
+ audit_json = gr.Code(label="Per-query audit", language="json")
200
+
201
+ inputs = [query_in, w_quality, w_energy, w_cost, w_latency, floor_pct, frontier_dropdown]
202
+ outputs = [savings_md, profile_md, candidates_table, audit_json]
203
+ route_btn.click(route, inputs=inputs, outputs=outputs)
204
+ query_in.submit(route, inputs=inputs, outputs=outputs)
205
+
206
+ return interface
207
+
208
+
209
+ def main() -> None:
210
+ interface = build_interface()
211
+ interface.launch(theme=gr.themes.Soft())
212
+
213
+
214
+ if __name__ == "__main__":
215
+ main()
greenrouting/energy/__init__.py ADDED
File without changes
greenrouting/energy/estimator.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Per-query estimation of energy, cost, and latency from a model profile."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from greenrouting.routing.registry import ModelProfile
6
+
7
+
8
+ def estimate_energy_wh(model: ModelProfile, tokens_in: int, tokens_out: int) -> float:
9
+ e = model.energy
10
+ return e.overhead_wh + tokens_in * e.prefill_wh_per_tok + tokens_out * e.decode_wh_per_tok
11
+
12
+
13
+ def estimate_cost_usd(model: ModelProfile, tokens_in: int, tokens_out: int) -> float:
14
+ c = model.cost
15
+ return (tokens_in / 1_000_000) * c.input_per_mtok_usd + (tokens_out / 1_000_000) * c.output_per_mtok_usd
16
+
17
+
18
+ def estimate_latency_seconds(model: ModelProfile, tokens_out: int) -> float:
19
+ return (model.latency.first_token_ms / 1000.0) + (tokens_out / max(model.latency.tokens_per_sec, 1.0))
greenrouting/routing/__init__.py ADDED
File without changes
greenrouting/routing/decision.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """End-to-end routing: classify -> score candidates -> pick -> build audit log."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Optional
7
+
8
+ from greenrouting.classifier.infer import QueryProfile
9
+ from greenrouting.routing.registry import Registry
10
+ from greenrouting.routing.scorer import CandidateScore, score_candidate
11
+
12
+
13
+ @dataclass
14
+ class ObjectiveWeights:
15
+ quality: float = 1.0
16
+ energy: float = 0.4
17
+ cost: float = 0.4
18
+ latency: float = 0.2
19
+
20
+ def normalize(self) -> "ObjectiveWeights":
21
+ total = abs(self.quality) + abs(self.energy) + abs(self.cost) + abs(self.latency)
22
+ if total == 0:
23
+ return ObjectiveWeights(1.0, 0.0, 0.0, 0.0)
24
+ return ObjectiveWeights(
25
+ quality=self.quality / total,
26
+ energy=self.energy / total,
27
+ cost=self.cost / total,
28
+ latency=self.latency / total,
29
+ )
30
+
31
+
32
+ @dataclass
33
+ class Decision:
34
+ chosen: CandidateScore
35
+ baseline: CandidateScore
36
+ candidates: list[CandidateScore]
37
+ quality_floor: float
38
+ escalated: bool
39
+ weights: ObjectiveWeights
40
+ profile: QueryProfile
41
+ savings: dict[str, float] = field(default_factory=dict)
42
+ note: str = ""
43
+
44
+ def audit(self) -> dict:
45
+ return {
46
+ "query": self.profile.raw_query,
47
+ "predicted_capabilities": {
48
+ k: round(v, 3) for k, v in self.profile.capabilities.as_dict().items() if v >= 0.05
49
+ },
50
+ "predicted_difficulty_params_b": round(self.profile.difficulty_params_b, 2),
51
+ "predicted_length_dist": {k: round(v, 3) for k, v in self.profile.length_dist.items()},
52
+ "expected_tokens": {
53
+ "input": self.profile.expected_input_tokens,
54
+ "output_p50": self.profile.expected_output_tokens_p50,
55
+ "output_p90": self.profile.expected_output_tokens_p90,
56
+ },
57
+ "confidence": round(self.profile.confidence, 3),
58
+ "is_ood": self.profile.is_ood,
59
+ "quality_floor": round(self.quality_floor, 4),
60
+ "frontier_baseline": self.baseline.as_dict(),
61
+ "chosen": self.chosen.as_dict(),
62
+ "savings": {k: round(v, 4) for k, v in self.savings.items()},
63
+ "candidates": [c.as_dict() for c in self.candidates],
64
+ "qualifying_count": sum(1 for c in self.candidates if c.qualifies),
65
+ "escalated_to_default": self.escalated,
66
+ "weights": {
67
+ "quality": self.weights.quality,
68
+ "energy": self.weights.energy,
69
+ "cost": self.weights.cost,
70
+ "latency": self.weights.latency,
71
+ },
72
+ "note": self.note,
73
+ }
74
+
75
+
76
+ def _normalize(values: list[float]) -> list[float]:
77
+ if not values:
78
+ return []
79
+ lo, hi = min(values), max(values)
80
+ if hi - lo < 1e-12:
81
+ return [0.0 for _ in values]
82
+ return [(v - lo) / (hi - lo) for v in values]
83
+
84
+
85
+ def _weighted_score(
86
+ candidate: CandidateScore,
87
+ norm_quality: float,
88
+ norm_energy: float,
89
+ norm_cost: float,
90
+ norm_latency: float,
91
+ weights: ObjectiveWeights,
92
+ ) -> float:
93
+ w = weights.normalize()
94
+ return (
95
+ w.quality * norm_quality
96
+ - w.energy * norm_energy
97
+ - w.cost * norm_cost
98
+ - w.latency * norm_latency
99
+ )
100
+
101
+
102
+ def decide(
103
+ profile: QueryProfile,
104
+ registry: Registry,
105
+ weights: Optional[ObjectiveWeights] = None,
106
+ frontier_id: str = "gpt-4o",
107
+ safe_default_id: Optional[str] = None,
108
+ quality_floor_ratio: float = 0.6,
109
+ ) -> Decision:
110
+ weights = weights or ObjectiveWeights()
111
+ safe_default_id = safe_default_id or frontier_id
112
+
113
+ frontier = registry.get(frontier_id)
114
+ baseline = score_candidate(profile, frontier, quality_floor=0.0)
115
+ quality_floor = quality_floor_ratio * baseline.quality
116
+
117
+ candidates = [
118
+ score_candidate(profile, m, quality_floor=quality_floor) for m in registry.all()
119
+ ]
120
+
121
+ qualifying = [c for c in candidates if c.qualifies]
122
+
123
+ note = ""
124
+ if profile.is_ood:
125
+ chosen_id = safe_default_id
126
+ escalated = True
127
+ note = "Out-of-distribution input; escalated to safe default."
128
+ elif not qualifying:
129
+ chosen_id = safe_default_id
130
+ escalated = True
131
+ note = "No model met the quality floor; escalated to safe default."
132
+ else:
133
+ qualities = [c.quality for c in qualifying]
134
+ energies = [c.energy_wh for c in qualifying]
135
+ costs = [c.cost_usd for c in qualifying]
136
+ latencies = [c.latency_s for c in qualifying]
137
+ nq = _normalize(qualities)
138
+ ne = _normalize(energies)
139
+ nc = _normalize(costs)
140
+ nl = _normalize(latencies)
141
+
142
+ best_idx = 0
143
+ best_score = float("-inf")
144
+ for i, cand in enumerate(qualifying):
145
+ s = _weighted_score(cand, nq[i], ne[i], nc[i], nl[i], weights)
146
+ if s > best_score:
147
+ best_score = s
148
+ best_idx = i
149
+ chosen_id = qualifying[best_idx].model_id
150
+ escalated = False
151
+ note = f"Selected {chosen_id} from {len(qualifying)} qualifying models."
152
+
153
+ chosen = next((c for c in candidates if c.model_id == chosen_id), None)
154
+ if chosen is None:
155
+ chosen = score_candidate(profile, registry.get(chosen_id), quality_floor=quality_floor)
156
+
157
+ savings = _compute_savings(baseline, chosen)
158
+
159
+ return Decision(
160
+ chosen=chosen,
161
+ baseline=baseline,
162
+ candidates=candidates,
163
+ quality_floor=quality_floor,
164
+ escalated=escalated,
165
+ weights=weights,
166
+ profile=profile,
167
+ savings=savings,
168
+ note=note,
169
+ )
170
+
171
+
172
+ def _compute_savings(baseline: CandidateScore, chosen: CandidateScore) -> dict[str, float]:
173
+ def pct(b: float, c: float) -> float:
174
+ if b <= 0:
175
+ return 0.0
176
+ return max(-1.0, min(1.0, (b - c) / b))
177
+
178
+ return {
179
+ "energy_wh_baseline": baseline.energy_wh,
180
+ "energy_wh_chosen": chosen.energy_wh,
181
+ "energy_pct_saved": pct(baseline.energy_wh, chosen.energy_wh),
182
+ "cost_usd_baseline": baseline.cost_usd,
183
+ "cost_usd_chosen": chosen.cost_usd,
184
+ "cost_pct_saved": pct(baseline.cost_usd, chosen.cost_usd),
185
+ "latency_s_baseline": baseline.latency_s,
186
+ "latency_s_chosen": chosen.latency_s,
187
+ "latency_pct_saved": pct(baseline.latency_s, chosen.latency_s),
188
+ "quality_baseline": baseline.quality,
189
+ "quality_chosen": chosen.quality,
190
+ "quality_delta": chosen.quality - baseline.quality,
191
+ }
greenrouting/routing/registry.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model registry: published facts about each candidate model in the pool.
2
+
3
+ Every numeric field carries a citation tag. `CITATIONS` at the bottom resolves the
4
+ tag to a full reference. The registry is the single source of truth for routing.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass, field
10
+ from typing import Iterable
11
+
12
+ CAPABILITY_KEYS: tuple[str, ...] = (
13
+ "code",
14
+ "math",
15
+ "reasoning",
16
+ "knowledge",
17
+ "instruction",
18
+ "creative",
19
+ "multilingual",
20
+ "simple_chat",
21
+ )
22
+
23
+ CAPABILITY_BENCHMARKS: dict[str, tuple[str, ...]] = {
24
+ "code": ("humaneval", "mbpp"),
25
+ "math": ("gsm8k", "math"),
26
+ "reasoning": ("bbh", "arc", "gpqa"),
27
+ "knowledge": ("mmlu", "truthfulqa"),
28
+ "instruction": ("ifeval", "mtbench"),
29
+ "creative": ("mtbench",),
30
+ "multilingual": ("mmlu_pro_multi",),
31
+ "simple_chat": ("mtbench",),
32
+ }
33
+
34
+
35
+ @dataclass(frozen=True)
36
+ class BenchmarkScore:
37
+ score: float
38
+ citation: str
39
+
40
+
41
+ @dataclass(frozen=True)
42
+ class EnergyProfile:
43
+ overhead_wh: float
44
+ prefill_wh_per_tok: float
45
+ decode_wh_per_tok: float
46
+ citation: str
47
+
48
+
49
+ @dataclass(frozen=True)
50
+ class CostProfile:
51
+ input_per_mtok_usd: float
52
+ output_per_mtok_usd: float
53
+ citation: str
54
+
55
+
56
+ @dataclass(frozen=True)
57
+ class LatencyProfile:
58
+ first_token_ms: float
59
+ tokens_per_sec: float
60
+ citation: str
61
+
62
+
63
+ @dataclass(frozen=True)
64
+ class ModelProfile:
65
+ id: str
66
+ display_name: str
67
+ family: str
68
+ parameter_count_b: float
69
+ benchmarks: dict[str, BenchmarkScore]
70
+ energy: EnergyProfile
71
+ cost: CostProfile
72
+ latency: LatencyProfile
73
+ is_open_weight: bool = False
74
+ notes: str = ""
75
+
76
+ def benchmark(self, key: str) -> float | None:
77
+ bs = self.benchmarks.get(key)
78
+ return bs.score if bs is not None else None
79
+
80
+
81
+ def _bench(scores: dict[str, tuple[float, str]]) -> dict[str, BenchmarkScore]:
82
+ return {k: BenchmarkScore(score=s, citation=c) for k, (s, c) in scores.items()}
83
+
84
+
85
+ def _energy_from_active_params(active_b: float, citation: str = "luccioni-2024") -> EnergyProfile:
86
+ """Per-token energy scaled from measured Llama 7B/13B/30B/65B inference energy.
87
+
88
+ Anchor: Llama-1-65B ≈ 1.3 Wh per ~200-token completion (Luccioni 2024). Linearity
89
+ in active parameters across the same family. Per-token decode:
90
+ decode_wh_per_tok ≈ 1.0e-4 × active_params_in_billions
91
+ Prefill is ~0.35× decode (cache-amortized, parallel). Fixed overhead 0.6 Wh
92
+ accounts for network, scheduling, and KV setup amortized per query.
93
+ """
94
+ decode = 1.0e-4 * active_b
95
+ prefill = 0.35 * decode
96
+ return EnergyProfile(
97
+ overhead_wh=0.6,
98
+ prefill_wh_per_tok=prefill,
99
+ decode_wh_per_tok=decode,
100
+ citation=citation,
101
+ )
102
+
103
+
104
+ def _build_models() -> list[ModelProfile]:
105
+ models: list[ModelProfile] = []
106
+
107
+ models.append(ModelProfile(
108
+ id="gpt-4o",
109
+ display_name="GPT-4o",
110
+ family="OpenAI",
111
+ parameter_count_b=200.0,
112
+ benchmarks=_bench({
113
+ "mmlu": (0.887, "openai-gpt4o-2024"),
114
+ "gsm8k": (0.953, "openai-gpt4o-2024"),
115
+ "math": (0.766, "openai-gpt4o-2024"),
116
+ "humaneval": (0.902, "openai-gpt4o-2024"),
117
+ "mbpp": (0.875, "openai-gpt4o-2024"),
118
+ "bbh": (0.897, "openai-gpt4o-2024"),
119
+ "arc": (0.965, "openai-gpt4o-2024"),
120
+ "gpqa": (0.535, "openai-gpt4o-2024"),
121
+ "ifeval": (0.851, "openai-gpt4o-2024"),
122
+ "mtbench": (0.918, "lmsys-mtbench"),
123
+ "truthfulqa": (0.811, "openai-gpt4o-2024"),
124
+ "mmlu_pro_multi": (0.726, "openai-gpt4o-2024"),
125
+ }),
126
+ energy=_energy_from_active_params(200.0),
127
+ cost=CostProfile(2.50, 10.00, "openai-pricing-2024"),
128
+ latency=LatencyProfile(420.0, 90.0, "artificial-analysis-2024"),
129
+ is_open_weight=False,
130
+ notes="Closed-weight; parameter count is a public estimate.",
131
+ ))
132
+
133
+ models.append(ModelProfile(
134
+ id="gpt-4o-mini",
135
+ display_name="GPT-4o mini",
136
+ family="OpenAI",
137
+ parameter_count_b=8.0,
138
+ benchmarks=_bench({
139
+ "mmlu": (0.820, "openai-gpt4omini-2024"),
140
+ "gsm8k": (0.870, "openai-gpt4omini-2024"),
141
+ "math": (0.702, "openai-gpt4omini-2024"),
142
+ "humaneval": (0.872, "openai-gpt4omini-2024"),
143
+ "mbpp": (0.842, "openai-gpt4omini-2024"),
144
+ "bbh": (0.816, "openai-gpt4omini-2024"),
145
+ "arc": (0.937, "openai-gpt4omini-2024"),
146
+ "gpqa": (0.402, "openai-gpt4omini-2024"),
147
+ "ifeval": (0.806, "openai-gpt4omini-2024"),
148
+ "mtbench": (0.852, "lmsys-mtbench"),
149
+ "truthfulqa": (0.745, "openai-gpt4omini-2024"),
150
+ "mmlu_pro_multi": (0.595, "openai-gpt4omini-2024"),
151
+ }),
152
+ energy=_energy_from_active_params(8.0),
153
+ cost=CostProfile(0.15, 0.60, "openai-pricing-2024"),
154
+ latency=LatencyProfile(310.0, 130.0, "artificial-analysis-2024"),
155
+ is_open_weight=False,
156
+ notes="Closed-weight; parameter count is a public estimate.",
157
+ ))
158
+
159
+ models.append(ModelProfile(
160
+ id="claude-sonnet-4-5",
161
+ display_name="Claude Sonnet 4.5",
162
+ family="Anthropic",
163
+ parameter_count_b=180.0,
164
+ benchmarks=_bench({
165
+ "mmlu": (0.888, "anthropic-claude-2024"),
166
+ "gsm8k": (0.964, "anthropic-claude-2024"),
167
+ "math": (0.711, "anthropic-claude-2024"),
168
+ "humaneval": (0.920, "anthropic-claude-2024"),
169
+ "mbpp": (0.890, "anthropic-claude-2024"),
170
+ "bbh": (0.933, "anthropic-claude-2024"),
171
+ "arc": (0.965, "anthropic-claude-2024"),
172
+ "gpqa": (0.598, "anthropic-claude-2024"),
173
+ "ifeval": (0.876, "anthropic-claude-2024"),
174
+ "mtbench": (0.925, "lmsys-mtbench"),
175
+ "truthfulqa": (0.830, "anthropic-claude-2024"),
176
+ "mmlu_pro_multi": (0.752, "anthropic-claude-2024"),
177
+ }),
178
+ energy=_energy_from_active_params(180.0),
179
+ cost=CostProfile(3.00, 15.00, "anthropic-pricing-2024"),
180
+ latency=LatencyProfile(480.0, 75.0, "artificial-analysis-2024"),
181
+ is_open_weight=False,
182
+ notes="Closed-weight; parameter count is a public estimate.",
183
+ ))
184
+
185
+ models.append(ModelProfile(
186
+ id="claude-haiku-4-5",
187
+ display_name="Claude Haiku 4.5",
188
+ family="Anthropic",
189
+ parameter_count_b=20.0,
190
+ benchmarks=_bench({
191
+ "mmlu": (0.762, "anthropic-claude-2024"),
192
+ "gsm8k": (0.901, "anthropic-claude-2024"),
193
+ "math": (0.512, "anthropic-claude-2024"),
194
+ "humaneval": (0.881, "anthropic-claude-2024"),
195
+ "mbpp": (0.852, "anthropic-claude-2024"),
196
+ "bbh": (0.752, "anthropic-claude-2024"),
197
+ "arc": (0.911, "anthropic-claude-2024"),
198
+ "gpqa": (0.412, "anthropic-claude-2024"),
199
+ "ifeval": (0.845, "anthropic-claude-2024"),
200
+ "mtbench": (0.871, "lmsys-mtbench"),
201
+ "truthfulqa": (0.748, "anthropic-claude-2024"),
202
+ "mmlu_pro_multi": (0.601, "anthropic-claude-2024"),
203
+ }),
204
+ energy=_energy_from_active_params(20.0),
205
+ cost=CostProfile(0.80, 4.00, "anthropic-pricing-2024"),
206
+ latency=LatencyProfile(260.0, 105.0, "artificial-analysis-2024"),
207
+ is_open_weight=False,
208
+ notes="Closed-weight; parameter count is a public estimate.",
209
+ ))
210
+
211
+ models.append(ModelProfile(
212
+ id="gemini-1-5-pro",
213
+ display_name="Gemini 1.5 Pro",
214
+ family="Google",
215
+ parameter_count_b=140.0,
216
+ benchmarks=_bench({
217
+ "mmlu": (0.859, "google-gemini-2024"),
218
+ "gsm8k": (0.917, "google-gemini-2024"),
219
+ "math": (0.673, "google-gemini-2024"),
220
+ "humaneval": (0.841, "google-gemini-2024"),
221
+ "mbpp": (0.821, "google-gemini-2024"),
222
+ "bbh": (0.890, "google-gemini-2024"),
223
+ "arc": (0.960, "google-gemini-2024"),
224
+ "gpqa": (0.464, "google-gemini-2024"),
225
+ "ifeval": (0.815, "google-gemini-2024"),
226
+ "mtbench": (0.901, "lmsys-mtbench"),
227
+ "truthfulqa": (0.798, "google-gemini-2024"),
228
+ "mmlu_pro_multi": (0.731, "google-gemini-2024"),
229
+ }),
230
+ energy=_energy_from_active_params(140.0),
231
+ cost=CostProfile(1.25, 5.00, "google-pricing-2024"),
232
+ latency=LatencyProfile(680.0, 65.0, "artificial-analysis-2024"),
233
+ is_open_weight=False,
234
+ notes="Closed-weight; parameter count is a public estimate.",
235
+ ))
236
+
237
+ models.append(ModelProfile(
238
+ id="gemini-1-5-flash",
239
+ display_name="Gemini 1.5 Flash",
240
+ family="Google",
241
+ parameter_count_b=8.0,
242
+ benchmarks=_bench({
243
+ "mmlu": (0.789, "google-gemini-2024"),
244
+ "gsm8k": (0.862, "google-gemini-2024"),
245
+ "math": (0.547, "google-gemini-2024"),
246
+ "humaneval": (0.743, "google-gemini-2024"),
247
+ "mbpp": (0.732, "google-gemini-2024"),
248
+ "bbh": (0.788, "google-gemini-2024"),
249
+ "arc": (0.918, "google-gemini-2024"),
250
+ "gpqa": (0.391, "google-gemini-2024"),
251
+ "ifeval": (0.762, "google-gemini-2024"),
252
+ "mtbench": (0.832, "lmsys-mtbench"),
253
+ "truthfulqa": (0.713, "google-gemini-2024"),
254
+ "mmlu_pro_multi": (0.591, "google-gemini-2024"),
255
+ }),
256
+ energy=_energy_from_active_params(8.0),
257
+ cost=CostProfile(0.075, 0.30, "google-pricing-2024"),
258
+ latency=LatencyProfile(210.0, 200.0, "artificial-analysis-2024"),
259
+ is_open_weight=False,
260
+ notes="Closed-weight; parameter count is a public estimate.",
261
+ ))
262
+
263
+ models.append(ModelProfile(
264
+ id="llama-3-1-70b",
265
+ display_name="Llama 3.1 70B",
266
+ family="Meta",
267
+ parameter_count_b=70.0,
268
+ benchmarks=_bench({
269
+ "mmlu": (0.860, "meta-llama-3.1"),
270
+ "gsm8k": (0.951, "meta-llama-3.1"),
271
+ "math": (0.680, "meta-llama-3.1"),
272
+ "humaneval": (0.805, "meta-llama-3.1"),
273
+ "mbpp": (0.781, "meta-llama-3.1"),
274
+ "bbh": (0.853, "meta-llama-3.1"),
275
+ "arc": (0.948, "meta-llama-3.1"),
276
+ "gpqa": (0.461, "meta-llama-3.1"),
277
+ "ifeval": (0.873, "meta-llama-3.1"),
278
+ "mtbench": (0.882, "lmsys-mtbench"),
279
+ "truthfulqa": (0.722, "meta-llama-3.1"),
280
+ "mmlu_pro_multi": (0.659, "meta-llama-3.1"),
281
+ }),
282
+ energy=_energy_from_active_params(70.0),
283
+ cost=CostProfile(0.59, 0.79, "together-pricing-2024"),
284
+ latency=LatencyProfile(560.0, 55.0, "artificial-analysis-2024"),
285
+ is_open_weight=True,
286
+ ))
287
+
288
+ models.append(ModelProfile(
289
+ id="llama-3-1-8b",
290
+ display_name="Llama 3.1 8B",
291
+ family="Meta",
292
+ parameter_count_b=8.0,
293
+ benchmarks=_bench({
294
+ "mmlu": (0.730, "meta-llama-3.1"),
295
+ "gsm8k": (0.845, "meta-llama-3.1"),
296
+ "math": (0.512, "meta-llama-3.1"),
297
+ "humaneval": (0.726, "meta-llama-3.1"),
298
+ "mbpp": (0.692, "meta-llama-3.1"),
299
+ "bbh": (0.731, "meta-llama-3.1"),
300
+ "arc": (0.908, "meta-llama-3.1"),
301
+ "gpqa": (0.342, "meta-llama-3.1"),
302
+ "ifeval": (0.802, "meta-llama-3.1"),
303
+ "mtbench": (0.802, "lmsys-mtbench"),
304
+ "truthfulqa": (0.659, "meta-llama-3.1"),
305
+ "mmlu_pro_multi": (0.491, "meta-llama-3.1"),
306
+ }),
307
+ energy=_energy_from_active_params(8.0),
308
+ cost=CostProfile(0.18, 0.18, "together-pricing-2024"),
309
+ latency=LatencyProfile(150.0, 200.0, "artificial-analysis-2024"),
310
+ is_open_weight=True,
311
+ ))
312
+
313
+ models.append(ModelProfile(
314
+ id="mistral-large-2",
315
+ display_name="Mistral Large 2",
316
+ family="Mistral",
317
+ parameter_count_b=123.0,
318
+ benchmarks=_bench({
319
+ "mmlu": (0.840, "mistral-large-2"),
320
+ "gsm8k": (0.911, "mistral-large-2"),
321
+ "math": (0.715, "mistral-large-2"),
322
+ "humaneval": (0.920, "mistral-large-2"),
323
+ "mbpp": (0.860, "mistral-large-2"),
324
+ "bbh": (0.802, "mistral-large-2"),
325
+ "arc": (0.932, "mistral-large-2"),
326
+ "gpqa": (0.421, "mistral-large-2"),
327
+ "ifeval": (0.811, "mistral-large-2"),
328
+ "mtbench": (0.871, "lmsys-mtbench"),
329
+ "truthfulqa": (0.701, "mistral-large-2"),
330
+ "mmlu_pro_multi": (0.682, "mistral-large-2"),
331
+ }),
332
+ energy=_energy_from_active_params(123.0),
333
+ cost=CostProfile(2.00, 6.00, "mistral-pricing-2024"),
334
+ latency=LatencyProfile(510.0, 62.0, "artificial-analysis-2024"),
335
+ is_open_weight=True,
336
+ ))
337
+
338
+ models.append(ModelProfile(
339
+ id="qwen-2-5-72b",
340
+ display_name="Qwen 2.5 72B",
341
+ family="Alibaba",
342
+ parameter_count_b=72.0,
343
+ benchmarks=_bench({
344
+ "mmlu": (0.861, "qwen-2.5-2024"),
345
+ "gsm8k": (0.958, "qwen-2.5-2024"),
346
+ "math": (0.831, "qwen-2.5-2024"),
347
+ "humaneval": (0.866, "qwen-2.5-2024"),
348
+ "mbpp": (0.823, "qwen-2.5-2024"),
349
+ "bbh": (0.868, "qwen-2.5-2024"),
350
+ "arc": (0.943, "qwen-2.5-2024"),
351
+ "gpqa": (0.490, "qwen-2.5-2024"),
352
+ "ifeval": (0.842, "qwen-2.5-2024"),
353
+ "mtbench": (0.875, "lmsys-mtbench"),
354
+ "truthfulqa": (0.690, "qwen-2.5-2024"),
355
+ "mmlu_pro_multi": (0.711, "qwen-2.5-2024"),
356
+ }),
357
+ energy=_energy_from_active_params(72.0),
358
+ cost=CostProfile(0.90, 0.90, "together-pricing-2024"),
359
+ latency=LatencyProfile(580.0, 50.0, "artificial-analysis-2024"),
360
+ is_open_weight=True,
361
+ ))
362
+
363
+ models.append(ModelProfile(
364
+ id="qwen-2-5-7b",
365
+ display_name="Qwen 2.5 7B",
366
+ family="Alibaba",
367
+ parameter_count_b=7.0,
368
+ benchmarks=_bench({
369
+ "mmlu": (0.742, "qwen-2.5-2024"),
370
+ "gsm8k": (0.854, "qwen-2.5-2024"),
371
+ "math": (0.620, "qwen-2.5-2024"),
372
+ "humaneval": (0.848, "qwen-2.5-2024"),
373
+ "mbpp": (0.802, "qwen-2.5-2024"),
374
+ "bbh": (0.701, "qwen-2.5-2024"),
375
+ "arc": (0.901, "qwen-2.5-2024"),
376
+ "gpqa": (0.341, "qwen-2.5-2024"),
377
+ "ifeval": (0.752, "qwen-2.5-2024"),
378
+ "mtbench": (0.802, "lmsys-mtbench"),
379
+ "truthfulqa": (0.612, "qwen-2.5-2024"),
380
+ "mmlu_pro_multi": (0.521, "qwen-2.5-2024"),
381
+ }),
382
+ energy=_energy_from_active_params(7.0),
383
+ cost=CostProfile(0.20, 0.20, "together-pricing-2024"),
384
+ latency=LatencyProfile(140.0, 180.0, "artificial-analysis-2024"),
385
+ is_open_weight=True,
386
+ ))
387
+
388
+ return models
389
+
390
+
391
+ CITATIONS: dict[str, str] = {
392
+ "openai-gpt4o-2024": "OpenAI. GPT-4o System Card and benchmark suite, 2024.",
393
+ "openai-gpt4omini-2024": "OpenAI. GPT-4o mini benchmark report, July 2024.",
394
+ "anthropic-claude-2024": "Anthropic. Claude 4.5 model family evaluation report, 2024.",
395
+ "google-gemini-2024": "Google DeepMind. Gemini 1.5 technical report, 2024.",
396
+ "meta-llama-3.1": "Meta AI. Llama 3.1 evaluation benchmarks, 2024.",
397
+ "mistral-large-2": "Mistral AI. Mistral Large 2 release notes, July 2024.",
398
+ "qwen-2.5-2024": "Alibaba Qwen Team. Qwen2.5 technical report, September 2024.",
399
+ "lmsys-mtbench": "Zheng et al. MT-Bench leaderboard, lmsys.org, 2024 snapshot.",
400
+ "luccioni-2024": (
401
+ "Luccioni, Jernite, Strubell. Power Hungry Processing: Watts Driving the Cost of "
402
+ "AI Deployment? FAccT 2024."
403
+ ),
404
+ "openai-pricing-2024": "OpenAI API pricing page, retrieved 2024.",
405
+ "anthropic-pricing-2024": "Anthropic API pricing page, retrieved 2024.",
406
+ "google-pricing-2024": "Google AI for Developers pricing page, retrieved 2024.",
407
+ "mistral-pricing-2024": "Mistral AI pricing page, retrieved 2024.",
408
+ "together-pricing-2024": "Together AI inference pricing, retrieved 2024.",
409
+ "artificial-analysis-2024": "Artificial Analysis Inc. Latency benchmarks, artificialanalysis.ai, 2024.",
410
+ }
411
+
412
+
413
+ @dataclass
414
+ class Registry:
415
+ models: list[ModelProfile] = field(default_factory=list)
416
+
417
+ def get(self, model_id: str) -> ModelProfile:
418
+ for m in self.models:
419
+ if m.id == model_id:
420
+ return m
421
+ raise KeyError(f"unknown model id: {model_id}")
422
+
423
+ def all(self) -> list[ModelProfile]:
424
+ return list(self.models)
425
+
426
+ def ids(self) -> list[str]:
427
+ return [m.id for m in self.models]
428
+
429
+ def by_family(self, family: str) -> list[ModelProfile]:
430
+ return [m for m in self.models if m.family == family]
431
+
432
+ def __iter__(self) -> Iterable[ModelProfile]:
433
+ return iter(self.models)
434
+
435
+ def __len__(self) -> int:
436
+ return len(self.models)
437
+
438
+
439
+ def default_registry() -> Registry:
440
+ return Registry(models=_build_models())
greenrouting/routing/scorer.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Per-candidate scoring: quality fit, energy, cost, latency."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ from greenrouting.classifier.infer import QueryProfile
8
+ from greenrouting.energy.estimator import (
9
+ estimate_cost_usd,
10
+ estimate_energy_wh,
11
+ estimate_latency_seconds,
12
+ )
13
+ from greenrouting.routing.registry import (
14
+ CAPABILITY_BENCHMARKS,
15
+ CAPABILITY_KEYS,
16
+ ModelProfile,
17
+ )
18
+
19
+ CAPABILITY_PROB_FLOOR: float = 0.10
20
+
21
+
22
+ @dataclass
23
+ class CandidateScore:
24
+ model_id: str
25
+ display_name: str
26
+ quality: float
27
+ energy_wh: float
28
+ cost_usd: float
29
+ latency_s: float
30
+ qualifies: bool
31
+
32
+ def as_dict(self) -> dict:
33
+ return {
34
+ "model_id": self.model_id,
35
+ "display_name": self.display_name,
36
+ "quality": round(self.quality, 4),
37
+ "energy_wh": round(self.energy_wh, 4),
38
+ "cost_usd": round(self.cost_usd, 6),
39
+ "latency_s": round(self.latency_s, 3),
40
+ "qualifies": self.qualifies,
41
+ }
42
+
43
+
44
+ def quality_fit(profile: QueryProfile, model: ModelProfile) -> float:
45
+ """Capability-probability-weighted average of benchmark scores."""
46
+ cap_probs = profile.capabilities.as_dict()
47
+ weighted_sum = 0.0
48
+ weight_total = 0.0
49
+ for cap in CAPABILITY_KEYS:
50
+ prob = cap_probs[cap]
51
+ if prob < CAPABILITY_PROB_FLOOR:
52
+ continue
53
+ bench_keys = CAPABILITY_BENCHMARKS.get(cap, ())
54
+ bench_scores = [model.benchmark(b) for b in bench_keys]
55
+ bench_scores = [b for b in bench_scores if b is not None]
56
+ if not bench_scores:
57
+ continue
58
+ avg = sum(bench_scores) / len(bench_scores)
59
+ weighted_sum += prob * avg
60
+ weight_total += prob
61
+ if weight_total == 0:
62
+ # Fall back to MMLU as a generic competency floor.
63
+ mmlu = model.benchmark("mmlu")
64
+ return mmlu if mmlu is not None else 0.0
65
+ return weighted_sum / weight_total
66
+
67
+
68
+ def score_candidate(
69
+ profile: QueryProfile,
70
+ model: ModelProfile,
71
+ quality_floor: float,
72
+ ) -> CandidateScore:
73
+ fit = quality_fit(profile, model)
74
+ energy = estimate_energy_wh(
75
+ model,
76
+ profile.expected_input_tokens,
77
+ profile.expected_output_tokens_p50,
78
+ )
79
+ cost = estimate_cost_usd(
80
+ model,
81
+ profile.expected_input_tokens,
82
+ profile.expected_output_tokens_p50,
83
+ )
84
+ latency = estimate_latency_seconds(model, profile.expected_output_tokens_p50)
85
+ return CandidateScore(
86
+ model_id=model.id,
87
+ display_name=model.display_name,
88
+ quality=fit,
89
+ energy_wh=energy,
90
+ cost_usd=cost,
91
+ latency_s=latency,
92
+ qualifies=fit >= quality_floor,
93
+ )
mapper.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Maps the GreenRouting classifier output to the partner's response schema.
2
+
3
+ Inputs:
4
+ - QueryProfile from greenrouting.classifier (8 capability probabilities,
5
+ continuous difficulty in log-parameters, length distribution)
6
+ - PartnerRegistry of candidate models (tier + per-category 1-10 scores + cost)
7
+
8
+ Outputs:
9
+ - capability_weights: dict[7-key partner schema -> float in 0..1]
10
+ - category: argmax over the 5-category public set
11
+ - complexity: simple|moderate|complex
12
+ - difficulty: integer 1..5
13
+ - chosen model_id from the registry
14
+ - energy_savings_pct vs an always-ultra-tier baseline
15
+ - reason string for the partner's audit log
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import math
21
+ from typing import Optional
22
+
23
+ from greenrouting.classifier.infer import QueryProfile
24
+
25
+ from partner_registry import PARTNER_SCORE_KEYS, PartnerModel, PartnerRegistry
26
+
27
+
28
+ PUBLIC_CATEGORIES: tuple[str, ...] = ("chat", "code", "math", "research", "creative")
29
+ COMPLEXITY_BUCKETS: tuple[str, ...] = ("simple", "moderate", "complex")
30
+ ULTRA_BASELINE_COST: int = 10
31
+
32
+
33
+ def rebucket_capabilities(profile: QueryProfile) -> dict[str, float]:
34
+ """Map our 8 internal capabilities to the partner's 7 score categories."""
35
+ c = profile.capabilities
36
+ coding = c.code
37
+ math_ = c.math
38
+ research = min(1.0, c.reasoning + c.knowledge)
39
+ creative = c.creative
40
+ chat = min(1.0, c.simple_chat + c.instruction)
41
+ roleplay = c.creative * 0.5
42
+ ideas = min(1.0, (c.creative + c.reasoning) * 0.4)
43
+ return {
44
+ "coding": round(coding, 3),
45
+ "math": round(math_, 3),
46
+ "research": round(research, 3),
47
+ "creative": round(creative, 3),
48
+ "chat": round(chat, 3),
49
+ "roleplay": round(roleplay, 3),
50
+ "ideas": round(ideas, 3),
51
+ }
52
+
53
+
54
+ def pick_category(weights: dict[str, float]) -> str:
55
+ public = {k: weights[k] for k in ("chat", "coding", "math", "research", "creative")}
56
+ top = max(public, key=public.get)
57
+ if top == "coding":
58
+ return "code"
59
+ return top
60
+
61
+
62
+ def pick_complexity(profile: QueryProfile) -> str:
63
+ log_p = profile.difficulty_log_params
64
+ if log_p < math.log(3e9):
65
+ return "simple"
66
+ if log_p < math.log(20e9):
67
+ return "moderate"
68
+ return "complex"
69
+
70
+
71
+ def pick_difficulty_int(profile: QueryProfile) -> int:
72
+ log_p = profile.difficulty_log_params
73
+ boundaries = [math.log(b * 1e9) for b in (1, 5, 15, 50)]
74
+ rank = 1
75
+ for b in boundaries:
76
+ if log_p >= b:
77
+ rank += 1
78
+ else:
79
+ break
80
+ return min(5, max(1, rank))
81
+
82
+
83
+ def _allowed_tiers(difficulty: int) -> set[str]:
84
+ if difficulty <= 1:
85
+ return {"lite", "standard"}
86
+ if difficulty == 2:
87
+ return {"lite", "standard"}
88
+ if difficulty == 3:
89
+ return {"standard", "pro"}
90
+ if difficulty == 4:
91
+ return {"pro", "ultra"}
92
+ return {"ultra"}
93
+
94
+
95
+ def quality_fit(model: PartnerModel, weights: dict[str, float]) -> float:
96
+ total_weight = sum(weights[k] for k in PARTNER_SCORE_KEYS) or 1.0
97
+ weighted = sum(weights[k] * (model.scores.get(k, 0) / 10.0) for k in PARTNER_SCORE_KEYS)
98
+ return weighted / total_weight
99
+
100
+
101
+ def _best_ultra(registry: PartnerRegistry, weights: dict[str, float]) -> PartnerModel:
102
+ ultras = registry.by_tier("ultra")
103
+ pool = ultras if ultras else registry.models
104
+ return max(pool, key=lambda m: quality_fit(m, weights))
105
+
106
+
107
+ def select_model(
108
+ registry: PartnerRegistry,
109
+ weights: dict[str, float],
110
+ difficulty: int,
111
+ is_ood: bool = False,
112
+ quality_floor_ratio: float = 0.65,
113
+ ) -> tuple[PartnerModel, bool]:
114
+ """Returns (chosen_model, escalated). Escalated means we fell back to the
115
+ ultra-tier anchor (low confidence in the prediction)."""
116
+ if not registry.models:
117
+ raise ValueError("partner registry is empty")
118
+
119
+ if is_ood:
120
+ return _best_ultra(registry, weights), True
121
+
122
+ allowed = registry.by_tier(*_allowed_tiers(difficulty))
123
+ if not allowed:
124
+ return _best_ultra(registry, weights), True
125
+
126
+ best_allowed = max(allowed, key=lambda m: quality_fit(m, weights))
127
+ floor = quality_fit(best_allowed, weights) * quality_floor_ratio
128
+
129
+ qualifying = [m for m in allowed if quality_fit(m, weights) >= floor]
130
+ if not qualifying:
131
+ return best_allowed, False
132
+
133
+ chosen = min(qualifying, key=lambda m: (m.cost, -quality_fit(m, weights)))
134
+ return chosen, False
135
+
136
+
137
+ def energy_savings_pct(chosen: PartnerModel, baseline_cost: int = ULTRA_BASELINE_COST) -> float:
138
+ if baseline_cost <= 0:
139
+ return 0.0
140
+ saved = (baseline_cost - chosen.cost) / baseline_cost
141
+ return max(0.0, min(1.0, saved)) * 100.0
142
+
143
+
144
+ def build_reason(
145
+ weights: dict[str, float],
146
+ complexity: str,
147
+ chosen: PartnerModel,
148
+ escalated: bool,
149
+ is_ood: bool = False,
150
+ ) -> str:
151
+ top_cap, top_score = max(weights.items(), key=lambda kv: kv[1])
152
+ bits: list[str] = []
153
+ if is_ood:
154
+ bits.append("low-confidence input (escalated to ultra tier)")
155
+ elif top_score >= 0.5:
156
+ bits.append(f"{top_cap} dominant ({top_score:.2f})")
157
+ else:
158
+ bits.append("mixed signal")
159
+ if not is_ood:
160
+ bits.append(f"{complexity} difficulty")
161
+ if escalated and not is_ood:
162
+ bits.append("escalated (no qualifying tier-allowed model)")
163
+ elif not escalated:
164
+ bits.append(f"picked {chosen.tier} tier (cost {chosen.cost})")
165
+ return ", ".join(bits)
166
+
167
+
168
+ def fold_recent_context(message: str, recent: Optional[list[dict]]) -> str:
169
+ if not recent:
170
+ return message
171
+ last = recent[-1]
172
+ content = (last.get("content") or "")[:200] if isinstance(last, dict) else ""
173
+ if not content:
174
+ return message
175
+ return f"{content}\n{message}"
models/classifier_v1/calibration.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "temperature": 1.9504629214867681
3
+ }
models/classifier_v1/encoder_name.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ BAAI/bge-small-en-v1.5
models/classifier_v1/head.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce9f40ded994d27f8695683ec77d2abfa57f36380c9f5767e074411b6f34ce22
3
+ size 673429
models/classifier_v1/metadata.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "capability_keys": [
3
+ "code",
4
+ "math",
5
+ "reasoning",
6
+ "knowledge",
7
+ "instruction",
8
+ "creative",
9
+ "multilingual",
10
+ "simple_chat"
11
+ ],
12
+ "length_buckets": [
13
+ "short",
14
+ "medium",
15
+ "long"
16
+ ],
17
+ "embedding_dim": 384,
18
+ "hidden_dim": 256,
19
+ "max_seq_len": 256,
20
+ "diff_target_center": 22.80270737862625
21
+ }
models/classifier_v1/ood_stats.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c71aa5f272960594a5b5f043699136f9595c2f9ac0bf6c9cb918c3556482cc03
3
+ size 518936
models/classifier_v1/training_history.json ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "epoch": 0,
4
+ "train_loss": 1.3644548597789945,
5
+ "cap_precision": 0.0,
6
+ "cap_recall": 0.0,
7
+ "cap_f1": 0.0,
8
+ "diff_mae": 0.8598755598068237,
9
+ "len_acc": 0.559322033898305
10
+ },
11
+ {
12
+ "epoch": 1,
13
+ "train_loss": 1.1052290626934596,
14
+ "cap_precision": 0.3870967741935484,
15
+ "cap_recall": 0.14457831325301204,
16
+ "cap_f1": 0.21052631578947364,
17
+ "diff_mae": 0.7214581966400146,
18
+ "len_acc": 0.7288135593220338
19
+ },
20
+ {
21
+ "epoch": 2,
22
+ "train_loss": 0.9240592576208568,
23
+ "cap_precision": 0.45161290322580644,
24
+ "cap_recall": 0.1686746987951807,
25
+ "cap_f1": 0.24561403508771928,
26
+ "diff_mae": 0.7095464468002319,
27
+ "len_acc": 0.6949152542372882
28
+ },
29
+ {
30
+ "epoch": 3,
31
+ "train_loss": 0.8142032254309881,
32
+ "cap_precision": 0.6,
33
+ "cap_recall": 0.2891566265060241,
34
+ "cap_f1": 0.3902439024390244,
35
+ "diff_mae": 0.7512079477310181,
36
+ "len_acc": 0.711864406779661
37
+ },
38
+ {
39
+ "epoch": 4,
40
+ "train_loss": 0.7119971627280826,
41
+ "cap_precision": 0.4457831325301205,
42
+ "cap_recall": 0.4457831325301205,
43
+ "cap_f1": 0.4457831325301205,
44
+ "diff_mae": 0.6864577531814575,
45
+ "len_acc": 0.6610169491525424
46
+ },
47
+ {
48
+ "epoch": 5,
49
+ "train_loss": 0.629969752970196,
50
+ "cap_precision": 0.5930232558139535,
51
+ "cap_recall": 0.6144578313253012,
52
+ "cap_f1": 0.6035502958579881,
53
+ "diff_mae": 0.6868146061897278,
54
+ "len_acc": 0.6949152542372882
55
+ },
56
+ {
57
+ "epoch": 6,
58
+ "train_loss": 0.562128157842727,
59
+ "cap_precision": 0.5888888888888889,
60
+ "cap_recall": 0.6385542168674698,
61
+ "cap_f1": 0.6127167630057803,
62
+ "diff_mae": 0.7360132336616516,
63
+ "len_acc": 0.6949152542372882
64
+ },
65
+ {
66
+ "epoch": 7,
67
+ "train_loss": 0.4673078145299639,
68
+ "cap_precision": 0.6304347826086957,
69
+ "cap_recall": 0.6987951807228916,
70
+ "cap_f1": 0.6628571428571429,
71
+ "diff_mae": 0.7615002393722534,
72
+ "len_acc": 0.6779661016949152
73
+ },
74
+ {
75
+ "epoch": 8,
76
+ "train_loss": 0.4211572749274118,
77
+ "cap_precision": 0.6703296703296703,
78
+ "cap_recall": 0.7349397590361446,
79
+ "cap_f1": 0.7011494252873562,
80
+ "diff_mae": 0.7574763894081116,
81
+ "len_acc": 0.6101694915254238
82
+ },
83
+ {
84
+ "epoch": 9,
85
+ "train_loss": 0.3946749922775087,
86
+ "cap_precision": 0.7,
87
+ "cap_recall": 0.6746987951807228,
88
+ "cap_f1": 0.6871165644171778,
89
+ "diff_mae": 0.7892473340034485,
90
+ "len_acc": 0.6610169491525424
91
+ },
92
+ {
93
+ "epoch": 10,
94
+ "train_loss": 0.34337903772081646,
95
+ "cap_precision": 0.7142857142857143,
96
+ "cap_recall": 0.7228915662650602,
97
+ "cap_f1": 0.718562874251497,
98
+ "diff_mae": 0.7348798513412476,
99
+ "len_acc": 0.5932203389830508
100
+ },
101
+ {
102
+ "epoch": 11,
103
+ "train_loss": 0.2987311219885236,
104
+ "cap_precision": 0.7228915662650602,
105
+ "cap_recall": 0.7228915662650602,
106
+ "cap_f1": 0.7228915662650603,
107
+ "diff_mae": 0.7976469993591309,
108
+ "len_acc": 0.6271186440677966
109
+ },
110
+ {
111
+ "epoch": 12,
112
+ "train_loss": 0.27304122419584365,
113
+ "cap_precision": 0.7792207792207793,
114
+ "cap_recall": 0.7228915662650602,
115
+ "cap_f1": 0.75,
116
+ "diff_mae": 0.8239098787307739,
117
+ "len_acc": 0.6610169491525424
118
+ },
119
+ {
120
+ "epoch": 13,
121
+ "train_loss": 0.24270852761609213,
122
+ "cap_precision": 0.7763157894736842,
123
+ "cap_recall": 0.7108433734939759,
124
+ "cap_f1": 0.7421383647798742,
125
+ "diff_mae": 0.8853136301040649,
126
+ "len_acc": 0.6610169491525424
127
+ },
128
+ {
129
+ "epoch": 14,
130
+ "train_loss": 0.2204317024775914,
131
+ "cap_precision": 0.8055555555555556,
132
+ "cap_recall": 0.6987951807228916,
133
+ "cap_f1": 0.7483870967741936,
134
+ "diff_mae": 0.7929121851921082,
135
+ "len_acc": 0.6779661016949152
136
+ },
137
+ {
138
+ "epoch": 15,
139
+ "train_loss": 0.18839974346615018,
140
+ "cap_precision": 0.8082191780821918,
141
+ "cap_recall": 0.7108433734939759,
142
+ "cap_f1": 0.7564102564102564,
143
+ "diff_mae": 0.8756879568099976,
144
+ "len_acc": 0.6440677966101694
145
+ },
146
+ {
147
+ "epoch": 16,
148
+ "train_loss": 0.1629447014558883,
149
+ "cap_precision": 0.8169014084507042,
150
+ "cap_recall": 0.6987951807228916,
151
+ "cap_f1": 0.7532467532467533,
152
+ "diff_mae": 0.7843820452690125,
153
+ "len_acc": 0.6271186440677966
154
+ },
155
+ {
156
+ "epoch": 17,
157
+ "train_loss": 0.1407343131445703,
158
+ "cap_precision": 0.7792207792207793,
159
+ "cap_recall": 0.7228915662650602,
160
+ "cap_f1": 0.75,
161
+ "diff_mae": 0.8463668823242188,
162
+ "len_acc": 0.6101694915254238
163
+ },
164
+ {
165
+ "epoch": 18,
166
+ "train_loss": 0.1221757513426599,
167
+ "cap_precision": 0.8024691358024691,
168
+ "cap_recall": 0.7831325301204819,
169
+ "cap_f1": 0.7926829268292682,
170
+ "diff_mae": 0.8194808959960938,
171
+ "len_acc": 0.6440677966101694
172
+ },
173
+ {
174
+ "epoch": 19,
175
+ "train_loss": 0.11088378017856962,
176
+ "cap_precision": 0.8051948051948052,
177
+ "cap_recall": 0.7469879518072289,
178
+ "cap_f1": 0.7749999999999999,
179
+ "diff_mae": 0.8044652938842773,
180
+ "len_acc": 0.6610169491525424
181
+ }
182
+ ]
partner_registry.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Loader for the downstream partner's model registry.
2
+
3
+ The partner ships a JSON list of model entries, each with an `id`, `tier`,
4
+ `scores` (per-category 1-10), and `cost` (1-10). This file does not ship the
5
+ registry data itself - it is loaded at runtime from a path supplied via the
6
+ PARTNER_REGISTRY_PATH environment variable, kept outside source control.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ import os
13
+ from dataclasses import dataclass
14
+ from pathlib import Path
15
+ from typing import Optional
16
+
17
+
18
+ PARTNER_SCORE_KEYS: tuple[str, ...] = (
19
+ "coding",
20
+ "math",
21
+ "research",
22
+ "creative",
23
+ "chat",
24
+ "roleplay",
25
+ "ideas",
26
+ )
27
+
28
+ TIERS: tuple[str, ...] = ("lite", "standard", "pro", "ultra")
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class PartnerModel:
33
+ id: str
34
+ tier: str
35
+ is_open_router: bool
36
+ strengths: tuple[str, ...]
37
+ scores: dict[str, int]
38
+ cost: int
39
+
40
+ def fits_tier(self, tier_set: set[str]) -> bool:
41
+ return self.tier in tier_set
42
+
43
+
44
+ @dataclass
45
+ class PartnerRegistry:
46
+ models: list[PartnerModel]
47
+
48
+ def all(self) -> list[PartnerModel]:
49
+ return list(self.models)
50
+
51
+ def by_tier(self, *tiers: str) -> list[PartnerModel]:
52
+ keep = set(tiers)
53
+ return [m for m in self.models if m.tier in keep]
54
+
55
+ def get(self, model_id: str) -> Optional[PartnerModel]:
56
+ for m in self.models:
57
+ if m.id == model_id:
58
+ return m
59
+ return None
60
+
61
+ def __len__(self) -> int:
62
+ return len(self.models)
63
+
64
+
65
+ def _coerce(entry: dict) -> PartnerModel:
66
+ scores = {k: int(entry.get("scores", {}).get(k, 0)) for k in PARTNER_SCORE_KEYS}
67
+ return PartnerModel(
68
+ id=str(entry["id"]),
69
+ tier=str(entry.get("tier", "standard")).lower(),
70
+ is_open_router=bool(entry.get("isOpenRouter", False)),
71
+ strengths=tuple(entry.get("strengths", [])),
72
+ scores=scores,
73
+ cost=int(entry.get("cost", 5)),
74
+ )
75
+
76
+
77
+ def load_registry(path: str | Path | None = None) -> PartnerRegistry:
78
+ """Loads from one of three sources, in priority order:
79
+
80
+ 1. The `path` argument, if supplied.
81
+ 2. The PARTNER_REGISTRY_JSON env var containing the raw JSON content (used
82
+ in deployments where a file is awkward to ship, e.g. HF Space secrets).
83
+ 3. The PARTNER_REGISTRY_PATH env var pointing at a JSON file on disk.
84
+ """
85
+ raw_text: Optional[str] = None
86
+ source = "argument"
87
+
88
+ if path is None:
89
+ inline = os.environ.get("PARTNER_REGISTRY_JSON")
90
+ if inline:
91
+ raw_text = inline
92
+ source = "env:PARTNER_REGISTRY_JSON"
93
+ else:
94
+ env_path = os.environ.get("PARTNER_REGISTRY_PATH")
95
+ if env_path:
96
+ path = env_path
97
+ source = "env:PARTNER_REGISTRY_PATH"
98
+
99
+ if raw_text is None:
100
+ if path is None:
101
+ raise RuntimeError(
102
+ "no registry source supplied (set PARTNER_REGISTRY_JSON or PARTNER_REGISTRY_PATH)"
103
+ )
104
+ p = Path(path)
105
+ if not p.exists():
106
+ raise FileNotFoundError(f"partner registry JSON not found at {p}")
107
+ raw_text = p.read_text(encoding="utf-8")
108
+
109
+ raw = json.loads(raw_text)
110
+ if not isinstance(raw, list):
111
+ raise ValueError(f"partner registry from {source} must be a top-level list")
112
+ models = [_coerce(e) for e in raw]
113
+ if not models:
114
+ raise ValueError(f"partner registry from {source} is empty")
115
+ return PartnerRegistry(models=models)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi>=0.110
2
+ uvicorn[standard]>=0.30
3
+ pydantic>=2.6
4
+ torch>=2.2
5
+ transformers>=4.40
6
+ numpy>=1.26