ar9avg commited on
Commit
3c665d2
Β·
1 Parent(s): d796343

Initial submission: SQL Agent OpenEnv for Meta+HF hackathon

Browse files
Files changed (49) hide show
  1. .gitignore +8 -0
  2. Dockerfile +58 -0
  3. README.md +13 -0
  4. backend/api/__init__.py +0 -0
  5. backend/api/demo.py +495 -0
  6. backend/api/openenv.py +138 -0
  7. backend/data/.gitkeep +0 -0
  8. backend/data/benchmark.db +0 -0
  9. backend/env/__init__.py +0 -0
  10. backend/env/database.py +430 -0
  11. backend/env/sql_env.py +594 -0
  12. backend/env/tasks.py +345 -0
  13. backend/gepa/__init__.py +0 -0
  14. backend/gepa/optimizer.py +347 -0
  15. backend/main.py +104 -0
  16. backend/requirements.txt +9 -0
  17. backend/rl/__init__.py +0 -0
  18. backend/rl/environment.py +266 -0
  19. backend/rl/error_classifier.py +98 -0
  20. backend/rl/experience.py +208 -0
  21. backend/rl/grader.py +99 -0
  22. backend/rl/linucb.py +190 -0
  23. backend/rl/repair_strategies.py +219 -0
  24. backend/rl/types.py +161 -0
  25. frontend/index.html +14 -0
  26. frontend/package-lock.json +0 -0
  27. frontend/package.json +30 -0
  28. frontend/postcss.config.js +6 -0
  29. frontend/src/App.tsx +179 -0
  30. frontend/src/components/BenchmarkPanel.tsx +384 -0
  31. frontend/src/components/ChatPanel.tsx +599 -0
  32. frontend/src/components/ERDiagram.tsx +234 -0
  33. frontend/src/components/Header.tsx +110 -0
  34. frontend/src/components/LeftSidebar.tsx +157 -0
  35. frontend/src/components/PerformanceGraph.tsx +175 -0
  36. frontend/src/components/PromptEvolution.tsx +148 -0
  37. frontend/src/components/ResultsTable.tsx +78 -0
  38. frontend/src/components/RightSidebar.tsx +27 -0
  39. frontend/src/index.css +187 -0
  40. frontend/src/lib/api.ts +97 -0
  41. frontend/src/lib/types.ts +131 -0
  42. frontend/src/main.tsx +19 -0
  43. frontend/src/store/useStore.ts +175 -0
  44. frontend/src/vite-env.d.ts +9 -0
  45. frontend/tailwind.config.js +20 -0
  46. frontend/tsconfig.json +24 -0
  47. frontend/vite.config.ts +28 -0
  48. inference.py +230 -0
  49. openenv.yaml +137 -0
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ .venv/
2
+ __pycache__/
3
+ *.pyc
4
+ backend/data/rl_weights.json
5
+ backend/data/rl_experiences.json
6
+ backend/data/gepa_prompt.json
7
+ node_modules/
8
+ frontend/dist/
Dockerfile ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SQL Agent OpenEnv β€” Docker build for Hugging Face Spaces
2
+ #
3
+ # Stage 1: Build React frontend
4
+ # Stage 2: Python FastAPI app serving both the API and static UI
5
+ #
6
+ # HF Spaces expects the app to listen on port 7860.
7
+
8
+ # ── Stage 1: Frontend build ───────────────────────────────────────────────────
9
+ FROM node:20-slim AS frontend-builder
10
+
11
+ WORKDIR /app/frontend
12
+
13
+ # Install deps first (layer cache)
14
+ COPY frontend/package.json frontend/package-lock.json* ./
15
+ RUN npm ci --prefer-offline --no-audit
16
+
17
+ # Build the app
18
+ COPY frontend/ ./
19
+ RUN npm run build
20
+
21
+
22
+ # ── Stage 2: Python runtime ───────────────────────────────────────────────────
23
+ FROM python:3.11-slim
24
+
25
+ # System deps
26
+ RUN apt-get update && apt-get install -y --no-install-recommends \
27
+ gcc \
28
+ && rm -rf /var/lib/apt/lists/*
29
+
30
+ WORKDIR /app
31
+
32
+ # Python deps
33
+ COPY backend/requirements.txt ./requirements.txt
34
+ RUN pip install --no-cache-dir -r requirements.txt
35
+
36
+ # Copy backend source
37
+ COPY backend/ ./backend/
38
+
39
+ # Copy built frontend
40
+ COPY --from=frontend-builder /app/frontend/dist ./frontend/dist
41
+
42
+ # Copy repo-root artefacts
43
+ COPY inference.py openenv.yaml README.md ./
44
+
45
+ # Ensure data dir exists (RL weights, GEPA prompts, SQLite DB)
46
+ RUN mkdir -p ./backend/data
47
+
48
+ # ── HF Spaces config ──────────────────────────────────────────────────────────
49
+ EXPOSE 7860
50
+
51
+ ENV PORT=7860 \
52
+ PYTHONUNBUFFERED=1 \
53
+ PYTHONDONTWRITEBYTECODE=1
54
+
55
+ # Run from backend/ so relative imports and data/ paths resolve correctly
56
+ WORKDIR /app/backend
57
+
58
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
README.md CHANGED
@@ -1,4 +1,17 @@
1
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  title: Sql Agent Openenv
3
  emoji: 🏒
4
  colorFrom: yellow
 
1
  ---
2
+ title: SQL Agent OpenEnv
3
+ emoji: πŸ—„οΈ
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ tags:
9
+ - openenv
10
+ - sql
11
+ - reinforcement-learning
12
+ - contextual-bandit
13
+ ---
14
+ ---
15
  title: Sql Agent Openenv
16
  emoji: 🏒
17
  colorFrom: yellow
backend/api/__init__.py ADDED
File without changes
backend/api/demo.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Demo API routes β€” streaming SSE endpoints matching the original TypeScript API.
3
+
4
+ Routes:
5
+ GET /api/init
6
+ POST /api/execute-query (SSE)
7
+ POST /api/benchmark (SSE)
8
+ GET /api/rl-state
9
+ GET /api/schema-graph
10
+ POST /api/feedback
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import asyncio
16
+ import json
17
+ import time
18
+ from typing import AsyncIterator, Optional
19
+
20
+ from fastapi import APIRouter
21
+ from pydantic import BaseModel
22
+ from sse_starlette.sse import EventSourceResponse
23
+
24
+ from env.database import (
25
+ ensure_seeded,
26
+ get_table_stats,
27
+ get_schema_info,
28
+ get_schema_graph,
29
+ execute_query,
30
+ )
31
+ from env.tasks import TASKS, get_task
32
+ from env.sql_env import SQLAgentEnv, Action, get_env, BASE_SYSTEM_PROMPT, _clean_sql
33
+ from rl.environment import get_bandit_state
34
+ from rl.types import RepairAction, REPAIR_ACTION_NAMES, REPAIR_ACTION_BY_NAME
35
+ from rl.error_classifier import classify_error, extract_offending_token
36
+ from rl.grader import GraderInput, compute_reward, compute_episode_reward
37
+ from rl.types import RLState, EpisodeStep, featurize, ERROR_CLASS_NAMES
38
+ from gepa.optimizer import get_gepa, QueryResult
39
+
40
+ router = APIRouter()
41
+
42
+
43
+ # ─── /api/init ────────────────────────────────────────────────────
44
+
45
+ @router.get("/init")
46
+ async def init_db():
47
+ seeded = ensure_seeded()
48
+ tables = get_table_stats()
49
+ return {"tables": tables, "seeded": seeded}
50
+
51
+
52
+ # ─── /api/execute-query ───────────────────────────────────────────
53
+
54
+ class ExecuteQueryRequest(BaseModel):
55
+ question: str
56
+ task_id: str = "simple_queries"
57
+
58
+
59
+ @router.post("/execute-query")
60
+ async def execute_query_stream(req: ExecuteQueryRequest):
61
+ async def event_generator() -> AsyncIterator[dict]:
62
+ env = get_env()
63
+ obs = env.reset(req.task_id)
64
+
65
+ # Pick first question of task matching question text, or default
66
+ task = get_task(req.task_id)
67
+ question_obj = task.questions[0]
68
+ # Override question text
69
+ env._episode.question = req.question # type: ignore[union-attr]
70
+
71
+ max_attempts = env.MAX_ATTEMPTS
72
+ done = False
73
+ all_step_rewards: list[float] = []
74
+ success = False
75
+
76
+ # Initial generate action
77
+ action = Action(repair_action="generate")
78
+
79
+ for attempt in range(1, max_attempts + 1):
80
+ yield {"data": json.dumps({"type": "attempt_start", "attempt": attempt})}
81
+
82
+ ep = env._episode # type: ignore[union-attr]
83
+ ep.attempt_number = attempt
84
+
85
+ # Generate SQL with streaming
86
+ from env.sql_env import _make_client, _MODEL
87
+ from openai import AsyncOpenAI
88
+
89
+ if attempt == 1 or ep.current_sql is None:
90
+ system_prompt = BASE_SYSTEM_PROMPT
91
+ user_msg = (
92
+ f"Schema:\n{obs.schema_info}\n\nQuestion: {req.question}\n\n"
93
+ "Write a SQL query to answer this question."
94
+ )
95
+ else:
96
+ from rl.repair_strategies import RepairContext, get_repair_system_suffix, build_repair_user_message
97
+ from env.sql_env import REPAIR_ACTION_BY_NAME
98
+
99
+ # Bandit selects action
100
+ if ep.current_features is not None:
101
+ repair_enum, scores = env._bandit.select_action(ep.current_features)
102
+ ucb_scores = {
103
+ REPAIR_ACTION_NAMES[RepairAction(i)]: round(scores[i], 4)
104
+ for i in range(len(scores))
105
+ }
106
+ action = Action(repair_action=REPAIR_ACTION_NAMES[repair_enum])
107
+ yield {"data": json.dumps({
108
+ "type": "rl_action",
109
+ "action": action.repair_action,
110
+ "ucb_scores": ucb_scores,
111
+ })}
112
+ else:
113
+ repair_enum = RepairAction.REWRITE_FULL
114
+ action = Action(repair_action="rewrite_full")
115
+
116
+ suffix = get_repair_system_suffix(repair_enum)
117
+ offending = extract_offending_token(ep.error_message or "")
118
+ ctx = RepairContext(
119
+ schema=obs.schema_info,
120
+ question=req.question,
121
+ failing_sql=ep.current_sql or "",
122
+ error_message=ep.error_message or "",
123
+ offending_token=offending,
124
+ )
125
+ system_prompt = BASE_SYSTEM_PROMPT + suffix
126
+ user_msg = build_repair_user_message(repair_enum, ctx)
127
+
128
+ # Stream SQL generation
129
+ client = _make_client()
130
+ chunks: list[str] = []
131
+ try:
132
+ stream = await client.chat.completions.create(
133
+ model=_MODEL,
134
+ messages=[
135
+ {"role": "system", "content": system_prompt},
136
+ {"role": "user", "content": user_msg},
137
+ ],
138
+ stream=True,
139
+ temperature=0.1,
140
+ )
141
+ async for chunk in stream:
142
+ delta = chunk.choices[0].delta.content
143
+ if delta:
144
+ chunks.append(delta)
145
+ yield {"data": json.dumps({"type": "sql_chunk", "chunk": delta})}
146
+ except Exception as e:
147
+ yield {"data": json.dumps({"type": "error", "error": str(e), "error_class": "other"})}
148
+ break
149
+
150
+ generated_sql = _clean_sql("".join(chunks))
151
+ yield {"data": json.dumps({"type": "sql_complete", "sql": generated_sql})}
152
+ yield {"data": json.dumps({"type": "executing"})}
153
+
154
+ rows, error = execute_query(generated_sql)
155
+
156
+ from env.tasks import grade_response
157
+ task_score = grade_response(
158
+ req.task_id, question_obj.id, generated_sql, rows, error, attempt
159
+ )
160
+ attempt_success = task_score >= 0.8
161
+
162
+ current_error_class = None
163
+ error_class_name = None
164
+
165
+ if error:
166
+ ec = classify_error(error)
167
+ current_error_class = ec
168
+ error_class_name = ERROR_CLASS_NAMES[ec]
169
+
170
+ error_changed = (
171
+ ep.previous_error_class is not None
172
+ and ep.previous_error_class != current_error_class
173
+ )
174
+ if ep.previous_error_class == current_error_class:
175
+ ep.consecutive_same_error += 1
176
+ else:
177
+ ep.consecutive_same_error = 1
178
+
179
+ rl_state = RLState(
180
+ error_class=current_error_class,
181
+ attempt_number=attempt,
182
+ previous_action=ep.last_action,
183
+ error_changed=error_changed,
184
+ consecutive_same_error=ep.consecutive_same_error,
185
+ )
186
+ ep.current_rl_state = rl_state
187
+ ep.current_features = featurize(rl_state)
188
+
189
+ # Stream diagnosis chunk
190
+ try:
191
+ diag_stream = await client.chat.completions.create(
192
+ model=_MODEL,
193
+ messages=[
194
+ {"role": "system", "content": "You are a SQL debugger. Briefly explain the error in one sentence."},
195
+ {"role": "user", "content": f"Error: {error}\nSQL: {generated_sql}"},
196
+ ],
197
+ stream=True,
198
+ temperature=0.3,
199
+ )
200
+ async for chunk in diag_stream:
201
+ delta = chunk.choices[0].delta.content
202
+ if delta:
203
+ yield {"data": json.dumps({"type": "diagnosis_chunk", "chunk": delta})}
204
+ except Exception:
205
+ pass
206
+
207
+ yield {"data": json.dumps({"type": "error", "error": error, "error_class": error_class_name})}
208
+
209
+ # Grader + RL update
210
+ grader_in = GraderInput(
211
+ success=attempt_success,
212
+ attempt_number=attempt,
213
+ current_error_class=current_error_class,
214
+ previous_error_class=ep.previous_error_class,
215
+ )
216
+ grader_out = compute_reward(grader_in)
217
+ all_step_rewards.append(grader_out.reward)
218
+
219
+ if ep.current_rl_state and ep.current_features:
220
+ repair_enum_for_step = REPAIR_ACTION_BY_NAME.get(
221
+ action.repair_action, RepairAction.REWRITE_FULL
222
+ )
223
+ step_obj = EpisodeStep(
224
+ state=ep.current_rl_state,
225
+ featurized=ep.current_features,
226
+ action=repair_enum_for_step,
227
+ reward=grader_out.reward,
228
+ error_message=error or "",
229
+ sql=generated_sql,
230
+ success=attempt_success,
231
+ )
232
+ ep.steps.append(step_obj)
233
+ env._bandit.update(ep.current_features, repair_enum_for_step, grader_out.reward)
234
+ ep.last_action = repair_enum_for_step
235
+
236
+ ep.current_sql = generated_sql
237
+ ep.error_message = error
238
+ ep.error_class = error_class_name
239
+ ep.previous_error_class = current_error_class
240
+
241
+ yield {"data": json.dumps({
242
+ "type": "rl_reward",
243
+ "reward": grader_out.reward,
244
+ "breakdown": {
245
+ "base": grader_out.breakdown.base,
246
+ "attempt_penalty": grader_out.breakdown.attempt_penalty,
247
+ "severity_bonus": grader_out.breakdown.severity_bonus,
248
+ "change_bonus": grader_out.breakdown.change_bonus,
249
+ },
250
+ })}
251
+
252
+ if attempt_success:
253
+ success = True
254
+ yield {"data": json.dumps({
255
+ "type": "success",
256
+ "rows": rows,
257
+ "row_count": len(rows),
258
+ "sql": generated_sql,
259
+ })}
260
+ done = True
261
+ break
262
+
263
+ total_reward = compute_episode_reward(all_step_rewards, success)
264
+ yield {"data": json.dumps({
265
+ "type": "rl_episode_end",
266
+ "total_reward": total_reward,
267
+ "success": success,
268
+ })}
269
+
270
+ # Record GEPA history
271
+ gepa = get_gepa()
272
+ gepa.record_result(QueryResult(
273
+ question=req.question,
274
+ final_sql=env._episode.current_sql or "" if env._episode else "", # type: ignore[union-attr]
275
+ attempts=len(all_step_rewards),
276
+ success=success,
277
+ errors=[s.error_message for s in (env._episode.steps if env._episode else []) if s.error_message],
278
+ timestamp=time.time(),
279
+ ))
280
+
281
+ # Finalize episode
282
+ env._finalize_episode(success=success)
283
+ if env._episode:
284
+ env._episode.done = True
285
+ env._episode.success = success
286
+
287
+ # Trigger GEPA if needed
288
+ if gepa.should_optimize():
289
+ try:
290
+ await gepa.run_optimization_cycle()
291
+ except Exception:
292
+ pass
293
+
294
+ return EventSourceResponse(event_generator())
295
+
296
+
297
+ # ─── /api/benchmark ───────────────────────────────────────────────
298
+
299
+ class BenchmarkRequest(BaseModel):
300
+ task_id: str = "simple_queries"
301
+
302
+
303
+ @router.post("/benchmark")
304
+ async def run_benchmark(req: BenchmarkRequest):
305
+ async def event_generator() -> AsyncIterator[dict]:
306
+ task = get_task(req.task_id)
307
+ scores: list[float] = []
308
+
309
+ for question_obj in task.questions:
310
+ yield {"data": json.dumps({
311
+ "type": "query_start",
312
+ "query_id": question_obj.id,
313
+ "question": question_obj.question,
314
+ })}
315
+
316
+ # Run the question through the env
317
+ env = SQLAgentEnv()
318
+ obs = env.reset_with_question(req.task_id, question_obj.id)
319
+
320
+ attempt = 0
321
+ sql = ""
322
+ success = False
323
+ task_score = 0.0
324
+ max_attempts = env.MAX_ATTEMPTS
325
+ ep = env._episode # type: ignore[union-attr]
326
+
327
+ gepa = get_gepa()
328
+ system_prompt = gepa.get_current_prompt()
329
+ from env.sql_env import _make_client, _MODEL
330
+
331
+ for attempt in range(1, max_attempts + 1):
332
+ ep.attempt_number = attempt
333
+
334
+ if attempt == 1 or ep.current_sql is None:
335
+ user_msg = (
336
+ f"Schema:\n{obs.schema_info}\n\n"
337
+ f"Question: {question_obj.question}\n\n"
338
+ "Write a SQL query to answer this question."
339
+ )
340
+ sys_prompt = system_prompt
341
+ else:
342
+ from rl.repair_strategies import RepairContext, get_repair_system_suffix, build_repair_user_message
343
+ if ep.current_features is not None:
344
+ repair_enum, _ = env._bandit.select_action(ep.current_features)
345
+ else:
346
+ repair_enum = RepairAction.REWRITE_FULL
347
+ suffix = get_repair_system_suffix(repair_enum)
348
+ offending = extract_offending_token(ep.error_message or "")
349
+ ctx = RepairContext(
350
+ schema=obs.schema_info,
351
+ question=question_obj.question,
352
+ failing_sql=ep.current_sql or "",
353
+ error_message=ep.error_message or "",
354
+ offending_token=offending,
355
+ )
356
+ sys_prompt = system_prompt + suffix
357
+ user_msg = build_repair_user_message(repair_enum, ctx)
358
+
359
+ client = _make_client()
360
+ try:
361
+ resp = await client.chat.completions.create(
362
+ model=_MODEL,
363
+ messages=[
364
+ {"role": "system", "content": sys_prompt},
365
+ {"role": "user", "content": user_msg},
366
+ ],
367
+ temperature=0.1,
368
+ )
369
+ sql = _clean_sql(resp.choices[0].message.content or "")
370
+ except Exception as e:
371
+ break
372
+
373
+ rows, error = execute_query(sql)
374
+ from env.tasks import grade_response
375
+ task_score = grade_response(
376
+ req.task_id, question_obj.id, sql, rows, error, attempt
377
+ )
378
+ success = task_score >= 0.8
379
+
380
+ current_ec = None
381
+ if error:
382
+ ec = classify_error(error)
383
+ current_ec = ec
384
+ error_changed = ep.previous_error_class is not None and ep.previous_error_class != ec
385
+ if ep.previous_error_class == ec:
386
+ ep.consecutive_same_error += 1
387
+ else:
388
+ ep.consecutive_same_error = 1
389
+ rl_state = RLState(
390
+ error_class=ec,
391
+ attempt_number=attempt,
392
+ previous_action=ep.last_action,
393
+ error_changed=error_changed,
394
+ consecutive_same_error=ep.consecutive_same_error,
395
+ )
396
+ ep.current_rl_state = rl_state
397
+ ep.current_features = featurize(rl_state)
398
+
399
+ from rl.grader import GraderInput, compute_reward
400
+ grader_in = GraderInput(
401
+ success=success,
402
+ attempt_number=attempt,
403
+ current_error_class=current_ec,
404
+ previous_error_class=ep.previous_error_class,
405
+ )
406
+ grader_out = compute_reward(grader_in)
407
+
408
+ ep.current_sql = sql
409
+ ep.error_message = error
410
+ ep.error_class = ERROR_CLASS_NAMES[current_ec] if current_ec else None
411
+ ep.previous_error_class = current_ec
412
+
413
+ if success:
414
+ break
415
+
416
+ scores.append(task_score)
417
+
418
+ yield {"data": json.dumps({
419
+ "type": "query_result",
420
+ "query_id": question_obj.id,
421
+ "success": success,
422
+ "score": task_score,
423
+ "sql": sql,
424
+ "attempts": attempt,
425
+ })}
426
+
427
+ overall_score = sum(scores) / len(scores) if scores else 0.0
428
+ yield {"data": json.dumps({
429
+ "type": "done",
430
+ "overall_score": overall_score,
431
+ "task_id": req.task_id,
432
+ })}
433
+
434
+ return EventSourceResponse(event_generator())
435
+
436
+
437
+ # ─── /api/rl-state ────────────────────────────────────────────────
438
+
439
+ @router.get("/rl-state")
440
+ async def get_rl_state():
441
+ state = get_bandit_state()
442
+ action_names = [REPAIR_ACTION_NAMES[RepairAction(i)] for i in range(8)]
443
+ action_distribution = {
444
+ name: state["action_counts"][i]
445
+ for i, name in enumerate(action_names)
446
+ }
447
+ return {
448
+ "action_counts": state["action_counts"],
449
+ "alpha": state["alpha"],
450
+ "total_updates": state["total_updates"],
451
+ "action_distribution": action_distribution,
452
+ }
453
+
454
+
455
+ # ─── /api/schema-graph ────────────────────────────────────────────
456
+
457
+ @router.get("/schema-graph")
458
+ async def schema_graph():
459
+ return get_schema_graph()
460
+
461
+
462
+ # ─── /api/feedback ────────────────────────────────────────────────
463
+
464
+ class FeedbackRequest(BaseModel):
465
+ question: str
466
+ sql: str
467
+ correct: bool
468
+
469
+
470
+ @router.post("/feedback")
471
+ async def submit_feedback(req: FeedbackRequest):
472
+ gepa = get_gepa()
473
+ gepa.record_result(QueryResult(
474
+ question=req.question,
475
+ final_sql=req.sql,
476
+ attempts=1,
477
+ success=req.correct,
478
+ errors=[] if req.correct else ["User marked as incorrect"],
479
+ timestamp=time.time(),
480
+ ))
481
+
482
+ result = None
483
+ if not req.correct and gepa.should_optimize():
484
+ try:
485
+ result = await gepa.run_optimization_cycle(
486
+ user_feedback_context=f"User marked query as incorrect.\nQuestion: {req.question}\nSQL: {req.sql}"
487
+ )
488
+ except Exception:
489
+ pass
490
+
491
+ return {
492
+ "received": True,
493
+ "gepa_triggered": result is not None,
494
+ "reflection": result.get("reflection") if result else None,
495
+ }
backend/api/openenv.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenEnv spec routes.
3
+
4
+ POST /env/reset β†’ Observation
5
+ POST /env/step β†’ {observation: Observation, reward: RewardInfo}
6
+ GET /env/state β†’ current episode state dict
7
+ GET /env/tasks β†’ list of task metadata
8
+ GET /env/info β†’ env metadata
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from fastapi import APIRouter, HTTPException
14
+ from pydantic import BaseModel
15
+ from typing import Optional
16
+
17
+ from env.sql_env import get_env, Observation, Action, RewardInfo
18
+ from env.tasks import get_all_tasks
19
+
20
+ router = APIRouter()
21
+
22
+
23
+ # ─── Request Models ───────────────────────────────────────────────
24
+
25
+ class ResetRequest(BaseModel):
26
+ task_id: str = "simple_queries"
27
+ question_id: Optional[str] = None
28
+
29
+
30
+ class StepRequest(BaseModel):
31
+ repair_action: str = "generate"
32
+ custom_sql: Optional[str] = None
33
+
34
+
35
+ # ─── Routes ───────────────────────────────────────────────────────
36
+
37
+ @router.post("/reset", response_model=Observation)
38
+ async def env_reset(req: ResetRequest):
39
+ """Reset the environment to start a new episode."""
40
+ env = get_env()
41
+ if req.question_id:
42
+ obs = env.reset_with_question(req.task_id, req.question_id)
43
+ else:
44
+ obs = env.reset(req.task_id)
45
+ return obs
46
+
47
+
48
+ @router.post("/step")
49
+ async def env_step(req: StepRequest):
50
+ """Execute one step in the current episode."""
51
+ env = get_env()
52
+ try:
53
+ action = Action(
54
+ repair_action=req.repair_action,
55
+ custom_sql=req.custom_sql,
56
+ )
57
+ obs, reward = await env.step(action)
58
+ return {
59
+ "observation": obs.model_dump(),
60
+ "reward": reward.model_dump(),
61
+ }
62
+ except RuntimeError as e:
63
+ raise HTTPException(status_code=400, detail=str(e))
64
+
65
+
66
+ @router.get("/state")
67
+ async def env_state():
68
+ """Get the current episode state."""
69
+ env = get_env()
70
+ return env.state()
71
+
72
+
73
+ @router.get("/tasks")
74
+ async def list_tasks():
75
+ """List all available tasks with metadata."""
76
+ tasks = get_all_tasks()
77
+ return [
78
+ {
79
+ "id": t.id,
80
+ "name": t.name,
81
+ "difficulty": t.difficulty,
82
+ "description": t.description,
83
+ "question_count": len(t.questions),
84
+ "questions": [
85
+ {
86
+ "id": q.id,
87
+ "question": q.question,
88
+ "hint_tables": q.hint_tables,
89
+ }
90
+ for q in t.questions
91
+ ],
92
+ }
93
+ for t in tasks
94
+ ]
95
+
96
+
97
+ @router.get("/info")
98
+ async def env_info():
99
+ """Return environment metadata (matches openenv.yaml spec)."""
100
+ return {
101
+ "name": "sql-agent-openenv",
102
+ "version": "1.0.0",
103
+ "description": "SQL generation and repair environment with RL-driven repair strategy selection.",
104
+ "action_space": {
105
+ "type": "discrete",
106
+ "actions": [
107
+ "generate",
108
+ "rewrite_full",
109
+ "fix_column",
110
+ "fix_table",
111
+ "add_groupby",
112
+ "rewrite_cte",
113
+ "fix_syntax",
114
+ "change_dialect",
115
+ "relax_filter",
116
+ ],
117
+ },
118
+ "observation_space": {
119
+ "type": "dict",
120
+ "fields": [
121
+ "question",
122
+ "schema_info",
123
+ "current_sql",
124
+ "error_message",
125
+ "error_class",
126
+ "attempt_number",
127
+ "max_attempts",
128
+ "task_id",
129
+ "task_difficulty",
130
+ ],
131
+ },
132
+ "reward_range": [-1.5, 1.5],
133
+ "max_steps": 5,
134
+ "tasks": ["simple_queries", "join_queries", "complex_queries"],
135
+ "rl_algorithm": "LinUCB (contextual bandit)",
136
+ "feature_dim": 20,
137
+ "num_actions": 8,
138
+ }
backend/data/.gitkeep ADDED
File without changes
backend/data/benchmark.db ADDED
Binary file (32.8 kB). View file
 
backend/env/__init__.py ADDED
File without changes
backend/env/database.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQLite database setup and schema for the benchmark marketplace.
3
+
4
+ Tables:
5
+ sellers (id, name, email, country, rating)
6
+ users (id, name, email, created_at, country)
7
+ products (id, name, category, price, stock_quantity, seller_id)
8
+ orders (id, user_id, product_id, quantity, total_price, status, created_at)
9
+ reviews (id, user_id, product_id, rating, comment, created_at)
10
+
11
+ ~50 rows per table of realistic seed data.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import os
17
+ import sqlite3
18
+ from pathlib import Path
19
+ from typing import Any
20
+
21
+ _DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data"))
22
+ DB_PATH = _DATA_DIR / "benchmark.db"
23
+
24
+
25
+ # ─── Schema ───────────────────────────────────────────────────────
26
+
27
+ _DDL = """
28
+ CREATE TABLE IF NOT EXISTS sellers (
29
+ id INTEGER PRIMARY KEY,
30
+ name TEXT NOT NULL,
31
+ email TEXT NOT NULL UNIQUE,
32
+ country TEXT NOT NULL,
33
+ rating REAL NOT NULL DEFAULT 4.0
34
+ );
35
+
36
+ CREATE TABLE IF NOT EXISTS users (
37
+ id INTEGER PRIMARY KEY,
38
+ name TEXT NOT NULL,
39
+ email TEXT NOT NULL UNIQUE,
40
+ created_at TEXT NOT NULL,
41
+ country TEXT NOT NULL
42
+ );
43
+
44
+ CREATE TABLE IF NOT EXISTS products (
45
+ id INTEGER PRIMARY KEY,
46
+ name TEXT NOT NULL,
47
+ category TEXT NOT NULL,
48
+ price REAL NOT NULL,
49
+ stock_quantity INTEGER NOT NULL DEFAULT 0,
50
+ seller_id INTEGER NOT NULL REFERENCES sellers(id)
51
+ );
52
+
53
+ CREATE TABLE IF NOT EXISTS orders (
54
+ id INTEGER PRIMARY KEY,
55
+ user_id INTEGER NOT NULL REFERENCES users(id),
56
+ product_id INTEGER NOT NULL REFERENCES products(id),
57
+ quantity INTEGER NOT NULL DEFAULT 1,
58
+ total_price REAL NOT NULL,
59
+ status TEXT NOT NULL DEFAULT 'pending',
60
+ created_at TEXT NOT NULL
61
+ );
62
+
63
+ CREATE TABLE IF NOT EXISTS reviews (
64
+ id INTEGER PRIMARY KEY,
65
+ user_id INTEGER NOT NULL REFERENCES users(id),
66
+ product_id INTEGER NOT NULL REFERENCES products(id),
67
+ rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5),
68
+ comment TEXT,
69
+ created_at TEXT NOT NULL
70
+ );
71
+ """
72
+
73
+ # ─── Seed Data ────────────────────────────────────────────────────
74
+
75
+ _SELLERS = [
76
+ (1, "TechGadgets Inc", "contact@techgadgets.com", "USA", 4.8),
77
+ (2, "FashionHub", "info@fashionhub.co.uk", "UK", 4.5),
78
+ (3, "HomeDecor Pro", "sales@homedecopro.de", "Germany", 4.3),
79
+ (4, "SportZone", "hello@sportzone.fr", "France", 4.6),
80
+ (5, "BookWorld", "support@bookworld.ca", "Canada", 4.9),
81
+ (6, "ElectroMart", "contact@electromart.jp", "Japan", 4.7),
82
+ (7, "GreenGrocer", "team@greengrocer.au", "Australia", 4.4),
83
+ (8, "KidsToys Hub", "info@kidstoys.us", "USA", 4.2),
84
+ (9, "PetSupplies Co", "hello@petsupplies.nl", "Netherlands", 4.6),
85
+ (10, "OfficeSupply Plus", "contact@officesupply.sg", "Singapore", 4.1),
86
+ ]
87
+
88
+ _USERS = [
89
+ (1, "Alice Johnson", "alice@example.com", "2023-01-15", "USA"),
90
+ (2, "Bob Smith", "bob@example.com", "2023-02-10", "UK"),
91
+ (3, "Carol White", "carol@example.com", "2023-03-05", "Canada"),
92
+ (4, "David Brown", "david@example.com", "2023-03-20", "Germany"),
93
+ (5, "Emma Davis", "emma@example.com", "2023-04-12", "France"),
94
+ (6, "Frank Miller", "frank@example.com", "2023-05-01", "Australia"),
95
+ (7, "Grace Wilson", "grace@example.com", "2023-05-18", "Japan"),
96
+ (8, "Henry Taylor", "henry@example.com", "2023-06-03", "USA"),
97
+ (9, "Isabella Anderson", "isabella@example.com", "2023-06-25", "UK"),
98
+ (10, "Jack Martinez", "jack@example.com", "2023-07-09", "Spain"),
99
+ (11, "Karen Thomas", "karen@example.com", "2023-07-22", "Italy"),
100
+ (12, "Liam Jackson", "liam@example.com", "2023-08-04", "Brazil"),
101
+ (13, "Mia Harris", "mia@example.com", "2023-08-17", "Canada"),
102
+ (14, "Noah Martin", "noah@example.com", "2023-09-01", "USA"),
103
+ (15, "Olivia Garcia", "olivia@example.com", "2023-09-14", "Mexico"),
104
+ (16, "Paul Robinson", "paul@example.com", "2023-10-02", "Australia"),
105
+ (17, "Quinn Lewis", "quinn@example.com", "2023-10-20", "New Zealand"),
106
+ (18, "Rachel Walker", "rachel@example.com", "2023-11-05", "UK"),
107
+ (19, "Sam Hall", "sam@example.com", "2023-11-19", "USA"),
108
+ (20, "Tina Allen", "tina@example.com", "2023-12-01", "Germany"),
109
+ (21, "Umar Young", "umar@example.com", "2024-01-08", "Pakistan"),
110
+ (22, "Vera Hernandez", "vera@example.com", "2024-01-22", "Spain"),
111
+ (23, "Will King", "will@example.com", "2024-02-06", "USA"),
112
+ (24, "Xena Wright", "xena@example.com", "2024-02-20", "Canada"),
113
+ (25, "Yusuf Lopez", "yusuf@example.com", "2024-03-05", "Morocco"),
114
+ (26, "Zoe Hill", "zoe@example.com", "2024-03-19", "UK"),
115
+ (27, "Aaron Scott", "aaron@example.com", "2024-04-02", "USA"),
116
+ (28, "Bella Green", "bella@example.com", "2024-04-16", "Australia"),
117
+ (29, "Carlos Adams", "carlos@example.com", "2024-05-01", "Brazil"),
118
+ (30, "Diana Baker", "diana@example.com", "2024-05-15", "Canada"),
119
+ (31, "Ethan Gonzalez", "ethan@example.com", "2024-05-29", "USA"),
120
+ (32, "Fatima Nelson", "fatima@example.com", "2024-06-12", "Nigeria"),
121
+ (33, "George Carter", "george@example.com", "2024-06-26", "UK"),
122
+ (34, "Hannah Mitchell", "hannah@example.com", "2024-07-10", "Germany"),
123
+ (35, "Ivan Perez", "ivan@example.com", "2024-07-24", "Russia"),
124
+ (36, "Julia Roberts", "juliar@example.com", "2024-08-07", "USA"),
125
+ (37, "Kevin Turner", "kevin@example.com", "2024-08-21", "Canada"),
126
+ (38, "Luna Phillips", "luna@example.com", "2024-09-04", "France"),
127
+ (39, "Mike Campbell", "mike@example.com", "2024-09-18", "USA"),
128
+ (40, "Nancy Parker", "nancy@example.com", "2024-10-02", "Japan"),
129
+ (41, "Oscar Evans", "oscar@example.com", "2024-10-16", "UK"),
130
+ (42, "Penny Edwards", "penny@example.com", "2024-10-30", "Australia"),
131
+ (43, "Roy Collins", "roy@example.com", "2024-11-13", "USA"),
132
+ (44, "Sara Stewart", "sara@example.com", "2024-11-27", "Canada"),
133
+ (45, "Tom Morris", "tom@example.com", "2024-12-11", "UK"),
134
+ (46, "Uma Rogers", "uma@example.com", "2024-12-25", "India"),
135
+ (47, "Victor Reed", "victor@example.com", "2025-01-08", "USA"),
136
+ (48, "Wendy Cook", "wendy@example.com", "2025-01-22", "Germany"),
137
+ (49, "Xavier Morgan", "xavier@example.com", "2025-02-05", "France"),
138
+ (50, "Yasmin Bell", "yasmin@example.com", "2025-02-19", "UK"),
139
+ ]
140
+
141
+ _PRODUCTS = [
142
+ (1, "Wireless Headphones Pro", "Electronics", 149.99, 120, 1),
143
+ (2, "Laptop Stand Adjustable", "Electronics", 49.99, 200, 1),
144
+ (3, "USB-C Hub 7-in-1", "Electronics", 39.99, 350, 6),
145
+ (4, "Mechanical Keyboard RGB", "Electronics", 89.99, 85, 6),
146
+ (5, "Webcam 4K Ultra", "Electronics", 129.99, 60, 1),
147
+ (6, "Summer Floral Dress", "Fashion", 59.99, 180, 2),
148
+ (7, "Men Slim Fit Chinos", "Fashion", 44.99, 220, 2),
149
+ (8, "Leather Wallet Bifold", "Fashion", 34.99, 300, 2),
150
+ (9, "Running Shoes Ultralight", "Fashion", 109.99, 95, 4),
151
+ (10, "Yoga Pants High Waist", "Fashion", 54.99, 150, 4),
152
+ (11, "Ceramic Vase Set", "Home & Garden", 79.99, 70, 3),
153
+ (12, "Bamboo Cutting Board", "Home & Garden", 29.99, 400, 3),
154
+ (13, "Scented Candle Collection", "Home & Garden", 24.99, 500, 3),
155
+ (14, "Smart LED Bulb Pack", "Home & Garden", 59.99, 250, 1),
156
+ (15, "Coffee Table Book Stand", "Home & Garden", 49.99, 130, 3),
157
+ (16, "Protein Powder Vanilla", "Sports & Fitness", 54.99, 210, 4),
158
+ (17, "Resistance Band Set", "Sports & Fitness", 24.99, 600, 4),
159
+ (18, "Yoga Mat Non-Slip", "Sports & Fitness", 39.99, 300, 4),
160
+ (19, "Tennis Racket Pro", "Sports & Fitness", 89.99, 45, 4),
161
+ (20, "Water Bottle Insulated", "Sports & Fitness", 29.99, 450, 7),
162
+ (21, "The Python Handbook", "Books", 29.99, 200, 5),
163
+ (22, "Machine Learning Basics", "Books", 34.99, 175, 5),
164
+ (23, "Data Structures Guide", "Books", 27.99, 220, 5),
165
+ (24, "Mystery Novel Collection", "Books", 49.99, 100, 5),
166
+ (25, "Children Story Box Set", "Books", 44.99, 130, 8),
167
+ (26, "Dog Bed Orthopedic", "Pet Supplies", 79.99, 90, 9),
168
+ (27, "Cat Scratching Post", "Pet Supplies", 34.99, 170, 9),
169
+ (28, "Fish Tank Starter Kit", "Pet Supplies", 59.99, 55, 9),
170
+ (29, "Bird Cage Deluxe", "Pet Supplies", 89.99, 35, 9),
171
+ (30, "Pet Grooming Kit", "Pet Supplies", 39.99, 140, 9),
172
+ (31, "LEGO City Set 600pcs", "Toys", 69.99, 80, 8),
173
+ (32, "Remote Control Car", "Toys", 49.99, 120, 8),
174
+ (33, "Board Game Strategy", "Toys", 34.99, 200, 8),
175
+ (34, "Puzzle 1000 Pieces", "Toys", 24.99, 350, 8),
176
+ (35, "Art & Craft Kit Kids", "Toys", 29.99, 280, 8),
177
+ (36, "Office Desk Organizer", "Office", 39.99, 300, 10),
178
+ (37, "Wireless Mouse Ergonomic", "Electronics", 59.99, 200, 6),
179
+ (38, "Notebook Set Premium", "Office", 19.99, 600, 10),
180
+ (39, "Sticky Notes Colorful", "Office", 9.99, 800, 10),
181
+ (40, "Printer Paper Ream", "Office", 14.99, 500, 10),
182
+ (41, "Smart Watch Fitness", "Electronics", 199.99, 75, 1),
183
+ (42, "Blender High Power", "Home & Garden", 89.99, 110, 3),
184
+ (43, "Air Purifier HEPA", "Home & Garden", 149.99, 65, 1),
185
+ (44, "Backpack Waterproof", "Fashion", 79.99, 160, 2),
186
+ (45, "Sunglasses Polarized", "Fashion", 69.99, 200, 2),
187
+ (46, "Dumbbells Set 20kg", "Sports & Fitness", 79.99, 85, 4),
188
+ (47, "Jump Rope Speed", "Sports & Fitness", 19.99, 400, 4),
189
+ (48, "Graphic Novel Bundle", "Books", 59.99, 90, 5),
190
+ (49, "Phone Stand Adjustable", "Electronics", 24.99, 350, 6),
191
+ (50, "Desk Lamp LED", "Office", 44.99, 230, 10),
192
+ ]
193
+
194
+ _ORDERS = [
195
+ (1, 1, 1, 1, 149.99, "delivered", "2024-01-10"),
196
+ (2, 2, 6, 2, 119.98, "delivered", "2024-01-15"),
197
+ (3, 3, 21, 1, 29.99, "delivered", "2024-01-20"),
198
+ (4, 4, 11, 1, 79.99, "delivered", "2024-01-25"),
199
+ (5, 5, 16, 2, 109.98, "delivered", "2024-02-01"),
200
+ (6, 6, 31, 1, 69.99, "delivered", "2024-02-05"),
201
+ (7, 7, 3, 2, 79.98, "shipped", "2024-02-10"),
202
+ (8, 8, 41, 1, 199.99, "delivered", "2024-02-14"),
203
+ (9, 9, 26, 1, 79.99, "delivered", "2024-02-18"),
204
+ (10, 10, 17, 3, 74.97, "delivered", "2024-02-22"),
205
+ (11, 11, 22, 1, 34.99, "delivered", "2024-03-01"),
206
+ (12, 12, 7, 1, 44.99, "delivered", "2024-03-05"),
207
+ (13, 13, 18, 2, 79.98, "delivered", "2024-03-10"),
208
+ (14, 14, 37, 1, 59.99, "shipped", "2024-03-14"),
209
+ (15, 15, 44, 1, 79.99, "delivered", "2024-03-18"),
210
+ (16, 16, 2, 1, 49.99, "delivered", "2024-03-22"),
211
+ (17, 17, 50, 1, 44.99, "pending", "2024-03-26"),
212
+ (18, 18, 5, 1, 129.99, "delivered", "2024-04-01"),
213
+ (19, 19, 12, 2, 59.98, "delivered", "2024-04-05"),
214
+ (20, 20, 33, 1, 34.99, "delivered", "2024-04-09"),
215
+ (21, 21, 9, 1, 109.99, "delivered", "2024-04-13"),
216
+ (22, 22, 14, 2, 119.98, "delivered", "2024-04-17"),
217
+ (23, 23, 43, 1, 149.99, "shipped", "2024-04-21"),
218
+ (24, 24, 25, 1, 44.99, "delivered", "2024-04-25"),
219
+ (25, 25, 8, 2, 69.98, "delivered", "2024-04-29"),
220
+ (26, 26, 4, 1, 89.99, "delivered", "2024-05-03"),
221
+ (27, 27, 29, 1, 89.99, "delivered", "2024-05-07"),
222
+ (28, 28, 20, 3, 89.97, "delivered", "2024-05-11"),
223
+ (29, 29, 35, 2, 59.98, "delivered", "2024-05-15"),
224
+ (30, 30, 46, 1, 79.99, "pending", "2024-05-19"),
225
+ (31, 31, 13, 5, 124.95, "delivered", "2024-05-23"),
226
+ (32, 32, 36, 2, 79.98, "delivered", "2024-05-27"),
227
+ (33, 33, 48, 1, 59.99, "delivered", "2024-05-31"),
228
+ (34, 34, 1, 1, 149.99, "delivered", "2024-06-04"),
229
+ (35, 35, 24, 1, 49.99, "delivered", "2024-06-08"),
230
+ (36, 36, 10, 2, 109.98, "shipped", "2024-06-12"),
231
+ (37, 37, 42, 1, 89.99, "delivered", "2024-06-16"),
232
+ (38, 38, 27, 1, 34.99, "delivered", "2024-06-20"),
233
+ (39, 39, 6, 1, 59.99, "delivered", "2024-06-24"),
234
+ (40, 40, 41, 1, 199.99, "delivered", "2024-06-28"),
235
+ (41, 41, 19, 1, 89.99, "cancelled", "2024-07-02"),
236
+ (42, 42, 34, 2, 49.98, "delivered", "2024-07-06"),
237
+ (43, 43, 23, 1, 27.99, "delivered", "2024-07-10"),
238
+ (44, 44, 47, 3, 59.97, "delivered", "2024-07-14"),
239
+ (45, 45, 15, 1, 49.99, "delivered", "2024-07-18"),
240
+ (46, 46, 32, 1, 49.99, "delivered", "2024-07-22"),
241
+ (47, 47, 3, 1, 39.99, "pending", "2024-07-26"),
242
+ (48, 48, 28, 1, 59.99, "delivered", "2024-07-30"),
243
+ (49, 49, 39, 10, 99.90, "delivered", "2024-08-03"),
244
+ (50, 50, 21, 2, 59.98, "delivered", "2024-08-07"),
245
+ ]
246
+
247
+ _REVIEWS = [
248
+ (1, 1, 1, 5, "Excellent headphones, crystal clear sound!", "2024-01-15"),
249
+ (2, 2, 6, 4, "Beautiful dress, fits perfectly.", "2024-01-20"),
250
+ (3, 3, 21, 5, "Best Python book for beginners.", "2024-01-25"),
251
+ (4, 4, 11, 4, "Very elegant vase set.", "2024-01-30"),
252
+ (5, 5, 16, 3, "Decent protein powder, average taste.", "2024-02-05"),
253
+ (6, 6, 31, 5, "My kid loves this LEGO set!", "2024-02-10"),
254
+ (7, 7, 3, 5, "Incredibly useful USB hub.", "2024-02-15"),
255
+ (8, 8, 41, 5, "Smart watch exceeded expectations.", "2024-02-20"),
256
+ (9, 9, 26, 4, "Dog loves the orthopedic bed.", "2024-02-25"),
257
+ (10, 10, 17, 5, "Great resistance bands, very durable.", "2024-03-01"),
258
+ (11, 11, 22, 4, "Solid ML intro book.", "2024-03-06"),
259
+ (12, 12, 7, 3, "Chinos are OK, sizing runs small.", "2024-03-11"),
260
+ (13, 13, 18, 5, "Perfect yoga mat, non-slip is great.", "2024-03-16"),
261
+ (14, 14, 37, 4, "Smooth wireless mouse.", "2024-03-21"),
262
+ (15, 15, 44, 5, "Waterproof backpack is amazing.", "2024-03-26"),
263
+ (16, 16, 2, 4, "Laptop stand is sturdy and adjustable.", "2024-03-31"),
264
+ (17, 17, 49, 3, "Decent phone stand but wobbly.", "2024-04-05"),
265
+ (18, 18, 5, 5, "Best webcam I've ever used.", "2024-04-10"),
266
+ (19, 19, 12, 5, "Bamboo cutting board is beautiful.", "2024-04-15"),
267
+ (20, 20, 33, 4, "Fun strategy board game.", "2024-04-20"),
268
+ (21, 21, 9, 5, "Running shoes are so comfortable!", "2024-04-25"),
269
+ (22, 22, 14, 4, "Smart bulbs work well with app.", "2024-04-30"),
270
+ (23, 23, 43, 4, "Air purifier is quiet and effective.", "2024-05-05"),
271
+ (24, 24, 25, 5, "Beautiful story box set for kids.", "2024-05-10"),
272
+ (25, 25, 8, 4, "Leather wallet is high quality.", "2024-05-15"),
273
+ (26, 26, 4, 5, "Mechanical keyboard is a joy to type on.", "2024-05-20"),
274
+ (27, 27, 29, 4, "Bird cage is spacious and well-made.", "2024-05-25"),
275
+ (28, 28, 20, 5, "Water bottle keeps drinks cold all day.", "2024-05-30"),
276
+ (29, 29, 35, 4, "Great art kit for kids.", "2024-06-04"),
277
+ (30, 30, 46, 4, "Solid dumbbells, good grip.", "2024-06-09"),
278
+ (31, 1, 13, 5, "Scented candles smell amazing.", "2024-06-14"),
279
+ (32, 2, 36, 4, "Desk organizer keeps my workspace tidy.", "2024-06-19"),
280
+ (33, 3, 48, 5, "Graphic novel bundle is worth every penny.", "2024-06-24"),
281
+ (34, 4, 1, 4, "Good headphones, comfy for long sessions.", "2024-06-29"),
282
+ (35, 5, 24, 5, "Love these mystery novels!", "2024-07-04"),
283
+ (36, 6, 10, 4, "High waist yoga pants are flattering.", "2024-07-09"),
284
+ (37, 7, 42, 4, "Powerful blender, handles frozen fruit.", "2024-07-14"),
285
+ (38, 8, 27, 5, "Cat scratching post is well built.", "2024-07-19"),
286
+ (39, 9, 6, 4, "Floral dress is as pictured.", "2024-07-24"),
287
+ (40, 10, 41, 5, "Smart watch has excellent battery life.", "2024-07-29"),
288
+ (41, 11, 19, 2, "Tennis racket feels cheap for the price.", "2024-08-03"),
289
+ (42, 12, 34, 5, "Puzzle is a perfect family activity.", "2024-08-08"),
290
+ (43, 13, 23, 5, "Data structures book is very clear.", "2024-08-13"),
291
+ (44, 14, 47, 4, "Jump rope is fast and durable.", "2024-08-18"),
292
+ (45, 15, 15, 3, "Book stand is okay, a bit light.", "2024-08-23"),
293
+ (46, 16, 32, 5, "Remote control car is very fast!", "2024-08-28"),
294
+ (47, 17, 3, 4, "USB hub works great on MacBook.", "2024-09-02"),
295
+ (48, 18, 28, 4, "Fish tank kit is easy to set up.", "2024-09-07"),
296
+ (49, 19, 38, 5, "Premium notebook has great paper.", "2024-09-12"),
297
+ (50, 20, 21, 5, "Python handbook is my go-to reference.", "2024-09-17"),
298
+ ]
299
+
300
+
301
+ # ─── Public API ───────────────────────────────────────────────────
302
+
303
+ def get_db_path() -> Path:
304
+ return DB_PATH
305
+
306
+
307
+ def ensure_seeded() -> bool:
308
+ """
309
+ Create the database and populate seed data if not already done.
310
+ Returns True if seed was needed (first run), False if already seeded.
311
+ """
312
+ _DATA_DIR.mkdir(parents=True, exist_ok=True)
313
+ conn = sqlite3.connect(str(DB_PATH))
314
+ try:
315
+ conn.executescript(_DDL)
316
+ conn.commit()
317
+
318
+ count = conn.execute("SELECT COUNT(*) FROM users").fetchone()[0]
319
+ if count >= 50:
320
+ return False # Already seeded
321
+
322
+ conn.execute("DELETE FROM reviews")
323
+ conn.execute("DELETE FROM orders")
324
+ conn.execute("DELETE FROM products")
325
+ conn.execute("DELETE FROM users")
326
+ conn.execute("DELETE FROM sellers")
327
+
328
+ conn.executemany(
329
+ "INSERT OR REPLACE INTO sellers VALUES (?,?,?,?,?)", _SELLERS
330
+ )
331
+ conn.executemany(
332
+ "INSERT OR REPLACE INTO users VALUES (?,?,?,?,?)", _USERS
333
+ )
334
+ conn.executemany(
335
+ "INSERT OR REPLACE INTO products VALUES (?,?,?,?,?,?)", _PRODUCTS
336
+ )
337
+ conn.executemany(
338
+ "INSERT OR REPLACE INTO orders VALUES (?,?,?,?,?,?,?)", _ORDERS
339
+ )
340
+ conn.executemany(
341
+ "INSERT OR REPLACE INTO reviews VALUES (?,?,?,?,?,?)", _REVIEWS
342
+ )
343
+ conn.commit()
344
+ return True
345
+ finally:
346
+ conn.close()
347
+
348
+
349
+ def get_schema_info() -> str:
350
+ """
351
+ Return a concise textual schema summary for use in prompts.
352
+ """
353
+ conn = sqlite3.connect(str(DB_PATH))
354
+ try:
355
+ lines = []
356
+ for table in ["sellers", "users", "products", "orders", "reviews"]:
357
+ info = conn.execute(f"PRAGMA table_info({table})").fetchall()
358
+ cols = ", ".join(
359
+ f"{col[1]} {col[2]}{'(PK)' if col[5] else ''}"
360
+ for col in info
361
+ )
362
+ row_count = conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0]
363
+ lines.append(f"Table: {table} ({row_count} rows)\n Columns: {cols}")
364
+ return "\n\n".join(lines)
365
+ finally:
366
+ conn.close()
367
+
368
+
369
+ def execute_query(sql: str) -> tuple[list[dict], str | None]:
370
+ """
371
+ Execute a SQL query and return (rows, error_message).
372
+ rows is a list of dicts; error_message is None on success.
373
+ """
374
+ conn = sqlite3.connect(str(DB_PATH))
375
+ conn.row_factory = sqlite3.Row
376
+ try:
377
+ cursor = conn.execute(sql)
378
+ rows = [dict(row) for row in cursor.fetchall()]
379
+ return rows, None
380
+ except sqlite3.Error as e:
381
+ return [], str(e)
382
+ finally:
383
+ conn.close()
384
+
385
+
386
+ def get_table_stats() -> list[dict]:
387
+ """Return [{name, rows}, ...] for all tables."""
388
+ conn = sqlite3.connect(str(DB_PATH))
389
+ try:
390
+ tables = ["sellers", "users", "products", "orders", "reviews"]
391
+ return [
392
+ {
393
+ "name": t,
394
+ "rows": conn.execute(f"SELECT COUNT(*) FROM {t}").fetchone()[0],
395
+ }
396
+ for t in tables
397
+ ]
398
+ finally:
399
+ conn.close()
400
+
401
+
402
+ def get_schema_graph() -> dict:
403
+ """Return schema graph with tables, columns, and foreign keys."""
404
+ conn = sqlite3.connect(str(DB_PATH))
405
+ try:
406
+ tables = []
407
+ for table in ["sellers", "users", "products", "orders", "reviews"]:
408
+ info = conn.execute(f"PRAGMA table_info({table})").fetchall()
409
+ columns = [
410
+ {"name": col[1], "type": col[2], "pk": bool(col[5])}
411
+ for col in info
412
+ ]
413
+ tables.append({"name": table, "columns": columns})
414
+
415
+ foreign_keys = []
416
+ for table in ["sellers", "users", "products", "orders", "reviews"]:
417
+ fks = conn.execute(f"PRAGMA foreign_key_list({table})").fetchall()
418
+ for fk in fks:
419
+ foreign_keys.append(
420
+ {
421
+ "from_table": table,
422
+ "from_col": fk[3],
423
+ "to_table": fk[2],
424
+ "to_col": fk[4],
425
+ }
426
+ )
427
+
428
+ return {"tables": tables, "foreign_keys": foreign_keys}
429
+ finally:
430
+ conn.close()
backend/env/sql_env.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQLAgentEnv β€” OpenEnv-compliant environment for SQL generation.
3
+
4
+ Observation β†’ Action β†’ (Observation, Reward) loop.
5
+
6
+ The step() function:
7
+ 1. Selects a repair prompt based on action.repair_action
8
+ 2. Calls the LLM (OpenAI-compatible) to generate/repair SQL
9
+ 3. Executes SQL on the benchmark DB
10
+ 4. Classifies any error
11
+ 5. Computes reward via grader
12
+ 6. Updates LinUCB bandit
13
+ 7. Returns (new_observation, reward)
14
+
15
+ Environment variables:
16
+ API_BASE_URL β€” OpenAI-compatible base URL (default: https://api.openai.com/v1)
17
+ MODEL_NAME β€” model to use (default: gpt-4o-mini)
18
+ HF_TOKEN β€” bearer token / API key
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import asyncio
24
+ import os
25
+ import re
26
+ from typing import Optional, AsyncIterator
27
+
28
+ from openai import AsyncOpenAI
29
+ from pydantic import BaseModel
30
+
31
+ from env.database import ensure_seeded, get_schema_info, execute_query
32
+ from env.tasks import get_task, get_all_tasks, TASKS
33
+ from rl.types import RepairAction, REPAIR_ACTION_NAMES, REPAIR_ACTION_BY_NAME
34
+ from rl.error_classifier import classify_error, extract_offending_token
35
+ from rl.grader import GraderInput, compute_reward, compute_episode_reward
36
+ from rl.linucb import LinUCB
37
+ from rl.repair_strategies import RepairContext, get_repair_system_suffix, build_repair_user_message
38
+ from rl.experience import record_episode
39
+ from rl.types import RLState, EpisodeStep, featurize, ERROR_CLASS_NAMES
40
+
41
+ # ─── OpenEnv Models ──────────────────────────────────────────────
42
+
43
+
44
+ class Observation(BaseModel):
45
+ question: str
46
+ schema_info: str
47
+ current_sql: Optional[str] = None
48
+ error_message: Optional[str] = None
49
+ error_class: Optional[str] = None
50
+ attempt_number: int = 0
51
+ max_attempts: int = 5
52
+ task_id: str
53
+ task_difficulty: str
54
+
55
+
56
+ class Action(BaseModel):
57
+ repair_action: str # one of 8 repair action names or "generate"
58
+ custom_sql: Optional[str] = None # optional direct SQL override
59
+
60
+
61
+ class RewardInfo(BaseModel):
62
+ value: float
63
+ success: bool
64
+ done: bool
65
+ info: dict
66
+
67
+
68
+ # ─── LLM Client ──────────────────────────────────────────────────
69
+
70
+ def _make_client() -> AsyncOpenAI:
71
+ return AsyncOpenAI(
72
+ api_key=os.environ.get("HF_TOKEN", ""),
73
+ base_url=os.environ.get("API_BASE_URL", "https://api.openai.com/v1"),
74
+ )
75
+
76
+
77
+ _MODEL = os.environ.get("MODEL_NAME", "gpt-4o-mini")
78
+
79
+ BASE_SYSTEM_PROMPT = """You are a SQL expert. Given a natural language question and a SQLite database schema, write a correct SQL query.
80
+
81
+ Rules:
82
+ - Output ONLY the SQL query, nothing else
83
+ - No markdown, no code fences, no explanation
84
+ - Use SQLite syntax
85
+ - Do not include semicolons at the end"""
86
+
87
+
88
+ def _clean_sql(raw: str) -> str:
89
+ """Strip markdown code fences and extra whitespace."""
90
+ raw = raw.strip()
91
+ raw = re.sub(r"^```(?:sql)?\s*", "", raw, flags=re.IGNORECASE)
92
+ raw = re.sub(r"\s*```$", "", raw)
93
+ return raw.strip().rstrip(";")
94
+
95
+
96
+ async def _call_llm(
97
+ system_prompt: str,
98
+ user_message: str,
99
+ stream: bool = False,
100
+ ) -> AsyncIterator[str] | str:
101
+ """Call the LLM and return the generated text."""
102
+ client = _make_client()
103
+
104
+ if stream:
105
+ async def _gen():
106
+ resp = await client.chat.completions.create(
107
+ model=_MODEL,
108
+ messages=[
109
+ {"role": "system", "content": system_prompt},
110
+ {"role": "user", "content": user_message},
111
+ ],
112
+ stream=True,
113
+ temperature=0.1,
114
+ )
115
+ async for chunk in resp:
116
+ delta = chunk.choices[0].delta.content
117
+ if delta:
118
+ yield delta
119
+ return _gen()
120
+ else:
121
+ resp = await client.chat.completions.create(
122
+ model=_MODEL,
123
+ messages=[
124
+ {"role": "system", "content": system_prompt},
125
+ {"role": "user", "content": user_message},
126
+ ],
127
+ temperature=0.1,
128
+ )
129
+ return resp.choices[0].message.content or ""
130
+
131
+
132
+ # ─── Episode State ────────────────────────────────────────────────
133
+
134
+ class _Episode:
135
+ def __init__(self, task_id: str, question_id: str, question: str) -> None:
136
+ self.task_id = task_id
137
+ self.question_id = question_id
138
+ self.question = question
139
+ self.attempt_number = 0
140
+ self.current_sql: Optional[str] = None
141
+ self.error_message: Optional[str] = None
142
+ self.error_class: Optional[str] = None
143
+ self.steps: list[EpisodeStep] = []
144
+ self.step_rewards: list[float] = []
145
+ self.previous_error_class = None
146
+ self.consecutive_same_error = 0
147
+ self.last_action: Optional[RepairAction] = None
148
+ self.current_rl_state: Optional[RLState] = None
149
+ self.current_features: Optional[list[float]] = None
150
+ self.done = False
151
+ self.success = False
152
+
153
+
154
+ # ─── Main Environment Class ───────────────────────────────────────
155
+
156
+ class SQLAgentEnv:
157
+ """
158
+ OpenEnv-compliant environment for SQL generation and repair.
159
+ One active episode at a time.
160
+ """
161
+
162
+ MAX_ATTEMPTS = 5
163
+
164
+ def __init__(self) -> None:
165
+ ensure_seeded()
166
+ self._bandit = LinUCB()
167
+ self._episode: Optional[_Episode] = None
168
+ self._schema_info = get_schema_info()
169
+
170
+ def reset(self, task_id: str = "simple_queries") -> Observation:
171
+ """Start a new episode, picking the first question of the task."""
172
+ if self._episode and self._episode.steps and not self._episode.done:
173
+ self._finalize_episode(success=False)
174
+
175
+ task = get_task(task_id)
176
+ question_obj = task.questions[0]
177
+
178
+ self._episode = _Episode(
179
+ task_id=task_id,
180
+ question_id=question_obj.id,
181
+ question=question_obj.question,
182
+ )
183
+
184
+ return self._build_observation()
185
+
186
+ def reset_with_question(
187
+ self, task_id: str, question_id: str
188
+ ) -> Observation:
189
+ """Start a new episode for a specific question."""
190
+ if self._episode and self._episode.steps and not self._episode.done:
191
+ self._finalize_episode(success=False)
192
+
193
+ task = get_task(task_id)
194
+ question_obj = next(
195
+ (q for q in task.questions if q.id == question_id), task.questions[0]
196
+ )
197
+
198
+ self._episode = _Episode(
199
+ task_id=task_id,
200
+ question_id=question_obj.id,
201
+ question=question_obj.question,
202
+ )
203
+ return self._build_observation()
204
+
205
+ async def step(self, action: Action) -> tuple[Observation, RewardInfo]:
206
+ """
207
+ Execute one step:
208
+ 1. Generate/repair SQL via LLM
209
+ 2. Execute SQL
210
+ 3. Grade and reward
211
+ 4. Update bandit
212
+ """
213
+ if self._episode is None:
214
+ raise RuntimeError("Call reset() before step()")
215
+ if self._episode.done:
216
+ raise RuntimeError("Episode is done. Call reset() to start a new one.")
217
+
218
+ ep = self._episode
219
+ ep.attempt_number += 1
220
+
221
+ # ── 1. Build prompt ──────────────────────────────────────
222
+ if action.custom_sql:
223
+ generated_sql = action.custom_sql
224
+ else:
225
+ generated_sql = await self._generate_sql(action, ep)
226
+
227
+ generated_sql = _clean_sql(generated_sql)
228
+
229
+ # ── 2. Execute SQL ───────────────────────────────────────
230
+ rows, error = execute_query(generated_sql)
231
+ success = error is None and len(rows) > 0
232
+
233
+ # ── 3. Grade ─────────────────────────────────────────────
234
+ task = get_task(ep.task_id)
235
+ question_obj = next(q for q in task.questions if q.id == ep.question_id)
236
+
237
+ from env.tasks import grade_response
238
+ task_score = grade_response(
239
+ ep.task_id, ep.question_id, generated_sql, rows, error, ep.attempt_number
240
+ )
241
+ success = task_score >= 0.8
242
+
243
+ # ── 4. RL state + reward ─────────────────────────────────
244
+ current_error_class = None
245
+ error_class_name = None
246
+ if error:
247
+ ec = classify_error(error)
248
+ current_error_class = ec
249
+ error_class_name = ERROR_CLASS_NAMES[ec]
250
+
251
+ error_changed = (
252
+ ep.previous_error_class is not None
253
+ and ep.previous_error_class != current_error_class
254
+ )
255
+
256
+ if ep.previous_error_class == current_error_class:
257
+ ep.consecutive_same_error += 1
258
+ else:
259
+ ep.consecutive_same_error = 1
260
+
261
+ rl_state = RLState(
262
+ error_class=current_error_class,
263
+ attempt_number=ep.attempt_number,
264
+ previous_action=ep.last_action,
265
+ error_changed=error_changed,
266
+ consecutive_same_error=ep.consecutive_same_error,
267
+ )
268
+ ep.current_rl_state = rl_state
269
+ ep.current_features = featurize(rl_state)
270
+
271
+ grader_in = GraderInput(
272
+ success=success,
273
+ attempt_number=ep.attempt_number,
274
+ current_error_class=current_error_class,
275
+ previous_error_class=ep.previous_error_class,
276
+ )
277
+ grader_out = compute_reward(grader_in)
278
+
279
+ if ep.current_rl_state and ep.current_features:
280
+ # Determine action index
281
+ if action.repair_action == "generate":
282
+ repair_action_enum = RepairAction.REWRITE_FULL
283
+ else:
284
+ repair_action_enum = REPAIR_ACTION_BY_NAME.get(
285
+ action.repair_action, RepairAction.REWRITE_FULL
286
+ )
287
+
288
+ step_obj = EpisodeStep(
289
+ state=ep.current_rl_state,
290
+ featurized=ep.current_features,
291
+ action=repair_action_enum,
292
+ reward=grader_out.reward,
293
+ error_message=error or "",
294
+ sql=generated_sql,
295
+ success=success,
296
+ )
297
+ ep.steps.append(step_obj)
298
+
299
+ ep.step_rewards.append(grader_out.reward)
300
+ ep.current_sql = generated_sql
301
+ ep.error_message = error
302
+ ep.error_class = error_class_name
303
+ ep.previous_error_class = current_error_class
304
+
305
+ # ── 5. Done check ────────────────────────────────────────
306
+ done = success or ep.attempt_number >= self.MAX_ATTEMPTS
307
+
308
+ if done:
309
+ self._finalize_episode(success=success)
310
+ ep.done = True
311
+ ep.success = success
312
+
313
+ obs = self._build_observation()
314
+ reward_info = RewardInfo(
315
+ value=grader_out.reward,
316
+ success=success,
317
+ done=done,
318
+ info={
319
+ "task_score": task_score,
320
+ "attempt": ep.attempt_number,
321
+ "breakdown": {
322
+ "base": grader_out.breakdown.base,
323
+ "attempt_penalty": grader_out.breakdown.attempt_penalty,
324
+ "severity_bonus": grader_out.breakdown.severity_bonus,
325
+ "change_bonus": grader_out.breakdown.change_bonus,
326
+ },
327
+ "rows": rows[:5] if rows else [],
328
+ "row_count": len(rows),
329
+ "sql": generated_sql,
330
+ },
331
+ )
332
+
333
+ return obs, reward_info
334
+
335
+ async def step_streaming(
336
+ self, action: Action
337
+ ) -> AsyncIterator[dict]:
338
+ """
339
+ Step with SSE-compatible event streaming.
340
+ Yields dicts representing stream events.
341
+ """
342
+ if self._episode is None:
343
+ raise RuntimeError("Call reset() before step_streaming()")
344
+
345
+ ep = self._episode
346
+ ep.attempt_number += 1
347
+
348
+ yield {"type": "attempt_start", "attempt": ep.attempt_number}
349
+
350
+ # Generate SQL
351
+ if action.custom_sql:
352
+ generated_sql = action.custom_sql
353
+ yield {"type": "sql_complete", "sql": generated_sql}
354
+ else:
355
+ chunks = []
356
+ async for chunk in await self._generate_sql_streaming(action, ep):
357
+ chunks.append(chunk)
358
+ yield {"type": "sql_chunk", "chunk": chunk}
359
+ generated_sql = _clean_sql("".join(chunks))
360
+ yield {"type": "sql_complete", "sql": generated_sql}
361
+
362
+ yield {"type": "executing"}
363
+
364
+ rows, error = execute_query(generated_sql)
365
+
366
+ from env.tasks import grade_response
367
+ task_score = grade_response(
368
+ ep.task_id, ep.question_id, generated_sql, rows, error, ep.attempt_number
369
+ )
370
+ success = task_score >= 0.8
371
+
372
+ # RL processing
373
+ current_error_class = None
374
+ error_class_name = None
375
+ repair_action_enum = RepairAction.REWRITE_FULL
376
+
377
+ if action.repair_action != "generate":
378
+ repair_action_enum = REPAIR_ACTION_BY_NAME.get(
379
+ action.repair_action, RepairAction.REWRITE_FULL
380
+ )
381
+
382
+ if error:
383
+ ec = classify_error(error)
384
+ current_error_class = ec
385
+ error_class_name = ERROR_CLASS_NAMES[ec]
386
+
387
+ error_changed = (
388
+ ep.previous_error_class is not None
389
+ and ep.previous_error_class != current_error_class
390
+ )
391
+ if ep.previous_error_class == current_error_class:
392
+ ep.consecutive_same_error += 1
393
+ else:
394
+ ep.consecutive_same_error = 1
395
+
396
+ rl_state = RLState(
397
+ error_class=current_error_class,
398
+ attempt_number=ep.attempt_number,
399
+ previous_action=ep.last_action,
400
+ error_changed=error_changed,
401
+ consecutive_same_error=ep.consecutive_same_error,
402
+ )
403
+ ep.current_rl_state = rl_state
404
+ ep.current_features = featurize(rl_state)
405
+
406
+ _, scores = self._bandit.select_action(ep.current_features)
407
+ ucb_scores = {
408
+ REPAIR_ACTION_NAMES[RepairAction(i)]: round(scores[i], 4)
409
+ for i in range(len(scores))
410
+ }
411
+ yield {
412
+ "type": "rl_action",
413
+ "action": REPAIR_ACTION_NAMES[repair_action_enum],
414
+ "ucb_scores": ucb_scores,
415
+ }
416
+
417
+ yield {"type": "error", "error": error, "error_class": error_class_name}
418
+
419
+ grader_in = GraderInput(
420
+ success=success,
421
+ attempt_number=ep.attempt_number,
422
+ current_error_class=current_error_class,
423
+ previous_error_class=ep.previous_error_class,
424
+ )
425
+ grader_out = compute_reward(grader_in)
426
+
427
+ if ep.current_rl_state and ep.current_features:
428
+ step_obj = EpisodeStep(
429
+ state=ep.current_rl_state,
430
+ featurized=ep.current_features,
431
+ action=repair_action_enum,
432
+ reward=grader_out.reward,
433
+ error_message=error or "",
434
+ sql=generated_sql,
435
+ success=success,
436
+ )
437
+ ep.steps.append(step_obj)
438
+ self._bandit.update(ep.current_features, repair_action_enum, grader_out.reward)
439
+
440
+ ep.step_rewards.append(grader_out.reward)
441
+ ep.current_sql = generated_sql
442
+ ep.error_message = error
443
+ ep.error_class = error_class_name
444
+ ep.previous_error_class = current_error_class
445
+
446
+ yield {
447
+ "type": "rl_reward",
448
+ "reward": grader_out.reward,
449
+ "breakdown": {
450
+ "base": grader_out.breakdown.base,
451
+ "attempt_penalty": grader_out.breakdown.attempt_penalty,
452
+ "severity_bonus": grader_out.breakdown.severity_bonus,
453
+ "change_bonus": grader_out.breakdown.change_bonus,
454
+ },
455
+ }
456
+
457
+ done = success or ep.attempt_number >= self.MAX_ATTEMPTS
458
+
459
+ if success:
460
+ yield {
461
+ "type": "success",
462
+ "rows": rows,
463
+ "row_count": len(rows),
464
+ "sql": generated_sql,
465
+ }
466
+
467
+ if done:
468
+ total_reward = compute_episode_reward(ep.step_rewards, success)
469
+ self._finalize_episode(success=success)
470
+ ep.done = True
471
+ ep.success = success
472
+ yield {
473
+ "type": "rl_episode_end",
474
+ "total_reward": total_reward,
475
+ "success": success,
476
+ }
477
+
478
+ def state(self) -> dict:
479
+ if self._episode is None:
480
+ return {"active": False}
481
+ ep = self._episode
482
+ return {
483
+ "active": True,
484
+ "task_id": ep.task_id,
485
+ "question_id": ep.question_id,
486
+ "question": ep.question,
487
+ "attempt_number": ep.attempt_number,
488
+ "max_attempts": self.MAX_ATTEMPTS,
489
+ "current_sql": ep.current_sql,
490
+ "error_message": ep.error_message,
491
+ "error_class": ep.error_class,
492
+ "done": ep.done,
493
+ "success": ep.success,
494
+ "step_rewards": ep.step_rewards,
495
+ "total_reward": compute_episode_reward(ep.step_rewards, ep.success),
496
+ }
497
+
498
+ # ─── Private Helpers ──────────────────────────────────────────
499
+
500
+ def _build_observation(self) -> Observation:
501
+ if self._episode is None:
502
+ raise RuntimeError("No active episode")
503
+ ep = self._episode
504
+ task = get_task(ep.task_id)
505
+ return Observation(
506
+ question=ep.question,
507
+ schema_info=self._schema_info,
508
+ current_sql=ep.current_sql,
509
+ error_message=ep.error_message,
510
+ error_class=ep.error_class,
511
+ attempt_number=ep.attempt_number,
512
+ max_attempts=self.MAX_ATTEMPTS,
513
+ task_id=ep.task_id,
514
+ task_difficulty=task.difficulty,
515
+ )
516
+
517
+ async def _generate_sql(self, action: Action, ep: _Episode) -> str:
518
+ if action.repair_action == "generate" or ep.current_sql is None:
519
+ system = BASE_SYSTEM_PROMPT
520
+ user = (
521
+ f"Schema:\n{self._schema_info}\n\n"
522
+ f"Question: {ep.question}\n\n"
523
+ "Write a SQL query to answer this question."
524
+ )
525
+ else:
526
+ repair_action_enum = REPAIR_ACTION_BY_NAME.get(
527
+ action.repair_action, RepairAction.REWRITE_FULL
528
+ )
529
+ suffix = get_repair_system_suffix(repair_action_enum)
530
+ offending_token = extract_offending_token(ep.error_message or "")
531
+ ctx = RepairContext(
532
+ schema=self._schema_info,
533
+ question=ep.question,
534
+ failing_sql=ep.current_sql or "",
535
+ error_message=ep.error_message or "",
536
+ offending_token=offending_token,
537
+ )
538
+ system = BASE_SYSTEM_PROMPT + suffix
539
+ user = build_repair_user_message(repair_action_enum, ctx)
540
+
541
+ result = await _call_llm(system, user, stream=False)
542
+ return result # type: ignore[return-value]
543
+
544
+ async def _generate_sql_streaming(
545
+ self, action: Action, ep: _Episode
546
+ ) -> AsyncIterator[str]:
547
+ if action.repair_action == "generate" or ep.current_sql is None:
548
+ system = BASE_SYSTEM_PROMPT
549
+ user = (
550
+ f"Schema:\n{self._schema_info}\n\n"
551
+ f"Question: {ep.question}\n\n"
552
+ "Write a SQL query to answer this question."
553
+ )
554
+ else:
555
+ repair_action_enum = REPAIR_ACTION_BY_NAME.get(
556
+ action.repair_action, RepairAction.REWRITE_FULL
557
+ )
558
+ suffix = get_repair_system_suffix(repair_action_enum)
559
+ offending_token = extract_offending_token(ep.error_message or "")
560
+ ctx = RepairContext(
561
+ schema=self._schema_info,
562
+ question=ep.question,
563
+ failing_sql=ep.current_sql or "",
564
+ error_message=ep.error_message or "",
565
+ offending_token=offending_token,
566
+ )
567
+ system = BASE_SYSTEM_PROMPT + suffix
568
+ user = build_repair_user_message(repair_action_enum, ctx)
569
+
570
+ return await _call_llm(system, user, stream=True) # type: ignore[return-value]
571
+
572
+ def _finalize_episode(self, success: bool) -> None:
573
+ ep = self._episode
574
+ if ep is None or not ep.steps:
575
+ return
576
+ try:
577
+ episode_obj, relabeled = record_episode(ep.question, ep.steps, success)
578
+ for exp in relabeled:
579
+ self._bandit.update(exp.state, exp.action, exp.reward)
580
+ self._bandit.decay_alpha()
581
+ except Exception:
582
+ pass
583
+
584
+
585
+ # ─── Singleton instance ───────────────────────────────────────────
586
+
587
+ _env_instance: Optional[SQLAgentEnv] = None
588
+
589
+
590
+ def get_env() -> SQLAgentEnv:
591
+ global _env_instance
592
+ if _env_instance is None:
593
+ _env_instance = SQLAgentEnv()
594
+ return _env_instance
backend/env/tasks.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Task definitions for the SQL agent benchmark.
3
+
4
+ Three difficulty tiers, each with 5 questions and a grader function.
5
+
6
+ Grader contract: grader(sql, rows, error, attempts) -> float in [0.0, 1.0]
7
+ - rows: list[dict] from the executed SQL (may be empty)
8
+ - error: str | None
9
+ - attempts: int (1-indexed count of attempts taken)
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import re
15
+ from dataclasses import dataclass, field
16
+ from typing import Callable, Optional
17
+
18
+ from env.database import execute_query
19
+
20
+
21
+ # ─── Task Definitions ─────────────────────────────────────────────
22
+
23
+ @dataclass
24
+ class TaskQuestion:
25
+ id: str
26
+ question: str
27
+ expected_columns: list[str] # at least these columns should appear
28
+ min_rows: int # minimum expected rows
29
+ max_rows: Optional[int] = None # None = no upper bound
30
+ hint_tables: list[str] = field(default_factory=list) # tables that must be touched
31
+
32
+
33
+ @dataclass
34
+ class Task:
35
+ id: str
36
+ name: str
37
+ difficulty: str # "easy" | "medium" | "hard"
38
+ description: str
39
+ questions: list[TaskQuestion]
40
+ grader: Callable # grader(question, sql, rows, error, attempts) -> float
41
+
42
+
43
+ # ─── Grader Helpers ───────────────────────────────────────────────
44
+
45
+ def _has_required_columns(rows: list[dict], required: list[str]) -> bool:
46
+ if not rows:
47
+ return False
48
+ row_keys = {k.lower() for k in rows[0].keys()}
49
+ return all(col.lower() in row_keys for col in required)
50
+
51
+
52
+ def _row_count_score(rows: list[dict], min_rows: int, max_rows: Optional[int]) -> float:
53
+ n = len(rows)
54
+ if n == 0:
55
+ return 0.0
56
+ if n >= min_rows:
57
+ if max_rows is None or n <= max_rows:
58
+ return 1.0
59
+ # Over the expected maximum β€” might be a missing WHERE clause
60
+ return 0.5
61
+ # Partial result
62
+ return 0.5 * (n / min_rows)
63
+
64
+
65
+ # ─── Task 1: Simple Queries (Easy) ────────────────────────────────
66
+
67
+ _SIMPLE_QUESTIONS = [
68
+ TaskQuestion(
69
+ id="sq-01",
70
+ question="List all users from the USA.",
71
+ expected_columns=["name", "email", "country"],
72
+ min_rows=10,
73
+ max_rows=25,
74
+ hint_tables=["users"],
75
+ ),
76
+ TaskQuestion(
77
+ id="sq-02",
78
+ question="Show all products in the 'Electronics' category with their prices.",
79
+ expected_columns=["name", "price"],
80
+ min_rows=8,
81
+ max_rows=20,
82
+ hint_tables=["products"],
83
+ ),
84
+ TaskQuestion(
85
+ id="sq-03",
86
+ question="Find all orders with status 'delivered'.",
87
+ expected_columns=["id", "status"],
88
+ min_rows=30,
89
+ max_rows=50,
90
+ hint_tables=["orders"],
91
+ ),
92
+ TaskQuestion(
93
+ id="sq-04",
94
+ question="List all sellers and their countries.",
95
+ expected_columns=["name", "country"],
96
+ min_rows=10,
97
+ max_rows=10,
98
+ hint_tables=["sellers"],
99
+ ),
100
+ TaskQuestion(
101
+ id="sq-05",
102
+ question="Show all reviews with a rating of 5 stars.",
103
+ expected_columns=["rating"],
104
+ min_rows=15,
105
+ max_rows=35,
106
+ hint_tables=["reviews"],
107
+ ),
108
+ ]
109
+
110
+
111
+ def _grade_simple(
112
+ question: TaskQuestion,
113
+ sql: str,
114
+ rows: list[dict],
115
+ error: Optional[str],
116
+ attempts: int,
117
+ ) -> float:
118
+ if error:
119
+ return 0.0
120
+
121
+ col_ok = _has_required_columns(rows, question.expected_columns)
122
+ row_score = _row_count_score(rows, question.min_rows, question.max_rows)
123
+
124
+ if col_ok and row_score == 1.0:
125
+ return 1.0
126
+ if col_ok or row_score >= 0.5:
127
+ return 0.5
128
+ return 0.0
129
+
130
+
131
+ _TASK_SIMPLE = Task(
132
+ id="simple_queries",
133
+ name="Simple Queries",
134
+ difficulty="easy",
135
+ description="Single-table SELECT queries with basic filters.",
136
+ questions=_SIMPLE_QUESTIONS,
137
+ grader=_grade_simple,
138
+ )
139
+
140
+
141
+ # ─── Task 2: Join Queries (Medium) ────────────────────────────────
142
+
143
+ _JOIN_QUESTIONS = [
144
+ TaskQuestion(
145
+ id="jq-01",
146
+ question="Show the total number of orders per user, including the user's name.",
147
+ expected_columns=["name"],
148
+ min_rows=10,
149
+ hint_tables=["users", "orders"],
150
+ ),
151
+ TaskQuestion(
152
+ id="jq-02",
153
+ question="List products along with the name of their seller.",
154
+ expected_columns=["name", "name"], # product name + seller name both called 'name'
155
+ min_rows=20,
156
+ hint_tables=["products", "sellers"],
157
+ ),
158
+ TaskQuestion(
159
+ id="jq-03",
160
+ question="Find the average rating for each product category.",
161
+ expected_columns=["category"],
162
+ min_rows=5,
163
+ max_rows=10,
164
+ hint_tables=["products", "reviews"],
165
+ ),
166
+ TaskQuestion(
167
+ id="jq-04",
168
+ question="Show the total revenue (sum of total_price) per seller.",
169
+ expected_columns=["name"],
170
+ min_rows=5,
171
+ hint_tables=["sellers", "products", "orders"],
172
+ ),
173
+ TaskQuestion(
174
+ id="jq-05",
175
+ question="List the top 5 most reviewed products with their review counts.",
176
+ expected_columns=["name"],
177
+ min_rows=5,
178
+ max_rows=5,
179
+ hint_tables=["products", "reviews"],
180
+ ),
181
+ ]
182
+
183
+
184
+ def _grade_join(
185
+ question: TaskQuestion,
186
+ sql: str,
187
+ rows: list[dict],
188
+ error: Optional[str],
189
+ attempts: int,
190
+ ) -> float:
191
+ if error:
192
+ return 0.0
193
+
194
+ col_ok = _has_required_columns(rows, [question.expected_columns[0]])
195
+ row_score = _row_count_score(rows, question.min_rows, question.max_rows)
196
+
197
+ base = 0.0
198
+ if col_ok and row_score == 1.0:
199
+ base = 1.0
200
+ elif col_ok or row_score >= 0.5:
201
+ base = 0.5
202
+
203
+ # Penalize extra attempts
204
+ attempt_penalty = max(0.0, 0.1 * (attempts - 1))
205
+ return max(0.0, base - attempt_penalty)
206
+
207
+
208
+ _TASK_JOIN = Task(
209
+ id="join_queries",
210
+ name="Join Queries",
211
+ difficulty="medium",
212
+ description="Multi-table JOINs with GROUP BY and aggregation.",
213
+ questions=_JOIN_QUESTIONS,
214
+ grader=_grade_join,
215
+ )
216
+
217
+
218
+ # ─── Task 3: Complex Queries (Hard) ───────────────────────────────
219
+
220
+ _COMPLEX_QUESTIONS = [
221
+ TaskQuestion(
222
+ id="cq-01",
223
+ question=(
224
+ "Find users who have placed more than 1 order, showing their name "
225
+ "and total number of orders, ordered by order count descending."
226
+ ),
227
+ expected_columns=["name"],
228
+ min_rows=1,
229
+ hint_tables=["users", "orders"],
230
+ ),
231
+ TaskQuestion(
232
+ id="cq-02",
233
+ question=(
234
+ "For each product category, show the category name, number of products, "
235
+ "average price, and total stock. Use a CTE."
236
+ ),
237
+ expected_columns=["category"],
238
+ min_rows=5,
239
+ max_rows=10,
240
+ hint_tables=["products"],
241
+ ),
242
+ TaskQuestion(
243
+ id="cq-03",
244
+ question=(
245
+ "Show each seller's name, their total sales revenue, and rank them "
246
+ "by revenue using a window function (RANK() or ROW_NUMBER())."
247
+ ),
248
+ expected_columns=["name"],
249
+ min_rows=5,
250
+ hint_tables=["sellers", "products", "orders"],
251
+ ),
252
+ TaskQuestion(
253
+ id="cq-04",
254
+ question=(
255
+ "Find the top-rated product in each category (highest average review rating). "
256
+ "Show category, product name, and average rating."
257
+ ),
258
+ expected_columns=["category", "name"],
259
+ min_rows=5,
260
+ max_rows=10,
261
+ hint_tables=["products", "reviews"],
262
+ ),
263
+ TaskQuestion(
264
+ id="cq-05",
265
+ question=(
266
+ "Calculate the month-over-month order count for 2024, showing year, "
267
+ "month, order_count, and a running total."
268
+ ),
269
+ expected_columns=["month"],
270
+ min_rows=6,
271
+ max_rows=12,
272
+ hint_tables=["orders"],
273
+ ),
274
+ ]
275
+
276
+
277
+ def _grade_complex(
278
+ question: TaskQuestion,
279
+ sql: str,
280
+ rows: list[dict],
281
+ error: Optional[str],
282
+ attempts: int,
283
+ ) -> float:
284
+ if error:
285
+ return 0.0
286
+
287
+ col_ok = _has_required_columns(rows, question.expected_columns)
288
+ row_score = _row_count_score(rows, question.min_rows, question.max_rows)
289
+
290
+ if not col_ok or row_score == 0.0:
291
+ return 0.0
292
+
293
+ # Hard task base max is 0.8 unless first-attempt bonus
294
+ if row_score == 1.0 and col_ok:
295
+ base = 0.8 + (0.2 if attempts == 1 else 0.0)
296
+ else:
297
+ base = 0.4 # partial
298
+
299
+ # Strict attempt penalty for hard queries
300
+ attempt_penalty = 0.1 * (attempts - 1)
301
+ return max(0.0, base - attempt_penalty)
302
+
303
+
304
+ _TASK_COMPLEX = Task(
305
+ id="complex_queries",
306
+ name="Complex Queries",
307
+ difficulty="hard",
308
+ description="CTEs, window functions, and nested aggregations.",
309
+ questions=_COMPLEX_QUESTIONS,
310
+ grader=_grade_complex,
311
+ )
312
+
313
+
314
+ # ─── Registry ─────────────────────────────────────────────────────
315
+
316
+ TASKS: dict[str, Task] = {
317
+ "simple_queries": _TASK_SIMPLE,
318
+ "join_queries": _TASK_JOIN,
319
+ "complex_queries": _TASK_COMPLEX,
320
+ }
321
+
322
+
323
+ def get_task(task_id: str) -> Task:
324
+ if task_id not in TASKS:
325
+ raise ValueError(f"Unknown task_id: {task_id!r}. Valid: {list(TASKS)}")
326
+ return TASKS[task_id]
327
+
328
+
329
+ def get_all_tasks() -> list[Task]:
330
+ return list(TASKS.values())
331
+
332
+
333
+ def grade_response(
334
+ task_id: str,
335
+ question_id: str,
336
+ sql: str,
337
+ rows: list[dict],
338
+ error: Optional[str],
339
+ attempts: int,
340
+ ) -> float:
341
+ task = get_task(task_id)
342
+ question = next((q for q in task.questions if q.id == question_id), None)
343
+ if question is None:
344
+ raise ValueError(f"Unknown question_id {question_id!r} in task {task_id!r}")
345
+ return task.grader(question, sql, rows, error, attempts)
backend/gepa/__init__.py ADDED
File without changes
backend/gepa/optimizer.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GEPA (Goal-directed Evolutionary Prompt Adaptation) optimizer.
3
+
4
+ Ported from gepa.ts. Key steps:
5
+ 1. Reflection: LLM analyzes failure history, outputs diagnosis
6
+ 2. Mutation: LLM rewrites system prompt based on diagnosis
7
+ 3. Scoring: Run 3 golden queries with new prompt, compute score
8
+ 4. Pareto front: Keep top 3 prompts by (score, diversity)
9
+
10
+ State is persisted to data/gepa_prompt.json.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import json
16
+ import os
17
+ import time
18
+ from pathlib import Path
19
+ from typing import Optional
20
+
21
+ from openai import AsyncOpenAI
22
+ from pydantic import BaseModel
23
+
24
+ _DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data"))
25
+ GEPA_PATH = _DATA_DIR / "gepa_prompt.json"
26
+
27
+ _MODEL = os.environ.get("MODEL_NAME", "gpt-4o-mini")
28
+
29
+ SEED_SYSTEM_PROMPT = """You are a SQL expert. Given a natural language question and a SQLite database schema, write a correct SQL query.
30
+
31
+ Rules:
32
+ - Output ONLY the SQL query, nothing else
33
+ - No markdown, no code fences, no explanation
34
+ - Use SQLite syntax"""
35
+
36
+
37
+ # ─── Models ──────────────────────────────────────────────────────
38
+
39
+ class QueryResult(BaseModel):
40
+ question: str
41
+ final_sql: str
42
+ attempts: int
43
+ success: bool
44
+ errors: list[str]
45
+ timestamp: float
46
+
47
+
48
+ class Candidate(BaseModel):
49
+ system_prompt: str
50
+ score: float
51
+ avg_attempts: float
52
+ success_rate: float
53
+ generation: int
54
+ feedback: list[str]
55
+
56
+
57
+ # ─── LLM Helper ──────────────────────────────────────────────────
58
+
59
+ def _make_client() -> AsyncOpenAI:
60
+ return AsyncOpenAI(
61
+ api_key=os.environ.get("HF_TOKEN", ""),
62
+ base_url=os.environ.get("API_BASE_URL", "https://api.openai.com/v1"),
63
+ )
64
+
65
+
66
+ async def _complete(system: str, user: str) -> str:
67
+ client = _make_client()
68
+ resp = await client.chat.completions.create(
69
+ model=_MODEL,
70
+ messages=[
71
+ {"role": "system", "content": system},
72
+ {"role": "user", "content": user},
73
+ ],
74
+ temperature=0.7,
75
+ )
76
+ return resp.choices[0].message.content or ""
77
+
78
+
79
+ # ─── Golden Queries for Scoring ──────────────────────────────────
80
+
81
+ _GOLDEN_QUERIES = [
82
+ {
83
+ "id": "gq-01",
84
+ "question": "List all users from the USA.",
85
+ "expected_min_rows": 10,
86
+ },
87
+ {
88
+ "id": "gq-02",
89
+ "question": "Show all products in the 'Electronics' category.",
90
+ "expected_min_rows": 8,
91
+ },
92
+ {
93
+ "id": "gq-03",
94
+ "question": "Find the total number of orders per user.",
95
+ "expected_min_rows": 10,
96
+ },
97
+ {
98
+ "id": "gq-04",
99
+ "question": "Show the average rating for each product category.",
100
+ "expected_min_rows": 5,
101
+ },
102
+ {
103
+ "id": "gq-05",
104
+ "question": "List products along with their seller name.",
105
+ "expected_min_rows": 20,
106
+ },
107
+ ]
108
+
109
+
110
+ # ─── Optimizer Class ──────────────────────────────────────────────
111
+
112
+ class GEPAOptimizer:
113
+ def __init__(self) -> None:
114
+ self._history: list[QueryResult] = []
115
+ self._pareto_front: list[Candidate] = [
116
+ Candidate(
117
+ system_prompt=SEED_SYSTEM_PROMPT,
118
+ score=0.5,
119
+ avg_attempts=3.0,
120
+ success_rate=0.5,
121
+ generation=0,
122
+ feedback=[],
123
+ )
124
+ ]
125
+ self._load()
126
+
127
+ # ─── Public Interface ─────────────────────────────────────────
128
+
129
+ def record_result(self, result: QueryResult) -> None:
130
+ self._history.append(result)
131
+ self._save()
132
+
133
+ def get_current_prompt(self) -> str:
134
+ if not self._pareto_front:
135
+ return SEED_SYSTEM_PROMPT
136
+ return max(self._pareto_front, key=lambda c: c.score).system_prompt
137
+
138
+ def get_history(self) -> list[QueryResult]:
139
+ return list(self._history)
140
+
141
+ def get_pareto_front(self) -> list[Candidate]:
142
+ return list(self._pareto_front)
143
+
144
+ def set_current_prompt(self, prompt: str) -> None:
145
+ if self._pareto_front:
146
+ best = max(self._pareto_front, key=lambda c: c.score)
147
+ best.system_prompt = prompt
148
+ else:
149
+ self._pareto_front.append(
150
+ Candidate(
151
+ system_prompt=prompt,
152
+ score=0.5,
153
+ avg_attempts=3.0,
154
+ success_rate=0.5,
155
+ generation=0,
156
+ feedback=[],
157
+ )
158
+ )
159
+ self._save()
160
+
161
+ def should_optimize(self) -> bool:
162
+ return len(self._history) > 0 and len(self._history) % 4 == 0
163
+
164
+ def reset(self) -> None:
165
+ self._history.clear()
166
+ self._pareto_front.clear()
167
+ self._pareto_front.append(
168
+ Candidate(
169
+ system_prompt=SEED_SYSTEM_PROMPT,
170
+ score=0.5,
171
+ avg_attempts=3.0,
172
+ success_rate=0.5,
173
+ generation=0,
174
+ feedback=[],
175
+ )
176
+ )
177
+ self._save()
178
+
179
+ async def run_optimization_cycle(
180
+ self,
181
+ user_feedback_context: Optional[str] = None,
182
+ dialect: str = "SQLite",
183
+ ) -> Optional[dict]:
184
+ """
185
+ Run one GEPA cycle: reflect β†’ mutate β†’ score β†’ update Pareto front.
186
+ Returns {new_prompt, reflection} or None if not enough data.
187
+ """
188
+ if len(self._history) < 2:
189
+ return None
190
+
191
+ recent_failures = [
192
+ h for h in self._history if h.attempts > 1 or not h.success
193
+ ][-8:]
194
+ if len(recent_failures) < 2:
195
+ return None
196
+
197
+ current_best = self.get_current_prompt()
198
+
199
+ # ── Step 1: Reflect ──────────────────────────────────────
200
+ failure_summary = "\n\n---\n\n".join(
201
+ f'Query {i+1}: "{f.question}"\n'
202
+ f"Attempts: {f.attempts}\n"
203
+ f"Errors:\n" + "\n".join(f" - {e}" for e in f.errors) + "\n"
204
+ f"Final SQL: {f.final_sql}"
205
+ for i, f in enumerate(recent_failures)
206
+ )
207
+
208
+ user_ctx_block = (
209
+ f"\n\nUser conversation:\n{user_feedback_context}"
210
+ if user_feedback_context
211
+ else ""
212
+ )
213
+
214
+ reflection = await _complete(
215
+ f"You are an expert SQL prompt engineer analyzing why an LLM SQL agent is failing.\n"
216
+ f"The target database is {dialect} β€” all rules must use {dialect} syntax.\n"
217
+ "Your job: identify specific, recurring patterns in these failures and state EXACTLY "
218
+ "what rules or knowledge the system prompt is missing.\n"
219
+ "Be very specific β€” name the exact functions, syntax patterns, or schema reasoning gaps.\n"
220
+ "Output a concise diagnosis (3-5 bullet points max).",
221
+ f"Current system prompt:\n{current_best}\n\n"
222
+ f"Recent failures:\n{failure_summary}{user_ctx_block}",
223
+ )
224
+
225
+ # ── Step 2: Mutate ───────────────────────────────────────
226
+ current_generation = max(c.generation for c in self._pareto_front) if self._pareto_front else 0
227
+
228
+ new_prompt = await _complete(
229
+ f"You are an expert prompt engineer. Improve a system prompt for a {dialect} SQL generation agent.\n"
230
+ "Rules for the new prompt:\n"
231
+ "- Keep it concise and actionable\n"
232
+ f"- The target database is {dialect} β€” use ONLY {dialect} syntax and functions\n"
233
+ "- Add specific rules that address the diagnosed failure patterns\n"
234
+ "- Do NOT add generic fluff β€” every rule must be earned by a real failure\n"
235
+ "- Output ONLY the improved system prompt text, nothing else",
236
+ f"Current system prompt:\n{current_best}\n\n"
237
+ f"Diagnosed failure patterns:\n{reflection}\n\n"
238
+ "Write the improved system prompt:",
239
+ )
240
+
241
+ # ── Step 3: Score ────────────────────────────────────────
242
+ benchmark_score = await self._score_prompt(new_prompt)
243
+
244
+ current_avg_attempts = (
245
+ sum(h.attempts for h in self._history) / len(self._history)
246
+ if self._history
247
+ else 3.0
248
+ )
249
+
250
+ new_candidate = Candidate(
251
+ system_prompt=new_prompt,
252
+ score=benchmark_score,
253
+ avg_attempts=max(current_avg_attempts - 0.5, 1.0),
254
+ success_rate=benchmark_score,
255
+ generation=current_generation + 1,
256
+ feedback=[reflection],
257
+ )
258
+
259
+ # ── Step 4: Update Pareto front ──────────────────────────
260
+ self._pareto_front.append(new_candidate)
261
+ self._pareto_front.sort(key=lambda c: c.score, reverse=True)
262
+ if len(self._pareto_front) > 3:
263
+ self._pareto_front = self._pareto_front[:3]
264
+
265
+ self._save()
266
+ return {"new_prompt": new_prompt, "reflection": reflection}
267
+
268
+ async def _score_prompt(self, prompt: str) -> float:
269
+ """
270
+ Score a prompt by running 3 golden queries and measuring success rate.
271
+ """
272
+ from env.database import execute_query, get_schema_info
273
+ import re
274
+
275
+ schema = get_schema_info()
276
+ client = _make_client()
277
+
278
+ scores = []
279
+ for gq in _GOLDEN_QUERIES[:3]:
280
+ try:
281
+ resp = await client.chat.completions.create(
282
+ model=_MODEL,
283
+ messages=[
284
+ {"role": "system", "content": prompt},
285
+ {
286
+ "role": "user",
287
+ "content": (
288
+ f"Schema:\n{schema}\n\n"
289
+ f"Question: {gq['question']}\n\n"
290
+ "Write a SQL query."
291
+ ),
292
+ },
293
+ ],
294
+ temperature=0.1,
295
+ )
296
+ sql = resp.choices[0].message.content or ""
297
+ sql = re.sub(r"^```(?:sql)?\s*", "", sql.strip(), flags=re.IGNORECASE)
298
+ sql = re.sub(r"\s*```$", "", sql).strip().rstrip(";")
299
+
300
+ rows, error = execute_query(sql)
301
+ if error is None and len(rows) >= gq["expected_min_rows"]:
302
+ scores.append(1.0)
303
+ elif error is None and rows:
304
+ scores.append(0.5)
305
+ else:
306
+ scores.append(0.0)
307
+ except Exception:
308
+ scores.append(0.0)
309
+
310
+ return sum(scores) / len(scores) if scores else 0.3
311
+
312
+ # ─── Persistence ─────────────────────────────────────────────
313
+
314
+ def _save(self) -> None:
315
+ try:
316
+ GEPA_PATH.parent.mkdir(parents=True, exist_ok=True)
317
+ data = {
318
+ "history": [r.model_dump() for r in self._history[-100:]],
319
+ "pareto_front": [c.model_dump() for c in self._pareto_front],
320
+ }
321
+ GEPA_PATH.write_text(json.dumps(data, default=str))
322
+ except Exception:
323
+ pass
324
+
325
+ def _load(self) -> None:
326
+ try:
327
+ if not GEPA_PATH.exists():
328
+ return
329
+ data = json.loads(GEPA_PATH.read_text())
330
+ self._history = [QueryResult(**r) for r in data.get("history", [])]
331
+ loaded_front = [Candidate(**c) for c in data.get("pareto_front", [])]
332
+ if loaded_front:
333
+ self._pareto_front = loaded_front
334
+ except Exception:
335
+ pass
336
+
337
+
338
+ # ─── Singleton ────────────────────────────────────────────────────
339
+
340
+ _gepa_instance: Optional[GEPAOptimizer] = None
341
+
342
+
343
+ def get_gepa() -> GEPAOptimizer:
344
+ global _gepa_instance
345
+ if _gepa_instance is None:
346
+ _gepa_instance = GEPAOptimizer()
347
+ return _gepa_instance
backend/main.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQL Agent OpenEnv β€” FastAPI entry point.
3
+
4
+ Start with:
5
+ uvicorn main:app --reload --port 8000
6
+
7
+ Environment variables:
8
+ API_BASE_URL β€” OpenAI-compatible base URL
9
+ MODEL_NAME β€” model name
10
+ HF_TOKEN β€” API key / bearer token
11
+ DATA_DIR β€” override data directory (default: ./data)
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import os
17
+ from pathlib import Path
18
+
19
+ from fastapi import FastAPI
20
+ from fastapi.middleware.cors import CORSMiddleware
21
+ from fastapi.staticfiles import StaticFiles
22
+
23
+ from api.demo import router as demo_router
24
+ from api.openenv import router as openenv_router, ResetRequest, StepRequest, env_reset, env_step, env_state
25
+ from env.database import ensure_seeded
26
+
27
+ app = FastAPI(
28
+ title="SQL Agent OpenEnv",
29
+ description=(
30
+ "A SQL generation environment powered by a LinUCB contextual bandit "
31
+ "and GEPA prompt evolution, built for the Meta + Hugging Face OpenEnv hackathon."
32
+ ),
33
+ version="1.0.0",
34
+ )
35
+
36
+ # ─── CORS ────────────────────────────────────────────────────────
37
+
38
+ app.add_middleware(
39
+ CORSMiddleware,
40
+ allow_origins=["*"],
41
+ allow_credentials=True,
42
+ allow_methods=["*"],
43
+ allow_headers=["*"],
44
+ )
45
+
46
+ # ─── Routers ─────────────────────────────────────────────────────
47
+
48
+ app.include_router(demo_router, prefix="/api", tags=["demo"])
49
+ app.include_router(openenv_router, prefix="/env", tags=["openenv"])
50
+
51
+ # ─── Top-level OpenEnv aliases (required by openenv validate + pre-validation) ─
52
+ # The validator pings POST <url>/reset β€” these mirror /env/* without the prefix.
53
+
54
+ @app.post("/reset", tags=["openenv"])
55
+ async def root_reset(req: ResetRequest = None):
56
+ return await env_reset(req or ResetRequest())
57
+
58
+
59
+ @app.post("/step", tags=["openenv"])
60
+ async def root_step(req: StepRequest = None):
61
+ return await env_step(req or StepRequest())
62
+
63
+
64
+ @app.get("/state", tags=["openenv"])
65
+ async def root_state():
66
+ return await env_state()
67
+
68
+
69
+ # ─── Health check ────────────────────────────────────────────────
70
+
71
+ @app.get("/health", tags=["system"])
72
+ async def health():
73
+ return {"status": "ok", "service": "sql-agent-openenv"}
74
+
75
+
76
+ # ─── Startup ─────────────────────────────────────────────────────
77
+
78
+ @app.on_event("startup")
79
+ async def startup_event():
80
+ """Seed the database on first startup."""
81
+ try:
82
+ ensure_seeded()
83
+ except Exception as e:
84
+ print(f"Warning: database seed failed: {e}")
85
+
86
+
87
+ # ─── Static files (frontend) β€” mount last ─────────────────────────
88
+
89
+ _frontend_dist = Path(__file__).parent.parent / "frontend" / "dist"
90
+ if _frontend_dist.exists():
91
+ app.mount(
92
+ "/",
93
+ StaticFiles(directory=str(_frontend_dist), html=True),
94
+ name="frontend",
95
+ )
96
+ else:
97
+ @app.get("/", tags=["system"])
98
+ async def root():
99
+ return {
100
+ "message": "SQL Agent OpenEnv API",
101
+ "docs": "/docs",
102
+ "health": "/health",
103
+ "env_info": "/env/info",
104
+ }
backend/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.115.0
2
+ uvicorn[standard]>=0.30.0
3
+ openai>=1.40.0
4
+ pydantic>=2.8.0
5
+ numpy>=1.26.0
6
+ aiofiles>=24.0.0
7
+ python-multipart>=0.0.9
8
+ sse-starlette>=2.1.0
9
+ aiosqlite>=0.20.0
backend/rl/__init__.py ADDED
File without changes
backend/rl/environment.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQLDebugEnvironment β€” Gym-like RL environment for the SQL debug loop.
3
+
4
+ Lifecycle:
5
+ 1. env.reset(question) β€” start new episode
6
+ 2. env.observe_error(error, sql) β€” classify error, build state
7
+ 3. env.select_action() β€” bandit picks repair strategy
8
+ 4. env.get_repair_prompt(...) β€” get specialized prompt for chosen action
9
+ 5. env.record_step(success) β€” record outcome, compute reward
10
+ 6. Repeat 2-5 until success or max attempts
11
+ 7. env.end_episode(success) β€” finalize, HER relabeling, bandit update
12
+
13
+ This module is a stateful singleton β€” one active episode at a time.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import time
19
+ from typing import Optional
20
+
21
+ from rl.types import (
22
+ RLState,
23
+ RepairAction,
24
+ ErrorClass,
25
+ EpisodeStep,
26
+ RLMetrics,
27
+ featurize,
28
+ REPAIR_ACTION_NAMES,
29
+ ERROR_CLASS_NAMES,
30
+ )
31
+ from rl.error_classifier import classify_error, extract_offending_token
32
+ from rl.grader import GraderInput, compute_reward
33
+ from rl.linucb import LinUCB
34
+ from rl.experience import record_episode, get_metrics, reset_experience
35
+ from rl.repair_strategies import (
36
+ RepairContext,
37
+ get_repair_system_suffix,
38
+ build_repair_user_message,
39
+ )
40
+
41
+ # ─── Singleton State ─────────────────────────────────────────────
42
+
43
+ _bandit: Optional[LinUCB] = None
44
+
45
+
46
+ class _EpisodeContext:
47
+ def __init__(self, question: str) -> None:
48
+ self.question = question
49
+ self.steps: list[EpisodeStep] = []
50
+ self.previous_error_class: Optional[ErrorClass] = None
51
+ self.consecutive_same_error: int = 0
52
+ self.last_action: Optional[RepairAction] = None
53
+ self.current_state: Optional[RLState] = None
54
+ self.current_features: Optional[list[float]] = None
55
+
56
+
57
+ _current_episode: Optional[_EpisodeContext] = None
58
+
59
+
60
+ def _get_bandit() -> LinUCB:
61
+ global _bandit
62
+ if _bandit is None:
63
+ _bandit = LinUCB()
64
+ return _bandit
65
+
66
+
67
+ # ─── Environment Interface ────────────────────────────────────────
68
+
69
+ def reset(question: str) -> None:
70
+ """Start a new episode. If a previous episode was active, end it as failure."""
71
+ global _current_episode
72
+ if _current_episode and _current_episode.steps:
73
+ end_episode(False)
74
+ _current_episode = _EpisodeContext(question)
75
+
76
+
77
+ def observe_error(
78
+ error_message: str,
79
+ failing_sql: str,
80
+ attempt_number: int,
81
+ ) -> dict:
82
+ """
83
+ Classify the SQL execution error and build the RL state.
84
+ Returns a dict with keys: error_class, error_class_name, state.
85
+ """
86
+ if _current_episode is None:
87
+ raise RuntimeError("Call reset() before observe_error()")
88
+
89
+ error_class = classify_error(error_message)
90
+ error_changed = (
91
+ _current_episode.previous_error_class is not None
92
+ and _current_episode.previous_error_class != error_class
93
+ )
94
+
95
+ if _current_episode.previous_error_class == error_class:
96
+ _current_episode.consecutive_same_error += 1
97
+ else:
98
+ _current_episode.consecutive_same_error = 1
99
+
100
+ state = RLState(
101
+ error_class=error_class,
102
+ attempt_number=attempt_number,
103
+ previous_action=_current_episode.last_action,
104
+ error_changed=error_changed,
105
+ consecutive_same_error=_current_episode.consecutive_same_error,
106
+ )
107
+
108
+ _current_episode.current_state = state
109
+ _current_episode.current_features = featurize(state)
110
+
111
+ return {
112
+ "error_class": error_class,
113
+ "error_class_name": ERROR_CLASS_NAMES[error_class],
114
+ "state": state,
115
+ }
116
+
117
+
118
+ def select_action() -> dict:
119
+ """
120
+ Ask the bandit to select a repair action based on current state.
121
+ Returns dict with keys: action, action_name, scores.
122
+ """
123
+ if _current_episode is None or _current_episode.current_features is None:
124
+ raise RuntimeError("Call observe_error() before select_action()")
125
+
126
+ b = _get_bandit()
127
+ action, scores = b.select_action(_current_episode.current_features)
128
+ _current_episode.last_action = action
129
+
130
+ return {
131
+ "action": action,
132
+ "action_name": REPAIR_ACTION_NAMES[action],
133
+ "scores": scores,
134
+ }
135
+
136
+
137
+ def get_repair_prompt(
138
+ action: RepairAction,
139
+ schema: str,
140
+ question: str,
141
+ failing_sql: str,
142
+ error_message: str,
143
+ ) -> dict:
144
+ """
145
+ Build the system suffix and user message for the chosen repair action.
146
+ Returns dict with keys: system_suffix, user_message.
147
+ """
148
+ offending_token = extract_offending_token(error_message)
149
+ ctx = RepairContext(
150
+ schema=schema,
151
+ question=question,
152
+ failing_sql=failing_sql,
153
+ error_message=error_message,
154
+ offending_token=offending_token,
155
+ )
156
+ return {
157
+ "system_suffix": get_repair_system_suffix(action),
158
+ "user_message": build_repair_user_message(action, ctx),
159
+ }
160
+
161
+
162
+ def record_step(
163
+ action: RepairAction,
164
+ success: bool,
165
+ error_message: str,
166
+ sql: str,
167
+ ) -> dict:
168
+ """
169
+ Record the outcome of a repair step and compute shaped reward.
170
+ Returns dict with keys: reward, breakdown.
171
+ """
172
+ if _current_episode is None or _current_episode.current_state is None:
173
+ raise RuntimeError("Call observe_error() before record_step()")
174
+
175
+ state = _current_episode.current_state
176
+
177
+ grader_input = GraderInput(
178
+ success=success,
179
+ attempt_number=state.attempt_number,
180
+ current_error_class=None if success else classify_error(error_message),
181
+ previous_error_class=_current_episode.previous_error_class,
182
+ )
183
+ result = compute_reward(grader_input)
184
+
185
+ step = EpisodeStep(
186
+ state=state,
187
+ featurized=_current_episode.current_features or featurize(state),
188
+ action=action,
189
+ reward=result.reward,
190
+ error_message=error_message,
191
+ sql=sql,
192
+ success=success,
193
+ )
194
+
195
+ _current_episode.steps.append(step)
196
+ _current_episode.previous_error_class = state.error_class
197
+
198
+ return {
199
+ "reward": result.reward,
200
+ "breakdown": {
201
+ "base": result.breakdown.base,
202
+ "attempt_penalty": result.breakdown.attempt_penalty,
203
+ "severity_bonus": result.breakdown.severity_bonus,
204
+ "change_bonus": result.breakdown.change_bonus,
205
+ },
206
+ }
207
+
208
+
209
+ def end_episode(success: bool) -> Optional[dict]:
210
+ """
211
+ End the current episode. Runs HER relabeling and updates the bandit.
212
+ Returns dict with keys: total_reward, episode_length.
213
+ """
214
+ global _current_episode
215
+ if _current_episode is None or not _current_episode.steps:
216
+ _current_episode = None
217
+ return None
218
+
219
+ b = _get_bandit()
220
+ episode, relabeled = record_episode(
221
+ _current_episode.question,
222
+ _current_episode.steps,
223
+ success,
224
+ )
225
+
226
+ for exp in relabeled:
227
+ b.update(exp.state, exp.action, exp.reward)
228
+
229
+ b.decay_alpha()
230
+
231
+ result = {
232
+ "total_reward": episode.total_reward,
233
+ "episode_length": len(episode.steps),
234
+ }
235
+
236
+ _current_episode = None
237
+ return result
238
+
239
+
240
+ # ─── Query Interface ──────────────────────────────────────────────
241
+
242
+ def get_rl_metrics() -> RLMetrics:
243
+ return get_metrics()
244
+
245
+
246
+ def get_bandit_state() -> dict:
247
+ b = _get_bandit()
248
+ return {
249
+ "action_counts": b.get_action_counts(),
250
+ "total_updates": b.get_total_updates(),
251
+ "alpha": b.get_alpha(),
252
+ "action_distribution": b.get_action_distribution(),
253
+ }
254
+
255
+
256
+ def is_episode_active() -> bool:
257
+ return _current_episode is not None
258
+
259
+
260
+ def reset_rl() -> None:
261
+ """Reset the entire RL system β€” bandit weights and experience store."""
262
+ global _bandit, _current_episode
263
+ if _bandit:
264
+ _bandit.reset()
265
+ reset_experience()
266
+ _current_episode = None
backend/rl/error_classifier.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQL error classifier: maps raw SQLite error messages to one of 8
3
+ canonical ErrorClass values.
4
+
5
+ Severity ordering (lower = less severe / closer to correct):
6
+ OTHER=5, SYNTAX_ERROR=4, NO_SUCH_FUNCTION=3, NO_SUCH_TABLE=3,
7
+ DATATYPE_MISMATCH=2, AGGREGATION_ERROR=2,
8
+ NO_SUCH_COLUMN=1, AMBIGUOUS_COLUMN=1
9
+ """
10
+
11
+ import re
12
+ from typing import Optional
13
+
14
+ from rl.types import ErrorClass
15
+
16
+ _SEVERITY: dict[ErrorClass, int] = {
17
+ ErrorClass.OTHER: 5,
18
+ ErrorClass.SYNTAX_ERROR: 4,
19
+ ErrorClass.NO_SUCH_FUNCTION: 3,
20
+ ErrorClass.NO_SUCH_TABLE: 3,
21
+ ErrorClass.DATATYPE_MISMATCH: 2,
22
+ ErrorClass.AGGREGATION_ERROR: 2,
23
+ ErrorClass.NO_SUCH_COLUMN: 1,
24
+ ErrorClass.AMBIGUOUS_COLUMN: 1,
25
+ }
26
+
27
+
28
+ def error_severity(error_class: ErrorClass) -> int:
29
+ return _SEVERITY[error_class]
30
+
31
+
32
+ def classify_error(error_message: str) -> ErrorClass:
33
+ """
34
+ Classify a raw SQLite error message into one of 8 canonical classes.
35
+ Patterns are ordered most-specific-first to avoid false matches.
36
+ """
37
+ msg = error_message.lower()
38
+
39
+ # Column-level errors
40
+ if "no such column" in msg:
41
+ return ErrorClass.NO_SUCH_COLUMN
42
+ if "ambiguous column" in msg:
43
+ return ErrorClass.AMBIGUOUS_COLUMN
44
+
45
+ # Table-level errors
46
+ if "no such table" in msg:
47
+ return ErrorClass.NO_SUCH_TABLE
48
+
49
+ # Function errors
50
+ if "no such function" in msg:
51
+ return ErrorClass.NO_SUCH_FUNCTION
52
+
53
+ # Aggregation / GROUP BY
54
+ if (
55
+ "not an aggregate" in msg
56
+ or "misuse of aggregate" in msg
57
+ or ("group by" in msg and "must appear" in msg)
58
+ or "must be an aggregate" in msg
59
+ ):
60
+ return ErrorClass.AGGREGATION_ERROR
61
+
62
+ # Syntax errors (broad β€” must come after more specific patterns)
63
+ if "syntax error" in msg or re.search(r'near\s+"', msg):
64
+ return ErrorClass.SYNTAX_ERROR
65
+
66
+ # Type errors
67
+ if "datatype mismatch" in msg or "type mismatch" in msg:
68
+ return ErrorClass.DATATYPE_MISMATCH
69
+
70
+ return ErrorClass.OTHER
71
+
72
+
73
+ def extract_offending_token(error_message: str) -> Optional[str]:
74
+ """
75
+ Extract the offending token from a SQLite error message.
76
+ Returns None if no specific token can be identified.
77
+ """
78
+ # "no such column: X"
79
+ m = re.search(r"no such column:\s*(\S+)", error_message, re.IGNORECASE)
80
+ if m:
81
+ return m.group(1)
82
+
83
+ # "no such table: X"
84
+ m = re.search(r"no such table:\s*(\S+)", error_message, re.IGNORECASE)
85
+ if m:
86
+ return m.group(1)
87
+
88
+ # 'near "X": syntax error'
89
+ m = re.search(r'near\s+"([^"]+)"', error_message, re.IGNORECASE)
90
+ if m:
91
+ return m.group(1)
92
+
93
+ # "no such function: X"
94
+ m = re.search(r"no such function:\s*(\S+)", error_message, re.IGNORECASE)
95
+ if m:
96
+ return m.group(1)
97
+
98
+ return None
backend/rl/experience.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Experience store: logs episodes, persists to disk, and implements
3
+ Hindsight Experience Replay (HER) for reward relabeling.
4
+
5
+ HER (Andrychowicz et al., 2017): If a later attempt in the same episode
6
+ succeeded, relabel earlier failed steps with partial credit proportional
7
+ to their distance from the success step. This multiplies the effective
8
+ training signal from sparse rewards.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ import os
15
+ import time
16
+ import random
17
+ from pathlib import Path
18
+ from typing import Optional
19
+
20
+ from rl.types import (
21
+ Episode,
22
+ EpisodeStep,
23
+ Experience,
24
+ RLMetrics,
25
+ RepairAction,
26
+ REPAIR_ACTION_NAMES,
27
+ ERROR_CLASS_NAMES,
28
+ )
29
+ from rl.grader import compute_episode_reward
30
+
31
+ _DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data"))
32
+ EXPERIENCE_PATH = _DATA_DIR / "rl_experiences.json"
33
+ MAX_EPISODES = 500
34
+
35
+ _episodes: list[Episode] = []
36
+ _loaded: bool = False
37
+
38
+
39
+ def _ensure_loaded() -> None:
40
+ global _loaded, _episodes
41
+ if _loaded:
42
+ return
43
+ _loaded = True
44
+ try:
45
+ if EXPERIENCE_PATH.exists():
46
+ raw = json.loads(EXPERIENCE_PATH.read_text())
47
+ _episodes = [Episode(**ep) for ep in raw]
48
+ except Exception:
49
+ _episodes = []
50
+
51
+
52
+ def _persist() -> None:
53
+ try:
54
+ EXPERIENCE_PATH.parent.mkdir(parents=True, exist_ok=True)
55
+ data = [ep.model_dump() for ep in _episodes[-MAX_EPISODES:]]
56
+ EXPERIENCE_PATH.write_text(json.dumps(data, default=str))
57
+ except Exception:
58
+ pass
59
+
60
+
61
+ def record_episode(
62
+ question: str,
63
+ steps: list[EpisodeStep],
64
+ success: bool,
65
+ ) -> tuple[Episode, list[Experience]]:
66
+ """
67
+ Record a completed episode, run HER relabeling, and persist.
68
+ Returns (episode, relabeled_experiences).
69
+ """
70
+ _ensure_loaded()
71
+
72
+ step_rewards = [s.reward for s in steps]
73
+ total_reward = compute_episode_reward(step_rewards, success)
74
+
75
+ episode = Episode(
76
+ id=f"ep-{int(time.time() * 1000)}-{random.randint(1000, 9999)}",
77
+ question=question,
78
+ steps=steps,
79
+ total_reward=total_reward,
80
+ success=success,
81
+ timestamp=time.time(),
82
+ )
83
+
84
+ _episodes.append(episode)
85
+ if len(_episodes) > MAX_EPISODES:
86
+ _episodes[:] = _episodes[-MAX_EPISODES:]
87
+ _persist()
88
+
89
+ relabeled = _apply_her(episode)
90
+ return episode, relabeled
91
+
92
+
93
+ def _apply_her(episode: Episode) -> list[Experience]:
94
+ """
95
+ Hindsight Experience Replay.
96
+
97
+ If the episode eventually succeeded at step T, relabel earlier
98
+ failed steps with a hindsight bonus:
99
+ bonus(t) = 0.3 * (1 - (T - t) / T)
100
+
101
+ Steps closer to the eventual success receive more credit.
102
+ """
103
+ experiences: list[Experience] = []
104
+ success_step_idx = next(
105
+ (i for i, s in enumerate(episode.steps) if s.success), -1
106
+ )
107
+
108
+ for t, step in enumerate(episode.steps):
109
+ reward = step.reward
110
+
111
+ if success_step_idx > t:
112
+ distance = success_step_idx - t
113
+ total_steps = len(episode.steps)
114
+ her_bonus = 0.3 * (1.0 - distance / total_steps)
115
+ reward += her_bonus
116
+
117
+ next_step = episode.steps[t + 1] if t < len(episode.steps) - 1 else None
118
+
119
+ experiences.append(
120
+ Experience(
121
+ state=step.featurized,
122
+ action=step.action,
123
+ reward=reward,
124
+ next_state=next_step.featurized if next_step else None,
125
+ done=(t == len(episode.steps) - 1),
126
+ timestamp=episode.timestamp,
127
+ metadata={
128
+ "question": episode.question,
129
+ "error_message": step.error_message,
130
+ "sql": step.sql,
131
+ "error_class": int(step.state.error_class),
132
+ "attempt_number": step.state.attempt_number,
133
+ },
134
+ )
135
+ )
136
+
137
+ return experiences
138
+
139
+
140
+ def replay_all(bandit) -> int:
141
+ """
142
+ Replay all stored experiences through the bandit to rebuild weights.
143
+ Useful after a reset or if weights are lost.
144
+ """
145
+ _ensure_loaded()
146
+ count = 0
147
+ for ep in _episodes:
148
+ relabeled = _apply_her(ep)
149
+ for exp in relabeled:
150
+ bandit.update(exp.state, exp.action, exp.reward)
151
+ count += 1
152
+ return count
153
+
154
+
155
+ def get_metrics() -> RLMetrics:
156
+ _ensure_loaded()
157
+
158
+ recent_window = 50
159
+ recent = _episodes[-recent_window:]
160
+ all_steps = [s for ep in _episodes for s in ep.steps]
161
+
162
+ action_dist: dict[str, int] = {}
163
+ error_dist: dict[str, int] = {}
164
+
165
+ for step in all_steps:
166
+ a_name = REPAIR_ACTION_NAMES[step.action]
167
+ action_dist[a_name] = action_dist.get(a_name, 0) + 1
168
+ e_name = ERROR_CLASS_NAMES[step.state.error_class]
169
+ error_dist[e_name] = error_dist.get(e_name, 0) + 1
170
+
171
+ return RLMetrics(
172
+ total_episodes=len(_episodes),
173
+ total_steps=len(all_steps),
174
+ cumulative_reward=sum(ep.total_reward for ep in _episodes),
175
+ success_rate=(
176
+ sum(1 for ep in recent if ep.success) / len(recent)
177
+ if recent
178
+ else 0.0
179
+ ),
180
+ avg_attempts=(
181
+ sum(len(ep.steps) for ep in recent) / len(recent)
182
+ if recent
183
+ else 0.0
184
+ ),
185
+ action_distribution=action_dist,
186
+ error_distribution=error_dist,
187
+ reward_history=[ep.total_reward for ep in _episodes],
188
+ )
189
+
190
+
191
+ def get_episodes() -> list[Episode]:
192
+ _ensure_loaded()
193
+ return list(_episodes)
194
+
195
+
196
+ def get_recent_episodes(n: int) -> list[Episode]:
197
+ _ensure_loaded()
198
+ return _episodes[-n:]
199
+
200
+
201
+ def reset_experience() -> None:
202
+ global _episodes, _loaded
203
+ _episodes = []
204
+ _loaded = True
205
+ try:
206
+ EXPERIENCE_PATH.unlink(missing_ok=True)
207
+ except Exception:
208
+ pass
backend/rl/grader.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shaped reward function for the SQL debug RL environment.
3
+
4
+ Reward components:
5
+ +1.0 base success reward
6
+ -0.1 per attempt (attempt penalty β€” incentivizes early resolution)
7
+ +0.2 if error severity decreased (progress signal)
8
+ +0.1 if error class changed at all (exploration signal)
9
+ -0.1 base failure penalty per step
10
+
11
+ The shaping is potential-based (Ng et al., 1999), preserving
12
+ the optimal policy while accelerating learning.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import Optional
18
+ from dataclasses import dataclass
19
+
20
+ from rl.types import ErrorClass
21
+ from rl.error_classifier import error_severity
22
+
23
+
24
+ @dataclass
25
+ class GraderInput:
26
+ success: bool
27
+ attempt_number: int # 1-indexed
28
+ current_error_class: Optional[ErrorClass] # None if success
29
+ previous_error_class: Optional[ErrorClass] # None on first attempt
30
+
31
+
32
+ @dataclass
33
+ class RewardBreakdown:
34
+ base: float
35
+ attempt_penalty: float
36
+ severity_bonus: float
37
+ change_bonus: float
38
+
39
+
40
+ @dataclass
41
+ class GraderOutput:
42
+ reward: float
43
+ breakdown: RewardBreakdown
44
+
45
+
46
+ def compute_reward(inp: GraderInput) -> GraderOutput:
47
+ if inp.success:
48
+ base = 1.0
49
+ attempt_penalty = -0.1 * (inp.attempt_number - 1)
50
+ return GraderOutput(
51
+ reward=base + attempt_penalty,
52
+ breakdown=RewardBreakdown(
53
+ base=base,
54
+ attempt_penalty=attempt_penalty,
55
+ severity_bonus=0.0,
56
+ change_bonus=0.0,
57
+ ),
58
+ )
59
+
60
+ # Failed step β€” base penalty + potential shaping
61
+ base = -0.1
62
+ attempt_penalty = -0.05 * inp.attempt_number
63
+
64
+ severity_bonus = 0.0
65
+ change_bonus = 0.0
66
+
67
+ if inp.previous_error_class is not None and inp.current_error_class is not None:
68
+ prev_sev = error_severity(inp.previous_error_class)
69
+ curr_sev = error_severity(inp.current_error_class)
70
+
71
+ if curr_sev < prev_sev:
72
+ severity_bonus = 0.2 # Progress toward solution
73
+ elif curr_sev > prev_sev:
74
+ severity_bonus = -0.1 # Regression
75
+
76
+ if inp.current_error_class != inp.previous_error_class:
77
+ change_bonus = 0.1 # At least something different happened
78
+
79
+ reward = base + attempt_penalty + severity_bonus + change_bonus
80
+
81
+ return GraderOutput(
82
+ reward=reward,
83
+ breakdown=RewardBreakdown(
84
+ base=base,
85
+ attempt_penalty=attempt_penalty,
86
+ severity_bonus=severity_bonus,
87
+ change_bonus=change_bonus,
88
+ ),
89
+ )
90
+
91
+
92
+ def compute_episode_reward(step_rewards: list[float], success: bool) -> float:
93
+ """
94
+ Compute total episode reward from individual step rewards.
95
+ Includes a terminal bonus/penalty based on final outcome.
96
+ """
97
+ total = sum(step_rewards)
98
+ terminal = 0.5 if success else -0.5
99
+ return total + terminal
backend/rl/linucb.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LinUCB Contextual Bandit (Li et al., 2010).
3
+
4
+ Maintains per-action inverse covariance matrices using the
5
+ Sherman-Morrison rank-1 update formula for O(d^2) updates.
6
+
7
+ For each action a in {0..K-1}:
8
+ A_inv[a] β€” dΓ—d inverse covariance (starts as I_d)
9
+ b[a] β€” d reward-weighted feature accumulator
10
+ theta[a] = A_inv[a] @ b[a] (ridge regression estimate)
11
+ UCB_a(x) = theta[a] @ x + alpha * sqrt(max(0, x @ A_inv[a] @ x))
12
+
13
+ Action selection: argmax_a UCB_a(x)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ import os
20
+ import random
21
+ from pathlib import Path
22
+ from typing import List, Optional, Tuple
23
+
24
+ import numpy as np
25
+
26
+ from rl.types import FEATURE_DIM, NUM_ACTIONS, RepairAction, REPAIR_ACTION_NAMES
27
+
28
+ # Default path β€” can be overridden by DATA_DIR env var
29
+ _DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data"))
30
+ WEIGHTS_PATH = _DATA_DIR / "rl_weights.json"
31
+
32
+
33
+ class LinUCB:
34
+ """
35
+ LinUCB contextual bandit with Sherman-Morrison updates and alpha decay.
36
+ Weights are persisted to JSON after every 10 updates.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ d: int = FEATURE_DIM,
42
+ K: int = NUM_ACTIONS,
43
+ alpha: float = 1.5,
44
+ ) -> None:
45
+ self.d = d
46
+ self.K = K
47
+ self.alpha = alpha
48
+ self.total_updates = 0
49
+
50
+ loaded = self._load_weights()
51
+ if loaded is not None:
52
+ self.A_inv = loaded["A_inv"]
53
+ self.b = loaded["b"]
54
+ self.counts = loaded["counts"]
55
+ self.total_updates = loaded["total_updates"]
56
+ else:
57
+ self.A_inv: List[np.ndarray] = [np.eye(d) for _ in range(K)]
58
+ self.b: List[np.ndarray] = [np.zeros(d) for _ in range(K)]
59
+ self.counts: List[int] = [0] * K
60
+
61
+ # ─── Core Interface ──────────────────────────────────────────
62
+
63
+ def select_action(self, x: List[float]) -> Tuple[RepairAction, List[float]]:
64
+ """
65
+ Select the action with highest UCB score.
66
+ Returns (action, scores_for_all_actions).
67
+ """
68
+ xv = np.array(x, dtype=np.float64)
69
+ scores = []
70
+
71
+ for a in range(self.K):
72
+ theta = self.A_inv[a] @ self.b[a]
73
+ exploit = float(theta @ xv)
74
+ quad = float(xv @ self.A_inv[a] @ xv)
75
+ explore = self.alpha * float(np.sqrt(max(0.0, quad)))
76
+ scores.append(exploit + explore)
77
+
78
+ # Argmax with random tie-breaking
79
+ best_action = 0
80
+ best_score = scores[0]
81
+ for a in range(1, self.K):
82
+ if scores[a] > best_score or (
83
+ scores[a] == best_score and random.random() > 0.5
84
+ ):
85
+ best_score = scores[a]
86
+ best_action = a
87
+
88
+ return RepairAction(best_action), scores
89
+
90
+ def update(self, x: List[float], action: RepairAction, reward: float) -> None:
91
+ """
92
+ Update the model after observing a reward.
93
+ Uses Sherman-Morrison: (A + xx^T)^{-1} = A^{-1} - (A^{-1}xx^T A^{-1}) / (1 + x^T A^{-1} x)
94
+ """
95
+ a = int(action)
96
+ xv = np.array(x, dtype=np.float64)
97
+
98
+ A_inv_x = self.A_inv[a] @ xv # shape (d,)
99
+ denom = 1.0 + float(xv @ A_inv_x) # scalar
100
+
101
+ # Rank-1 downdate
102
+ self.A_inv[a] -= np.outer(A_inv_x, A_inv_x) / denom
103
+
104
+ # Reward-weighted feature accumulation
105
+ self.b[a] += reward * xv
106
+
107
+ self.counts[a] += 1
108
+ self.total_updates += 1
109
+
110
+ if self.total_updates % 10 == 0:
111
+ self.save_weights()
112
+
113
+ def get_estimated_rewards(self, x: List[float]) -> List[float]:
114
+ """
115
+ Return theta^T x for each action (no exploration bonus).
116
+ Useful for understanding learned policy.
117
+ """
118
+ xv = np.array(x, dtype=np.float64)
119
+ return [float((self.A_inv[a] @ self.b[a]) @ xv) for a in range(self.K)]
120
+
121
+ def get_action_counts(self) -> List[int]:
122
+ return list(self.counts)
123
+
124
+ def get_total_updates(self) -> int:
125
+ return self.total_updates
126
+
127
+ def get_alpha(self) -> float:
128
+ return self.alpha
129
+
130
+ def decay_alpha(self, min_alpha: float = 0.3) -> None:
131
+ """Decay exploration coefficient toward exploitation."""
132
+ self.alpha = max(min_alpha, self.alpha * 0.995)
133
+
134
+ def get_action_distribution(self) -> dict:
135
+ total = sum(self.counts) or 1
136
+ return {
137
+ REPAIR_ACTION_NAMES[RepairAction(a)]: self.counts[a] / total
138
+ for a in range(self.K)
139
+ }
140
+
141
+ # ─── Persistence ─────────────────────────────────────────────
142
+
143
+ def save_weights(self) -> None:
144
+ try:
145
+ WEIGHTS_PATH.parent.mkdir(parents=True, exist_ok=True)
146
+ data = {
147
+ "A_inv": [m.tolist() for m in self.A_inv],
148
+ "b": [v.tolist() for v in self.b],
149
+ "counts": self.counts,
150
+ "total_updates": self.total_updates,
151
+ "alpha": self.alpha,
152
+ }
153
+ WEIGHTS_PATH.write_text(json.dumps(data))
154
+ except Exception:
155
+ pass # Non-fatal
156
+
157
+ def _load_weights(self) -> Optional[dict]:
158
+ try:
159
+ if not WEIGHTS_PATH.exists():
160
+ return None
161
+ raw = json.loads(WEIGHTS_PATH.read_text())
162
+ A_inv = [np.array(m, dtype=np.float64) for m in raw["A_inv"]]
163
+ b = [np.array(v, dtype=np.float64) for v in raw["b"]]
164
+ # Validate dimensions
165
+ if (
166
+ len(A_inv) == self.K
167
+ and A_inv[0].shape == (self.d, self.d)
168
+ and len(b) == self.K
169
+ and b[0].shape == (self.d,)
170
+ ):
171
+ return {
172
+ "A_inv": A_inv,
173
+ "b": b,
174
+ "counts": raw["counts"],
175
+ "total_updates": raw["total_updates"],
176
+ }
177
+ return None
178
+ except Exception:
179
+ return None
180
+
181
+ def reset(self) -> None:
182
+ self.A_inv = [np.eye(self.d) for _ in range(self.K)]
183
+ self.b = [np.zeros(self.d) for _ in range(self.K)]
184
+ self.counts = [0] * self.K
185
+ self.total_updates = 0
186
+ self.alpha = 1.5
187
+ try:
188
+ WEIGHTS_PATH.unlink(missing_ok=True)
189
+ except Exception:
190
+ pass
backend/rl/repair_strategies.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Repair strategy prompt templates for each of the 8 RepairAction values.
3
+
4
+ Each strategy provides:
5
+ - system_suffix: appended to the base system prompt
6
+ - user_template: callable that builds the user message given a RepairContext
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Optional, Callable
13
+
14
+ from rl.types import RepairAction
15
+
16
+
17
+ @dataclass
18
+ class RepairContext:
19
+ schema: str
20
+ question: str
21
+ failing_sql: str
22
+ error_message: str
23
+ offending_token: Optional[str]
24
+
25
+
26
+ @dataclass
27
+ class RepairStrategy:
28
+ action: RepairAction
29
+ name: str
30
+ system_suffix: str
31
+ user_template: Callable[[RepairContext], str]
32
+
33
+
34
+ def _tmpl_rewrite_full(ctx: RepairContext) -> str:
35
+ return (
36
+ f"Schema:\n{ctx.schema}\n\n"
37
+ f"Question: {ctx.question}\n\n"
38
+ f"A previous attempt failed with: {ctx.error_message}\n\n"
39
+ "Write a completely new SQL query from scratch. Do NOT reference the previous attempt."
40
+ )
41
+
42
+
43
+ def _tmpl_fix_column(ctx: RepairContext) -> str:
44
+ token_hint = f"\n\nThe problematic column is: {ctx.offending_token}" if ctx.offending_token else ""
45
+ return (
46
+ f"Schema:\n{ctx.schema}\n\n"
47
+ f"Question: {ctx.question}\n\n"
48
+ f"Previous SQL:\n{ctx.failing_sql}\n\n"
49
+ f"Error: {ctx.error_message}"
50
+ f"{token_hint}\n\n"
51
+ "Fix ONLY the column name issue. Check the schema for correct column names."
52
+ )
53
+
54
+
55
+ def _tmpl_fix_table(ctx: RepairContext) -> str:
56
+ token_hint = f"\n\nThe problematic table is: {ctx.offending_token}" if ctx.offending_token else ""
57
+ return (
58
+ f"Schema:\n{ctx.schema}\n\n"
59
+ f"Question: {ctx.question}\n\n"
60
+ f"Previous SQL:\n{ctx.failing_sql}\n\n"
61
+ f"Error: {ctx.error_message}"
62
+ f"{token_hint}\n\n"
63
+ "Fix the table name or JOIN issue. Verify all table names exist in the schema."
64
+ )
65
+
66
+
67
+ def _tmpl_add_groupby(ctx: RepairContext) -> str:
68
+ return (
69
+ f"Schema:\n{ctx.schema}\n\n"
70
+ f"Question: {ctx.question}\n\n"
71
+ f"Previous SQL:\n{ctx.failing_sql}\n\n"
72
+ f"Error: {ctx.error_message}\n\n"
73
+ "Fix the GROUP BY / aggregation issue. Ensure every non-aggregate column in SELECT is in GROUP BY."
74
+ )
75
+
76
+
77
+ def _tmpl_rewrite_cte(ctx: RepairContext) -> str:
78
+ return (
79
+ f"Schema:\n{ctx.schema}\n\n"
80
+ f"Question: {ctx.question}\n\n"
81
+ f"Previous SQL:\n{ctx.failing_sql}\n\n"
82
+ f"Error: {ctx.error_message}\n\n"
83
+ "Restructure the CTEs or subqueries. Break the query into clear, named WITH clauses."
84
+ )
85
+
86
+
87
+ def _tmpl_fix_syntax(ctx: RepairContext) -> str:
88
+ token_hint = f"\n\nSyntax error near: {ctx.offending_token}" if ctx.offending_token else ""
89
+ return (
90
+ f"Schema:\n{ctx.schema}\n\n"
91
+ f"Question: {ctx.question}\n\n"
92
+ f"Previous SQL:\n{ctx.failing_sql}\n\n"
93
+ f"Error: {ctx.error_message}"
94
+ f"{token_hint}\n\n"
95
+ "Fix the syntax error. Check for typos, missing commas, unmatched parentheses."
96
+ )
97
+
98
+
99
+ def _tmpl_change_dialect(ctx: RepairContext) -> str:
100
+ return (
101
+ f"Schema:\n{ctx.schema}\n\n"
102
+ f"Question: {ctx.question}\n\n"
103
+ f"Previous SQL:\n{ctx.failing_sql}\n\n"
104
+ f"Error: {ctx.error_message}\n\n"
105
+ "The SQL uses functions or syntax not supported by SQLite. "
106
+ "Rewrite using SQLite-compatible alternatives."
107
+ )
108
+
109
+
110
+ def _tmpl_relax_filter(ctx: RepairContext) -> str:
111
+ return (
112
+ f"Schema:\n{ctx.schema}\n\n"
113
+ f"Question: {ctx.question}\n\n"
114
+ f"Previous SQL:\n{ctx.failing_sql}\n\n"
115
+ f"Error: {ctx.error_message}\n\n"
116
+ "Review and relax the WHERE/HAVING conditions. "
117
+ "Check date formats, value ranges, and filter logic."
118
+ )
119
+
120
+
121
+ _STRATEGIES: dict[RepairAction, RepairStrategy] = {
122
+ RepairAction.REWRITE_FULL: RepairStrategy(
123
+ action=RepairAction.REWRITE_FULL,
124
+ name="Full Rewrite",
125
+ system_suffix=(
126
+ "\n\nIMPORTANT: The previous SQL attempt was fundamentally flawed. "
127
+ "Discard it entirely and write a new query from scratch based only on "
128
+ "the schema and question. Do NOT try to patch the previous SQL."
129
+ ),
130
+ user_template=_tmpl_rewrite_full,
131
+ ),
132
+ RepairAction.FIX_COLUMN: RepairStrategy(
133
+ action=RepairAction.FIX_COLUMN,
134
+ name="Fix Column",
135
+ system_suffix=(
136
+ "\n\nIMPORTANT: The previous SQL referenced a wrong column name. "
137
+ "Carefully check the schema for the exact column names in each table. "
138
+ "Pay attention to singular vs plural, underscores, and exact spelling."
139
+ ),
140
+ user_template=_tmpl_fix_column,
141
+ ),
142
+ RepairAction.FIX_TABLE: RepairStrategy(
143
+ action=RepairAction.FIX_TABLE,
144
+ name="Fix Table",
145
+ system_suffix=(
146
+ "\n\nIMPORTANT: The previous SQL referenced a wrong table name or had "
147
+ "incorrect JOIN relationships. Check the schema for exact table names "
148
+ "and foreign key relationships."
149
+ ),
150
+ user_template=_tmpl_fix_table,
151
+ ),
152
+ RepairAction.ADD_GROUPBY: RepairStrategy(
153
+ action=RepairAction.ADD_GROUPBY,
154
+ name="Fix GROUP BY",
155
+ system_suffix=(
156
+ "\n\nIMPORTANT: The previous SQL has an aggregation error. Every column "
157
+ "in SELECT that is not inside an aggregate function (COUNT, SUM, AVG, etc.) "
158
+ "MUST appear in the GROUP BY clause. Check all selected columns."
159
+ ),
160
+ user_template=_tmpl_add_groupby,
161
+ ),
162
+ RepairAction.REWRITE_CTE: RepairStrategy(
163
+ action=RepairAction.REWRITE_CTE,
164
+ name="Rewrite CTE/Subquery",
165
+ system_suffix=(
166
+ "\n\nIMPORTANT: The previous SQL had issues with CTEs or subqueries. "
167
+ "Restructure the query β€” consider using WITH clauses for clarity, or "
168
+ "flatten nested subqueries. Ensure CTE column names are explicitly defined if needed."
169
+ ),
170
+ user_template=_tmpl_rewrite_cte,
171
+ ),
172
+ RepairAction.FIX_SYNTAX: RepairStrategy(
173
+ action=RepairAction.FIX_SYNTAX,
174
+ name="Fix Syntax",
175
+ system_suffix=(
176
+ "\n\nIMPORTANT: The previous SQL has a syntax error. Check for: "
177
+ "missing commas, unmatched parentheses, misspelled keywords, "
178
+ "incorrect operator usage, missing AS aliases."
179
+ ),
180
+ user_template=_tmpl_fix_syntax,
181
+ ),
182
+ RepairAction.CHANGE_DIALECT: RepairStrategy(
183
+ action=RepairAction.CHANGE_DIALECT,
184
+ name="Fix Dialect",
185
+ system_suffix=(
186
+ "\n\nIMPORTANT: The previous SQL used functions or syntax not available in SQLite. "
187
+ "Key SQLite rules:\n"
188
+ "- Use strftime() for date formatting, NOT DATE_FORMAT or EXTRACT\n"
189
+ "- No FULL OUTER JOIN or RIGHT JOIN β€” use LEFT JOIN with UNION\n"
190
+ "- Use CAST(x AS INTEGER), not CONVERT()\n"
191
+ "- No ILIKE β€” use LIKE (case-insensitive by default for ASCII)\n"
192
+ "- String concatenation uses || not CONCAT()\n"
193
+ "- No LIMIT inside subqueries with IN (use CTE instead)"
194
+ ),
195
+ user_template=_tmpl_change_dialect,
196
+ ),
197
+ RepairAction.RELAX_FILTER: RepairStrategy(
198
+ action=RepairAction.RELAX_FILTER,
199
+ name="Relax Filter",
200
+ system_suffix=(
201
+ "\n\nIMPORTANT: The previous SQL may have overly restrictive WHERE conditions, "
202
+ "incorrect date ranges, or wrong filter values causing empty results or errors. "
203
+ "Review the filter conditions and broaden them to capture the intended data."
204
+ ),
205
+ user_template=_tmpl_relax_filter,
206
+ ),
207
+ }
208
+
209
+
210
+ def get_repair_system_suffix(action: RepairAction) -> str:
211
+ return _STRATEGIES[action].system_suffix
212
+
213
+
214
+ def build_repair_user_message(action: RepairAction, ctx: RepairContext) -> str:
215
+ return _STRATEGIES[action].user_template(ctx)
216
+
217
+
218
+ def get_repair_name(action: RepairAction) -> str:
219
+ return _STRATEGIES[action].name
backend/rl/types.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RL type definitions and feature engineering.
3
+
4
+ Mirrors the TypeScript types.ts exactly:
5
+ - 8 error classes, 8 repair actions
6
+ - FEATURE_DIM = 20
7
+ - featurize() builds the state vector
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from enum import IntEnum
13
+ from typing import Optional, List, Dict, Any
14
+ from pydantic import BaseModel
15
+
16
+
17
+ # ─── Error Taxonomy ─────────────────────────────────────────────
18
+
19
+ class ErrorClass(IntEnum):
20
+ NO_SUCH_COLUMN = 0
21
+ NO_SUCH_TABLE = 1
22
+ SYNTAX_ERROR = 2
23
+ AMBIGUOUS_COLUMN = 3
24
+ DATATYPE_MISMATCH = 4
25
+ NO_SUCH_FUNCTION = 5
26
+ AGGREGATION_ERROR = 6
27
+ OTHER = 7
28
+
29
+
30
+ ERROR_CLASS_NAMES: Dict[ErrorClass, str] = {
31
+ ErrorClass.NO_SUCH_COLUMN: "no_such_column",
32
+ ErrorClass.NO_SUCH_TABLE: "no_such_table",
33
+ ErrorClass.SYNTAX_ERROR: "syntax_error",
34
+ ErrorClass.AMBIGUOUS_COLUMN: "ambiguous_column",
35
+ ErrorClass.DATATYPE_MISMATCH: "datatype_mismatch",
36
+ ErrorClass.NO_SUCH_FUNCTION: "no_such_function",
37
+ ErrorClass.AGGREGATION_ERROR: "aggregation_error",
38
+ ErrorClass.OTHER: "other",
39
+ }
40
+
41
+ NUM_ERROR_CLASSES = 8
42
+
43
+
44
+ # ─── Repair Actions ─────────────────────────────────────────────
45
+
46
+ class RepairAction(IntEnum):
47
+ REWRITE_FULL = 0
48
+ FIX_COLUMN = 1
49
+ FIX_TABLE = 2
50
+ ADD_GROUPBY = 3
51
+ REWRITE_CTE = 4
52
+ FIX_SYNTAX = 5
53
+ CHANGE_DIALECT = 6
54
+ RELAX_FILTER = 7
55
+
56
+
57
+ REPAIR_ACTION_NAMES: Dict[RepairAction, str] = {
58
+ RepairAction.REWRITE_FULL: "rewrite_full",
59
+ RepairAction.FIX_COLUMN: "fix_column",
60
+ RepairAction.FIX_TABLE: "fix_table",
61
+ RepairAction.ADD_GROUPBY: "add_groupby",
62
+ RepairAction.REWRITE_CTE: "rewrite_cte",
63
+ RepairAction.FIX_SYNTAX: "fix_syntax",
64
+ RepairAction.CHANGE_DIALECT: "change_dialect",
65
+ RepairAction.RELAX_FILTER: "relax_filter",
66
+ }
67
+
68
+ # Inverse map: name β†’ enum
69
+ REPAIR_ACTION_BY_NAME: Dict[str, RepairAction] = {v: k for k, v in REPAIR_ACTION_NAMES.items()}
70
+
71
+ NUM_ACTIONS = 8
72
+
73
+ # Feature vector:
74
+ # [0..7] error class one-hot (8)
75
+ # [8] attempt / 5.0 (1)
76
+ # [9..16] prev action one-hot (8)
77
+ # [17] error_changed (1)
78
+ # [18] consec_count / 5.0 (1)
79
+ # [19] bias = 1.0 (1)
80
+ # total = 20
81
+ FEATURE_DIM = 20
82
+
83
+
84
+ # ─── State ──────────────────────────────────────────────────────
85
+
86
+ class RLState(BaseModel):
87
+ error_class: ErrorClass
88
+ attempt_number: int # 1-indexed
89
+ previous_action: Optional[RepairAction] = None
90
+ error_changed: bool = False
91
+ consecutive_same_error: int = 1
92
+
93
+
94
+ def featurize(state: RLState) -> List[float]:
95
+ """Build the 20-dimensional feature vector from an RLState."""
96
+ x = [0.0] * FEATURE_DIM
97
+
98
+ # Error class one-hot [0..7]
99
+ x[state.error_class] = 1.0
100
+
101
+ # Attempt number normalized [8]
102
+ x[8] = state.attempt_number / 5.0
103
+
104
+ # Previous action one-hot [9..16]
105
+ if state.previous_action is not None:
106
+ x[9 + int(state.previous_action)] = 1.0
107
+
108
+ # Error changed flag [17]
109
+ x[17] = 1.0 if state.error_changed else 0.0
110
+
111
+ # Consecutive same error normalized [18]
112
+ x[18] = min(state.consecutive_same_error, 5) / 5.0
113
+
114
+ # Bias term [19]
115
+ x[19] = 1.0
116
+
117
+ return x
118
+
119
+
120
+ # ─── Experience / Episode ────────────────────────────────────────
121
+
122
+ class EpisodeStep(BaseModel):
123
+ state: RLState
124
+ featurized: List[float]
125
+ action: RepairAction
126
+ reward: float
127
+ error_message: str
128
+ sql: str
129
+ success: bool
130
+
131
+
132
+ class Episode(BaseModel):
133
+ id: str
134
+ question: str
135
+ steps: List[EpisodeStep]
136
+ total_reward: float
137
+ success: bool
138
+ timestamp: float
139
+
140
+
141
+ class Experience(BaseModel):
142
+ state: List[float]
143
+ action: RepairAction
144
+ reward: float
145
+ next_state: Optional[List[float]] = None
146
+ done: bool
147
+ timestamp: float
148
+ metadata: Dict[str, Any]
149
+
150
+
151
+ # ─── Metrics ────────────────────────────────────────────────────
152
+
153
+ class RLMetrics(BaseModel):
154
+ total_episodes: int
155
+ total_steps: int
156
+ cumulative_reward: float
157
+ success_rate: float
158
+ avg_attempts: float
159
+ action_distribution: Dict[str, int]
160
+ error_distribution: Dict[str, int]
161
+ reward_history: List[float]
frontend/index.html ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <link rel="icon" type="image/svg+xml" href="/favicon.svg" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
+ <title>SQL Agent OpenEnv β€” RL Environment</title>
8
+ <meta name="description" content="SQL Agent with Reinforcement Learning and GEPA prompt evolution" />
9
+ </head>
10
+ <body>
11
+ <div id="root"></div>
12
+ <script type="module" src="/src/main.tsx"></script>
13
+ </body>
14
+ </html>
frontend/package-lock.json ADDED
The diff for this file is too large to render. See raw diff
 
frontend/package.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "sql-openenv-ui",
3
+ "private": true,
4
+ "version": "0.1.0",
5
+ "type": "module",
6
+ "scripts": {
7
+ "dev": "vite --port 5173",
8
+ "build": "vite build",
9
+ "preview": "vite preview"
10
+ },
11
+ "dependencies": {
12
+ "react": "^19.0.0",
13
+ "react-dom": "^19.0.0",
14
+ "framer-motion": "^11.0.0",
15
+ "lucide-react": "^0.400.0",
16
+ "recharts": "^2.12.0",
17
+ "zustand": "^4.5.0",
18
+ "react-markdown": "^9.0.0"
19
+ },
20
+ "devDependencies": {
21
+ "@types/react": "^19.0.0",
22
+ "@types/react-dom": "^19.0.0",
23
+ "@vitejs/plugin-react": "^4.3.0",
24
+ "typescript": "^5.5.0",
25
+ "vite": "^5.4.0",
26
+ "tailwindcss": "^3.4.0",
27
+ "autoprefixer": "^10.4.0",
28
+ "postcss": "^8.4.0"
29
+ }
30
+ }
frontend/postcss.config.js ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ export default {
2
+ plugins: {
3
+ tailwindcss: {},
4
+ autoprefixer: {},
5
+ },
6
+ }
frontend/src/App.tsx ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState, useEffect } from 'react'
2
+ import { motion, AnimatePresence } from 'framer-motion'
3
+ import { MessageSquare, Target, GitFork, X } from 'lucide-react'
4
+
5
+ import { Header } from './components/Header'
6
+ import { LeftSidebar } from './components/LeftSidebar'
7
+ import { ChatPanel } from './components/ChatPanel'
8
+ import { BenchmarkPanel } from './components/BenchmarkPanel'
9
+ import { ERDiagram } from './components/ERDiagram'
10
+ import { RightSidebar } from './components/RightSidebar'
11
+ import { useStore } from './store/useStore'
12
+ import { fetchInit } from './lib/api'
13
+
14
+ type Tab = 'chat' | 'benchmark' | 'er'
15
+
16
+ const TABS: { id: Tab; label: string; icon: React.ReactNode }[] = [
17
+ { id: 'chat', label: 'Chat', icon: <MessageSquare size={12} /> },
18
+ { id: 'benchmark', label: 'Benchmark', icon: <Target size={12} /> },
19
+ { id: 'er', label: 'ER Diagram', icon: <GitFork size={12} /> },
20
+ ]
21
+
22
+ export default function App() {
23
+ const [activeTab, setActiveTab] = useState<Tab>('chat')
24
+ const [leftOpen, setLeftOpen] = useState(false)
25
+ const [rightOpen, setRightOpen] = useState(false)
26
+
27
+ const { theme, setDbSeeded, setTables, setSchemaGraph } = useStore()
28
+
29
+ // Apply theme on mount / change
30
+ useEffect(() => {
31
+ document.documentElement.setAttribute('data-theme', theme)
32
+ }, [theme])
33
+
34
+ // Restore theme from storage on mount
35
+ useEffect(() => {
36
+ try {
37
+ const saved = localStorage.getItem('theme') as 'dark' | 'light' | null
38
+ if (saved) {
39
+ document.documentElement.setAttribute('data-theme', saved)
40
+ useStore.setState({ theme: saved })
41
+ }
42
+ } catch { /* noop */ }
43
+ }, [])
44
+
45
+ // Fetch init data
46
+ useEffect(() => {
47
+ fetchInit()
48
+ .then((d) => {
49
+ setDbSeeded(true)
50
+ setTables(d.tables)
51
+ // Lazy-load schema graph
52
+ fetch('/api/schema-graph')
53
+ .then((r) => r.json())
54
+ .then((g) => setSchemaGraph(g))
55
+ .catch(() => { /* noop */ })
56
+ })
57
+ .catch(() => { /* noop */ })
58
+ }, [setDbSeeded, setTables, setSchemaGraph])
59
+
60
+ // Close mobile sidebars on tab change
61
+ useEffect(() => {
62
+ setLeftOpen(false)
63
+ setRightOpen(false)
64
+ }, [activeTab])
65
+
66
+ return (
67
+ <div
68
+ className="h-screen flex flex-col overflow-hidden theme-bg-primary theme-text-primary"
69
+ style={{ fontFamily: 'ui-monospace,"SF Mono",Consolas,"Liberation Mono",monospace' }}
70
+ >
71
+ <Header
72
+ onToggleLeft={() => { setLeftOpen((v) => !v); setRightOpen(false) }}
73
+ onToggleRight={() => { setRightOpen((v) => !v); setLeftOpen(false) }}
74
+ />
75
+
76
+ <div className="flex flex-1 overflow-hidden relative">
77
+ {/* Overlay backdrop (mobile) */}
78
+ {(leftOpen || rightOpen) && (
79
+ <div
80
+ className="fixed inset-0 bg-black/50 z-30 lg:hidden"
81
+ onClick={() => { setLeftOpen(false); setRightOpen(false) }}
82
+ />
83
+ )}
84
+
85
+ {/* LEFT SIDEBAR */}
86
+ <aside
87
+ className={`
88
+ fixed top-[53px] bottom-0 left-0 z-40 w-60 border-r theme-border flex flex-col overflow-y-auto
89
+ transition-transform duration-200 ease-out
90
+ lg:static lg:w-60 lg:shrink-0 lg:translate-x-0 lg:z-auto
91
+ ${leftOpen ? 'translate-x-0' : '-translate-x-full'}
92
+ `}
93
+ style={{ background: 'var(--bg-secondary)' }}
94
+ >
95
+ <div className="flex items-center justify-between px-4 pt-3 pb-1 lg:hidden">
96
+ <span className="text-[10px] font-semibold text-gray-500 uppercase tracking-wider">
97
+ Dataset & Tasks
98
+ </span>
99
+ <button
100
+ onClick={() => setLeftOpen(false)}
101
+ className="p-1 rounded hover:bg-white/5 text-gray-500"
102
+ >
103
+ <X size={14} />
104
+ </button>
105
+ </div>
106
+ <div className="flex-1 px-4 py-3">
107
+ <LeftSidebar />
108
+ </div>
109
+ </aside>
110
+
111
+ {/* CENTER: Tabbed panel */}
112
+ <main className="flex-1 flex flex-col overflow-hidden min-w-0">
113
+ {/* Tab bar */}
114
+ <div
115
+ className="flex items-center gap-1 px-2 sm:px-4 py-2.5 border-b theme-border shrink-0 overflow-x-auto scrollbar-none"
116
+ style={{ background: 'var(--bg-secondary)' }}
117
+ >
118
+ {TABS.map((tab) => (
119
+ <button
120
+ key={tab.id}
121
+ onClick={() => setActiveTab(tab.id)}
122
+ className={`flex items-center gap-1.5 px-2.5 sm:px-3 py-1.5 rounded-lg text-xs font-medium transition-all whitespace-nowrap shrink-0 ${
123
+ activeTab === tab.id
124
+ ? 'bg-violet-600/20 text-violet-300 border border-violet-500/30'
125
+ : 'text-gray-500 hover:text-gray-300 hover:bg-white/5 border border-transparent'
126
+ }`}
127
+ >
128
+ {tab.icon}
129
+ <span>{tab.label}</span>
130
+ </button>
131
+ ))}
132
+ </div>
133
+
134
+ {/* Tab content */}
135
+ <div className="flex-1 overflow-hidden relative">
136
+ <AnimatePresence mode="wait">
137
+ <motion.div
138
+ key={activeTab}
139
+ initial={{ opacity: 0, y: 4 }}
140
+ animate={{ opacity: 1, y: 0 }}
141
+ exit={{ opacity: 0 }}
142
+ transition={{ duration: 0.15 }}
143
+ className="absolute inset-0 flex flex-col overflow-hidden"
144
+ >
145
+ {activeTab === 'chat' && <ChatPanel />}
146
+ {activeTab === 'benchmark' && <BenchmarkPanel />}
147
+ {activeTab === 'er' && <ERDiagram />}
148
+ </motion.div>
149
+ </AnimatePresence>
150
+ </div>
151
+ </main>
152
+
153
+ {/* RIGHT SIDEBAR */}
154
+ <aside
155
+ className={`
156
+ fixed top-[53px] bottom-0 right-0 z-40 w-72 border-l theme-border flex flex-col overflow-hidden
157
+ transition-transform duration-200 ease-out
158
+ lg:static lg:w-72 lg:shrink-0 lg:translate-x-0 lg:z-auto
159
+ ${rightOpen ? 'translate-x-0' : 'translate-x-full'}
160
+ `}
161
+ style={{ background: 'var(--bg-secondary)' }}
162
+ >
163
+ <div className="flex items-center justify-between px-4 pt-3 pb-1 lg:hidden">
164
+ <span className="text-[10px] font-semibold text-gray-500 uppercase tracking-wider">
165
+ GEPA & RL
166
+ </span>
167
+ <button
168
+ onClick={() => setRightOpen(false)}
169
+ className="p-1 rounded hover:bg-white/5 text-gray-500"
170
+ >
171
+ <X size={14} />
172
+ </button>
173
+ </div>
174
+ <RightSidebar />
175
+ </aside>
176
+ </div>
177
+ </div>
178
+ )
179
+ }
frontend/src/components/BenchmarkPanel.tsx ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState, useCallback } from 'react'
2
+ import { motion, AnimatePresence } from 'framer-motion'
3
+ import {
4
+ Target, Play, Loader2, CheckCircle2, XCircle,
5
+ ChevronDown, RotateCcw, Zap,
6
+ } from 'lucide-react'
7
+ import { useStore } from '../store/useStore'
8
+ import { streamBenchmark } from '../lib/api'
9
+ import type { BenchmarkResult, Difficulty } from '../lib/types'
10
+
11
+ const DIFFICULTY_TABS: { id: Difficulty; label: string }[] = [
12
+ { id: 'easy', label: 'Easy' },
13
+ { id: 'medium', label: 'Medium' },
14
+ { id: 'hard', label: 'Hard' },
15
+ ]
16
+
17
+ function QueryRow({
18
+ result,
19
+ isActive,
20
+ isExpanded,
21
+ onToggleExpand,
22
+ onRunSingle,
23
+ isRunning,
24
+ dbSeeded,
25
+ }: {
26
+ result: BenchmarkResult
27
+ isActive: boolean
28
+ isExpanded: boolean
29
+ onToggleExpand: () => void
30
+ onRunSingle: () => void
31
+ isRunning: boolean
32
+ dbSeeded: boolean
33
+ }) {
34
+ const statusIcon = () => {
35
+ switch (result.status) {
36
+ case 'pending': return <span className="w-2 h-2 rounded-full bg-gray-600 shrink-0" />
37
+ case 'running': return <Loader2 size={12} className="text-violet-400 animate-spin shrink-0" />
38
+ case 'pass': return <CheckCircle2 size={12} className="text-green-400 shrink-0" />
39
+ case 'fail': return <XCircle size={12} className="text-red-400 shrink-0" />
40
+ }
41
+ }
42
+
43
+ const difficultyColor =
44
+ result.difficulty === 'hard'
45
+ ? 'text-red-400 bg-red-500/10 border-red-500/25'
46
+ : result.difficulty === 'medium'
47
+ ? 'text-amber-400 bg-amber-500/10 border-amber-500/25'
48
+ : 'text-blue-400 bg-blue-500/10 border-blue-500/25'
49
+
50
+ return (
51
+ <div
52
+ className={`rounded-xl border transition-all duration-150 ${
53
+ isActive
54
+ ? 'border-violet-500/40 bg-violet-500/5'
55
+ : 'border-white/5 bg-white/[0.02] hover:bg-white/[0.04]'
56
+ }`}
57
+ >
58
+ <div
59
+ className="flex items-start gap-2 px-3 py-2.5 cursor-pointer"
60
+ onClick={onToggleExpand}
61
+ >
62
+ <div className="mt-0.5 shrink-0">{statusIcon()}</div>
63
+ <div className="flex-1 min-w-0">
64
+ <div className="flex items-center gap-2 mb-0.5 flex-wrap">
65
+ <span className="text-[10px] font-mono text-gray-600">{result.id}</span>
66
+ <span className={`text-[9px] font-semibold px-1.5 py-0.5 rounded-full border ${difficultyColor}`}>
67
+ {result.difficulty}
68
+ </span>
69
+ {result.score !== null && (
70
+ <span className={`text-[10px] font-mono font-bold ${result.status === 'pass' ? 'text-green-400' : 'text-red-400'}`}>
71
+ {result.score.toFixed(2)}
72
+ </span>
73
+ )}
74
+ {result.attempts !== null && (
75
+ <span className="text-[9px] text-gray-600 font-mono">
76
+ {result.attempts} attempt{result.attempts !== 1 ? 's' : ''}
77
+ </span>
78
+ )}
79
+ </div>
80
+ <div className="text-xs text-gray-300 leading-relaxed line-clamp-2">
81
+ {result.question}
82
+ </div>
83
+ {result.reason && result.status !== 'pending' && (
84
+ <div className={`text-[10px] mt-1 ${result.status === 'pass' ? 'text-green-500/70' : 'text-red-400/70'}`}>
85
+ {result.reason.length > 120 ? result.reason.slice(0, 120) + '…' : result.reason}
86
+ </div>
87
+ )}
88
+ </div>
89
+ <div className="flex items-center gap-1.5 shrink-0">
90
+ {result.status === 'pending' && dbSeeded && !isRunning && (
91
+ <button
92
+ onClick={(e) => { e.stopPropagation(); onRunSingle() }}
93
+ className="p-1 rounded-lg hover:bg-white/10 transition-colors"
94
+ title="Run this query"
95
+ >
96
+ <Play size={10} className="text-gray-500 hover:text-violet-400" />
97
+ </button>
98
+ )}
99
+ <ChevronDown
100
+ size={11}
101
+ className={`text-gray-600 transition-transform duration-150 ${isExpanded ? 'rotate-180' : ''}`}
102
+ />
103
+ </div>
104
+ </div>
105
+
106
+ {/* Expanded detail */}
107
+ <AnimatePresence>
108
+ {isExpanded && (
109
+ <motion.div
110
+ initial={{ height: 0, opacity: 0 }}
111
+ animate={{ height: 'auto', opacity: 1 }}
112
+ exit={{ height: 0, opacity: 0 }}
113
+ transition={{ duration: 0.15 }}
114
+ className="overflow-hidden"
115
+ >
116
+ <div className="px-3 pb-3 flex flex-col gap-2 border-t border-white/5 pt-2">
117
+ <p className="text-xs text-gray-400 leading-relaxed">{result.question}</p>
118
+
119
+ {result.sql && (
120
+ <div>
121
+ <div className="text-[10px] text-gray-600 mb-1 font-semibold uppercase tracking-wider">
122
+ Generated SQL
123
+ </div>
124
+ <pre className="text-[10px] font-mono text-violet-200/70 bg-black/40 rounded-lg p-2.5 border border-white/5 whitespace-pre-wrap leading-relaxed max-h-40 overflow-y-auto">
125
+ {result.sql}
126
+ </pre>
127
+ </div>
128
+ )}
129
+
130
+ {(result.refRowCount !== null || result.reason) && (
131
+ <div className="flex flex-col gap-1.5">
132
+ {result.refRowCount !== null && (
133
+ <div className="flex items-center gap-3 text-[10px] font-mono">
134
+ <span className="text-gray-600">reference:</span>
135
+ <span className="text-blue-400">{result.refRowCount} rows</span>
136
+ <span className="text-gray-600">agent:</span>
137
+ <span className={
138
+ result.agentRowCount === result.refRowCount
139
+ ? 'text-green-400'
140
+ : result.agentRowCount === 0
141
+ ? 'text-red-400'
142
+ : 'text-amber-400'
143
+ }>
144
+ {result.agentRowCount ?? 0} rows
145
+ </span>
146
+ </div>
147
+ )}
148
+ {result.reason && (
149
+ <div className={`text-[10px] leading-relaxed ${result.status === 'pass' ? 'text-green-400/80' : 'text-red-400/80'}`}>
150
+ {result.reason}
151
+ </div>
152
+ )}
153
+ </div>
154
+ )}
155
+
156
+ {result.status !== 'pending' && result.status !== 'running' && !isRunning && dbSeeded && (
157
+ <button
158
+ onClick={(e) => { e.stopPropagation(); onRunSingle() }}
159
+ className="flex items-center gap-1 text-[10px] text-violet-400 hover:text-violet-300 transition-colors self-start mt-1"
160
+ >
161
+ <RotateCcw size={9} />
162
+ Re-run
163
+ </button>
164
+ )}
165
+ </div>
166
+ </motion.div>
167
+ )}
168
+ </AnimatePresence>
169
+ </div>
170
+ )
171
+ }
172
+
173
+ export function BenchmarkPanel() {
174
+ const {
175
+ benchmarkResults, isBenchmarking, overallScore,
176
+ activeBenchmarkId, dbSeeded,
177
+ setIsBenchmarking, updateBenchmarkResult, setOverallScore,
178
+ setActiveBenchmarkId, resetBenchmark,
179
+ taskDifficulty, setTaskDifficulty,
180
+ } = useStore()
181
+
182
+ const [expandedIds, setExpandedIds] = useState<Set<string>>(new Set())
183
+
184
+ const toggleExpand = (id: string) => {
185
+ setExpandedIds((prev) => {
186
+ const next = new Set(prev)
187
+ if (next.has(id)) next.delete(id)
188
+ else next.add(id)
189
+ return next
190
+ })
191
+ }
192
+
193
+ const runBenchmark = useCallback(
194
+ async (queryIds?: string[]) => {
195
+ if (isBenchmarking) return
196
+ setIsBenchmarking(true)
197
+
198
+ const targetIds = queryIds ?? benchmarkResults.map((r) => r.id)
199
+ for (const id of targetIds) {
200
+ const existing = benchmarkResults.find((r) => r.id === id)
201
+ if (existing) {
202
+ updateBenchmarkResult({ ...existing, status: 'running', score: null, reason: null, sql: null })
203
+ }
204
+ }
205
+
206
+ try {
207
+ for await (const event of streamBenchmark(taskDifficulty, queryIds)) {
208
+ if (event.type === 'query_start') {
209
+ setActiveBenchmarkId(event.id as string)
210
+ const existing = benchmarkResults.find((r) => r.id === event.id)
211
+ if (existing) updateBenchmarkResult({ ...existing, status: 'running' })
212
+ } else if (event.type === 'query_result') {
213
+ const existing = benchmarkResults.find((r) => r.id === event.id)
214
+ if (existing) {
215
+ updateBenchmarkResult({
216
+ ...existing,
217
+ status: (event.pass as boolean) ? 'pass' : 'fail',
218
+ score: event.score as number,
219
+ reason: event.reason as string,
220
+ sql: event.sql as string,
221
+ attempts: (event.attempts as number) ?? null,
222
+ refRowCount: (event.refRowCount as number) ?? null,
223
+ agentRowCount: (event.agentRowCount as number) ?? null,
224
+ })
225
+ }
226
+ } else if (event.type === 'done') {
227
+ setOverallScore(event.overallScore as number)
228
+ setActiveBenchmarkId(null)
229
+ setIsBenchmarking(false)
230
+ } else if (event.type === 'error') {
231
+ setActiveBenchmarkId(null)
232
+ setIsBenchmarking(false)
233
+ }
234
+ }
235
+ } catch {
236
+ setIsBenchmarking(false)
237
+ setActiveBenchmarkId(null)
238
+ }
239
+ },
240
+ [isBenchmarking, benchmarkResults, setIsBenchmarking, updateBenchmarkResult,
241
+ setOverallScore, setActiveBenchmarkId, taskDifficulty]
242
+ )
243
+
244
+ const passedCount = benchmarkResults.filter((r) => r.status === 'pass').length
245
+ const completedCount = benchmarkResults.filter((r) => r.status === 'pass' || r.status === 'fail').length
246
+ const totalScore = benchmarkResults.reduce((s, r) => s + (r.score ?? 0), 0)
247
+ const progressPct = benchmarkResults.length > 0 ? Math.round((completedCount / benchmarkResults.length) * 100) : 0
248
+ const scorePct = completedCount > 0 ? Math.round((totalScore / benchmarkResults.length) * 100) : 0
249
+
250
+ return (
251
+ <div className="flex flex-col h-full">
252
+ {/* Header */}
253
+ <div className="px-4 py-3 border-b border-white/[0.06] shrink-0">
254
+ <div className="flex items-center justify-between mb-2">
255
+ <div className="flex items-center gap-2">
256
+ <Target size={14} className="text-violet-400" />
257
+ <span className="text-xs font-semibold text-white">Benchmark</span>
258
+ {completedCount > 0 && (
259
+ <span className="text-xs text-gray-500 font-mono">
260
+ {passedCount}/{benchmarkResults.length} passed
261
+ </span>
262
+ )}
263
+ </div>
264
+ <div className="flex items-center gap-2">
265
+ {completedCount > 0 && (
266
+ <button
267
+ onClick={resetBenchmark}
268
+ disabled={isBenchmarking}
269
+ className="flex items-center gap-1 px-2 py-1 rounded-lg text-[10px] text-gray-500 hover:text-gray-300 hover:bg-white/5 transition-all disabled:opacity-40"
270
+ >
271
+ <RotateCcw size={10} />
272
+ Reset
273
+ </button>
274
+ )}
275
+ <button
276
+ onClick={() => void runBenchmark()}
277
+ disabled={isBenchmarking || !dbSeeded}
278
+ className="flex items-center gap-1.5 px-3 py-1.5 rounded-lg bg-violet-600 hover:bg-violet-500 disabled:opacity-40 disabled:cursor-not-allowed transition-all text-white text-xs font-semibold"
279
+ >
280
+ {isBenchmarking ? (
281
+ <Loader2 size={11} className="animate-spin" />
282
+ ) : (
283
+ <Play size={11} />
284
+ )}
285
+ Run All
286
+ </button>
287
+ </div>
288
+ </div>
289
+
290
+ {/* Overall score */}
291
+ {overallScore !== null && (
292
+ <motion.div
293
+ initial={{ opacity: 0, scale: 0.95 }}
294
+ animate={{ opacity: 1, scale: 1 }}
295
+ className="mb-2 flex items-center gap-3 px-3 py-2 rounded-xl border border-violet-500/20 bg-violet-500/5"
296
+ >
297
+ <Zap size={14} className="text-violet-400 shrink-0" />
298
+ <div>
299
+ <div className="text-[10px] text-gray-500 uppercase tracking-wider">Overall Score</div>
300
+ <div className="text-xl font-bold font-mono text-violet-300">
301
+ {(overallScore * 100).toFixed(0)}%
302
+ </div>
303
+ </div>
304
+ </motion.div>
305
+ )}
306
+
307
+ {/* Score bar */}
308
+ {completedCount > 0 && (
309
+ <div className="flex flex-col gap-1">
310
+ <div className="flex items-center justify-between text-[10px]">
311
+ <span className="text-gray-500">
312
+ Score: {totalScore.toFixed(1)}/{benchmarkResults.length}
313
+ </span>
314
+ <span className="text-violet-400 font-mono">{scorePct}%</span>
315
+ </div>
316
+ <div className="h-1.5 bg-white/5 rounded-full overflow-hidden">
317
+ <motion.div
318
+ className="h-full rounded-full bg-gradient-to-r from-violet-600 to-violet-400"
319
+ initial={{ width: 0 }}
320
+ animate={{ width: `${scorePct}%` }}
321
+ transition={{ duration: 0.5, ease: 'easeOut' }}
322
+ />
323
+ </div>
324
+ </div>
325
+ )}
326
+
327
+ {/* Progress */}
328
+ {isBenchmarking && (
329
+ <div className="mt-1.5">
330
+ <div className="h-1 bg-white/5 rounded-full overflow-hidden">
331
+ <motion.div
332
+ className="h-full rounded-full bg-violet-500/60"
333
+ initial={{ width: 0 }}
334
+ animate={{ width: `${progressPct}%` }}
335
+ transition={{ duration: 0.3 }}
336
+ />
337
+ </div>
338
+ </div>
339
+ )}
340
+ </div>
341
+
342
+ {/* Difficulty tabs */}
343
+ <div className="flex items-center gap-1 px-4 py-2 border-b border-white/[0.06] shrink-0">
344
+ {DIFFICULTY_TABS.map((tab) => (
345
+ <button
346
+ key={tab.id}
347
+ onClick={() => setTaskDifficulty(tab.id)}
348
+ className={`px-3 py-1 rounded-lg text-xs font-medium transition-all ${
349
+ taskDifficulty === tab.id
350
+ ? 'bg-violet-600/20 text-violet-300 border border-violet-500/30'
351
+ : 'text-gray-500 hover:text-gray-300 hover:bg-white/5 border border-transparent'
352
+ }`}
353
+ >
354
+ {tab.label}
355
+ </button>
356
+ ))}
357
+ </div>
358
+
359
+ {/* Query list */}
360
+ <div className="flex-1 overflow-y-auto">
361
+ <div className="p-2 flex flex-col gap-1">
362
+ {benchmarkResults.map((result) => (
363
+ <QueryRow
364
+ key={result.id}
365
+ result={result}
366
+ isActive={activeBenchmarkId === result.id}
367
+ isExpanded={expandedIds.has(result.id)}
368
+ onToggleExpand={() => toggleExpand(result.id)}
369
+ onRunSingle={() => void runBenchmark([result.id])}
370
+ isRunning={isBenchmarking}
371
+ dbSeeded={dbSeeded}
372
+ />
373
+ ))}
374
+ </div>
375
+ </div>
376
+
377
+ {!dbSeeded && (
378
+ <div className="px-4 py-2 border-t border-white/[0.06] text-[10px] text-gray-600 text-center shrink-0">
379
+ Waiting for database initialization...
380
+ </div>
381
+ )}
382
+ </div>
383
+ )
384
+ }
frontend/src/components/ChatPanel.tsx ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState, useRef, useEffect, useCallback } from 'react'
2
+ import { motion, AnimatePresence } from 'framer-motion'
3
+ import {
4
+ Send, CheckCircle2, XCircle, ChevronDown, ChevronUp,
5
+ Loader2, MessageSquare, Zap, RefreshCw, Trash2,
6
+ } from 'lucide-react'
7
+ import { useStore } from '../store/useStore'
8
+ import { streamExecuteQuery, submitFeedback } from '../lib/api'
9
+ import { ResultsTable } from './ResultsTable'
10
+ import type { ChatMessage, AttemptStep } from '../lib/types'
11
+
12
+ // ─── SQL Syntax Highlighter ───────────────────────────────────────
13
+
14
+ const SQL_KEYWORDS = /\b(SELECT|FROM|WHERE|JOIN|LEFT|RIGHT|INNER|OUTER|FULL|ON|GROUP\s+BY|ORDER\s+BY|HAVING|LIMIT|OFFSET|UNION|ALL|DISTINCT|AS|AND|OR|NOT|IN|IS|NULL|LIKE|BETWEEN|CASE|WHEN|THEN|ELSE|END|WITH|CTE|INSERT|UPDATE|DELETE|CREATE|DROP|ALTER|TABLE|INDEX|VIEW|SET|VALUES|INTO|EXISTS|COUNT|SUM|AVG|MIN|MAX|COALESCE|NULLIF|CAST|OVER|PARTITION\s+BY|ROW_NUMBER|RANK|DENSE_RANK|LAG|LEAD|DATE|STRFTIME|JULIANDAY|ROUND|ABS|LENGTH|SUBSTR|UPPER|LOWER|TRIM|REPLACE|IFNULL)\b/gi
15
+
16
+ function SqlBlock({ sql, streaming }: { sql: string; streaming?: boolean }) {
17
+ const parts: React.ReactNode[] = []
18
+ let last = 0
19
+ let match: RegExpExecArray | null
20
+
21
+ const re = new RegExp(SQL_KEYWORDS.source, 'gi')
22
+ while ((match = re.exec(sql)) !== null) {
23
+ if (match.index > last) {
24
+ parts.push(<span key={`t-${last}`}>{sql.slice(last, match.index)}</span>)
25
+ }
26
+ parts.push(
27
+ <span key={`k-${match.index}`} className="sql-keyword">
28
+ {match[0]}
29
+ </span>
30
+ )
31
+ last = match.index + match[0].length
32
+ }
33
+ if (last < sql.length) {
34
+ parts.push(<span key={`t-end`}>{sql.slice(last)}</span>)
35
+ }
36
+
37
+ return (
38
+ <pre
39
+ className="px-3 py-2.5 text-xs font-mono bg-violet-950/20 whitespace-pre-wrap overflow-x-auto leading-relaxed border-t border-white/[0.04]"
40
+ style={{ color: 'rgba(221, 214, 254, 0.8)' }}
41
+ >
42
+ {parts}
43
+ {streaming && <span className="cursor-blink" />}
44
+ </pre>
45
+ )
46
+ }
47
+
48
+ // ─── Attempt badge ────────────────────────────────────────────────
49
+
50
+ function AttemptBadge({ attempt, total }: { attempt: number; total: number }) {
51
+ const colors =
52
+ attempt === 1
53
+ ? 'text-gray-400 bg-white/5 border-white/10'
54
+ : attempt === 2
55
+ ? 'text-amber-400 bg-amber-500/10 border-amber-500/20'
56
+ : attempt === 3
57
+ ? 'text-orange-400 bg-orange-500/10 border-orange-500/20'
58
+ : 'text-red-400 bg-red-500/10 border-red-500/20'
59
+
60
+ return (
61
+ <span className={`text-[10px] font-semibold px-2 py-0.5 rounded-full border ${colors}`}>
62
+ Attempt {attempt}/{total}
63
+ </span>
64
+ )
65
+ }
66
+
67
+ // ─── RL Action badge ──────────────────────────────────────────────
68
+
69
+ function RLActionBadge({ action, score }: { action: string; score?: number }) {
70
+ return (
71
+ <span className="inline-flex items-center gap-1 text-[10px] font-semibold px-2 py-0.5 rounded-full border border-orange-500/30 bg-orange-500/10 text-orange-400">
72
+ <Zap size={9} />
73
+ {action}
74
+ {score !== undefined && (
75
+ <span className="text-orange-400/60 ml-0.5">{score.toFixed(2)}</span>
76
+ )}
77
+ </span>
78
+ )
79
+ }
80
+
81
+ // ─── Reward display ──────────────────────────────────────────────
82
+
83
+ function RewardBadge({ reward }: { reward: number }) {
84
+ const positive = reward >= 0
85
+ return (
86
+ <motion.span
87
+ initial={{ scale: 0.8, opacity: 0 }}
88
+ animate={{ scale: 1, opacity: 1 }}
89
+ transition={{ type: 'spring', stiffness: 300 }}
90
+ className={`inline-flex items-center gap-0.5 text-[11px] font-bold tabular-nums reward-pulse ${
91
+ positive ? 'text-green-400' : 'text-red-400'
92
+ }`}
93
+ >
94
+ {positive ? '+' : ''}{reward.toFixed(2)}
95
+ </motion.span>
96
+ )
97
+ }
98
+
99
+ // ─── Attempt steps collapsible ────────────────────────────────────
100
+
101
+ function AttemptSteps({ steps }: { steps: AttemptStep[] }) {
102
+ const [open, setOpen] = useState(false)
103
+ if (steps.length <= 1) return null
104
+
105
+ return (
106
+ <div className="border border-white/[0.05] rounded-xl overflow-hidden">
107
+ <button
108
+ onClick={() => setOpen((v) => !v)}
109
+ className="w-full flex items-center justify-between px-3 py-2 bg-white/[0.02] hover:bg-white/[0.04] transition-colors text-[10px] text-gray-500"
110
+ >
111
+ <span>{steps.length} attempts to solve</span>
112
+ {open ? <ChevronUp size={11} /> : <ChevronDown size={11} />}
113
+ </button>
114
+ <AnimatePresence>
115
+ {open && (
116
+ <motion.div
117
+ initial={{ height: 0, opacity: 0 }}
118
+ animate={{ height: 'auto', opacity: 1 }}
119
+ exit={{ height: 0, opacity: 0 }}
120
+ transition={{ duration: 0.15 }}
121
+ className="overflow-hidden"
122
+ >
123
+ <div className="flex flex-col divide-y divide-white/[0.04]">
124
+ {steps.map((step) => (
125
+ <div key={step.attempt} className="px-3 py-2">
126
+ <div className="flex items-center gap-2 mb-1.5">
127
+ <AttemptBadge attempt={step.attempt} total={steps.length} />
128
+ {step.action && (
129
+ <RLActionBadge action={step.action} score={step.actionScore} />
130
+ )}
131
+ {step.reward !== undefined && <RewardBadge reward={step.reward} />}
132
+ </div>
133
+ {step.error && (
134
+ <div className="text-[10px] text-red-400/70 mb-1 bg-red-500/5 rounded px-2 py-1 border border-red-500/15">
135
+ {step.error}
136
+ </div>
137
+ )}
138
+ <SqlBlock sql={step.sql} />
139
+ </div>
140
+ ))}
141
+ </div>
142
+ </motion.div>
143
+ )}
144
+ </AnimatePresence>
145
+ </div>
146
+ )
147
+ }
148
+
149
+ // ─── Suggested query chips ────────────────────────────────────────
150
+
151
+ const SUGGESTED: Record<string, string[]> = {
152
+ easy: ['Show all products', 'List users from USA', 'What categories exist?'],
153
+ medium: ['Top 5 sellers by revenue', 'Average order value by country', 'Products with low stock'],
154
+ hard: ['Rolling 7-day revenue', 'Seller ranking with rank change', 'Cohort retention analysis'],
155
+ }
156
+
157
+ function EmptyState({ onSelect }: { onSelect: (q: string) => void }) {
158
+ const { taskDifficulty } = useStore()
159
+ const suggestions = SUGGESTED[taskDifficulty] ?? SUGGESTED.easy
160
+
161
+ return (
162
+ <div className="flex flex-col items-center justify-center h-full gap-6 px-8 text-center">
163
+ <div>
164
+ <div
165
+ className="w-12 h-12 rounded-2xl flex items-center justify-center mx-auto mb-4"
166
+ style={{ background: '#1e3a5f', boxShadow: '0 8px 24px rgba(30,58,95,0.4)' }}
167
+ >
168
+ <MessageSquare size={22} className="text-white" />
169
+ </div>
170
+ <h2 className="text-base font-semibold text-white mb-1">Ask about your data</h2>
171
+ <p className="text-xs text-gray-500 max-w-xs">
172
+ Type a question in natural language. The agent will generate SQL, execute it,
173
+ and self-repair on errors using reinforcement learning.
174
+ </p>
175
+ </div>
176
+
177
+ <div className="flex flex-col gap-2 w-full max-w-sm">
178
+ <div className="text-[10px] text-gray-600 uppercase tracking-wider mb-0.5">
179
+ Try these queries
180
+ </div>
181
+ {suggestions.map((q) => (
182
+ <button
183
+ key={q}
184
+ onClick={() => onSelect(q)}
185
+ className="flex items-center gap-2 px-3 py-2.5 rounded-xl border border-white/[0.06] bg-white/[0.02] hover:bg-white/[0.05] hover:border-violet-500/30 transition-all text-left group"
186
+ >
187
+ <span className="text-violet-500 shrink-0 group-hover:text-violet-400">β€Ί</span>
188
+ <span className="text-xs text-gray-300">{q}</span>
189
+ </button>
190
+ ))}
191
+ </div>
192
+ </div>
193
+ )
194
+ }
195
+
196
+ // ─── Message Card ─────────────────────────────────────────────────
197
+
198
+ function MessageCard({
199
+ msg,
200
+ onFeedback,
201
+ onRetry,
202
+ }: {
203
+ msg: ChatMessage
204
+ onFeedback: (id: string, correct: boolean) => Promise<void>
205
+ onRetry: (q: string) => void
206
+ }) {
207
+ const [sqlOpen, setSqlOpen] = useState(true)
208
+
209
+ return (
210
+ <div className="flex flex-col gap-2.5">
211
+ {/* User question bubble */}
212
+ <div className="flex justify-end">
213
+ <div className="max-w-[80%] bg-violet-600/20 border border-violet-500/25 rounded-2xl rounded-tr-sm px-4 py-2.5">
214
+ <p className="text-sm text-white leading-relaxed">{msg.question}</p>
215
+ </div>
216
+ </div>
217
+
218
+ {/* Agent response */}
219
+ <div className="flex flex-col gap-2">
220
+ {/* Streaming thinking */}
221
+ {msg.status === 'streaming' && !msg.sql && (
222
+ <div className="flex items-center gap-2 text-xs text-gray-500 px-1">
223
+ <Loader2 size={11} className="animate-spin text-violet-400" />
224
+ Generating SQL...
225
+ </div>
226
+ )}
227
+
228
+ {/* Multiple attempts */}
229
+ <AttemptSteps steps={msg.steps} />
230
+
231
+ {/* Final SQL block */}
232
+ {msg.sql && (
233
+ <div className="border border-white/[0.06] rounded-xl overflow-hidden">
234
+ <button
235
+ onClick={() => setSqlOpen((v) => !v)}
236
+ className="w-full flex items-center justify-between px-3 py-2 bg-white/[0.02] hover:bg-white/[0.04] transition-colors"
237
+ >
238
+ <div className="flex items-center gap-2">
239
+ <span className="text-[10px] font-semibold text-gray-500 uppercase tracking-wider">
240
+ SQL
241
+ </span>
242
+ {msg.status === 'streaming' && (
243
+ <Loader2 size={10} className="animate-spin text-violet-400" />
244
+ )}
245
+ {msg.attempts > 1 && (
246
+ <AttemptBadge attempt={msg.attempts} total={msg.attempts} />
247
+ )}
248
+ </div>
249
+ {sqlOpen ? (
250
+ <ChevronUp size={11} className="text-gray-600" />
251
+ ) : (
252
+ <ChevronDown size={11} className="text-gray-600" />
253
+ )}
254
+ </button>
255
+ {sqlOpen && (
256
+ <SqlBlock sql={msg.sql} streaming={msg.status === 'streaming'} />
257
+ )}
258
+ </div>
259
+ )}
260
+
261
+ {/* Executing indicator */}
262
+ {msg.status === 'streaming' && msg.sql && msg.rows.length === 0 && !msg.errorMsg && (
263
+ <div className="flex items-center gap-2 text-xs text-gray-500 px-1">
264
+ <Loader2 size={11} className="animate-spin text-violet-400" />
265
+ Executing...
266
+ </div>
267
+ )}
268
+
269
+ {/* RL badges row */}
270
+ {(msg.rlAction || msg.reward !== undefined) && (
271
+ <div className="flex items-center gap-2 flex-wrap">
272
+ {msg.rlAction && (
273
+ <RLActionBadge action={msg.rlAction} score={msg.rlActionScore} />
274
+ )}
275
+ {msg.reward !== undefined && <RewardBadge reward={msg.reward} />}
276
+ </div>
277
+ )}
278
+
279
+ {/* Result table */}
280
+ {msg.status === 'done' && msg.attempts > 0 && (
281
+ <div className="flex flex-col gap-1.5">
282
+ <div className="flex items-center gap-2 text-[10px] px-0.5">
283
+ <CheckCircle2 size={11} className="text-green-400" />
284
+ <span className="text-green-400 font-semibold">Success</span>
285
+ <span className="text-gray-600">
286
+ Β· {msg.rowCount} row{msg.rowCount !== 1 ? 's' : ''}
287
+ </span>
288
+ {msg.attempts > 1 && (
289
+ <span className="text-amber-400/60">{msg.attempts} attempts</span>
290
+ )}
291
+ </div>
292
+ <ResultsTable rows={msg.rows} rowCount={msg.rowCount} />
293
+ </div>
294
+ )}
295
+
296
+ {/* Error */}
297
+ {msg.status === 'error' && (
298
+ <div className="flex items-start gap-2 bg-red-500/10 border border-red-500/20 rounded-xl px-3 py-2.5 text-xs text-red-300">
299
+ <XCircle size={12} className="shrink-0 mt-0.5" />
300
+ <div>
301
+ <p className="font-semibold mb-0.5">Query failed</p>
302
+ <p className="opacity-80">{msg.errorMsg ?? 'Agent exhausted all repair attempts'}</p>
303
+ </div>
304
+ </div>
305
+ )}
306
+
307
+ {/* Feedback */}
308
+ {msg.status === 'done' && msg.attempts > 0 && (
309
+ <div className="flex items-center gap-2">
310
+ {msg.feedback ? (
311
+ <div
312
+ className={`text-xs flex items-center gap-1.5 ${
313
+ msg.feedback === 'correct' ? 'text-green-400' : 'text-red-400'
314
+ }`}
315
+ >
316
+ {msg.feedback === 'correct' ? (
317
+ <CheckCircle2 size={12} />
318
+ ) : (
319
+ <XCircle size={12} />
320
+ )}
321
+ Marked as {msg.feedback}
322
+ </div>
323
+ ) : (
324
+ <>
325
+ <span className="text-[10px] text-gray-600 mr-0.5">Was this correct?</span>
326
+ <button
327
+ disabled={msg.feedbackSending}
328
+ onClick={() => onFeedback(msg.id, true)}
329
+ className="flex items-center gap-1 px-2 py-1 text-[10px] font-medium rounded-lg border border-green-500/25 bg-green-500/8 text-green-400 hover:bg-green-500/15 transition-all disabled:opacity-40"
330
+ >
331
+ <CheckCircle2 size={10} />
332
+ Correct
333
+ </button>
334
+ <button
335
+ disabled={msg.feedbackSending}
336
+ onClick={() => onFeedback(msg.id, false)}
337
+ className="flex items-center gap-1 px-2 py-1 text-[10px] font-medium rounded-lg border border-red-500/25 bg-red-500/8 text-red-400 hover:bg-red-500/15 transition-all disabled:opacity-40"
338
+ >
339
+ <XCircle size={10} />
340
+ Wrong
341
+ </button>
342
+ </>
343
+ )}
344
+ {(msg.status === 'done' || msg.status === 'error') && (
345
+ <button
346
+ onClick={() => onRetry(msg.question)}
347
+ className="ml-auto flex items-center gap-1 text-[10px] text-gray-600 hover:text-gray-400 transition-colors"
348
+ >
349
+ <RefreshCw size={10} />
350
+ Retry
351
+ </button>
352
+ )}
353
+ </div>
354
+ )}
355
+ </div>
356
+ </div>
357
+ )
358
+ }
359
+
360
+ // ─── Chat Panel ───────────────────────────────────────────────────
361
+
362
+ export function ChatPanel() {
363
+ const {
364
+ messages, addMessage, updateMessage, clearMessages,
365
+ isExecuting, setIsExecuting,
366
+ taskId, taskDifficulty,
367
+ optimizingBanner, setOptimizingBanner,
368
+ promptGeneration,
369
+ } = useStore()
370
+
371
+ const [input, setInput] = useState('')
372
+ const bottomRef = useRef<HTMLDivElement>(null)
373
+ const inputRef = useRef<HTMLTextAreaElement>(null)
374
+
375
+ useEffect(() => {
376
+ bottomRef.current?.scrollIntoView({ behavior: 'smooth' })
377
+ }, [messages.length])
378
+
379
+ const handleFeedback = useCallback(
380
+ async (id: string, correct: boolean) => {
381
+ const msg = messages.find((m) => m.id === id)
382
+ if (!msg) return
383
+ updateMessage(id, { feedbackSending: true })
384
+ try {
385
+ await submitFeedback(msg.question, msg.sql, correct)
386
+ updateMessage(id, { feedback: correct ? 'correct' : 'wrong', feedbackSending: false })
387
+ } catch {
388
+ updateMessage(id, { feedbackSending: false })
389
+ }
390
+ },
391
+ [messages, updateMessage]
392
+ )
393
+
394
+ const execute = useCallback(
395
+ async (question: string) => {
396
+ if (!question.trim() || isExecuting) return
397
+ setIsExecuting(true)
398
+
399
+ const msgId = `msg-${Date.now()}`
400
+ const newMsg: ChatMessage = {
401
+ id: msgId,
402
+ question,
403
+ status: 'streaming',
404
+ sql: '',
405
+ rows: [],
406
+ rowCount: 0,
407
+ attempts: 0,
408
+ steps: [],
409
+ feedback: null,
410
+ promptGeneration,
411
+ }
412
+ addMessage(newMsg)
413
+
414
+ try {
415
+ for await (const event of streamExecuteQuery(question, taskId)) {
416
+ if (event.type === 'sql') {
417
+ updateMessage(msgId, { sql: event.sql as string })
418
+ } else if (event.type === 'sql_chunk') {
419
+ // incremental SQL streaming β€” read current sql from store
420
+ const curSql = useStore.getState().messages.find((m) => m.id === msgId)?.sql ?? ''
421
+ updateMessage(msgId, { sql: curSql + (event.chunk as string) })
422
+ } else if (event.type === 'attempt') {
423
+ const step: AttemptStep = {
424
+ attempt: event.attempt as number,
425
+ sql: event.sql as string,
426
+ error: event.error as string | undefined,
427
+ action: event.action as string | undefined,
428
+ actionScore: event.action_score as number | undefined,
429
+ reward: event.reward as number | undefined,
430
+ }
431
+ const curSteps = useStore.getState().messages.find((m) => m.id === msgId)?.steps ?? []
432
+ updateMessage(msgId, {
433
+ attempts: event.attempt as number,
434
+ steps: [...curSteps, step],
435
+ sql: event.sql as string,
436
+ rlAction: event.action as string | undefined,
437
+ rlActionScore: event.action_score as number | undefined,
438
+ })
439
+ } else if (event.type === 'result') {
440
+ updateMessage(msgId, {
441
+ rows: (event.rows as Record<string, unknown>[]) ?? [],
442
+ rowCount: (event.row_count as number) ?? 0,
443
+ reward: event.reward as number | undefined,
444
+ })
445
+ } else if (event.type === 'done') {
446
+ updateMessage(msgId, {
447
+ status: 'done',
448
+ attempts: (event.attempts as number) ?? 1,
449
+ reward: event.reward as number | undefined,
450
+ })
451
+ } else if (event.type === 'error') {
452
+ updateMessage(msgId, {
453
+ status: 'error',
454
+ errorMsg: event.message as string,
455
+ })
456
+ } else if (event.type === 'gepa_start') {
457
+ setOptimizingBanner(true)
458
+ } else if (event.type === 'gepa_done') {
459
+ setOptimizingBanner(false)
460
+ }
461
+ }
462
+ } catch (err) {
463
+ updateMessage(msgId, {
464
+ status: 'error',
465
+ errorMsg: err instanceof Error ? err.message : 'Network error',
466
+ })
467
+ } finally {
468
+ setIsExecuting(false)
469
+ // If still streaming after generator ends, mark done
470
+ const finalMsg = useStore.getState().messages.find((m) => m.id === msgId)
471
+ if (finalMsg?.status === 'streaming') {
472
+ updateMessage(msgId, { status: finalMsg.sql ? 'done' : 'error' })
473
+ }
474
+ }
475
+ },
476
+ [isExecuting, setIsExecuting, addMessage, updateMessage, taskId, promptGeneration, setOptimizingBanner]
477
+ )
478
+
479
+ const handleSend = () => {
480
+ if (!input.trim()) return
481
+ const q = input.trim()
482
+ setInput('')
483
+ void execute(q)
484
+ }
485
+
486
+ const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
487
+ if (e.key === 'Enter' && !e.shiftKey) {
488
+ e.preventDefault()
489
+ handleSend()
490
+ }
491
+ }
492
+
493
+ const suggestions = SUGGESTED[taskDifficulty] ?? SUGGESTED.easy
494
+
495
+ return (
496
+ <div className="flex flex-col h-full">
497
+ {/* Optimizing banner */}
498
+ <AnimatePresence>
499
+ {optimizingBanner && (
500
+ <motion.div
501
+ initial={{ height: 0, opacity: 0 }}
502
+ animate={{ height: 'auto', opacity: 1 }}
503
+ exit={{ height: 0, opacity: 0 }}
504
+ className="shrink-0 overflow-hidden"
505
+ >
506
+ <div className="shimmer-banner border-b border-violet-500/20 px-4 py-2 flex items-center gap-2">
507
+ <Loader2 size={12} className="animate-spin text-violet-400" />
508
+ <span className="text-xs text-violet-300 font-semibold">
509
+ Optimizing system prompt via GEPA...
510
+ </span>
511
+ </div>
512
+ </motion.div>
513
+ )}
514
+ </AnimatePresence>
515
+
516
+ {/* Messages */}
517
+ <div className="flex-1 overflow-y-auto px-4 py-4">
518
+ {messages.length === 0 ? (
519
+ <EmptyState onSelect={(q) => { setInput(q); inputRef.current?.focus() }} />
520
+ ) : (
521
+ <div className="flex flex-col gap-6 max-w-3xl mx-auto">
522
+ {messages.map((msg) => (
523
+ <MessageCard
524
+ key={msg.id}
525
+ msg={msg}
526
+ onFeedback={handleFeedback}
527
+ onRetry={(q) => { setInput(q); inputRef.current?.focus() }}
528
+ />
529
+ ))}
530
+ <div ref={bottomRef} />
531
+ </div>
532
+ )}
533
+ </div>
534
+
535
+ {/* Input area */}
536
+ <div
537
+ className="shrink-0 border-t border-white/[0.06] px-4 py-3"
538
+ style={{ background: 'var(--bg-secondary)' }}
539
+ >
540
+ {/* Suggested chips */}
541
+ {messages.length > 0 && (
542
+ <div className="flex gap-1.5 flex-wrap mb-2.5">
543
+ {suggestions.slice(0, 3).map((q) => (
544
+ <button
545
+ key={q}
546
+ onClick={() => { setInput(q); inputRef.current?.focus() }}
547
+ className="text-[10px] px-2.5 py-1 rounded-full border border-white/[0.06] text-gray-500 hover:text-gray-300 hover:border-violet-500/30 transition-all"
548
+ >
549
+ {q}
550
+ </button>
551
+ ))}
552
+ </div>
553
+ )}
554
+
555
+ <div className="flex items-end gap-2">
556
+ <div className="flex-1 relative">
557
+ <textarea
558
+ ref={inputRef}
559
+ value={input}
560
+ onChange={(e) => setInput(e.target.value)}
561
+ onKeyDown={handleKeyDown}
562
+ placeholder="Ask about products, orders, sellers..."
563
+ disabled={isExecuting}
564
+ rows={1}
565
+ className="w-full px-3 py-2.5 pr-10 text-sm text-white rounded-xl border border-white/[0.06] bg-white/[0.03] placeholder-gray-600 resize-none focus:outline-none focus:border-violet-500/40 focus:bg-white/[0.05] transition-all disabled:opacity-50"
566
+ style={{ minHeight: 40, maxHeight: 120, overflowY: 'auto' }}
567
+ />
568
+ </div>
569
+ <div className="flex flex-col gap-1.5 shrink-0">
570
+ <button
571
+ onClick={handleSend}
572
+ disabled={!input.trim() || isExecuting}
573
+ className="w-9 h-9 rounded-xl bg-violet-600 hover:bg-violet-500 disabled:opacity-40 disabled:cursor-not-allowed transition-all flex items-center justify-center"
574
+ >
575
+ {isExecuting ? (
576
+ <Loader2 size={14} className="animate-spin text-white" />
577
+ ) : (
578
+ <Send size={14} className="text-white" />
579
+ )}
580
+ </button>
581
+ {messages.length > 0 && (
582
+ <button
583
+ onClick={clearMessages}
584
+ disabled={isExecuting}
585
+ className="w-9 h-9 rounded-xl border border-white/[0.06] hover:bg-white/5 disabled:opacity-40 transition-all flex items-center justify-center text-gray-600 hover:text-gray-400"
586
+ title="Clear chat"
587
+ >
588
+ <Trash2 size={12} />
589
+ </button>
590
+ )}
591
+ </div>
592
+ </div>
593
+ <p className="text-[9px] text-gray-700 mt-1.5 text-center">
594
+ Enter to send Β· Shift+Enter for newline Β· Agent uses LinUCB + GEPA
595
+ </p>
596
+ </div>
597
+ </div>
598
+ )
599
+ }
frontend/src/components/ERDiagram.tsx ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState, useEffect, useRef } from 'react'
2
+ import { Loader2, GitFork } from 'lucide-react'
3
+ import { useStore } from '../store/useStore'
4
+ import { fetchSchemaGraph } from '../lib/api'
5
+ import type { SchemaTable, SchemaRelationship } from '../lib/types'
6
+
7
+ // ─── Table card ───────────────────────────────────────────────────
8
+
9
+ function TableCard({ table, x, y }: { table: SchemaTable; x: number; y: number }) {
10
+ return (
11
+ <g transform={`translate(${x},${y})`}>
12
+ {/* Card bg */}
13
+ <rect
14
+ width={180}
15
+ height={28 + table.columns.length * 20}
16
+ rx={8}
17
+ fill="#0e0e16"
18
+ stroke="rgba(255,255,255,0.08)"
19
+ strokeWidth={1}
20
+ />
21
+ {/* Header */}
22
+ <rect width={180} height={28} rx={8} fill="rgba(139,92,246,0.15)" />
23
+ <rect y={20} width={180} height={8} fill="rgba(139,92,246,0.15)" />
24
+ <text
25
+ x={10}
26
+ y={18}
27
+ fill="#a78bfa"
28
+ fontSize={11}
29
+ fontWeight="bold"
30
+ fontFamily="ui-monospace,monospace"
31
+ >
32
+ {table.name}
33
+ </text>
34
+
35
+ {/* Columns */}
36
+ {table.columns.map((col, i) => (
37
+ <g key={col.name} transform={`translate(0,${28 + i * 20})`}>
38
+ <rect
39
+ width={180}
40
+ height={20}
41
+ fill={i % 2 === 0 ? 'rgba(255,255,255,0.01)' : 'transparent'}
42
+ />
43
+ <text
44
+ x={10}
45
+ y={14}
46
+ fill={col.pk ? '#60a5fa' : col.fk ? '#34d399' : 'rgba(255,255,255,0.5)'}
47
+ fontSize={10}
48
+ fontFamily="ui-monospace,monospace"
49
+ >
50
+ {col.pk ? 'πŸ”‘ ' : col.fk ? 'πŸ”— ' : ' '}
51
+ {col.name}
52
+ </text>
53
+ <text
54
+ x={170}
55
+ y={14}
56
+ fill="rgba(255,255,255,0.2)"
57
+ fontSize={9}
58
+ fontFamily="ui-monospace,monospace"
59
+ textAnchor="end"
60
+ >
61
+ {col.type}
62
+ </text>
63
+ </g>
64
+ ))}
65
+ </g>
66
+ )
67
+ }
68
+
69
+ // ─── Layout helpers ───────────────────────────────────────────────
70
+
71
+ function layoutTables(tables: SchemaTable[]) {
72
+ const CARD_W = 180
73
+ const CARD_H_BASE = 28
74
+ const COL_H = 20
75
+ const GAP_X = 40
76
+ const GAP_Y = 30
77
+ const COLS_PER_ROW = 3
78
+
79
+ const positions: Record<string, { x: number; y: number; w: number; h: number }> = {}
80
+ let maxRowH = 0
81
+
82
+ tables.forEach((t, i) => {
83
+ const col = i % COLS_PER_ROW
84
+ const row = Math.floor(i / COLS_PER_ROW)
85
+ const h = CARD_H_BASE + t.columns.length * COL_H
86
+
87
+ if (row === Math.floor(i / COLS_PER_ROW) && col === 0) maxRowH = 0
88
+ maxRowH = Math.max(maxRowH, h)
89
+
90
+ const prevRowsH = tables
91
+ .slice(0, row * COLS_PER_ROW)
92
+ .reduce((acc, _, idx) => {
93
+ if (idx % COLS_PER_ROW === 0) {
94
+ const rowH = tables.slice(idx, idx + COLS_PER_ROW).reduce(
95
+ (m, rt) => Math.max(m, CARD_H_BASE + rt.columns.length * COL_H),
96
+ 0
97
+ )
98
+ return acc + rowH + GAP_Y
99
+ }
100
+ return acc
101
+ }, 0)
102
+
103
+ positions[t.name] = {
104
+ x: col * (CARD_W + GAP_X) + 20,
105
+ y: prevRowsH + 20,
106
+ w: CARD_W,
107
+ h,
108
+ }
109
+ })
110
+
111
+ return positions
112
+ }
113
+
114
+ function RelationshipLine({
115
+ from,
116
+ to,
117
+ positions,
118
+ }: {
119
+ from: string
120
+ to: string
121
+ positions: Record<string, { x: number; y: number; w: number; h: number }>
122
+ }) {
123
+ const a = positions[from]
124
+ const b = positions[to]
125
+ if (!a || !b) return null
126
+
127
+ const x1 = a.x + a.w
128
+ const y1 = a.y + 14
129
+ const x2 = b.x
130
+ const y2 = b.y + 14
131
+ const cx = (x1 + x2) / 2
132
+
133
+ return (
134
+ <path
135
+ d={`M${x1},${y1} C${cx},${y1} ${cx},${y2} ${x2},${y2}`}
136
+ stroke="rgba(139,92,246,0.3)"
137
+ strokeWidth={1.5}
138
+ fill="none"
139
+ strokeDasharray="4 3"
140
+ />
141
+ )
142
+ }
143
+
144
+ // ─── ER Diagram component ─────────────────────────────────────────
145
+
146
+ export function ERDiagram() {
147
+ const { schemaGraph, setSchemaGraph } = useStore()
148
+ const [loading, setLoading] = useState(false)
149
+ const svgRef = useRef<SVGSVGElement>(null)
150
+
151
+ const load = async () => {
152
+ setLoading(true)
153
+ try {
154
+ const data = await fetchSchemaGraph()
155
+ setSchemaGraph(data)
156
+ } catch {
157
+ // noop
158
+ } finally {
159
+ setLoading(false)
160
+ }
161
+ }
162
+
163
+ useEffect(() => {
164
+ if (!schemaGraph) void load()
165
+ // eslint-disable-next-line react-hooks/exhaustive-deps
166
+ }, [])
167
+
168
+ if (loading) {
169
+ return (
170
+ <div className="flex items-center justify-center h-full gap-2 text-gray-500">
171
+ <Loader2 size={16} className="animate-spin" />
172
+ <span className="text-sm">Loading schema...</span>
173
+ </div>
174
+ )
175
+ }
176
+
177
+ if (!schemaGraph || schemaGraph.tables.length === 0) {
178
+ return (
179
+ <div className="flex flex-col items-center justify-center h-full gap-3 text-gray-600">
180
+ <GitFork size={32} className="text-gray-700" />
181
+ <p className="text-sm">Schema will appear after database connects</p>
182
+ <button
183
+ onClick={() => void load()}
184
+ className="text-xs text-violet-400 hover:text-violet-300 transition-colors"
185
+ >
186
+ Retry
187
+ </button>
188
+ </div>
189
+ )
190
+ }
191
+
192
+ const { tables, relationships } = schemaGraph
193
+ const positions = layoutTables(tables)
194
+
195
+ const allX = Object.values(positions).map((p) => p.x + p.w)
196
+ const allY = Object.values(positions).map((p) => p.y + p.h)
197
+ const svgW = Math.max(...allX) + 40
198
+ const svgH = Math.max(...allY) + 40
199
+
200
+ return (
201
+ <div className="h-full overflow-auto p-4">
202
+ <div className="text-[10px] text-gray-500 uppercase tracking-widest mb-3 flex items-center gap-1.5">
203
+ <GitFork size={10} className="text-violet-400" />
204
+ Entity Relationship Diagram
205
+ <span className="text-gray-700">Β· {tables.length} tables</span>
206
+ </div>
207
+ <svg
208
+ ref={svgRef}
209
+ width={svgW}
210
+ height={svgH}
211
+ style={{ minWidth: svgW }}
212
+ >
213
+ {/* FK lines */}
214
+ {(relationships as SchemaRelationship[]).map((rel, i) => (
215
+ <RelationshipLine
216
+ key={i}
217
+ from={rel.from}
218
+ to={rel.to}
219
+ positions={positions}
220
+ />
221
+ ))}
222
+ {/* Tables */}
223
+ {tables.map((t: SchemaTable) => (
224
+ <TableCard
225
+ key={t.name}
226
+ table={t}
227
+ x={positions[t.name]?.x ?? 0}
228
+ y={positions[t.name]?.y ?? 0}
229
+ />
230
+ ))}
231
+ </svg>
232
+ </div>
233
+ )
234
+ }
frontend/src/components/Header.tsx ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { Database, Sun, Moon, PanelLeftOpen, PanelRightOpen, Cpu } from 'lucide-react'
2
+ import { useStore } from '../store/useStore'
3
+ import type { Difficulty } from '../lib/types'
4
+
5
+ interface HeaderProps {
6
+ onToggleLeft: () => void
7
+ onToggleRight: () => void
8
+ }
9
+
10
+ const DIFFICULTIES: { id: Difficulty; label: string; color: string }[] = [
11
+ { id: 'easy', label: 'Easy', color: 'text-green-400 border-green-500/30 bg-green-500/10' },
12
+ { id: 'medium', label: 'Medium', color: 'text-amber-400 border-amber-500/30 bg-amber-500/10' },
13
+ { id: 'hard', label: 'Hard', color: 'text-red-400 border-red-500/30 bg-red-500/10' },
14
+ ]
15
+
16
+ export function Header({ onToggleLeft, onToggleRight }: HeaderProps) {
17
+ const { theme, toggleTheme, dbSeeded, taskDifficulty, setTaskDifficulty } = useStore()
18
+
19
+ return (
20
+ <header
21
+ className="border-b px-3 sm:px-5 py-3 flex items-center justify-between shrink-0 backdrop-blur-sm sticky top-0 z-50 theme-border"
22
+ style={{ background: 'var(--bg-secondary)' }}
23
+ >
24
+ <div className="flex items-center gap-2 sm:gap-3">
25
+ {/* Mobile sidebar toggle */}
26
+ <button
27
+ onClick={onToggleLeft}
28
+ className="lg:hidden flex items-center gap-1 px-2 py-1.5 rounded-lg hover:bg-white/5 text-gray-400 hover:text-white transition-colors text-[10px]"
29
+ >
30
+ <PanelLeftOpen size={14} />
31
+ <span className="hidden sm:inline">Data</span>
32
+ </button>
33
+
34
+ {/* Logo */}
35
+ <div
36
+ className="w-7 h-7 rounded-lg flex items-center justify-center shadow-lg shrink-0"
37
+ style={{ background: '#1e3a5f', boxShadow: '0 4px 12px rgba(30,58,95,0.4)' }}
38
+ >
39
+ <Database size={13} className="text-white" />
40
+ </div>
41
+
42
+ {/* Title */}
43
+ <div>
44
+ <h1 className="text-sm font-bold text-white tracking-tight leading-none">
45
+ SQL Agent OpenEnv
46
+ </h1>
47
+ <p className="text-[10px] text-gray-600 hidden sm:block mt-0.5">
48
+ Reinforcement Learning Environment
49
+ </p>
50
+ </div>
51
+ </div>
52
+
53
+ <div className="flex items-center gap-2 sm:gap-3">
54
+ {/* Connection status */}
55
+ {dbSeeded ? (
56
+ <div className="hidden sm:flex items-center gap-1.5 text-[10px] text-green-400">
57
+ <span className="w-1.5 h-1.5 rounded-full bg-green-400 inline-block" />
58
+ benchmark db
59
+ </div>
60
+ ) : (
61
+ <div className="hidden sm:flex items-center gap-1.5 text-[10px] text-amber-400">
62
+ <span className="w-1.5 h-1.5 rounded-full bg-amber-400 inline-block animate-pulse" />
63
+ connecting...
64
+ </div>
65
+ )}
66
+
67
+ {/* RL indicator */}
68
+ <div className="hidden md:flex items-center gap-1.5 text-[10px] text-violet-400 border border-violet-500/20 rounded-full px-2 py-0.5">
69
+ <Cpu size={10} />
70
+ LinUCB Active
71
+ </div>
72
+
73
+ {/* Difficulty selector */}
74
+ <div className="flex items-center gap-1 border border-white/[0.06] rounded-lg p-0.5">
75
+ {DIFFICULTIES.map((d) => (
76
+ <button
77
+ key={d.id}
78
+ onClick={() => setTaskDifficulty(d.id)}
79
+ className={`text-[10px] font-semibold px-2 py-1 rounded transition-all ${
80
+ taskDifficulty === d.id
81
+ ? `${d.color} border`
82
+ : 'text-gray-500 hover:text-gray-300 border border-transparent'
83
+ }`}
84
+ >
85
+ {d.label}
86
+ </button>
87
+ ))}
88
+ </div>
89
+
90
+ {/* Theme toggle */}
91
+ <button
92
+ onClick={toggleTheme}
93
+ className="p-1.5 rounded-lg hover:bg-white/5 transition-colors theme-text-muted"
94
+ title={theme === 'dark' ? 'Switch to light' : 'Switch to dark'}
95
+ >
96
+ {theme === 'dark' ? <Sun size={14} /> : <Moon size={14} />}
97
+ </button>
98
+
99
+ {/* Mobile right sidebar toggle */}
100
+ <button
101
+ onClick={onToggleRight}
102
+ className="lg:hidden flex items-center gap-1 px-2 py-1.5 rounded-lg hover:bg-white/5 text-gray-400 hover:text-white transition-colors text-[10px]"
103
+ >
104
+ <span className="hidden sm:inline">GEPA</span>
105
+ <PanelRightOpen size={14} />
106
+ </button>
107
+ </div>
108
+ </header>
109
+ )
110
+ }
frontend/src/components/LeftSidebar.tsx ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState } from 'react'
2
+ import { motion, AnimatePresence } from 'framer-motion'
3
+ import { Database, Table2, ChevronDown, ChevronRight, GitFork, ShoppingCart } from 'lucide-react'
4
+ import { useStore } from '../store/useStore'
5
+ import type { Difficulty } from '../lib/types'
6
+
7
+ const DIFFICULTY_CONFIG: Record<Difficulty, { label: string; bg: string; text: string; border: string }> = {
8
+ easy: { label: 'Easy', bg: 'bg-green-500/10', text: 'text-green-400', border: 'border-green-500/30' },
9
+ medium: { label: 'Medium', bg: 'bg-amber-500/10', text: 'text-amber-400', border: 'border-amber-500/30' },
10
+ hard: { label: 'Hard', bg: 'bg-red-500/10', text: 'text-red-400', border: 'border-red-500/30' },
11
+ }
12
+
13
+ export function LeftSidebar() {
14
+ const { tables, taskDifficulty, setTaskDifficulty, dbSeeded } = useStore()
15
+ const [tablesExpanded, setTablesExpanded] = useState(true)
16
+
17
+ const cfg = DIFFICULTY_CONFIG[taskDifficulty]
18
+
19
+ return (
20
+ <div className="flex flex-col gap-4 py-1">
21
+ {/* Task Difficulty */}
22
+ <section>
23
+ <div className="text-[10px] font-semibold text-gray-500 uppercase tracking-widest mb-2 flex items-center gap-1.5">
24
+ <GitFork size={10} className="text-violet-400" />
25
+ Task Difficulty
26
+ </div>
27
+ <div className="flex flex-col gap-1">
28
+ {(Object.keys(DIFFICULTY_CONFIG) as Difficulty[]).map((d) => {
29
+ const c = DIFFICULTY_CONFIG[d]
30
+ const active = d === taskDifficulty
31
+ return (
32
+ <button
33
+ key={d}
34
+ onClick={() => setTaskDifficulty(d)}
35
+ className={`flex items-center justify-between px-3 py-2 rounded-lg border text-xs font-medium transition-all ${
36
+ active
37
+ ? `${c.bg} ${c.text} ${c.border}`
38
+ : 'border-transparent text-gray-500 hover:text-gray-300 hover:bg-white/5'
39
+ }`}
40
+ >
41
+ <span>{c.label}</span>
42
+ {active && (
43
+ <span className={`text-[9px] font-mono ${c.text} opacity-70`}>selected</span>
44
+ )}
45
+ </button>
46
+ )
47
+ })}
48
+ </div>
49
+ </section>
50
+
51
+ {/* Schema Tables */}
52
+ <section>
53
+ <button
54
+ className="text-[10px] font-semibold text-gray-500 uppercase tracking-widest mb-2 flex items-center gap-1.5 w-full"
55
+ onClick={() => setTablesExpanded((v) => !v)}
56
+ >
57
+ <Database size={10} className="text-blue-400" />
58
+ <span className="flex-1 text-left">Database Schema</span>
59
+ {tablesExpanded ? <ChevronDown size={10} /> : <ChevronRight size={10} />}
60
+ </button>
61
+ <AnimatePresence>
62
+ {tablesExpanded && (
63
+ <motion.div
64
+ initial={{ opacity: 0, height: 0 }}
65
+ animate={{ opacity: 1, height: 'auto' }}
66
+ exit={{ opacity: 0, height: 0 }}
67
+ className="overflow-hidden"
68
+ >
69
+ {dbSeeded && tables.length > 0 ? (
70
+ <div className="flex flex-col gap-1">
71
+ {tables.map((t) => (
72
+ <div
73
+ key={t.name}
74
+ className="flex items-center justify-between px-2.5 py-1.5 rounded-lg border border-white/[0.04] bg-white/[0.02] hover:bg-white/[0.04] transition-colors"
75
+ >
76
+ <div className="flex items-center gap-1.5">
77
+ <Table2 size={10} className="text-blue-400 shrink-0" />
78
+ <span className="text-xs text-gray-300 font-mono">{t.name}</span>
79
+ </div>
80
+ <span className="text-[9px] text-gray-600 font-mono tabular-nums">
81
+ {t.rows.toLocaleString()}
82
+ </span>
83
+ </div>
84
+ ))}
85
+ </div>
86
+ ) : (
87
+ <div className="flex flex-col gap-1">
88
+ {[120, 80, 95, 60, 70].map((w, i) => (
89
+ <div
90
+ key={i}
91
+ className="flex items-center justify-between px-2.5 py-1.5 rounded-lg border border-white/[0.04] bg-white/[0.02]"
92
+ >
93
+ <div
94
+ className="h-2 rounded bg-white/10 animate-pulse"
95
+ style={{ width: w }}
96
+ />
97
+ <div className="h-2 w-8 rounded bg-white/10 animate-pulse" />
98
+ </div>
99
+ ))}
100
+ </div>
101
+ )}
102
+ </motion.div>
103
+ )}
104
+ </AnimatePresence>
105
+ </section>
106
+
107
+ {/* Business Context */}
108
+ <section>
109
+ <div className="text-[10px] font-semibold text-gray-500 uppercase tracking-widest mb-2 flex items-center gap-1.5">
110
+ <ShoppingCart size={10} className="text-orange-400" />
111
+ Business Context
112
+ </div>
113
+ <div
114
+ className="rounded-xl border border-white/[0.05] p-3 text-[11px] text-gray-500 leading-relaxed"
115
+ style={{ background: 'var(--bg-card)' }}
116
+ >
117
+ <p className="mb-2 text-gray-400 font-medium">E-Commerce Marketplace</p>
118
+ <p>
119
+ Multi-vendor marketplace with products, orders, sellers, users, and reviews.
120
+ Supports complex analytical queries across sales, inventory, and user behavior.
121
+ </p>
122
+ <div className="mt-2 flex flex-wrap gap-1">
123
+ {['Products', 'Orders', 'Sellers', 'Users', 'Reviews', 'Categories'].map((t) => (
124
+ <span
125
+ key={t}
126
+ className="text-[9px] px-1.5 py-0.5 rounded border border-white/[0.06] text-gray-600"
127
+ >
128
+ {t}
129
+ </span>
130
+ ))}
131
+ </div>
132
+ </div>
133
+ </section>
134
+
135
+ {/* Current task badge */}
136
+ <section>
137
+ <div
138
+ className={`rounded-xl border ${cfg.border} ${cfg.bg} p-3 flex flex-col gap-1.5`}
139
+ >
140
+ <div className="flex items-center justify-between">
141
+ <span className={`text-[10px] font-semibold uppercase tracking-wider ${cfg.text}`}>
142
+ Current Task
143
+ </span>
144
+ <span className={`text-[10px] font-mono ${cfg.text}`}>{cfg.label}</span>
145
+ </div>
146
+ <p className="text-[11px] text-gray-400 leading-relaxed">
147
+ {taskDifficulty === 'easy'
148
+ ? 'Simple SELECT queries, basic filtering and aggregation'
149
+ : taskDifficulty === 'medium'
150
+ ? 'Multi-table JOINs, GROUP BY, subqueries, window functions'
151
+ : 'Complex CTEs, rolling aggregations, cohort analysis, ranking'}
152
+ </p>
153
+ </div>
154
+ </section>
155
+ </div>
156
+ )
157
+ }
frontend/src/components/PerformanceGraph.tsx ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect } from 'react'
2
+ import {
3
+ LineChart, Line, BarChart, Bar,
4
+ XAxis, YAxis, CartesianGrid, Tooltip,
5
+ ResponsiveContainer, ReferenceLine,
6
+ } from 'recharts'
7
+ import { TrendingUp, Loader2, RefreshCw } from 'lucide-react'
8
+ import { useStore } from '../store/useStore'
9
+ import { fetchRLState } from '../lib/api'
10
+
11
+ const CustomTooltip = ({
12
+ active,
13
+ payload,
14
+ label,
15
+ }: {
16
+ active?: boolean
17
+ payload?: { value: number; name: string; color: string }[]
18
+ label?: string | number
19
+ }) => {
20
+ if (active && payload?.length) {
21
+ return (
22
+ <div
23
+ className="border border-white/10 rounded-lg px-3 py-2 text-xs"
24
+ style={{ background: '#1a1a2e' }}
25
+ >
26
+ <p className="text-gray-400 mb-1">#{label}</p>
27
+ {payload.map((p) => (
28
+ <p key={p.name} style={{ color: p.color }}>
29
+ {p.name}: <span className="font-semibold">{p.value}</span>
30
+ </p>
31
+ ))}
32
+ </div>
33
+ )
34
+ }
35
+ return null
36
+ }
37
+
38
+ export function PerformanceGraph() {
39
+ const { rlState, setRlState } = useStore()
40
+
41
+ const load = async () => {
42
+ try {
43
+ const data = await fetchRLState()
44
+ setRlState(data)
45
+ } catch {
46
+ // noop β€” backend might not be up
47
+ }
48
+ }
49
+
50
+ useEffect(() => {
51
+ void load()
52
+ const interval = setInterval(() => void load(), 10_000)
53
+ return () => clearInterval(interval)
54
+ // eslint-disable-next-line react-hooks/exhaustive-deps
55
+ }, [])
56
+
57
+ if (!rlState) {
58
+ return (
59
+ <div className="flex flex-col items-center justify-center h-40 text-gray-600 gap-2">
60
+ <TrendingUp size={24} className="text-gray-700" />
61
+ <p className="text-[11px] text-center">
62
+ RL metrics appear after agent episodes
63
+ </p>
64
+ <Loader2 size={14} className="animate-spin text-gray-700" />
65
+ </div>
66
+ )
67
+ }
68
+
69
+ const { totalEpisodes, successRate, currentAlpha, episodes, actionDistribution } = rlState
70
+
71
+ return (
72
+ <div className="flex flex-col gap-3">
73
+ {/* Stats row */}
74
+ <div className="grid grid-cols-3 gap-1.5">
75
+ {[
76
+ { label: 'Episodes', value: totalEpisodes, color: 'text-blue-400' },
77
+ { label: 'Success', value: `${(successRate * 100).toFixed(0)}%`, color: 'text-green-400' },
78
+ { label: 'Alpha', value: currentAlpha.toFixed(3), color: 'text-orange-400' },
79
+ ].map((s) => (
80
+ <div
81
+ key={s.label}
82
+ className="bg-white/5 rounded-xl p-2 text-center"
83
+ >
84
+ <div className={`text-sm font-bold font-mono ${s.color}`}>{s.value}</div>
85
+ <div className="text-[9px] text-gray-500 mt-0.5">{s.label}</div>
86
+ </div>
87
+ ))}
88
+ </div>
89
+
90
+ {/* Reward per episode */}
91
+ {episodes.length > 0 && (
92
+ <div>
93
+ <div className="flex items-center justify-between mb-1.5">
94
+ <div className="text-[10px] text-gray-500 font-medium">Reward per Episode</div>
95
+ <button
96
+ onClick={() => void load()}
97
+ className="p-1 rounded hover:bg-white/5 text-gray-600 hover:text-gray-400 transition-colors"
98
+ title="Refresh"
99
+ >
100
+ <RefreshCw size={10} />
101
+ </button>
102
+ </div>
103
+ <ResponsiveContainer width="100%" height={110}>
104
+ <LineChart data={episodes} margin={{ top: 4, right: 4, bottom: 0, left: -20 }}>
105
+ <CartesianGrid strokeDasharray="3 3" stroke="#ffffff08" />
106
+ <XAxis dataKey="episode" tick={{ fontSize: 9, fill: '#6b7280' }} />
107
+ <YAxis domain={[-1, 1]} tick={{ fontSize: 9, fill: '#6b7280' }} />
108
+ <Tooltip content={<CustomTooltip />} />
109
+ <ReferenceLine y={0} stroke="#ffffff20" strokeDasharray="3 3" />
110
+ <Line
111
+ type="monotone"
112
+ dataKey="totalReward"
113
+ name="Reward"
114
+ stroke="#f97316"
115
+ strokeWidth={2}
116
+ dot={episodes.length < 30 ? { fill: '#f97316', r: 2 } : false}
117
+ activeDot={{ r: 4 }}
118
+ />
119
+ </LineChart>
120
+ </ResponsiveContainer>
121
+ </div>
122
+ )}
123
+
124
+ {/* Action distribution */}
125
+ {actionDistribution.length > 0 && (
126
+ <div>
127
+ <div className="text-[10px] text-gray-500 mb-1.5 font-medium">
128
+ LinUCB Action Distribution
129
+ </div>
130
+ <ResponsiveContainer width="100%" height={90}>
131
+ <BarChart
132
+ data={actionDistribution}
133
+ margin={{ top: 4, right: 4, bottom: 0, left: -20 }}
134
+ >
135
+ <CartesianGrid strokeDasharray="3 3" stroke="#ffffff08" />
136
+ <XAxis
137
+ dataKey="action"
138
+ tick={{ fontSize: 8, fill: '#6b7280' }}
139
+ tickFormatter={(v: string) => v.replace('FIX_', '').slice(0, 6)}
140
+ />
141
+ <YAxis tick={{ fontSize: 9, fill: '#6b7280' }} />
142
+ <Tooltip content={<CustomTooltip />} />
143
+ <Bar dataKey="count" name="Uses" fill="#8b5cf6" radius={[3, 3, 0, 0]} />
144
+ </BarChart>
145
+ </ResponsiveContainer>
146
+ </div>
147
+ )}
148
+
149
+ {/* Success rate line */}
150
+ {episodes.length >= 3 && (
151
+ <div>
152
+ <div className="text-[10px] text-gray-500 mb-1.5 font-medium">
153
+ Rolling Success Rate
154
+ </div>
155
+ <ResponsiveContainer width="100%" height={80}>
156
+ <LineChart data={episodes} margin={{ top: 4, right: 4, bottom: 0, left: -20 }}>
157
+ <CartesianGrid strokeDasharray="3 3" stroke="#ffffff08" />
158
+ <XAxis dataKey="episode" tick={{ fontSize: 9, fill: '#6b7280' }} />
159
+ <YAxis domain={[0, 1]} tick={{ fontSize: 9, fill: '#6b7280' }} />
160
+ <Tooltip content={<CustomTooltip />} />
161
+ <Line
162
+ type="monotone"
163
+ dataKey="successRate"
164
+ name="Success"
165
+ stroke="#22c55e"
166
+ strokeWidth={2}
167
+ dot={false}
168
+ />
169
+ </LineChart>
170
+ </ResponsiveContainer>
171
+ </div>
172
+ )}
173
+ </div>
174
+ )
175
+ }
frontend/src/components/PromptEvolution.tsx ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState, useEffect } from 'react'
2
+ import { motion, AnimatePresence } from 'framer-motion'
3
+ import { Brain, ChevronDown, ChevronUp, Zap, History } from 'lucide-react'
4
+ import { useStore } from '../store/useStore'
5
+ import { fetchPromptHistory } from '../lib/api'
6
+
7
+ const SEED_PROMPT = `You are a SQL expert. Given a natural language question and a SQLite database schema, write a correct SQL query.
8
+
9
+ Rules:
10
+ - Output ONLY the SQL query, nothing else
11
+ - No markdown, no code fences, no explanation
12
+ - Use SQLite syntax
13
+ - Always qualify column names with table aliases when using JOINs`
14
+
15
+ export function PromptEvolution() {
16
+ const { currentPrompt, promptGeneration, promptHistory, setPromptData } = useStore()
17
+ const [expanded, setExpanded] = useState(false)
18
+ const [historyExpanded, setHistoryExpanded] = useState(false)
19
+ const [loading, setLoading] = useState(false)
20
+
21
+ const prompt = currentPrompt || SEED_PROMPT
22
+ const generation = promptGeneration
23
+
24
+ const loadHistory = async () => {
25
+ setLoading(true)
26
+ try {
27
+ const data = await fetchPromptHistory()
28
+ setPromptData(data)
29
+ } catch {
30
+ // noop
31
+ } finally {
32
+ setLoading(false)
33
+ }
34
+ }
35
+
36
+ useEffect(() => {
37
+ void loadHistory()
38
+ // eslint-disable-next-line react-hooks/exhaustive-deps
39
+ }, [])
40
+
41
+ return (
42
+ <div className="flex flex-col gap-2">
43
+ {/* Header */}
44
+ <button
45
+ onClick={() => setExpanded((v) => !v)}
46
+ className="flex items-center justify-between w-full group"
47
+ >
48
+ <div className="flex items-center gap-2">
49
+ <Brain size={14} className="text-violet-400" />
50
+ <span className="text-xs font-semibold text-white/70">System Prompt</span>
51
+ {generation > 0 ? (
52
+ <span className="text-[10px] bg-violet-500/20 text-violet-300 border border-violet-500/30 rounded-full px-2 py-0.5">
53
+ Gen {generation} Β· Optimized
54
+ </span>
55
+ ) : (
56
+ <span className="text-[10px] bg-white/5 text-gray-500 rounded-full px-2 py-0.5">
57
+ Seed
58
+ </span>
59
+ )}
60
+ </div>
61
+ {expanded ? (
62
+ <ChevronUp size={13} className="text-gray-500" />
63
+ ) : (
64
+ <ChevronDown size={13} className="text-gray-500" />
65
+ )}
66
+ </button>
67
+
68
+ <AnimatePresence>
69
+ {expanded && (
70
+ <motion.div
71
+ initial={{ opacity: 0, height: 0 }}
72
+ animate={{ opacity: 1, height: 'auto' }}
73
+ exit={{ opacity: 0, height: 0 }}
74
+ transition={{ duration: 0.2 }}
75
+ className="overflow-hidden"
76
+ >
77
+ {/* Prompt preview */}
78
+ <div className="max-h-40 overflow-y-auto">
79
+ <pre className="text-[11px] font-mono text-violet-200/70 bg-violet-950/30 rounded-xl p-3 border border-violet-500/20 whitespace-pre-wrap leading-relaxed">
80
+ {prompt}
81
+ </pre>
82
+ </div>
83
+
84
+ {/* History button */}
85
+ {promptHistory.length > 0 && (
86
+ <button
87
+ onClick={() => setHistoryExpanded((v) => !v)}
88
+ className="mt-2 w-full flex items-center justify-center gap-2 px-3 py-2 text-xs font-medium bg-violet-600/15 text-violet-300 border border-violet-500/25 rounded-xl hover:bg-violet-600/25 hover:border-violet-500/40 transition-all"
89
+ >
90
+ <History size={12} />
91
+ {historyExpanded ? 'Hide' : 'View'} Evolution History
92
+ <span className="text-[10px] text-violet-400/60 ml-1">
93
+ ({promptHistory.length} gen{promptHistory.length !== 1 ? 's' : ''})
94
+ </span>
95
+ </button>
96
+ )}
97
+
98
+ {/* Generation history */}
99
+ <AnimatePresence>
100
+ {historyExpanded && promptHistory.length > 0 && (
101
+ <motion.div
102
+ initial={{ height: 0, opacity: 0 }}
103
+ animate={{ height: 'auto', opacity: 1 }}
104
+ exit={{ height: 0, opacity: 0 }}
105
+ transition={{ duration: 0.15 }}
106
+ className="overflow-hidden mt-2"
107
+ >
108
+ <div className="flex flex-col gap-1.5">
109
+ <div className="text-[10px] text-gray-500 font-medium flex items-center gap-1">
110
+ <Zap size={10} className="text-violet-400" />
111
+ Optimization History
112
+ </div>
113
+ {promptHistory.map((snap) => (
114
+ <div
115
+ key={snap.generation}
116
+ className="border border-white/5 rounded-xl p-2.5 hover:border-white/10 hover:bg-white/[0.02] transition-all"
117
+ >
118
+ <div className="flex items-center justify-between mb-1">
119
+ <span className="text-[10px] font-semibold text-violet-400">
120
+ Generation {snap.generation}
121
+ </span>
122
+ <span className="text-[10px] font-mono text-green-400">
123
+ {(snap.score * 100).toFixed(0)}%
124
+ </span>
125
+ </div>
126
+ <p className="text-[10px] text-gray-400 leading-relaxed line-clamp-2">
127
+ {snap.summary}
128
+ </p>
129
+ <p className="text-[9px] text-gray-600 mt-1">{snap.timestamp}</p>
130
+ </div>
131
+ ))}
132
+ </div>
133
+ </motion.div>
134
+ )}
135
+ </AnimatePresence>
136
+
137
+ {loading && (
138
+ <div className="flex items-center gap-2 text-[10px] text-gray-500 mt-2 px-1">
139
+ <span className="w-3 h-3 border border-violet-500/40 border-t-violet-400 rounded-full animate-spin inline-block" />
140
+ Loading history...
141
+ </div>
142
+ )}
143
+ </motion.div>
144
+ )}
145
+ </AnimatePresence>
146
+ </div>
147
+ )
148
+ }
frontend/src/components/ResultsTable.tsx ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const MAX_ROWS = 10
2
+ const MAX_CELL_LEN = 30
3
+
4
+ function truncate(val: unknown): string {
5
+ const s = val === null || val === undefined ? 'null' : String(val)
6
+ return s.length > MAX_CELL_LEN ? s.slice(0, MAX_CELL_LEN) + '…' : s
7
+ }
8
+
9
+ interface ResultsTableProps {
10
+ rows: Record<string, unknown>[]
11
+ rowCount: number
12
+ }
13
+
14
+ export function ResultsTable({ rows, rowCount }: ResultsTableProps) {
15
+ if (rows.length === 0) {
16
+ return (
17
+ <div className="text-xs text-gray-500 italic px-3 py-2 border border-white/[0.06] rounded-xl">
18
+ No rows returned.
19
+ </div>
20
+ )
21
+ }
22
+
23
+ const columns = Object.keys(rows[0])
24
+ const displayRows = rows.slice(0, MAX_ROWS)
25
+
26
+ return (
27
+ <div className="overflow-auto max-h-60 rounded-xl border border-white/[0.06]" style={{ fontSize: 11 }}>
28
+ {rowCount > MAX_ROWS && (
29
+ <div className="px-3 py-1 text-[10px] text-amber-400/70 bg-amber-500/5 border-b border-amber-500/10 shrink-0">
30
+ Showing {MAX_ROWS} of {rowCount} rows
31
+ </div>
32
+ )}
33
+ <table className="w-full font-mono border-collapse">
34
+ <thead>
35
+ <tr
36
+ className="border-b border-white/[0.06] sticky top-0"
37
+ style={{ background: 'var(--bg-tertiary)' }}
38
+ >
39
+ {columns.map((col) => (
40
+ <th
41
+ key={col}
42
+ className="px-3 py-1.5 text-left text-[10px] font-semibold text-gray-500 uppercase tracking-wider whitespace-nowrap"
43
+ >
44
+ {col}
45
+ </th>
46
+ ))}
47
+ </tr>
48
+ </thead>
49
+ <tbody>
50
+ {displayRows.map((row, i) => (
51
+ <tr
52
+ key={i}
53
+ className="border-b border-white/[0.03] hover:bg-white/[0.02] transition-colors"
54
+ >
55
+ {columns.map((col) => (
56
+ <td
57
+ key={col}
58
+ className={`px-3 py-1.5 whitespace-nowrap ${
59
+ row[col] === null ? 'text-gray-600 italic' : 'text-gray-300'
60
+ }`}
61
+ title={row[col] !== null ? String(row[col]) : undefined}
62
+ >
63
+ {truncate(row[col])}
64
+ </td>
65
+ ))}
66
+ </tr>
67
+ ))}
68
+ </tbody>
69
+ </table>
70
+ <div
71
+ className="px-3 py-1 text-[10px] text-gray-600 border-t border-white/[0.04]"
72
+ style={{ background: 'var(--bg-tertiary)' }}
73
+ >
74
+ Showing {displayRows.length} of {rowCount} rows
75
+ </div>
76
+ </div>
77
+ )
78
+ }
frontend/src/components/RightSidebar.tsx ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { Zap, Brain } from 'lucide-react'
2
+ import { PromptEvolution } from './PromptEvolution'
3
+ import { PerformanceGraph } from './PerformanceGraph'
4
+
5
+ export function RightSidebar() {
6
+ return (
7
+ <div className="flex flex-col h-full overflow-y-auto">
8
+ {/* GEPA Section */}
9
+ <div className="p-4 border-b border-white/[0.06] shrink-0">
10
+ <div className="flex items-center gap-2 text-[10px] font-semibold text-gray-500 uppercase tracking-widest mb-3">
11
+ <Brain size={10} className="text-violet-400" />
12
+ GEPA Prompt Evolution
13
+ </div>
14
+ <PromptEvolution />
15
+ </div>
16
+
17
+ {/* RL Charts */}
18
+ <div className="p-4 flex-1 overflow-y-auto">
19
+ <div className="flex items-center gap-2 text-[10px] font-semibold text-gray-500 uppercase tracking-widest mb-3">
20
+ <Zap size={10} className="text-violet-400" />
21
+ RL Learning Progress
22
+ </div>
23
+ <PerformanceGraph />
24
+ </div>
25
+ </div>
26
+ )
27
+ }
frontend/src/index.css ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @tailwind base;
2
+ @tailwind components;
3
+ @tailwind utilities;
4
+
5
+ /* ─── Theme Variables ──────────────────────────────────────────── */
6
+
7
+ :root {
8
+ --bg-primary: #08080d;
9
+ --bg-secondary: #09090f;
10
+ --bg-tertiary: #0a0a12;
11
+ --bg-card: #0e0e16;
12
+ --bg-input: rgba(255, 255, 255, 0.03);
13
+ --bg-hover: rgba(255, 255, 255, 0.02);
14
+ --bg-hover-strong: rgba(255, 255, 255, 0.05);
15
+
16
+ --text-primary: #ffffff;
17
+ --text-secondary: rgba(255, 255, 255, 0.7);
18
+ --text-muted: #6b7280;
19
+ --text-dim: #4b5563;
20
+
21
+ --border-color: rgba(255, 255, 255, 0.06);
22
+ --border-hover: rgba(255, 255, 255, 0.12);
23
+
24
+ --accent-violet: #8b5cf6;
25
+ --accent-green: #22c55e;
26
+ --accent-orange: #f97316;
27
+ --accent-red: #ef4444;
28
+ --accent-blue: #3b82f6;
29
+
30
+ --background: var(--bg-primary);
31
+ --foreground: var(--text-primary);
32
+ }
33
+
34
+ [data-theme="light"] {
35
+ --bg-primary: #f5f6f8;
36
+ --bg-secondary: #ffffff;
37
+ --bg-tertiary: #eef0f3;
38
+ --bg-card: #ffffff;
39
+ --bg-input: rgba(0, 0, 0, 0.04);
40
+ --bg-hover: rgba(0, 0, 0, 0.02);
41
+ --bg-hover-strong: rgba(0, 0, 0, 0.05);
42
+
43
+ --text-primary: #111827;
44
+ --text-secondary: #374151;
45
+ --text-muted: #6b7280;
46
+ --text-dim: #9ca3af;
47
+
48
+ --border-color: rgba(0, 0, 0, 0.1);
49
+ --border-hover: rgba(0, 0, 0, 0.2);
50
+
51
+ --background: var(--bg-primary);
52
+ --foreground: var(--text-primary);
53
+ }
54
+
55
+ /* ─── Base ─────────────────────────────────────────────────────── */
56
+
57
+ * {
58
+ box-sizing: border-box;
59
+ margin: 0;
60
+ padding: 0;
61
+ }
62
+
63
+ html,
64
+ body,
65
+ #root {
66
+ height: 100%;
67
+ width: 100%;
68
+ }
69
+
70
+ body {
71
+ background: var(--bg-primary);
72
+ color: var(--text-primary);
73
+ font-family: ui-monospace, 'SF Mono', Consolas, 'Liberation Mono', monospace;
74
+ -webkit-font-smoothing: antialiased;
75
+ -moz-osx-font-smoothing: grayscale;
76
+ }
77
+
78
+ /* ─── Theme Utility Classes ──────────────────────────────────────── */
79
+
80
+ .theme-bg-primary { background-color: var(--bg-primary) !important; }
81
+ .theme-bg-secondary { background-color: var(--bg-secondary) !important; }
82
+ .theme-bg-tertiary { background-color: var(--bg-tertiary) !important; }
83
+ .theme-bg-card { background-color: var(--bg-card) !important; }
84
+ .theme-text-primary { color: var(--text-primary) !important; }
85
+ .theme-text-secondary { color: var(--text-secondary) !important; }
86
+ .theme-text-muted { color: var(--text-muted) !important; }
87
+ .theme-border { border-color: var(--border-color) !important; }
88
+ .theme-border-hover { border-color: var(--border-hover) !important; }
89
+
90
+ /* ─── Scrollbars ─────────────────────────────────────────────────── */
91
+
92
+ .scrollbar-none {
93
+ -ms-overflow-style: none;
94
+ scrollbar-width: none;
95
+ }
96
+ .scrollbar-none::-webkit-scrollbar {
97
+ display: none;
98
+ }
99
+
100
+ ::-webkit-scrollbar {
101
+ width: 4px;
102
+ height: 4px;
103
+ }
104
+ ::-webkit-scrollbar-track {
105
+ background: transparent;
106
+ }
107
+ ::-webkit-scrollbar-thumb {
108
+ background: rgba(255, 255, 255, 0.1);
109
+ border-radius: 4px;
110
+ }
111
+ [data-theme="light"] ::-webkit-scrollbar-thumb {
112
+ background: rgba(0, 0, 0, 0.15);
113
+ }
114
+
115
+ /* ─── SQL Syntax Highlighting ───────────────────────────────────── */
116
+
117
+ .sql-keyword { color: #a78bfa; font-weight: 600; }
118
+ .sql-function { color: #60a5fa; }
119
+ .sql-string { color: #34d399; }
120
+ .sql-number { color: #f97316; }
121
+ .sql-comment { color: #6b7280; font-style: italic; }
122
+ .sql-operator { color: #e5e7eb; }
123
+
124
+ /* ─── Blinking cursor ──────────────────────────────────────────── */
125
+
126
+ @keyframes blink {
127
+ 0%, 100% { opacity: 1; }
128
+ 50% { opacity: 0; }
129
+ }
130
+ .cursor-blink {
131
+ display: inline-block;
132
+ width: 2px;
133
+ height: 1em;
134
+ background: currentColor;
135
+ animation: blink 1s step-end infinite;
136
+ vertical-align: text-bottom;
137
+ margin-left: 1px;
138
+ }
139
+
140
+ /* ─── Reward pulse animation ─────────────────────────────────────── */
141
+
142
+ @keyframes rewardPulse {
143
+ 0% { transform: scale(1); opacity: 0.7; }
144
+ 50% { transform: scale(1.15); opacity: 1; }
145
+ 100% { transform: scale(1); opacity: 1; }
146
+ }
147
+ .reward-pulse {
148
+ animation: rewardPulse 0.5s ease-out;
149
+ }
150
+
151
+ /* ─── Optimizing banner ──────────────────────────────────────────── */
152
+
153
+ @keyframes shimmer {
154
+ 0% { background-position: -200% 0; }
155
+ 100% { background-position: 200% 0; }
156
+ }
157
+ .shimmer-banner {
158
+ background: linear-gradient(
159
+ 90deg,
160
+ rgba(139, 92, 246, 0.15) 0%,
161
+ rgba(139, 92, 246, 0.3) 50%,
162
+ rgba(139, 92, 246, 0.15) 100%
163
+ );
164
+ background-size: 200% 100%;
165
+ animation: shimmer 2s linear infinite;
166
+ }
167
+
168
+ /* ─── Light Mode Global Overrides ─────────────────────────────── */
169
+
170
+ [data-theme="light"] .text-white { color: var(--text-primary) !important; }
171
+ [data-theme="light"] .text-white\/70 { color: var(--text-secondary) !important; }
172
+ [data-theme="light"] .text-gray-200 { color: #1f2937 !important; }
173
+ [data-theme="light"] .text-gray-300 { color: #374151 !important; }
174
+ [data-theme="light"] .text-gray-400 { color: #4b5563 !important; }
175
+ [data-theme="light"] .text-gray-500 { color: #6b7280 !important; }
176
+ [data-theme="light"] .text-gray-600 { color: #9ca3af !important; }
177
+ [data-theme="light"] .text-violet-300 { color: #7c3aed !important; }
178
+ [data-theme="light"] .text-violet-400 { color: #7c3aed !important; }
179
+ [data-theme="light"] .text-green-400 { color: #15803d !important; }
180
+ [data-theme="light"] .text-red-400 { color: #b91c1c !important; }
181
+ [data-theme="light"] pre {
182
+ background-color: var(--bg-tertiary) !important;
183
+ color: #374151 !important;
184
+ }
185
+ [data-theme="light"] .recharts-cartesian-grid line {
186
+ stroke: rgba(0, 0, 0, 0.06) !important;
187
+ }
frontend/src/lib/api.ts ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { InitResponse, RLState, SchemaGraph, SSEEvent } from './types'
2
+
3
+ const BASE_URL: string = import.meta.env.VITE_API_URL ?? ''
4
+
5
+ async function* parseSSE(response: Response): AsyncGenerator<SSEEvent> {
6
+ const reader = response.body!.getReader()
7
+ const decoder = new TextDecoder()
8
+ let buffer = ''
9
+
10
+ while (true) {
11
+ const { done, value } = await reader.read()
12
+ if (done) break
13
+ buffer += decoder.decode(value, { stream: true })
14
+ const lines = buffer.split('\n')
15
+ buffer = lines.pop() ?? ''
16
+
17
+ for (const line of lines) {
18
+ if (!line.startsWith('data: ')) continue
19
+ const raw = line.slice(6).trim()
20
+ if (raw === '[DONE]') return
21
+ try {
22
+ yield JSON.parse(raw) as SSEEvent
23
+ } catch {
24
+ // ignore malformed lines
25
+ }
26
+ }
27
+ }
28
+ }
29
+
30
+ export async function* streamExecuteQuery(
31
+ question: string,
32
+ taskId: string
33
+ ): AsyncGenerator<SSEEvent> {
34
+ const res = await fetch(`${BASE_URL}/api/execute-query`, {
35
+ method: 'POST',
36
+ headers: { 'Content-Type': 'application/json' },
37
+ body: JSON.stringify({ question, task_id: taskId }),
38
+ })
39
+ if (!res.ok) {
40
+ throw new Error(`HTTP ${res.status}: ${res.statusText}`)
41
+ }
42
+ yield* parseSSE(res)
43
+ }
44
+
45
+ export async function* streamBenchmark(
46
+ taskId: string,
47
+ queryIds?: string[]
48
+ ): AsyncGenerator<SSEEvent> {
49
+ const body: Record<string, unknown> = { task_id: taskId }
50
+ if (queryIds) body.queryIds = queryIds
51
+
52
+ const res = await fetch(`${BASE_URL}/api/benchmark`, {
53
+ method: 'POST',
54
+ headers: { 'Content-Type': 'application/json' },
55
+ body: JSON.stringify(body),
56
+ })
57
+ if (!res.ok) {
58
+ throw new Error(`HTTP ${res.status}: ${res.statusText}`)
59
+ }
60
+ yield* parseSSE(res)
61
+ }
62
+
63
+ export async function fetchInit(): Promise<InitResponse> {
64
+ const res = await fetch(`${BASE_URL}/api/init`)
65
+ if (!res.ok) throw new Error(`HTTP ${res.status}`)
66
+ return res.json() as Promise<InitResponse>
67
+ }
68
+
69
+ export async function fetchRLState(): Promise<RLState> {
70
+ const res = await fetch(`${BASE_URL}/api/rl-state`)
71
+ if (!res.ok) throw new Error(`HTTP ${res.status}`)
72
+ return res.json() as Promise<RLState>
73
+ }
74
+
75
+ export async function fetchSchemaGraph(): Promise<SchemaGraph> {
76
+ const res = await fetch(`${BASE_URL}/api/schema-graph`)
77
+ if (!res.ok) throw new Error(`HTTP ${res.status}`)
78
+ return res.json() as Promise<SchemaGraph>
79
+ }
80
+
81
+ export async function submitFeedback(
82
+ question: string,
83
+ sql: string,
84
+ correct: boolean
85
+ ): Promise<void> {
86
+ await fetch(`${BASE_URL}/api/feedback`, {
87
+ method: 'POST',
88
+ headers: { 'Content-Type': 'application/json' },
89
+ body: JSON.stringify({ question, sql, correct }),
90
+ })
91
+ }
92
+
93
+ export async function fetchPromptHistory() {
94
+ const res = await fetch(`${BASE_URL}/api/prompt-history`)
95
+ if (!res.ok) throw new Error(`HTTP ${res.status}`)
96
+ return res.json()
97
+ }
frontend/src/lib/types.ts ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // ─── Chat Types ──────────────────────────────────────────────────
2
+
3
+ export type MessageStatus = 'streaming' | 'done' | 'error'
4
+ export type FeedbackType = 'correct' | 'wrong' | null
5
+
6
+ export interface AttemptStep {
7
+ attempt: number
8
+ sql: string
9
+ error?: string
10
+ action?: string
11
+ actionScore?: number
12
+ reward?: number
13
+ }
14
+
15
+ export interface ChatMessage {
16
+ id: string
17
+ question: string
18
+ status: MessageStatus
19
+ sql: string
20
+ rows: Record<string, unknown>[]
21
+ rowCount: number
22
+ errorMsg?: string
23
+ attempts: number
24
+ steps: AttemptStep[]
25
+ reward?: number
26
+ rlAction?: string
27
+ rlActionScore?: number
28
+ feedback: FeedbackType
29
+ feedbackSending?: boolean
30
+ promptGeneration: number
31
+ streamingCursor?: boolean
32
+ }
33
+
34
+ // ─── Benchmark Types ─────────────────────────────────────────────
35
+
36
+ export type BenchmarkStatus = 'pending' | 'running' | 'pass' | 'fail'
37
+ export type Difficulty = 'easy' | 'medium' | 'hard'
38
+
39
+ export interface BenchmarkQuery {
40
+ id: string
41
+ question: string
42
+ difficulty: Difficulty
43
+ }
44
+
45
+ export interface BenchmarkResult {
46
+ id: string
47
+ question: string
48
+ difficulty: Difficulty
49
+ status: BenchmarkStatus
50
+ score: number | null
51
+ sql: string | null
52
+ reason: string | null
53
+ attempts: number | null
54
+ refRowCount: number | null
55
+ agentRowCount: number | null
56
+ }
57
+
58
+ // ─── RL State ────────────────────────────────────────────────────
59
+
60
+ export interface RLEpisode {
61
+ episode: number
62
+ totalReward: number
63
+ successRate: number
64
+ }
65
+
66
+ export interface ActionCount {
67
+ action: string
68
+ count: number
69
+ }
70
+
71
+ export interface RLState {
72
+ totalEpisodes: number
73
+ successRate: number
74
+ currentAlpha: number
75
+ episodes: RLEpisode[]
76
+ actionDistribution: ActionCount[]
77
+ currentGeneration: number
78
+ }
79
+
80
+ // ─── GEPA / Prompt ───────────────────────────────────────────────
81
+
82
+ export interface PromptSnapshot {
83
+ generation: number
84
+ prompt: string
85
+ score: number
86
+ summary: string
87
+ timestamp: string
88
+ }
89
+
90
+ // ─── Schema ──────────────────────────────────────────────────────
91
+
92
+ export interface TableInfo {
93
+ name: string
94
+ rows: number
95
+ }
96
+
97
+ export interface ColumnInfo {
98
+ name: string
99
+ type: string
100
+ pk?: boolean
101
+ fk?: string
102
+ }
103
+
104
+ export interface SchemaTable {
105
+ name: string
106
+ columns: ColumnInfo[]
107
+ }
108
+
109
+ export interface SchemaRelationship {
110
+ from: string
111
+ fromCol: string
112
+ to: string
113
+ toCol: string
114
+ }
115
+
116
+ export interface SchemaGraph {
117
+ tables: SchemaTable[]
118
+ relationships: SchemaRelationship[]
119
+ }
120
+
121
+ // ─── API Response Types ──────────────────────────────────────────
122
+
123
+ export interface InitResponse {
124
+ seeded: boolean
125
+ tables: TableInfo[]
126
+ }
127
+
128
+ export interface SSEEvent {
129
+ type: string
130
+ [key: string]: unknown
131
+ }
frontend/src/main.tsx ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React from 'react'
2
+ import ReactDOM from 'react-dom/client'
3
+ import App from './App'
4
+ import './index.css'
5
+
6
+ // Restore persisted theme
7
+ try {
8
+ const saved = localStorage.getItem('theme') as 'dark' | 'light' | null
9
+ if (saved) document.documentElement.setAttribute('data-theme', saved)
10
+ else document.documentElement.setAttribute('data-theme', 'dark')
11
+ } catch {
12
+ document.documentElement.setAttribute('data-theme', 'dark')
13
+ }
14
+
15
+ ReactDOM.createRoot(document.getElementById('root')!).render(
16
+ <React.StrictMode>
17
+ <App />
18
+ </React.StrictMode>
19
+ )
frontend/src/store/useStore.ts ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { create } from 'zustand'
2
+ import type {
3
+ ChatMessage,
4
+ BenchmarkResult,
5
+ RLState,
6
+ TableInfo,
7
+ SchemaGraph,
8
+ PromptSnapshot,
9
+ Difficulty,
10
+ } from '../lib/types'
11
+
12
+ interface Store {
13
+ // Theme
14
+ theme: 'dark' | 'light'
15
+ toggleTheme: () => void
16
+
17
+ // Task
18
+ taskId: string
19
+ taskDifficulty: Difficulty
20
+ setTaskId: (id: string) => void
21
+ setTaskDifficulty: (d: Difficulty) => void
22
+
23
+ // Init / DB
24
+ dbSeeded: boolean
25
+ setDbSeeded: (v: boolean) => void
26
+ tables: TableInfo[]
27
+ setTables: (tables: TableInfo[]) => void
28
+ schemaGraph: SchemaGraph | null
29
+ setSchemaGraph: (g: SchemaGraph) => void
30
+
31
+ // Chat
32
+ messages: ChatMessage[]
33
+ addMessage: (msg: ChatMessage) => void
34
+ updateMessage: (id: string, update: Partial<ChatMessage>) => void
35
+ clearMessages: () => void
36
+ isExecuting: boolean
37
+ setIsExecuting: (v: boolean) => void
38
+ optimizingBanner: boolean
39
+ setOptimizingBanner: (v: boolean) => void
40
+
41
+ // Benchmark
42
+ benchmarkResults: BenchmarkResult[]
43
+ setBenchmarkResults: (r: BenchmarkResult[]) => void
44
+ updateBenchmarkResult: (r: BenchmarkResult) => void
45
+ resetBenchmark: () => void
46
+ isBenchmarking: boolean
47
+ setIsBenchmarking: (v: boolean) => void
48
+ activeBenchmarkId: string | null
49
+ setActiveBenchmarkId: (id: string | null) => void
50
+ overallScore: number | null
51
+ setOverallScore: (s: number) => void
52
+
53
+ // RL State
54
+ rlState: RLState | null
55
+ setRlState: (s: RLState) => void
56
+
57
+ // GEPA / Prompt
58
+ currentPrompt: string
59
+ promptGeneration: number
60
+ promptHistory: PromptSnapshot[]
61
+ setPromptData: (data: { prompt: string; generation: number; history: PromptSnapshot[] }) => void
62
+ }
63
+
64
+ const EASY_QUERIES: BenchmarkResult[] = [
65
+ { id: 'E1', question: 'Show all products', difficulty: 'easy', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
66
+ { id: 'E2', question: 'List all users from the USA', difficulty: 'easy', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
67
+ { id: 'E3', question: 'What product categories exist?', difficulty: 'easy', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
68
+ { id: 'E4', question: 'How many orders are in the database?', difficulty: 'easy', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
69
+ { id: 'E5', question: 'Show all sellers with their names', difficulty: 'easy', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
70
+ ]
71
+
72
+ const MEDIUM_QUERIES: BenchmarkResult[] = [
73
+ { id: 'M1', question: 'Top 5 sellers by total revenue', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
74
+ { id: 'M2', question: 'Average order value by country', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
75
+ { id: 'M3', question: 'Products with stock below 10 units', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
76
+ { id: 'M4', question: 'Monthly order count for the last 12 months', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
77
+ { id: 'M5', question: 'Categories ranked by number of products', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
78
+ ]
79
+
80
+ const HARD_QUERIES: BenchmarkResult[] = [
81
+ { id: 'H1', question: 'Rolling 7-day revenue for the past 30 days', difficulty: 'hard', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
82
+ { id: 'H2', question: 'Seller ranking with rank change from previous month', difficulty: 'hard', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
83
+ { id: 'H3', question: 'Cohort retention analysis by signup month', difficulty: 'hard', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
84
+ { id: 'H4', question: 'Identify top products contributing to 80% of revenue (Pareto)', difficulty: 'hard', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
85
+ { id: 'H5', question: 'Customer lifetime value segmented by acquisition channel', difficulty: 'hard', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
86
+ ]
87
+
88
+ export const useStore = create<Store>((set) => ({
89
+ // Theme
90
+ theme: 'dark',
91
+ toggleTheme: () =>
92
+ set((s) => {
93
+ const next = s.theme === 'dark' ? 'light' : 'dark'
94
+ document.documentElement.setAttribute('data-theme', next)
95
+ try { localStorage.setItem('theme', next) } catch { /* noop */ }
96
+ return { theme: next }
97
+ }),
98
+
99
+ // Task
100
+ taskId: 'easy',
101
+ taskDifficulty: 'easy',
102
+ setTaskId: (id) => set({ taskId: id }),
103
+ setTaskDifficulty: (d) =>
104
+ set({
105
+ taskDifficulty: d,
106
+ taskId: d,
107
+ benchmarkResults:
108
+ d === 'easy' ? EASY_QUERIES : d === 'medium' ? MEDIUM_QUERIES : HARD_QUERIES,
109
+ overallScore: null,
110
+ }),
111
+
112
+ // Init
113
+ dbSeeded: false,
114
+ setDbSeeded: (v) => set({ dbSeeded: v }),
115
+ tables: [],
116
+ setTables: (tables) => set({ tables }),
117
+ schemaGraph: null,
118
+ setSchemaGraph: (g) => set({ schemaGraph: g }),
119
+
120
+ // Chat
121
+ messages: [],
122
+ addMessage: (msg) => set((s) => ({ messages: [...s.messages, msg] })),
123
+ updateMessage: (id, update) =>
124
+ set((s) => ({
125
+ messages: s.messages.map((m) => (m.id === id ? { ...m, ...update } : m)),
126
+ })),
127
+ clearMessages: () => set({ messages: [] }),
128
+ isExecuting: false,
129
+ setIsExecuting: (v) => set({ isExecuting: v }),
130
+ optimizingBanner: false,
131
+ setOptimizingBanner: (v) => set({ optimizingBanner: v }),
132
+
133
+ // Benchmark
134
+ benchmarkResults: EASY_QUERIES,
135
+ setBenchmarkResults: (r) => set({ benchmarkResults: r }),
136
+ updateBenchmarkResult: (r) =>
137
+ set((s) => ({
138
+ benchmarkResults: s.benchmarkResults.map((br) => (br.id === r.id ? r : br)),
139
+ })),
140
+ resetBenchmark: () =>
141
+ set((s) => ({
142
+ benchmarkResults: s.benchmarkResults.map((r) => ({
143
+ ...r,
144
+ status: 'pending' as const,
145
+ score: null,
146
+ sql: null,
147
+ reason: null,
148
+ attempts: null,
149
+ refRowCount: null,
150
+ agentRowCount: null,
151
+ })),
152
+ overallScore: null,
153
+ })),
154
+ isBenchmarking: false,
155
+ setIsBenchmarking: (v) => set({ isBenchmarking: v }),
156
+ activeBenchmarkId: null,
157
+ setActiveBenchmarkId: (id) => set({ activeBenchmarkId: id }),
158
+ overallScore: null,
159
+ setOverallScore: (s) => set({ overallScore: s }),
160
+
161
+ // RL State
162
+ rlState: null,
163
+ setRlState: (s) => set({ rlState: s }),
164
+
165
+ // GEPA
166
+ currentPrompt: '',
167
+ promptGeneration: 0,
168
+ promptHistory: [],
169
+ setPromptData: (data) =>
170
+ set({
171
+ currentPrompt: data.prompt,
172
+ promptGeneration: data.generation,
173
+ promptHistory: data.history,
174
+ }),
175
+ }))
frontend/src/vite-env.d.ts ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /// <reference types="vite/client" />
2
+
3
+ interface ImportMetaEnv {
4
+ readonly VITE_API_URL?: string
5
+ }
6
+
7
+ interface ImportMeta {
8
+ readonly env: ImportMetaEnv
9
+ }
frontend/tailwind.config.js ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /** @type {import('tailwindcss').Config} */
2
+ export default {
3
+ content: [
4
+ './index.html',
5
+ './src/**/*.{js,ts,jsx,tsx}',
6
+ ],
7
+ theme: {
8
+ extend: {
9
+ colors: {
10
+ 'bg-primary': '#08080d',
11
+ 'bg-secondary': '#09090f',
12
+ 'bg-card': '#0e0e16',
13
+ },
14
+ fontFamily: {
15
+ mono: ['ui-monospace', '"SF Mono"', 'Consolas', '"Liberation Mono"', 'monospace'],
16
+ },
17
+ },
18
+ },
19
+ plugins: [],
20
+ }
frontend/tsconfig.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compilerOptions": {
3
+ "target": "ES2020",
4
+ "useDefineForClassFields": true,
5
+ "lib": ["ES2020", "DOM", "DOM.Iterable"],
6
+ "module": "ESNext",
7
+ "skipLibCheck": true,
8
+ "moduleResolution": "bundler",
9
+ "allowImportingTsExtensions": true,
10
+ "resolveJsonModule": true,
11
+ "isolatedModules": true,
12
+ "noEmit": true,
13
+ "jsx": "react-jsx",
14
+ "strict": true,
15
+ "noUnusedLocals": true,
16
+ "noUnusedParameters": true,
17
+ "noFallthroughCasesInSwitch": true,
18
+ "baseUrl": ".",
19
+ "paths": {
20
+ "@/*": ["./src/*"]
21
+ }
22
+ },
23
+ "include": ["src"]
24
+ }
frontend/vite.config.ts ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { defineConfig } from 'vite'
2
+ import react from '@vitejs/plugin-react'
3
+ import path from 'path'
4
+
5
+ export default defineConfig({
6
+ plugins: [react()],
7
+ resolve: {
8
+ alias: {
9
+ '@': path.resolve(__dirname, './src'),
10
+ },
11
+ },
12
+ server: {
13
+ port: 5173,
14
+ proxy: {
15
+ '/api': {
16
+ target: 'http://localhost:8000',
17
+ changeOrigin: true,
18
+ },
19
+ '/env': {
20
+ target: 'http://localhost:8000',
21
+ changeOrigin: true,
22
+ },
23
+ },
24
+ },
25
+ build: {
26
+ outDir: 'dist',
27
+ },
28
+ })
inference.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQL Agent OpenEnv β€” Baseline Inference Script
3
+ ==============================================
4
+
5
+ Runs a baseline LLM agent against all 3 tasks of the SQL Agent OpenEnv environment.
6
+
7
+ Environment variables (required):
8
+ API_BASE_URL β€” OpenAI-compatible base URL (default: https://router.huggingface.co/v1)
9
+ MODEL_NAME β€” Model identifier (default: Qwen/Qwen2.5-72B-Instruct)
10
+ HF_TOKEN β€” Hugging Face / API key
11
+
12
+ STDOUT format (strictly enforced):
13
+ [START] task=<task_id> env=sql-agent-openenv model=<model>
14
+ [STEP] step=<n> action=<action> reward=<0.00> done=<true|false> error=<msg|null>
15
+ [END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...>
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import asyncio
21
+ import os
22
+ import sys
23
+ import textwrap
24
+ from typing import List, Optional
25
+
26
+ # ── Path setup (inference.py lives at repo root; backend is a subdirectory) ──
27
+ _BACKEND = os.path.join(os.path.dirname(os.path.abspath(__file__)), "backend")
28
+ if _BACKEND not in sys.path:
29
+ sys.path.insert(0, _BACKEND)
30
+
31
+ from openai import OpenAI # noqa: E402
32
+
33
+ from env.sql_env import SQLAgentEnv, Action, Observation # noqa: E402
34
+
35
+ # ── Config ────────────────────────────────────────────────────────────────────
36
+
37
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY", "")
38
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
39
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
40
+ BENCHMARK = "sql-agent-openenv"
41
+
42
+ TASKS = ["simple_queries", "join_queries", "complex_queries"]
43
+ MAX_STEPS = 5
44
+ TEMPERATURE = 0.2
45
+ MAX_TOKENS = 50
46
+
47
+ REPAIR_ACTIONS = [
48
+ "rewrite_full",
49
+ "fix_column",
50
+ "fix_table",
51
+ "add_groupby",
52
+ "rewrite_cte",
53
+ "fix_syntax",
54
+ "change_dialect",
55
+ "relax_filter",
56
+ ]
57
+
58
+ SYSTEM_PROMPT = textwrap.dedent("""
59
+ You are an expert SQL agent interacting with a SQL repair environment.
60
+
61
+ At each step you receive a natural language question, a database schema,
62
+ and optionally the last SQL attempt + error message.
63
+
64
+ Your job: pick ONE repair action from the list below that is most likely
65
+ to fix the SQL error on the next attempt.
66
+
67
+ Available actions:
68
+ generate β€” write fresh SQL from scratch (use on first attempt)
69
+ rewrite_full β€” completely rewrite the query from scratch
70
+ fix_column β€” fix wrong column name references
71
+ fix_table β€” fix wrong table name references
72
+ add_groupby β€” add or fix GROUP BY / aggregation clauses
73
+ rewrite_cte β€” restructure subqueries or CTEs
74
+ fix_syntax β€” fix syntax errors (brackets, commas, keywords)
75
+ change_dialect β€” convert to SQLite-compatible functions
76
+ relax_filter β€” broaden or remove overly strict WHERE conditions
77
+
78
+ Reply with ONLY the action name. No explanation. No punctuation.
79
+ Example: fix_column
80
+ """).strip()
81
+
82
+
83
+ # ── Logging ───────────────────────────────────────────────────────────────────
84
+
85
+ def log_start(task: str, model: str) -> None:
86
+ print(f"[START] task={task} env={BENCHMARK} model={model}", flush=True)
87
+
88
+
89
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
90
+ error_val = error.replace("\n", " ").strip() if error else "null"
91
+ done_val = str(done).lower()
92
+ print(
93
+ f"[STEP] step={step} action={action} reward={reward:.2f} "
94
+ f"done={done_val} error={error_val}",
95
+ flush=True,
96
+ )
97
+
98
+
99
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
100
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
101
+ print(
102
+ f"[END] success={str(success).lower()} steps={steps} "
103
+ f"score={score:.3f} rewards={rewards_str}",
104
+ flush=True,
105
+ )
106
+
107
+
108
+ # ── LLM helper ────────────────────────────────────────────────────────────────
109
+
110
+ def pick_action(
111
+ client: OpenAI,
112
+ obs: Observation,
113
+ step: int,
114
+ ) -> str:
115
+ """Ask the LLM to pick a repair action given the current observation."""
116
+ if step == 1 or obs.current_sql is None:
117
+ return "generate"
118
+
119
+ user_msg = textwrap.dedent(f"""
120
+ Question: {obs.question}
121
+
122
+ Current SQL (failed):
123
+ {obs.current_sql}
124
+
125
+ Error: {obs.error_message or "unknown"}
126
+ Error class: {obs.error_class or "unknown"}
127
+ Attempt number: {obs.attempt_number} of {obs.max_attempts}
128
+
129
+ Which repair action should I use next?
130
+ """).strip()
131
+
132
+ try:
133
+ completion = client.chat.completions.create(
134
+ model=MODEL_NAME,
135
+ messages=[
136
+ {"role": "system", "content": SYSTEM_PROMPT},
137
+ {"role": "user", "content": user_msg},
138
+ ],
139
+ temperature=TEMPERATURE,
140
+ max_tokens=MAX_TOKENS,
141
+ )
142
+ raw = (completion.choices[0].message.content or "").strip().lower()
143
+ # Normalise to valid action name
144
+ for action in REPAIR_ACTIONS:
145
+ if action in raw:
146
+ return action
147
+ return "rewrite_full"
148
+ except Exception as exc:
149
+ print(f"[DEBUG] LLM call failed: {exc}", flush=True)
150
+ return "rewrite_full"
151
+
152
+
153
+ # ── Single-episode runner ─────────────────────────────────────────────────────
154
+
155
+ async def run_episode(
156
+ env: SQLAgentEnv,
157
+ client: OpenAI,
158
+ task_id: str,
159
+ ) -> None:
160
+ """Run one full episode for a task, emitting structured stdout logs."""
161
+ log_start(task=task_id, model=MODEL_NAME)
162
+
163
+ rewards: List[float] = []
164
+ steps_taken = 0
165
+ score = 0.0
166
+ success = False
167
+ last_error: Optional[str] = None
168
+
169
+ try:
170
+ obs = env.reset(task_id)
171
+
172
+ for step in range(1, MAX_STEPS + 1):
173
+ action_name = pick_action(client, obs, step)
174
+ action = Action(repair_action=action_name)
175
+
176
+ try:
177
+ obs, reward_info = await env.step(action)
178
+ except RuntimeError as exc:
179
+ log_step(step=step, action=action_name, reward=0.0, done=True, error=str(exc))
180
+ rewards.append(0.0)
181
+ steps_taken = step
182
+ break
183
+
184
+ reward = reward_info.value
185
+ done = reward_info.done
186
+ last_error = obs.error_message
187
+ success = reward_info.success
188
+
189
+ rewards.append(reward)
190
+ steps_taken = step
191
+
192
+ log_step(
193
+ step=step,
194
+ action=action_name,
195
+ reward=reward,
196
+ done=done,
197
+ error=last_error,
198
+ )
199
+
200
+ if done:
201
+ break
202
+
203
+ # Score: clamp sum of rewards to [0, 1]
204
+ total = sum(rewards)
205
+ max_possible = MAX_STEPS * 1.0 # max reward per step is 1.0
206
+ score = min(max(total / max_possible, 0.0), 1.0)
207
+
208
+ finally:
209
+ log_end(
210
+ success=success,
211
+ steps=steps_taken,
212
+ score=score,
213
+ rewards=rewards,
214
+ )
215
+
216
+
217
+ # ── Main ──────────────────────────────────────────────────────────────────────
218
+
219
+ async def main() -> None:
220
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
221
+ env = SQLAgentEnv()
222
+
223
+ for task_id in TASKS:
224
+ await run_episode(env, client, task_id)
225
+ # Small gap between tasks for readability
226
+ print("", flush=True)
227
+
228
+
229
+ if __name__ == "__main__":
230
+ asyncio.run(main())
openenv.yaml ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sql-agent-openenv
2
+ version: "1.0.0"
3
+ description: >
4
+ A SQL generation and repair environment where an AI agent learns to write
5
+ correct SQL queries through a self-debugging loop powered by a LinUCB
6
+ contextual bandit and GEPA prompt evolution. Models real-world data analyst
7
+ workflows β€” querying databases with natural language, handling errors, and
8
+ iteratively improving.
9
+
10
+ author: sql-agent-openenv-team
11
+ tags:
12
+ - openenv
13
+ - sql
14
+ - rl
15
+ - nlp
16
+ - contextual-bandit
17
+
18
+ # ── Endpoints ────────────────────────────────────────────────────────────────
19
+ api:
20
+ reset: /reset
21
+ step: /step
22
+ state: /state
23
+
24
+ # ── Action Space ─────────────────────────────────────────────────────────────
25
+ action_space:
26
+ type: discrete
27
+ n: 9
28
+ actions:
29
+ - name: generate
30
+ description: "Generate SQL from scratch (first attempt)"
31
+ - name: rewrite_full
32
+ description: "Completely rewrite the query from scratch"
33
+ - name: fix_column
34
+ description: "Fix wrong column name references using schema"
35
+ - name: fix_table
36
+ description: "Fix wrong table name references or JOIN structure"
37
+ - name: add_groupby
38
+ description: "Add or fix GROUP BY / aggregation clauses"
39
+ - name: rewrite_cte
40
+ description: "Restructure CTEs or subqueries"
41
+ - name: fix_syntax
42
+ description: "Fix syntax errors (brackets, commas, keywords)"
43
+ - name: change_dialect
44
+ description: "Convert to SQLite-compatible functions"
45
+ - name: relax_filter
46
+ description: "Broaden or remove overly strict WHERE conditions"
47
+
48
+ # ── Observation Space ────────────────────────────────────────────────────────
49
+ observation_space:
50
+ type: dict
51
+ fields:
52
+ - name: question
53
+ type: string
54
+ description: "Natural language question to answer with SQL"
55
+ - name: schema_info
56
+ type: string
57
+ description: "Full database schema (tables, columns, types, FK relationships)"
58
+ - name: current_sql
59
+ type: string
60
+ nullable: true
61
+ description: "The SQL generated on the last attempt (null on first step)"
62
+ - name: error_message
63
+ type: string
64
+ nullable: true
65
+ description: "SQLite error message from the last attempt (null on success)"
66
+ - name: error_class
67
+ type: string
68
+ nullable: true
69
+ description: "Classified error type (e.g. no_such_column, syntax_error)"
70
+ - name: attempt_number
71
+ type: integer
72
+ description: "Current attempt number (0 at reset, increments each step)"
73
+ - name: max_attempts
74
+ type: integer
75
+ description: "Maximum allowed attempts per episode (5)"
76
+ - name: task_id
77
+ type: string
78
+ description: "Active task identifier"
79
+ - name: task_difficulty
80
+ type: string
81
+ description: "Task difficulty level: easy | medium | hard"
82
+
83
+ # ── Reward ───────────────────────────────────────────────────────────────────
84
+ reward:
85
+ range: [-1.5, 1.5]
86
+ description: >
87
+ Shaped reward providing partial progress signals throughout the episode.
88
+ Success on attempt N: 1.0 - 0.1*(N-1).
89
+ Failure step: -0.1 - 0.05*N + severity_improvement_bonus + error_class_change_bonus.
90
+ Penalizes infinite loops (consecutive same error) and rewards convergence toward correct SQL.
91
+
92
+ # ── Tasks ────────────────────────────────────────────────────────────────────
93
+ tasks:
94
+ - id: simple_queries
95
+ name: Simple SQL Queries
96
+ difficulty: easy
97
+ description: >
98
+ Single-table SELECT queries. Agent must retrieve correct rows by applying
99
+ basic filters and projections on the marketplace database.
100
+ question_count: 5
101
+ grader: >
102
+ Checks that required output columns are present and row count falls
103
+ within expected bounds. Attempt penalty not applied.
104
+
105
+ - id: join_queries
106
+ name: SQL Join Queries
107
+ difficulty: medium
108
+ description: >
109
+ Multi-table JOIN queries with GROUP BY and aggregation. Agent must
110
+ correctly join tables and compute aggregates over the marketplace data.
111
+ question_count: 5
112
+ grader: >
113
+ Correct columns + row count score multiplied by (1.0 - 0.1*(attempts-1)).
114
+ Rewards efficient, first-try solutions.
115
+
116
+ - id: complex_queries
117
+ name: Complex SQL Queries
118
+ difficulty: hard
119
+ description: >
120
+ Advanced queries using CTEs, window functions, nested aggregations, and
121
+ multi-level joins. Requires precise SQLite syntax knowledge.
122
+ question_count: 5
123
+ grader: >
124
+ Strict correctness required. Score capped at 0.8 without first-attempt
125
+ bonus. Attempt penalty of 0.1*(attempts-1) applied. Hard tasks genuinely
126
+ challenge frontier models.
127
+
128
+ # ── Environment Metadata ─────────────────────────────────────────────────────
129
+ metadata:
130
+ max_steps_per_episode: 5
131
+ database: SQLite (marketplace schema β€” users, products, orders, reviews, sellers)
132
+ rl_algorithm: LinUCB contextual bandit (feature_dim=20, 8 repair actions)
133
+ prompt_optimizer: GEPA (Generative Evolutionary Prompt Adaptation)
134
+ runtime_estimate_minutes: 5
135
+ compute_requirements:
136
+ vcpu: 2
137
+ memory_gb: 4