Spaces:
Running
Running
| """ | |
| Typed Pydantic models for the SQL Debug Environment. | |
| Implements the OpenEnv spec: Observation, Action, Reward. | |
| """ | |
| from typing import Optional, List, Dict, Any | |
| from pydantic import BaseModel, Field | |
| from enum import Enum | |
| class ActionType(str, Enum): | |
| SUBMIT_QUERY = "submit_query" # Submit a fixed SQL query for evaluation | |
| INSPECT_SCHEMA = "inspect_schema" # Request schema info (costs 0 reward, gives info) | |
| INSPECT_ERROR = "inspect_error" # Request error details (costs 0, gives stack trace) | |
| INSPECT_SAMPLE = "inspect_sample" # Request 3 sample rows from a table | |
| RESET_QUERY = "reset_query" # Reset to the original broken query (costs -0.05 penalty) | |
| class SQLDebugAction(BaseModel): | |
| """ | |
| Action model for the SQL Debug Environment. | |
| The agent can either: | |
| - submit_query: Submit a fixed SQL string for evaluation | |
| - inspect_schema: Get table schema info (free action, no reward change) | |
| - inspect_error: Get detailed error message from last query run | |
| - inspect_sample: Get sample rows from a specified table | |
| - reset_query: Go back to original broken query (costs -0.05 penalty) | |
| """ | |
| action_type: ActionType = Field( | |
| description="Type of action to take" | |
| ) | |
| query: Optional[str] = Field( | |
| default=None, | |
| description="SQL query string. Required when action_type is 'submit_query'." | |
| ) | |
| table_name: Optional[str] = Field( | |
| default=None, | |
| description="Table name. Required when action_type is 'inspect_sample'." | |
| ) | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "action_type": "submit_query", | |
| "query": "SELECT u.name, COUNT(o.id) as order_count FROM users u LEFT JOIN orders o ON u.id = o.user_id GROUP BY u.id, u.name ORDER BY order_count DESC" | |
| } | |
| } | |
| class QueryResult(BaseModel): | |
| """Result of executing a SQL query.""" | |
| success: bool | |
| rows: Optional[List[Dict[str, Any]]] = None | |
| row_count: Optional[int] = None | |
| error_message: Optional[str] = None | |
| execution_time_ms: Optional[float] = None | |
| class SchemaInfo(BaseModel): | |
| """Database schema information.""" | |
| tables: Dict[str, List[Dict[str, str]]] # table_name -> list of {name, type, nullable} | |
| sample_data: Optional[Dict[str, List[Dict[str, Any]]]] = None | |
| class SQLDebugObservation(BaseModel): | |
| """ | |
| Observation returned after each step. | |
| Contains the current state of the debugging session: | |
| - The original broken query (always visible) | |
| - The agent's current best query | |
| - Result of last action | |
| - Progress indicators | |
| - Schema/error info if requested | |
| """ | |
| task_id: str = Field(description="Current task identifier") | |
| task_description: str = Field(description="Natural language description of the bug to fix") | |
| original_query: str = Field(description="The original broken SQL query") | |
| current_query: Optional[str] = Field(default=None, description="Agent's last submitted query") | |
| expected_description: str = Field(description="Description of what the correct output should look like") | |
| # Last action result | |
| last_action_type: str | |
| last_query_result: Optional[QueryResult] = None | |
| # Progress | |
| steps_taken: int | |
| steps_remaining: int | |
| current_score: float = Field(description="Current score in strict range (0, 1) for this episode") | |
| # Contextual help (populated based on action type) | |
| schema_info: Optional[SchemaInfo] = None | |
| error_details: Optional[str] = None | |
| sample_rows: Optional[List[Dict[str, Any]]] = None | |
| # Hints (unlocked after step 3 on easy, step 5 on medium/hard) | |
| hint: Optional[str] = None | |
| # Episode status | |
| is_done: bool = False | |
| success: bool = False | |
| class SQLDebugReward(BaseModel): | |
| """ | |
| Reward signal for the SQL Debug Environment. | |
| Reward components (all sum to final reward): | |
| - correctness: 0.0-0.6 based on row-level match vs expected output | |
| - efficiency: 0.0-0.2 bonus for solving in fewer steps | |
| - syntax_progress: 0.0-0.1 for getting a syntactically valid query (even if wrong) | |
| - schema_bonus: 0.0-0.1 for queries that reference correct tables/columns | |
| - penalties: negative values for reset_query, infinite loops, destructive SQL | |
| """ | |
| value: float = Field(ge=0.001, le=0.999, description="Total reward for this step") | |
| correctness: float = Field(ge=0.0, le=0.6) | |
| efficiency: float = Field(ge=0.0, le=0.2) | |
| syntax_progress: float = Field(ge=0.0, le=0.1) | |
| schema_bonus: float = Field(ge=0.0, le=0.1) | |
| penalty: float = Field(ge=0.0, le=0.2, description="Penalty deduction magnitude (non-negative)") | |
| breakdown: str = Field(description="Human-readable reward breakdown") | |
| class EpisodeState(BaseModel): | |
| """Full internal state of an episode. Used by state() endpoint.""" | |
| task_id: str | |
| task_difficulty: str | |
| original_query: str | |
| current_query: Optional[str] | |
| best_score_so_far: float | |
| steps_taken: int | |
| max_steps: int | |
| action_history: List[Dict[str, Any]] | |
| reward_history: List[float] | |
| is_done: bool | |
| success: bool | |
| db_schema: Dict[str, Any] | |