Spaces:
Sleeping
Sleeping
| """ | |
| OpenEnv typed models — Observation, Action, Reward. | |
| All models are Pydantic v2 compliant. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Optional | |
| from pydantic import BaseModel, Field | |
| # --------------------------------------------------------------------------- | |
| # Observation | |
| # --------------------------------------------------------------------------- | |
| class Observation(BaseModel): | |
| """What the agent sees at each step.""" | |
| task_id: int = Field(..., description="Which task (1=easy, 2=medium, 3=hard)") | |
| task_name: str = Field(..., description="Human-readable task name") | |
| task_description: str = Field(..., description="What the agent must accomplish") | |
| query: str = Field(..., description="The SQL query the agent must fix / optimise") | |
| schema_context: str = Field( | |
| ..., description="DDL / schema description relevant to the query" | |
| ) | |
| hint: Optional[str] = Field( | |
| None, description="Optional natural-language hint for the current step" | |
| ) | |
| step_number: int = Field(0, description="Current step within the episode (0-indexed)") | |
| max_steps: int = Field(5, description="Maximum steps allowed per episode") | |
| done: bool = Field(False, description="Whether the episode has ended") | |
| # --------------------------------------------------------------------------- | |
| # Action | |
| # --------------------------------------------------------------------------- | |
| class Action(BaseModel): | |
| """What the agent submits at each step.""" | |
| rewritten_query: str = Field( | |
| ..., description="The agent's rewritten / improved SQL query" | |
| ) | |
| explanation: str = Field( | |
| ..., description="Natural-language explanation of changes made" | |
| ) | |
| is_done: bool = Field( | |
| False, | |
| description="Set True when the agent believes the query is fully optimised", | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Reward | |
| # --------------------------------------------------------------------------- | |
| class RewardBreakdown(BaseModel): | |
| correctness: float = Field(0.0, ge=0.0, le=1.0) | |
| performance: float = Field(0.0, ge=0.0, le=1.0) | |
| style: float = Field(0.0, ge=0.0, le=1.0) | |
| step_penalty: float = Field(0.0, le=0.0) # always ≤ 0 | |
| class Reward(BaseModel): | |
| """Reward returned after each step.""" | |
| score: float = Field(..., ge=0.0, le=1.0, description="Aggregate step reward") | |
| grader_score: float = Field( | |
| ..., ge=0.0, le=1.0, description="Raw grader score for the submitted query" | |
| ) | |
| breakdown: RewardBreakdown = Field( | |
| default_factory=RewardBreakdown, | |
| description="Per-dimension partial scores", | |
| ) | |
| feedback: str = Field("", description="Human-readable feedback from the grader") | |
| cumulative_score: float = Field( | |
| 0.0, ge=0.0, le=1.0, description="Total score accumulated over episode so far" | |
| ) | |