Abhinav Singh commited on
Commit
7841be7
Β·
1 Parent(s): c540c55

feat(core): add grader and SQLOptimEnv environment class

Browse files

graders.py β€” 6-component composite reward function:
1. Issue Detection 60% keyword-match against ground truth issues
2. Optimized Query 15% length + anti-pattern removal heuristics
3. Approval Correct 10% bool match vs. approved_expected
4. Summary Quality 8% progressive scoring on summary length
5. Improvement Est. 4% keyword-match on estimated_improvement field
6. Severity Labels 3% checks severity values are present
Minimum reward of 0.02 for any non-empty submission (partial signal)

env.py β€” SQLOptimEnv class:
- reset(task_id): validates task, initialises episode state, returns Observation
- step(action): grades action, tracks issues_found_so_far, returns StepResult
- state(): returns EnvironmentState snapshot without advancing episode
- Episode terminates on max_steps OR reward >= 0.95 (early exit)

Files changed (2) hide show
  1. env.py +109 -0
  2. graders.py +126 -0
env.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from models import Observation, Action, Reward, StepResult, EnvironmentState
3
+ from tasks import TASKS
4
+ from graders import grade
5
+
6
+
7
+ class SQLOptimEnv:
8
+ """
9
+ OpenEnv-compliant environment for SQL Query Optimization.
10
+
11
+ An AI agent iteratively analyzes a SQL query, identifies performance issues,
12
+ and submits optimized rewrites. The environment grades each action and tracks
13
+ progress across multiple steps within an episode.
14
+ """
15
+
16
+ def __init__(self):
17
+ self._task_data: Optional[dict] = None
18
+ self._step_count: int = 0
19
+ self._done: bool = False
20
+ self._cumulative_reward: float = 0.0
21
+ self._issues_found: list = []
22
+
23
+ def reset(self, task_id: str = "task_1_basic_antipatterns") -> Observation:
24
+ """Start a new episode for the given task."""
25
+ if task_id not in TASKS:
26
+ raise ValueError(
27
+ f"Unknown task_id '{task_id}'. "
28
+ f"Valid tasks: {list(TASKS.keys())}"
29
+ )
30
+ self._task_data = TASKS[task_id]
31
+ self._step_count = 0
32
+ self._done = False
33
+ self._cumulative_reward = 0.0
34
+ self._issues_found = []
35
+
36
+ return self._make_observation()
37
+
38
+ def step(self, action: Action) -> StepResult:
39
+ """Process one agent action and return (observation, reward, done, info)."""
40
+ if self._task_data is None:
41
+ raise RuntimeError("Episode not started. Call reset() first.")
42
+ if self._done:
43
+ raise RuntimeError("Episode already finished. Call reset() to start a new episode.")
44
+
45
+ self._step_count += 1
46
+
47
+ # Grade the action
48
+ reward: Reward = grade(self._task_data, action)
49
+ self._cumulative_reward += reward.score
50
+
51
+ # Track issue types found so far
52
+ for s in action.suggestions:
53
+ issue_type = s.get("issue_type", "")
54
+ if issue_type and issue_type not in self._issues_found:
55
+ self._issues_found.append(issue_type)
56
+
57
+ # Episode ends when max_steps reached OR agent finds a perfect score
58
+ max_steps = self._task_data["max_steps"]
59
+ done = self._step_count >= max_steps or reward.score >= 0.95
60
+
61
+ self._done = done
62
+
63
+ obs = self._make_observation()
64
+
65
+ return StepResult(
66
+ observation=obs,
67
+ reward=reward,
68
+ done=done,
69
+ info={
70
+ "step": self._step_count,
71
+ "cumulative_reward": round(self._cumulative_reward, 4),
72
+ "issues_found_count": len(self._issues_found),
73
+ }
74
+ )
75
+
76
+ def state(self) -> EnvironmentState:
77
+ """Return current environment state (for /state endpoint)."""
78
+ if self._task_data is None:
79
+ return EnvironmentState(
80
+ task_id="none",
81
+ step_count=0,
82
+ max_steps=0,
83
+ episode_done=True,
84
+ cumulative_reward=0.0,
85
+ current_task="No active episode"
86
+ )
87
+ return EnvironmentState(
88
+ task_id=self._task_data["task_id"],
89
+ step_count=self._step_count,
90
+ max_steps=self._task_data["max_steps"],
91
+ episode_done=self._done,
92
+ cumulative_reward=round(self._cumulative_reward, 4),
93
+ current_task=self._task_data["task_name"],
94
+ )
95
+
96
+ def _make_observation(self) -> Observation:
97
+ d = self._task_data
98
+ return Observation(
99
+ task_id=d["task_id"],
100
+ task_name=d["task_name"],
101
+ task_description=d["task_description"],
102
+ sql_query=d["sql_query"],
103
+ schema_info=d["schema_info"],
104
+ dialect=d.get("dialect", "postgresql"),
105
+ difficulty=d["difficulty"],
106
+ step_count=self._step_count,
107
+ max_steps=d["max_steps"],
108
+ issues_found_so_far=list(self._issues_found),
109
+ )
graders.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List
2
+ from models import Action, Reward
3
+
4
+
5
+ def _keyword_match(text: str, keywords: List[str]) -> bool:
6
+ """Check if any keyword appears in text (case-insensitive)."""
7
+ text_lower = text.lower()
8
+ return any(kw.lower() in text_lower for kw in keywords)
9
+
10
+
11
+ def _suggestions_text(action: Action) -> str:
12
+ """Flatten all suggestion fields into one searchable string."""
13
+ parts = [action.summary, action.optimized_query, action.estimated_improvement]
14
+ for s in action.suggestions:
15
+ parts.append(str(s.get("issue_type", "")))
16
+ parts.append(str(s.get("description", "")))
17
+ parts.append(str(s.get("fix", "")))
18
+ parts.append(str(s.get("line", "")))
19
+ parts.append(str(s.get("severity", "")))
20
+ return " ".join(parts)
21
+
22
+
23
+ def grade(task_data: Dict[str, Any], action: Action) -> Reward:
24
+ """
25
+ Grade an agent's SQL optimization action against ground truth issues.
26
+
27
+ Scoring breakdown:
28
+ - Issue Detection: 60% (did agent find the right problems?)
29
+ - Optimized Query Quality: 15% (did agent provide a meaningful rewrite?)
30
+ - Approval Correctness: 10% (correctly flagged as needing changes?)
31
+ - Summary Quality: 8% (is the summary thorough and informative?)
32
+ - Improvement Estimate: 4% (did agent quantify the expected gain?)
33
+ - Severity Labels: 3% (are severity levels present?)
34
+ """
35
+ ground_truth: List[Dict[str, Any]] = task_data["ground_truth_issues"]
36
+ full_text = _suggestions_text(action)
37
+
38
+ # ── 1. Issue Detection Score (0.0–0.60) ────────────────────────────
39
+ detected = 0
40
+ detection_feedback = []
41
+ for gt_issue in ground_truth:
42
+ found = _keyword_match(full_text, gt_issue["keywords"])
43
+ if found:
44
+ detected += 1
45
+ detection_feedback.append(f"βœ… Found: {gt_issue['type']} (line ~{gt_issue['line']})")
46
+ else:
47
+ detection_feedback.append(f"❌ Missed: {gt_issue['type']} (line ~{gt_issue['line']})")
48
+
49
+ detection_score = (detected / len(ground_truth)) * 0.60
50
+
51
+ # ── 2. Optimized Query Quality (0.0–0.15) ──────────────────────────
52
+ query_score = 0.0
53
+ oq = action.optimized_query.strip()
54
+ if len(oq) > 50:
55
+ query_score = 0.05
56
+ if len(oq) > 150:
57
+ query_score = 0.10
58
+ # Bonus if the rewrite removes obvious anti-patterns found in original
59
+ original_query = task_data["sql_query"].lower()
60
+ if "select *" in original_query and "select *" not in oq.lower():
61
+ query_score = min(query_score + 0.03, 0.15)
62
+ if query_score < 0.15 and len(action.suggestions) > 0 and len(oq) > 100:
63
+ query_score = min(query_score + 0.02, 0.15)
64
+ query_score = min(query_score, 0.15)
65
+
66
+ # ── 3. Approval Correctness (0.0–0.10) ─────────────────────────────
67
+ expected_approved = task_data.get("approved_expected", False)
68
+ approval_score = 0.10 if action.approved == expected_approved else 0.0
69
+
70
+ # ── 4. Summary Quality (0.0–0.08) ──────────────────────────────────
71
+ summary_score = 0.0
72
+ if len(action.summary) > 40:
73
+ summary_score = 0.04
74
+ if len(action.summary) > 100:
75
+ summary_score = 0.08
76
+
77
+ # ── 5. Improvement Estimate Present (0.0–0.04) ─────────────────────
78
+ improvement_keywords = ["x faster", "% less", "% faster", "% improvement", "times", "reduce", "improvement", "speedup"]
79
+ has_estimate = _keyword_match(action.estimated_improvement, improvement_keywords) and len(action.estimated_improvement) > 5
80
+ improvement_score = 0.04 if has_estimate else 0.0
81
+
82
+ # ── 6. Severity Labels Present (0.0–0.03) ──────────────────────────
83
+ severity_keywords = ["critical", "high", "medium", "low"]
84
+ has_severity = any(
85
+ _keyword_match(str(s.get("severity", "")), severity_keywords)
86
+ for s in action.suggestions
87
+ )
88
+ severity_score = 0.03 if has_severity else 0.0
89
+
90
+ # ── Final Score ─────────────────────────────────────────────────────
91
+ total = (
92
+ detection_score + query_score + approval_score +
93
+ summary_score + improvement_score + severity_score
94
+ )
95
+ total = round(min(max(total, 0.0), 1.0), 4)
96
+
97
+ # Minimum signal for any submission
98
+ if total == 0.0 and len(action.suggestions) > 0:
99
+ total = 0.02
100
+
101
+ breakdown = {
102
+ "issue_detection": round(detection_score, 4),
103
+ "optimized_query": round(query_score, 4),
104
+ "approval_correctness": round(approval_score, 4),
105
+ "summary_quality": round(summary_score, 4),
106
+ "improvement_estimate": round(improvement_score, 4),
107
+ "severity_labels": round(severity_score, 4),
108
+ }
109
+
110
+ n_suggestions = len(action.suggestions)
111
+ expected_n = len(ground_truth)
112
+
113
+ feedback_lines = detection_feedback + [
114
+ f"\nSuggestions submitted: {n_suggestions} (expected ~{expected_n})",
115
+ f"Optimized query length: {len(oq)} chars",
116
+ f"Approval correctness: {'βœ…' if action.approved == expected_approved else '❌'} "
117
+ f"(you said {'approved' if action.approved else 'needs changes'}, "
118
+ f"expected {'approved' if expected_approved else 'needs changes'})",
119
+ f"Total score: {total:.4f}",
120
+ ]
121
+
122
+ return Reward(
123
+ score=total,
124
+ breakdown=breakdown,
125
+ feedback="\n".join(feedback_lines)
126
+ )