Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import Any, Literal | |
| from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator | |
| Difficulty = Literal["easy", "medium", "hard"] | |
| ActionType = Literal["identify_issue", "suggest_fix", "approve", "request_more_context"] | |
| IssueCategory = Literal["syntax", "performance", "security", "logic", "style"] | |
| class StrictModel(BaseModel): | |
| model_config = ConfigDict(extra="forbid", populate_by_name=True) | |
| class GroundTruthIssue(StrictModel): | |
| id: str = Field(min_length=1) | |
| category: IssueCategory | |
| description: str = Field(min_length=1) | |
| severity: float = Field(gt=0.0, le=1.0) | |
| fix: str = Field(min_length=1) | |
| keywords: list[str] = Field(default_factory=list) | |
| def normalize_keywords(cls, value: list[str]) -> list[str]: | |
| deduped: list[str] = [] | |
| for keyword in value: | |
| normalized = keyword.strip().lower() | |
| if normalized and normalized not in deduped: | |
| deduped.append(normalized) | |
| return deduped | |
| class TaskRecord(StrictModel): | |
| task_id: str = Field(min_length=1) | |
| difficulty: Difficulty | |
| query: str = Field(min_length=1) | |
| schema_info: dict[str, dict[str, str]] = Field(default_factory=dict, alias="schema") | |
| context: str = Field(min_length=1) | |
| ground_truth_issues: list[GroundTruthIssue] = Field(default_factory=list) | |
| max_steps: int = Field(ge=1, le=12) | |
| class IdentifiedIssue(StrictModel): | |
| issue_id: str = Field(min_length=1) | |
| category: IssueCategory | |
| description: str = Field(min_length=1) | |
| class SQLReviewAction(StrictModel): | |
| action_type: ActionType | |
| issue_category: IssueCategory | None = None | |
| issue_description: str | None = None | |
| suggested_fix: str | None = None | |
| confidence: float = Field(default=0.5, ge=0.0, le=1.0) | |
| def validate_action(self) -> "SQLReviewAction": | |
| if self.action_type == "identify_issue": | |
| if not self.issue_category or not self.issue_description: | |
| raise ValueError("identify_issue requires issue_category and issue_description") | |
| elif self.action_type == "suggest_fix": | |
| if not self.suggested_fix: | |
| raise ValueError("suggest_fix requires suggested_fix") | |
| return self | |
| class SQLReviewObservation(StrictModel): | |
| query: str | |
| schema_info: dict[str, dict[str, str]] = Field(default_factory=dict) | |
| context: str | |
| issues_found_so_far: list[IdentifiedIssue] = Field(default_factory=list) | |
| remaining_actions: int = Field(ge=0) | |
| difficulty: Difficulty | |
| feedback: str | |
| class SQLReviewState(StrictModel): | |
| task_id: str | |
| step_count: int = Field(default=0, ge=0) | |
| issues_identified: list[IdentifiedIssue] = Field(default_factory=list) | |
| total_reward: float = 0.0 | |
| done: bool = False | |
| approved: bool = False | |
| fixes_suggested: list[str] = Field(default_factory=list) | |
| false_positive_count: int = Field(default=0, ge=0) | |
| final_score: float | None = Field(default=None, gt=0.0, lt=1.0) | |
| class StepResult(StrictModel): | |
| observation: SQLReviewObservation | |
| reward: float | |
| done: bool | |
| info: dict[str, Any] = Field(default_factory=dict) | |
| class ResetRequest(StrictModel): | |
| task_id: str | None = None | |