Abhinav Singh commited on
Commit
c540c55
Β·
1 Parent(s): cc0da5c

feat(data): add Pydantic models and task definitions

Browse files

models.py:
- Observation: task metadata + sql_query + schema_info + dialect
- Action: suggestions list + optimized_query + summary +
estimated_improvement + approved flag
- Reward: score (0-1), per-criterion breakdown dict, feedback str
- StepResult and EnvironmentState for REST API responses

tasks.py:
- task_1_basic_antipatterns (easy, 3 steps): SELECT *, non-SARGable
CAST/YEAR() predicates blocking index usage on a 5M-row orders table
- task_2_join_optimization (medium, 4 steps): 3 correlated subqueries
causing N+1 pattern across users/orders/products (10M+ row tables),
missing index on filter column, unindexed ORDER BY
- task_3_advanced_optimization (hard, 5 steps): JSONB arrow expression
killing index, CTE over-materialization, window function sort cost,
implicit ::text cast preventing index use, autovacuum bloat risk,
HAVING without pre-filter on 500M-row events table
- get_task_list() for /tasks endpoint with full action schema

Files changed (2) hide show
  1. models.py +56 -0
  2. tasks.py +237 -0
models.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Any, Dict, List, Optional
3
+
4
+
5
+ class Observation(BaseModel):
6
+ task_id: str = Field(..., description="Unique task identifier")
7
+ task_name: str = Field(..., description="Human-readable task name")
8
+ task_description: str = Field(..., description="What the agent must do")
9
+ sql_query: str = Field(..., description="The SQL query to analyze/optimize")
10
+ schema_info: str = Field(..., description="Database schema context")
11
+ dialect: str = Field(default="postgresql", description="SQL dialect (postgresql, mysql, sqlite)")
12
+ difficulty: str = Field(..., description="easy | medium | hard")
13
+ step_count: int = Field(default=0, description="Steps taken in this episode")
14
+ max_steps: int = Field(default=5, description="Max steps per episode")
15
+ issues_found_so_far: List[str] = Field(default_factory=list, description="Issues agent has flagged so far")
16
+
17
+
18
+ class OptimizationSuggestion(BaseModel):
19
+ issue_type: str = Field(..., description="Type of issue (e.g. missing_index, n_plus_one, full_table_scan, etc.)")
20
+ line: Optional[int] = Field(None, description="Approximate line number in query")
21
+ description: str = Field(..., description="Detailed description of the issue")
22
+ severity: str = Field(..., description="critical | high | medium | low")
23
+ fix: str = Field(..., description="Suggested fix or rewrite")
24
+
25
+
26
+ class Action(BaseModel):
27
+ suggestions: List[Dict[str, Any]] = Field(
28
+ ...,
29
+ description="List of optimization suggestions. Each: {issue_type, line, description, severity, fix}"
30
+ )
31
+ optimized_query: str = Field(..., description="Rewritten/optimized version of the SQL query")
32
+ summary: str = Field(..., description="Overall analysis summary")
33
+ estimated_improvement: str = Field(..., description="Estimated performance improvement (e.g. '10x faster', '~50% less I/O')")
34
+ approved: bool = Field(..., description="Whether query is already optimal (True) or needs changes (False)")
35
+
36
+
37
+ class Reward(BaseModel):
38
+ score: float = Field(..., ge=0.0, le=1.0, description="Reward score 0.0-1.0")
39
+ breakdown: Dict[str, float] = Field(..., description="Per-criterion scores")
40
+ feedback: str = Field(..., description="Human-readable feedback on the action")
41
+
42
+
43
+ class StepResult(BaseModel):
44
+ observation: Observation
45
+ reward: Reward
46
+ done: bool
47
+ info: Dict[str, Any]
48
+
49
+
50
+ class EnvironmentState(BaseModel):
51
+ task_id: str
52
+ step_count: int
53
+ max_steps: int
54
+ episode_done: bool
55
+ cumulative_reward: float
56
+ current_task: str
tasks.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List
2
+
3
+ TASKS: Dict[str, Dict[str, Any]] = {
4
+
5
+ # ──────────────────────────────────────────────────────────────────
6
+ # TASK 1 β€” EASY: Basic Query Anti-pattern Detection
7
+ # ──────────────────────────────────────────────────────────────────
8
+ "task_1_basic_antipatterns": {
9
+ "task_id": "task_1_basic_antipatterns",
10
+ "task_name": "Basic SQL Anti-pattern Detection",
11
+ "task_description": (
12
+ "Analyze the SQL query below for common anti-patterns that hurt performance. "
13
+ "Identify issues such as: SELECT *, missing WHERE clauses causing full table scans, "
14
+ "implicit type conversions, and non-SARGable predicates that prevent index usage. "
15
+ "For each issue, report: issue_type, line number, description, severity (critical|high|medium|low), and a suggested fix."
16
+ ),
17
+ "difficulty": "easy",
18
+ "dialect": "postgresql",
19
+ "max_steps": 3,
20
+ "schema_info": """\
21
+ Table: orders (id SERIAL PK, customer_id INT FK, status VARCHAR(20), total DECIMAL(10,2), created_at TIMESTAMPTZ)
22
+ Index: idx_orders_customer_id ON orders(customer_id)
23
+ Index: idx_orders_created_at ON orders(created_at)
24
+ Table size: ~5 million rows
25
+ """,
26
+ "sql_query": """\
27
+ -- Fetch recent orders for reporting dashboard
28
+ SELECT *
29
+ FROM orders
30
+ WHERE CAST(customer_id AS TEXT) = '12345'
31
+ AND YEAR(created_at) = 2024;
32
+ """,
33
+ "ground_truth_issues": [
34
+ {
35
+ "type": "select_star",
36
+ "line": 2,
37
+ "keywords": ["select *", "select star", "all columns", "specific columns", "unnecessary columns", "bandwidth"]
38
+ },
39
+ {
40
+ "type": "non_sargable_predicate",
41
+ "line": 4,
42
+ "keywords": ["cast", "convert", "non-sargable", "sargable", "index", "function on column", "type conversion", "implicit"]
43
+ },
44
+ {
45
+ "type": "non_sargable_predicate",
46
+ "line": 5,
47
+ "keywords": ["year(", "function on column", "non-sargable", "index", "date range", "between", "extract"]
48
+ },
49
+ ],
50
+ "approved_expected": False,
51
+ },
52
+
53
+ # ──────────────────────────────────────────────────────────────────
54
+ # TASK 2 β€” MEDIUM: N+1 Query and Join Optimization
55
+ # ──────────────────────────────────────────────────────────────────
56
+ "task_2_join_optimization": {
57
+ "task_id": "task_2_join_optimization",
58
+ "task_name": "N+1 Pattern & Join Optimization",
59
+ "task_description": (
60
+ "Review the SQL query below for join performance issues and N+1 query patterns. "
61
+ "Identify: missing indexes on join columns, inefficient subquery patterns that could be CTEs or JOINs, "
62
+ "correlated subqueries executing per-row, missing covering indexes, and cartesian join risks. "
63
+ "For each issue, report issue_type, line, description, severity, and a specific fix."
64
+ ),
65
+ "difficulty": "medium",
66
+ "dialect": "postgresql",
67
+ "max_steps": 4,
68
+ "schema_info": """\
69
+ Table: users (id SERIAL PK, email VARCHAR UNIQUE, tier VARCHAR(10), region VARCHAR(50), created_at TIMESTAMPTZ)
70
+ Table: orders (id SERIAL PK, user_id INT FK->users.id, product_id INT FK->products.id, amount DECIMAL, placed_at TIMESTAMPTZ, status VARCHAR(20))
71
+ Table: products (id SERIAL PK, name VARCHAR, category VARCHAR(50), price DECIMAL)
72
+ Table: order_items (id SERIAL PK, order_id INT FK->orders.id, product_id INT FK->products.id, qty INT, unit_price DECIMAL)
73
+ Indexes: users(id) PK, orders(user_id), products(id) PK
74
+ No index on: orders(product_id), orders(status), order_items(order_id)
75
+ Approximate sizes: users=500k rows, orders=10M rows, order_items=40M rows, products=50k rows
76
+ """,
77
+ "sql_query": """\
78
+ SELECT
79
+ u.email,
80
+ u.tier,
81
+ (SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id) AS order_count,
82
+ (SELECT SUM(o.amount) FROM orders o WHERE o.user_id = u.id AND o.status = 'completed') AS total_spent,
83
+ (SELECT MAX(o.placed_at) FROM orders o WHERE o.user_id = u.id) AS last_order_date
84
+ FROM users u
85
+ WHERE u.region = 'US'
86
+ AND u.created_at > '2023-01-01'
87
+ ORDER BY total_spent DESC
88
+ LIMIT 100;
89
+ """,
90
+ "ground_truth_issues": [
91
+ {
92
+ "type": "correlated_subquery",
93
+ "line": 4,
94
+ "keywords": ["correlated", "subquery", "per row", "n+1", "repeated", "each user", "lateral", "join"]
95
+ },
96
+ {
97
+ "type": "correlated_subquery",
98
+ "line": 5,
99
+ "keywords": ["correlated", "subquery", "per row", "n+1", "repeated", "each user", "lateral", "join"]
100
+ },
101
+ {
102
+ "type": "correlated_subquery",
103
+ "line": 6,
104
+ "keywords": ["correlated", "subquery", "per row", "n+1", "repeated", "each user", "lateral", "join"]
105
+ },
106
+ {
107
+ "type": "missing_index",
108
+ "line": 8,
109
+ "keywords": ["missing index", "no index", "region", "full scan", "index on region", "composite"]
110
+ },
111
+ {
112
+ "type": "sort_without_index",
113
+ "line": 10,
114
+ "keywords": ["order by", "sort", "filesort", "index", "total_spent", "computed", "no index for sort"]
115
+ },
116
+ ],
117
+ "approved_expected": False,
118
+ },
119
+
120
+ # ──────────────────────────────────────────────────────────────────
121
+ # TASK 3 β€” HARD: Complex Aggregation & Window Function Audit
122
+ # ──────────────────────────────────────────────────────────────────
123
+ "task_3_advanced_optimization": {
124
+ "task_id": "task_3_advanced_optimization",
125
+ "task_name": "Advanced Query & Window Function Audit",
126
+ "task_description": (
127
+ "Perform a deep performance audit of the complex analytical SQL query below. "
128
+ "Identify: missing partition/covering indexes for window functions, "
129
+ "inefficient GROUP BY with HAVING that could be pre-filtered, "
130
+ "implicit data type coercions preventing index usage, "
131
+ "redundant subqueries or CTEs that materialize too early, "
132
+ "missing query hints or planner directives, "
133
+ "and lock contention risks from large aggregations on live tables. "
134
+ "For each issue report: issue_type, line, severity (critical|high|medium|low), description, and a concrete fix."
135
+ ),
136
+ "difficulty": "hard",
137
+ "dialect": "postgresql",
138
+ "max_steps": 5,
139
+ "schema_info": """\
140
+ Table: events (id BIGSERIAL PK, user_id INT, session_id UUID, event_type VARCHAR(50), properties JSONB, occurred_at TIMESTAMPTZ)
141
+ Table: sessions (id UUID PK, user_id INT, started_at TIMESTAMPTZ, ended_at TIMESTAMPTZ, device VARCHAR(30))
142
+ Table: users (id INT PK, plan VARCHAR(20), country VARCHAR(3), created_at DATE)
143
+ Indexes: events(user_id, occurred_at), events(session_id), sessions(user_id)
144
+ No index on: events(event_type), events(occurred_at) standalone, users(plan, country)
145
+ Table sizes: events=500M rows, sessions=50M rows, users=2M rows
146
+ Autovacuum lag: events table has ~10% dead tuples
147
+ """,
148
+ "sql_query": """\
149
+ WITH user_sessions AS (
150
+ SELECT
151
+ e.user_id,
152
+ e.session_id,
153
+ COUNT(*) AS event_count,
154
+ SUM(CASE WHEN e.event_type = 'purchase' THEN 1 ELSE 0 END) AS purchases,
155
+ MIN(e.occurred_at) AS session_start,
156
+ MAX(e.occurred_at) AS session_end
157
+ FROM events e
158
+ JOIN sessions s ON s.id = e.session_id
159
+ WHERE e.occurred_at BETWEEN '2024-01-01' AND '2024-12-31'
160
+ AND properties->>'plan' = 'premium'
161
+ GROUP BY e.user_id, e.session_id
162
+ ),
163
+ ranked_sessions AS (
164
+ SELECT
165
+ *,
166
+ ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY purchases DESC, session_end DESC) AS rn,
167
+ AVG(event_count) OVER (PARTITION BY user_id) AS avg_events_per_session
168
+ FROM user_sessions
169
+ )
170
+ SELECT
171
+ u.country,
172
+ u.plan,
173
+ AVG(rs.purchases) AS avg_purchases,
174
+ COUNT(DISTINCT rs.user_id) AS active_users,
175
+ PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY rs.event_count) AS median_events
176
+ FROM ranked_sessions rs
177
+ JOIN users u ON u.id = rs.user_id
178
+ WHERE rs.rn = 1
179
+ AND u.plan::text IN ('premium', 'enterprise')
180
+ GROUP BY u.country, u.plan
181
+ HAVING COUNT(DISTINCT rs.user_id) > 10
182
+ ORDER BY avg_purchases DESC;
183
+ """,
184
+ "ground_truth_issues": [
185
+ {
186
+ "type": "json_extraction_kills_index",
187
+ "line": 10,
188
+ "keywords": ["jsonb", "properties->", "arrow", "json", "index", "expression index", "gin", "no index", "json field"]
189
+ },
190
+ {
191
+ "type": "redundant_cte_materialization",
192
+ "line": 1,
193
+ "keywords": ["cte", "materialize", "materialized", "inline", "common table expression", "scan twice", "performance"]
194
+ },
195
+ {
196
+ "type": "window_function_missing_index",
197
+ "line": 16,
198
+ "keywords": ["row_number", "window", "partition", "index", "sort", "covering index", "partition by user_id"]
199
+ },
200
+ {
201
+ "type": "implicit_cast_prevents_index",
202
+ "line": 28,
203
+ "keywords": ["cast", "::text", "implicit", "coerce", "index", "type cast", "data type", "prevent"]
204
+ },
205
+ {
206
+ "type": "vacuum_bloat_risk",
207
+ "line": 8,
208
+ "keywords": ["vacuum", "dead tuple", "bloat", "autovacuum", "table bloat", "live rows", "500M", "performance"]
209
+ },
210
+ {
211
+ "type": "having_without_pre_filter",
212
+ "line": 30,
213
+ "keywords": ["having", "group by", "pre-filter", "where", "filter before", "aggregate", "subquery push"]
214
+ },
215
+ ],
216
+ "approved_expected": False,
217
+ },
218
+ }
219
+
220
+
221
+ def get_task_list() -> List[Dict[str, Any]]:
222
+ return [
223
+ {
224
+ "task_id": t["task_id"],
225
+ "task_name": t["task_name"],
226
+ "difficulty": t["difficulty"],
227
+ "description": t["task_description"],
228
+ "action_schema": {
229
+ "suggestions": "List of {issue_type: str, line: int, description: str, severity: str, fix: str}",
230
+ "optimized_query": "str β€” rewritten SQL query with improvements",
231
+ "summary": "str β€” overall analysis summary",
232
+ "estimated_improvement": "str β€” expected performance gain",
233
+ "approved": "bool β€” whether query is already optimal (True) or not (False)"
234
+ }
235
+ }
236
+ for t in TASKS.values()
237
+ ]