Spaces:
Running
Running
| """ | |
| nl2sql-bench/server/tasks/base.py | |
| ================================== | |
| Abstract base for all NL2SQL tasks and the global task registry. | |
| Each task holds a list of (question, ground_truth_sql) pairs. | |
| The environment picks one pair per episode via a deterministic round-robin | |
| so that the same task always cycles through the same question sequence β | |
| this keeps grader results reproducible across runs. | |
| """ | |
| from __future__ import annotations | |
| import sqlite3 | |
| from abc import ABC, abstractmethod | |
| from typing import Dict, List, NamedTuple, Tuple, Type | |
| class TaskExample(NamedTuple): | |
| question: str | |
| sql: str | |
| # Human-readable description of what makes this question that difficulty | |
| notes: str = "" | |
| class BaseTask(ABC): | |
| """Abstract base class for all tasks.""" | |
| name: str = "" | |
| difficulty: str = "" # easy | medium | hard | |
| examples: List[TaskExample] = [] | |
| def __init__(self) -> None: | |
| if not self.examples: | |
| raise ValueError(f"Task {self.name!r} has no examples defined.") | |
| self._cursor = 0 # round-robin index | |
| def next_example(self) -> TaskExample: | |
| """Return the next question in round-robin order.""" | |
| example = self.examples[self._cursor % len(self.examples)] | |
| self._cursor += 1 | |
| return example | |
| def schema_context(cls) -> str: | |
| """Return a compact schema description for the agent system prompt.""" | |
| return _SCHEMA_CONTEXT | |
| def description(self) -> str: | |
| """One-sentence description for openenv.yaml.""" | |
| # ββ Global schema context string (injected into every observation) βββββββββ | |
| _SCHEMA_CONTEXT = """\ | |
| Database: ecommerce (SQLite, read-only) | |
| TABLES | |
| ------ | |
| categories(id INTEGER PK, name TEXT) | |
| products(id INTEGER PK, name TEXT, category_id INTEGER FKβcategories.id, | |
| price REAL, stock_quantity INTEGER) | |
| customers(id INTEGER PK, name TEXT, email TEXT, country TEXT, | |
| tier TEXT β {bronze|silver|gold}, created_at TEXT ISO-8601) | |
| orders(id INTEGER PK, customer_id INTEGER FKβcustomers.id, | |
| status TEXT β {pending|processing|shipped|delivered|cancelled}, | |
| created_at TEXT ISO-8601, total_amount REAL) | |
| order_items(id INTEGER PK, order_id INTEGER FKβorders.id, | |
| product_id INTEGER FKβproducts.id, | |
| quantity INTEGER, unit_price REAL) | |
| reviews(id INTEGER PK, product_id INTEGER FKβproducts.id, | |
| customer_id INTEGER FKβcustomers.id, | |
| rating INTEGER 1-5, created_at TEXT ISO-8601) | |
| NOTES | |
| ----- | |
| - Date comparisons: use created_at >= '2024-01-01' (text ISO sort works) | |
| - SQLite window functions (RANK, DENSE_RANK, ROW_NUMBER, LAG, LEAD) are available | |
| - strftime('%Y-%m', created_at) returns 'YYYY-MM' month strings | |
| - All monetary values are in USD | |
| """ | |
| # ββ Task registry ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _REGISTRY: Dict[str, Type[BaseTask]] = {} | |
| def register(cls: Type[BaseTask]) -> Type[BaseTask]: | |
| """Class decorator to register a task.""" | |
| _REGISTRY[cls.name] = cls | |
| return cls | |
| def get_task(name: str) -> BaseTask: | |
| if name not in _REGISTRY: | |
| raise KeyError(f"Unknown task {name!r}. Available: {list(_REGISTRY)}") | |
| return _REGISTRY[name]() | |
| def all_task_names() -> List[str]: | |
| return list(_REGISTRY.keys()) | |