sai1912 commited on
Commit
3411777
Β·
verified Β·
1 Parent(s): c272fb3

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. .gitignore +9 -0
  2. app.py +3 -1
  3. train_grpo.py +787 -0
  4. train_rl.md +92 -0
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ .env
2
+ __pycache__/
3
+ *.pyc
4
+ outputs/
5
+ *.log
6
+ .DS_Store
7
+ node_modules/
8
+ dist/
9
+ build/
app.py CHANGED
@@ -293,7 +293,9 @@ def _run_chaos_pipeline(con):
293
  )
294
 
295
  @app.post("/reset", tags=["Environment"])
296
- def reset_episode(req: ResetRequest):
 
 
297
  task_id = req.task_id if req.task_id in TASKS else "task_1_easy"
298
  task = TASKS[task_id]
299
 
 
293
  )
294
 
295
  @app.post("/reset", tags=["Environment"])
296
+ def reset_episode(req: ResetRequest = None):
297
+ if req is None:
298
+ req = ResetRequest()
299
  task_id = req.task_id if req.task_id in TASKS else "task_1_easy"
300
  task = TASKS[task_id]
301
 
train_grpo.py ADDED
@@ -0,0 +1,787 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train_grpo.py β€” Full GRPO training pipeline for SQL Debug & Data Pipeline Repair
3
+ using Qwen/Qwen2.5-Coder-7B-Instruct + TRL GRPOTrainer.
4
+
5
+ Follows the Module 5 pattern from https://github.com/huggingface/openenv-course
6
+
7
+ Pipeline:
8
+ 1. Init environment (local server or HF Space URL)
9
+ 2. Init model & tokenizer (Qwen2.5-Coder-7B-Instruct)
10
+ 3. Define system prompt (rules, response format, strategy, goal)
11
+ 4. Helper functions (prompt builder, SQL extractor)
12
+ 5. Rollout function (plays one full episode against the environment)
13
+ 6. Reward functions (wraps our grader decomposition into TRL callbacks)
14
+ 7. Create dataset (prompts for all 3 tasks Γ— N variants)
15
+ 8. Configure GRPO (GRPOConfig)
16
+ 9. Create GRPOTrainer and train
17
+ 10. Save & push to Hub
18
+ 11. Evaluation loop
19
+
20
+ Usage:
21
+ # Local environment (start server first)
22
+ uvicorn server.app:app --host 0.0.0.0 --port 7860
23
+
24
+ # Training (single GPU A100/H100 recommended)
25
+ python train_grpo.py
26
+
27
+ # With HF Space
28
+ ENV_URL=https://your-username-sql-debug-env.hf.space python train_grpo.py
29
+
30
+ Requirements:
31
+ pip install trl>=0.12.0 transformers>=4.45.0 torch>=2.3.0
32
+ pip install duckdb pandas pydantic requests vllm # for local env
33
+ """
34
+
35
+ from __future__ import annotations
36
+
37
+ import json
38
+ import os
39
+ import re
40
+ import sys
41
+ import time
42
+ from dataclasses import dataclass
43
+ from typing import Any, Dict, List, Optional, Tuple
44
+
45
+ import torch
46
+ from datasets import Dataset
47
+ from transformers import AutoModelForCausalLM, AutoTokenizer
48
+ from trl import GRPOConfig, GRPOTrainer
49
+
50
+ # ── Make local env importable ─────────────────────────────────────────────────
51
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
52
+
53
+ from client import SQLDebugEnv
54
+ from models import SQLDebugAction, SQLDebugObservation
55
+ from server.data import TASKS
56
+
57
+
58
+ # =============================================================================
59
+ # 1. ENVIRONMENT SETUP
60
+ # =============================================================================
61
+
62
+ # Point to your deployed HF Space or local server
63
+ ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
64
+
65
+ # For training we use the local Python environment directly (no HTTP round-trip)
66
+ # This is faster and avoids network latency during rollouts.
67
+ # Switch to SQLDebugEnv(ENV_URL) if you want to use the HTTP server.
68
+ USE_LOCAL_ENV = os.environ.get("USE_LOCAL_ENV", "true").lower() == "true"
69
+
70
+ if USE_LOCAL_ENV:
71
+ from server.environment import SQLDebugEnvironment
72
+ _SHARED_ENV = SQLDebugEnvironment() # single instance, reset() per episode
73
+ else:
74
+ # HTTP client β€” point at your HF Space
75
+ _HTTP_CLIENT = SQLDebugEnv(base_url=ENV_URL)
76
+
77
+
78
+ def get_env():
79
+ """Return the environment handle (local or HTTP)."""
80
+ if USE_LOCAL_ENV:
81
+ return _SHARED_ENV
82
+ return _HTTP_CLIENT
83
+
84
+
85
+ # =============================================================================
86
+ # 2. MODEL & TOKENIZER
87
+ # =============================================================================
88
+
89
+ MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-Coder-7B-Instruct")
90
+ HF_REPO_ID = os.environ.get("HF_REPO_ID", "sai1912/sql-debug-qwen-grpo")
91
+
92
+ print(f"Loading tokenizer: {MODEL_NAME}")
93
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
94
+ if tokenizer.pad_token is None:
95
+ tokenizer.pad_token = tokenizer.eos_token
96
+ tokenizer.padding_side = "left" # Required for decoder-only models in GRPO
97
+
98
+
99
+ # =============================================================================
100
+ # 3. SYSTEM PROMPT β€” Rules, Response Format, Strategy, Goal
101
+ # =============================================================================
102
+
103
+ SYSTEM_PROMPT = """\
104
+ You are an expert SQL debugger and data engineer. Your goal is to diagnose \
105
+ and fix broken SQL queries and ETL pipelines.
106
+
107
+ RULES:
108
+ - Read the broken SQL or pipeline code carefully
109
+ - Study the schema β€” table names, column names, and types matter
110
+ - Look for: syntax errors, wrong aliases, wrong JOIN types, type casting bugs
111
+ - Your fix must produce exactly the correct output described in the task
112
+ - Never use DROP TABLE, DELETE, or TRUNCATE on real data tables
113
+ - Do not repeat the same query if it was already rejected
114
+
115
+ RESPONSE FORMAT:
116
+ Always respond with EXACTLY this structure (no deviation):
117
+
118
+ <think>
119
+ [Your step-by-step diagnosis of the bug. Be explicit about what is wrong and why.]
120
+ </think>
121
+
122
+ ```sql
123
+ [Your complete corrected SQL query here]
124
+ ```
125
+
126
+ EXPLANATION (Task 3 only):
127
+ [One sentence naming the root cause step and why it causes wrong results]
128
+
129
+ STRATEGY:
130
+ - Task 1 (easy): Look for syntax errors (missing commas) and wrong table aliases
131
+ - Task 2 (medium): Check JOIN types β€” INNER JOIN silently drops NULL-keyed rows
132
+ - Task 3 (hard): Trace the timezone handling β€” CAST(ts AS DATE) strips offset
133
+
134
+ GOAL:
135
+ Return a corrected SQL query (Tasks 1/2) or corrected Python pipeline \
136
+ code (Task 3) that produces output matching the ground truth exactly.
137
+ """
138
+
139
+ # Task-specific addendum injected into user messages
140
+ TASK_HINTS = {
141
+ "task1_syntax_fix": (
142
+ "Hint: Check each line of the SELECT clause carefully. "
143
+ "Also verify every table alias used in JOIN conditions matches the FROM clause aliases."
144
+ ),
145
+ "task2_join_aggregation": (
146
+ "Hint: Consider what happens when a JOIN key is NULL. "
147
+ "INNER JOIN silently drops those rows β€” is that correct for this aggregation?"
148
+ ),
149
+ "task3_etl_timezone": (
150
+ "Hint: The timestamps include timezone offsets like '+05:30'. "
151
+ "What does CAST(ts AS DATE) do to that offset? "
152
+ "Which DuckDB type preserves timezone information?"
153
+ ),
154
+ }
155
+
156
+
157
+ # =============================================================================
158
+ # 4. HELPER FUNCTIONS
159
+ # =============================================================================
160
+
161
+ def build_user_message(obs: SQLDebugObservation) -> str:
162
+ """
163
+ Format an observation into a user-turn message.
164
+ Mirrors baseline.py but adds structured context for RL training.
165
+ """
166
+ # Schema block
167
+ schema_lines = []
168
+ for table, cols in obs.schema_info.items():
169
+ col_defs = ", ".join(f"{c['column']} {c['type']}" for c in cols)
170
+ schema_lines.append(f" {table}({col_defs})")
171
+ schema_str = "\n".join(schema_lines)
172
+
173
+ # Code block
174
+ if obs.task_id == "task3_etl_timezone":
175
+ code_block = (
176
+ f"## Broken ETL Pipeline (Python/DuckDB)\n\n"
177
+ f"```python\n{obs.pipeline_code}\n```"
178
+ )
179
+ if obs.intermediate_outputs:
180
+ wrong_output = json.dumps(obs.intermediate_outputs[-1]["rows"][:3], indent=2, default=str)
181
+ code_block += (
182
+ f"\n\n## Step 4 Wrong Output (first 3 rows)\n\n"
183
+ f"```json\n{wrong_output}\n```"
184
+ )
185
+ response_instruction = (
186
+ "Return the COMPLETE corrected Python pipeline code in a "
187
+ "```python ... ``` block. Set EXPLANATION to name the buggy step."
188
+ )
189
+ else:
190
+ code_block = f"## Broken SQL Query\n\n```sql\n{obs.broken_sql}\n```"
191
+ response_instruction = "Return the corrected SQL inside a ```sql ... ``` block."
192
+
193
+ # Previous attempts
194
+ history = ""
195
+ if obs.previous_attempts:
196
+ lines = ["\n## Previous Attempts (learn from these)\n"]
197
+ for a in obs.previous_attempts:
198
+ verdict = "CORRECT" if a.reward >= 1.0 else f"reward={a.reward:.2f}"
199
+ preview = a.fixed_sql[:150].replace("\n", " ")
200
+ lines.append(f" Attempt {a.step} [{verdict}]: {preview}...")
201
+ history = "\n".join(lines)
202
+
203
+ hint = TASK_HINTS.get(obs.task_id, "")
204
+
205
+ return (
206
+ f"## Task ({obs.difficulty.upper()}): {obs.task_id}\n\n"
207
+ f"{obs.task_description}\n\n"
208
+ f"## Database Schema\n\n{schema_str}\n\n"
209
+ f"{code_block}"
210
+ f"{history}\n\n"
211
+ f"## Instruction\n{response_instruction}\n\n"
212
+ f"{hint}"
213
+ ).strip()
214
+
215
+
216
+ def extract_sql_from_response(text: str, is_pipeline: bool = False) -> str:
217
+ """
218
+ Extract the SQL or Python code block from a model response.
219
+ Falls back to the raw text if no code block found.
220
+ """
221
+ lang = "python" if is_pipeline else "sql"
222
+ patterns = [
223
+ rf"```{lang}\s*\n(.*?)```",
224
+ r"```\s*\n(.*?)```",
225
+ r"```(.*?)```",
226
+ ]
227
+ for pattern in patterns:
228
+ m = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
229
+ if m:
230
+ return m.group(1).strip()
231
+ return text.strip()
232
+
233
+
234
+ def extract_explanation(text: str) -> Optional[str]:
235
+ """Extract EXPLANATION section (Task 3 root-cause scoring)."""
236
+ m = re.search(r"EXPLANATION[:\s]+(.*?)(?:```|$)", text, re.DOTALL | re.IGNORECASE)
237
+ if m:
238
+ return m.group(1).strip()[:300]
239
+ # Also check the think block for step identification
240
+ think_m = re.search(r"<think>(.*?)</think>", text, re.DOTALL | re.IGNORECASE)
241
+ if think_m:
242
+ return think_m.group(1).strip()[:300]
243
+ return None
244
+
245
+
246
+ def format_messages(obs: SQLDebugObservation) -> List[Dict[str, str]]:
247
+ """Build the chat message list for the model."""
248
+ return [
249
+ {"role": "system", "content": SYSTEM_PROMPT},
250
+ {"role": "user", "content": build_user_message(obs)},
251
+ ]
252
+
253
+
254
+ # =============================================================================
255
+ # 5. ROLLOUT FUNCTION
256
+ # =============================================================================
257
+
258
+ def generate_rollout_completions(trainer: GRPOTrainer, batch_messages: List[List[Dict]]) -> List[Dict]:
259
+ """
260
+ Generate completions using the current policy model via TRL's built-in
261
+ generate_completions utility (vLLM-backed when use_vllm=True).
262
+
263
+ Returns a list of dicts with keys: 'text', 'prompt_ids', 'completion_ids', 'logprobs'.
264
+ """
265
+ # Tokenize prompts
266
+ texts = [
267
+ tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
268
+ for msgs in batch_messages
269
+ ]
270
+ inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True,
271
+ max_length=2048).to(trainer.model.device)
272
+
273
+ with torch.no_grad():
274
+ output_ids = trainer.model.generate(
275
+ **inputs,
276
+ max_new_tokens=1024,
277
+ temperature=0.8,
278
+ top_p=0.95,
279
+ do_sample=True,
280
+ pad_token_id=tokenizer.eos_token_id,
281
+ )
282
+
283
+ results = []
284
+ for i, (prompt_ids, out_ids) in enumerate(zip(inputs["input_ids"], output_ids)):
285
+ prompt_len = prompt_ids.shape[0]
286
+ completion_ids = out_ids[prompt_len:]
287
+ text = tokenizer.decode(completion_ids, skip_special_tokens=True)
288
+ results.append({
289
+ "text": text,
290
+ "prompt_ids": prompt_ids,
291
+ "completion_ids": completion_ids,
292
+ "logprobs": None, # TRL computes logprobs internally
293
+ })
294
+ return results
295
+
296
+
297
+ def rollout_func(
298
+ trainer: GRPOTrainer,
299
+ batch: Dict[str, Any],
300
+ tokenizer: AutoTokenizer,
301
+ ) -> Dict[str, Any]:
302
+ """
303
+ TRL rollout function. Called by GRPOTrainer during training.
304
+
305
+ Plays one full episode per row in the batch:
306
+ 1. reset() the environment for the task
307
+ 2. Generate a fix with the current policy
308
+ 3. step() the environment
309
+ 4. Repeat up to max_turns (multi-turn RL)
310
+
311
+ Returns a batch-format dict that TRL expects.
312
+ """
313
+ env = get_env()
314
+ max_turns = 3 # 3 attempts per training episode (saves compute)
315
+
316
+ all_prompt_ids = []
317
+ all_completion_ids = []
318
+ all_rewards = []
319
+ all_task_rewards = [] # grade component (no penalties)
320
+
321
+ task_ids: List[str] = batch["task_id"]
322
+
323
+ for task_id in task_ids:
324
+ # ── Episode start ──────────────────────────────────────────────────
325
+ if USE_LOCAL_ENV:
326
+ obs = env.reset(seed=42, task_id=task_id)
327
+ else:
328
+ obs = env.reset(task_id=task_id)
329
+
330
+ episode_prompt_ids = []
331
+ episode_completion_ids = []
332
+ episode_rewards = []
333
+ is_pipeline = (task_id == "task3_etl_timezone")
334
+ done = False
335
+
336
+ for turn in range(max_turns):
337
+ if done:
338
+ break
339
+
340
+ messages = format_messages(obs)
341
+ completions = generate_rollout_completions(trainer, [messages])
342
+ completion = completions[0]
343
+
344
+ fixed_sql = extract_sql_from_response(completion["text"], is_pipeline=is_pipeline)
345
+ explanation = extract_explanation(completion["text"])
346
+
347
+ action = SQLDebugAction(fixed_sql=fixed_sql, explanation=explanation)
348
+
349
+ if USE_LOCAL_ENV:
350
+ obs, reward, done, info = env.step(action)
351
+ else:
352
+ obs, reward, done, info = env.step(action)
353
+
354
+ episode_prompt_ids.append(
355
+ tokenizer(
356
+ tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True),
357
+ return_tensors="pt",
358
+ )["input_ids"][0]
359
+ )
360
+ episode_completion_ids.append(completion["completion_ids"])
361
+ episode_rewards.append(reward)
362
+
363
+ # Use the best reward in the episode as the final signal
364
+ best_reward = max(episode_rewards) if episode_rewards else 0.0
365
+ all_rewards.extend([best_reward] * len(episode_rewards))
366
+ all_prompt_ids.extend(episode_prompt_ids)
367
+ all_completion_ids.extend(episode_completion_ids)
368
+
369
+ # Pad sequences to same length
370
+ max_prompt_len = max(t.shape[0] for t in all_prompt_ids)
371
+ max_comp_len = max(t.shape[0] for t in all_completion_ids)
372
+
373
+ padded_prompts = torch.stack([
374
+ torch.nn.functional.pad(t, (max_prompt_len - t.shape[0], 0), value=tokenizer.pad_token_id)
375
+ for t in all_prompt_ids
376
+ ])
377
+ padded_completions = torch.stack([
378
+ torch.nn.functional.pad(t, (0, max_comp_len - t.shape[0]), value=tokenizer.pad_token_id)
379
+ for t in all_completion_ids
380
+ ])
381
+
382
+ return {
383
+ "prompt_ids": padded_prompts,
384
+ "completion_ids": padded_completions,
385
+ "rewards": torch.tensor(all_rewards, dtype=torch.float32),
386
+ }
387
+
388
+
389
+ # =============================================================================
390
+ # 6. REWARD FUNCTIONS (TRL-style callbacks)
391
+ # =============================================================================
392
+ # TRL's GRPOTrainer can accept multiple reward_funcs. Each receives
393
+ # (completions, prompts, **kwargs) and returns a list of floats.
394
+ # We use our grader decomposition to provide multi-signal training.
395
+
396
+ def _run_grader(completion_text: str, task_id: str, is_pipeline: bool) -> Dict[str, float]:
397
+ """Run the environment grader and return breakdown dict."""
398
+ import duckdb as _duckdb
399
+ from server.data import TASK_MAP
400
+ from server.graders import grade_task1, grade_task2, grade_task3
401
+
402
+ task = TASK_MAP[task_id]
403
+ con = _duckdb.connect(":memory:")
404
+ con.execute(task.schema_ddl)
405
+ con.execute(task.seed_sql)
406
+ gt_df = con.execute(task.ground_truth_query).df()
407
+
408
+ fixed = extract_sql_from_response(completion_text, is_pipeline=is_pipeline)
409
+ explanation = extract_explanation(completion_text)
410
+
411
+ try:
412
+ if task_id == "task1_syntax_fix":
413
+ score, breakdown = grade_task1(fixed, gt_df, con)
414
+ elif task_id == "task2_join_aggregation":
415
+ score, breakdown = grade_task2(fixed, gt_df, con)
416
+ elif task_id == "task3_etl_timezone":
417
+ con.execute(task.schema_ddl)
418
+ con.execute(task.seed_sql)
419
+ score, breakdown = grade_task3(fixed, gt_df, con, explanation)
420
+ else:
421
+ score, breakdown = 0.0, {}
422
+ except Exception:
423
+ score, breakdown = 0.0, {}
424
+ finally:
425
+ con.close()
426
+
427
+ return {"score": score, **breakdown}
428
+
429
+
430
+ def reward_correctness(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
431
+ """
432
+ Primary reward: overall grader score (0.0–1.0).
433
+ This is the dense, decomposed score from our grader.
434
+ """
435
+ task_ids: List[str] = kwargs.get("task_id", ["task1_syntax_fix"] * len(completions))
436
+ rewards = []
437
+ for text, task_id in zip(completions, task_ids):
438
+ is_pipeline = (task_id == "task3_etl_timezone")
439
+ result = _run_grader(text, task_id, is_pipeline)
440
+ rewards.append(result["score"])
441
+ return rewards
442
+
443
+
444
+ def reward_parses(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
445
+ """
446
+ Shaping reward: did the SQL parse? (+0.1 bonus).
447
+ Encourages the model to produce syntactically valid SQL even when
448
+ semantics are wrong β€” important early in training.
449
+ """
450
+ task_ids: List[str] = kwargs.get("task_id", ["task1_syntax_fix"] * len(completions))
451
+ rewards = []
452
+ for text, task_id in zip(completions, task_ids):
453
+ is_pipeline = (task_id == "task3_etl_timezone")
454
+ result = _run_grader(text, task_id, is_pipeline)
455
+ rewards.append(result.get("parses", 0.0))
456
+ return rewards
457
+
458
+
459
+ def reward_format(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
460
+ """
461
+ Format reward: did the model use a ```sql ... ``` block?
462
+ This teaches the model the required response format.
463
+ """
464
+ rewards = []
465
+ task_ids: List[str] = kwargs.get("task_id", ["task1_syntax_fix"] * len(completions))
466
+ for text, task_id in zip(completions, task_ids):
467
+ lang = "python" if task_id == "task3_etl_timezone" else "sql"
468
+ has_block = bool(re.search(rf"```{lang}", text, re.IGNORECASE))
469
+ has_think = bool(re.search(r"<think>.*?</think>", text, re.DOTALL))
470
+ score = (0.5 if has_block else 0.0) + (0.5 if has_think else 0.0)
471
+ rewards.append(score)
472
+ return rewards
473
+
474
+
475
+ def reward_no_repetition(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
476
+ """
477
+ Penalise repetitive/trivial outputs (empty or < 10 chars of code).
478
+ """
479
+ rewards = []
480
+ task_ids: List[str] = kwargs.get("task_id", ["task1_syntax_fix"] * len(completions))
481
+ for text, task_id in zip(completions, task_ids):
482
+ is_pipeline = (task_id == "task3_etl_timezone")
483
+ code = extract_sql_from_response(text, is_pipeline=is_pipeline)
484
+ penalty = -0.3 if len(code) < 10 else 0.0
485
+ rewards.append(penalty)
486
+ return rewards
487
+
488
+
489
+ # =============================================================================
490
+ # 7. CREATE DATASET
491
+ # =============================================================================
492
+
493
+ def create_training_dataset(n_repeats: int = 50) -> Dataset:
494
+ """
495
+ Build a training dataset from the 3 tasks.
496
+ Each task is repeated n_repeats times so the model sees diverse episodes.
497
+ The 'prompt' column is a pre-tokenised chat string; 'task_id' is metadata
498
+ passed through to reward functions via kwargs.
499
+ """
500
+ env = get_env()
501
+ rows = []
502
+
503
+ for task in TASKS:
504
+ obs = env.reset(seed=42, task_id=task.task_id) if USE_LOCAL_ENV else env.reset(task_id=task.task_id)
505
+ messages = format_messages(obs)
506
+ prompt_text = tokenizer.apply_chat_template(
507
+ messages, tokenize=False, add_generation_prompt=True
508
+ )
509
+
510
+ for i in range(n_repeats):
511
+ rows.append({
512
+ "prompt": prompt_text,
513
+ "task_id": task.task_id,
514
+ "difficulty": task.difficulty,
515
+ # Seed varies so GRPO sees slightly different phrasings across epochs
516
+ "seed": 42 + i,
517
+ })
518
+
519
+ dataset = Dataset.from_list(rows)
520
+ print(f"Dataset created: {len(dataset)} rows "
521
+ f"({n_repeats} Γ— {len(TASKS)} tasks)")
522
+ return dataset
523
+
524
+
525
+ # =============================================================================
526
+ # 8. CONFIGURE GRPO TRAINING
527
+ # =============================================================================
528
+
529
+ def get_grpo_config(output_dir: str = "./sql-debug-qwen-grpo") -> GRPOConfig:
530
+ """
531
+ Return a GRPOConfig tuned for Qwen2.5-Coder-7B on a single A100/H100 40GB.
532
+ Reduce per_device_train_batch_size and num_generations for smaller GPUs.
533
+ """
534
+ return GRPOConfig(
535
+ # ── Output ──────────────────────────────────────────────────────────
536
+ output_dir=output_dir,
537
+ run_name="sql-debug-grpo-qwen25coder7b",
538
+
539
+ # ── Training schedule ───────────────────────────────────────────────
540
+ num_train_epochs=3,
541
+ learning_rate=5e-6,
542
+ lr_scheduler_type="cosine",
543
+ warmup_ratio=0.05,
544
+
545
+ # ── Batch & memory ──────────────────────────────────────────────────
546
+ per_device_train_batch_size=1,
547
+ gradient_accumulation_steps=8, # effective batch = 8
548
+ gradient_checkpointing=True,
549
+ bf16=True,
550
+
551
+ # ── GRPO-specific ────────────────────────────────────────────────────
552
+ num_generations=4, # G: candidates per prompt to compare
553
+ max_prompt_length=2048,
554
+ max_completion_length=1024, # SQL fixes can be verbose
555
+
556
+ # ── vLLM for fast generation (requires vllm package) ─────────────────
557
+ # Set use_vllm=False if not using vLLM (much slower but works on any GPU)
558
+ use_vllm=False, # set True on A100+ with vllm installed
559
+ # vllm_mode="colocate",
560
+ # vllm_gpu_memory_utilization=0.2,
561
+
562
+ # ── Logging ──────────────────────────────────────────────────────────
563
+ logging_steps=5,
564
+ save_steps=50,
565
+ eval_steps=50,
566
+ report_to="none", # set "wandb" or "tensorboard" as needed
567
+
568
+ # ── Hub ───────────────────────────────────────────────────────────────
569
+ push_to_hub=False, # set True to auto-push checkpoints
570
+ hub_model_id=HF_REPO_ID,
571
+ )
572
+
573
+
574
+ # =============================================================================
575
+ # 9. CREATE TRAINER & TRAIN
576
+ # =============================================================================
577
+
578
+ def build_trainer(
579
+ dataset: Dataset,
580
+ grpo_config: GRPOConfig,
581
+ ) -> GRPOTrainer:
582
+ """
583
+ Instantiate GRPOTrainer with:
584
+ - The base model (Qwen2.5-Coder-7B-Instruct)
585
+ - 3 reward functions (correctness, format, no-repetition)
586
+ - The rollout function that drives environment interaction
587
+ - The training dataset
588
+ """
589
+ trainer = GRPOTrainer(
590
+ model=MODEL_NAME,
591
+ # Multiple reward functions β€” TRL sums them with equal weight by default.
592
+ # You can pass reward_weights=[0.7, 0.2, 0.1] to control contribution.
593
+ reward_funcs=[
594
+ reward_correctness, # primary: correctness score 0.0–1.0
595
+ reward_format, # shaping: forces <think> + ```sql``` format
596
+ reward_no_repetition, # penalty: discourages trivial empty outputs
597
+ ],
598
+ reward_weights=[0.7, 0.2, 0.1],
599
+ args=grpo_config,
600
+ train_dataset=dataset,
601
+ processing_class=tokenizer,
602
+ # rollout_func: commented out here because TRL β‰₯0.12 uses reward_funcs
603
+ # directly for non-interactive tasks. Use rollout_func for multi-turn.
604
+ # rollout_func=rollout_func, # uncomment for multi-turn RL
605
+ )
606
+ return trainer
607
+
608
+
609
+ def train(n_repeats: int = 50):
610
+ """Main training entry point."""
611
+ print("=" * 60)
612
+ print(f"Model: {MODEL_NAME}")
613
+ print(f"Env URL: {ENV_URL if not USE_LOCAL_ENV else 'local'}")
614
+ print(f"Tasks: {[t.task_id for t in TASKS]}")
615
+ print("=" * 60)
616
+
617
+ dataset = create_training_dataset(n_repeats=n_repeats)
618
+ grpo_config = get_grpo_config()
619
+ trainer = build_trainer(dataset, grpo_config)
620
+
621
+ print("\nStarting GRPO training…")
622
+ trainer.train()
623
+
624
+ return trainer
625
+
626
+
627
+ # =============================================================================
628
+ # 10. SAVE & PUSH TO HUB
629
+ # =============================================================================
630
+
631
+ def save_and_push(trainer: GRPOTrainer, output_dir: str = "./sql-debug-qwen-grpo"):
632
+ """Save the trained model locally and optionally push to the Hub."""
633
+ print(f"\nSaving model to {output_dir}…")
634
+ trainer.save_model(output_dir)
635
+ tokenizer.save_pretrained(output_dir)
636
+
637
+ push = os.environ.get("PUSH_TO_HUB", "false").lower() == "true"
638
+ if push:
639
+ print(f"Pushing to Hub: {HF_REPO_ID}")
640
+ trainer.push_to_hub(
641
+ repo_id=HF_REPO_ID,
642
+ commit_message="GRPO-trained SQL debug model",
643
+ )
644
+ print(f"Model available at: https://huggingface.co/{HF_REPO_ID}")
645
+ else:
646
+ print(f"Set PUSH_TO_HUB=true to push to {HF_REPO_ID}")
647
+
648
+
649
+ # =============================================================================
650
+ # 11. EVALUATION
651
+ # =============================================================================
652
+
653
+ @dataclass
654
+ class EvalResult:
655
+ task_id: str
656
+ difficulty: str
657
+ n_episodes: int
658
+ mean_reward: float
659
+ best_reward: float
660
+ n_solved: int # episodes with reward >= 1.0
661
+
662
+
663
+ def evaluate(
664
+ model_path: str = "./sql-debug-qwen-grpo",
665
+ n_episodes: int = 10,
666
+ max_steps: int = 5,
667
+ ) -> List[EvalResult]:
668
+ """
669
+ Evaluate the trained model against all 3 tasks.
670
+ Loads the fine-tuned model and runs n_episodes per task.
671
+ """
672
+ print(f"\n{'='*60}\nEVALUATION β€” {model_path}\n{'='*60}")
673
+
674
+ eval_model = AutoModelForCausalLM.from_pretrained(
675
+ model_path,
676
+ torch_dtype=torch.bfloat16,
677
+ device_map="auto",
678
+ trust_remote_code=True,
679
+ )
680
+ eval_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
681
+ eval_model.eval()
682
+
683
+ env = get_env()
684
+ results: List[EvalResult] = []
685
+
686
+ for task in TASKS:
687
+ episode_rewards = []
688
+ n_solved = 0
689
+
690
+ for ep in range(n_episodes):
691
+ seed = 1000 + ep # different seeds from training
692
+ obs = env.reset(seed=seed, task_id=task.task_id) if USE_LOCAL_ENV \
693
+ else env.reset(task_id=task.task_id)
694
+
695
+ best_reward = 0.0
696
+ done = False
697
+ is_pipeline = (task.task_id == "task3_etl_timezone")
698
+
699
+ for step in range(max_steps):
700
+ if done:
701
+ break
702
+
703
+ messages = format_messages(obs)
704
+ prompt_text = eval_tokenizer.apply_chat_template(
705
+ messages, tokenize=False, add_generation_prompt=True
706
+ )
707
+ inputs = eval_tokenizer(
708
+ prompt_text, return_tensors="pt", truncation=True, max_length=2048
709
+ ).to(eval_model.device)
710
+
711
+ with torch.no_grad():
712
+ output_ids = eval_model.generate(
713
+ **inputs,
714
+ max_new_tokens=1024,
715
+ temperature=0.0, # greedy for eval
716
+ do_sample=False,
717
+ pad_token_id=eval_tokenizer.eos_token_id,
718
+ )
719
+
720
+ prompt_len = inputs["input_ids"].shape[1]
721
+ completion = eval_tokenizer.decode(
722
+ output_ids[0][prompt_len:], skip_special_tokens=True
723
+ )
724
+
725
+ fixed_sql = extract_sql_from_response(completion, is_pipeline=is_pipeline)
726
+ explanation = extract_explanation(completion)
727
+ action = SQLDebugAction(fixed_sql=fixed_sql, explanation=explanation)
728
+
729
+ obs, reward, done, info = env.step(action) if USE_LOCAL_ENV \
730
+ else env.step(action)
731
+
732
+ best_reward = max(best_reward, reward)
733
+
734
+ episode_rewards.append(best_reward)
735
+ if best_reward >= 1.0:
736
+ n_solved += 1
737
+
738
+ mean_r = sum(episode_rewards) / len(episode_rewards)
739
+ best_r = max(episode_rewards)
740
+
741
+ result = EvalResult(
742
+ task_id=task.task_id,
743
+ difficulty=task.difficulty,
744
+ n_episodes=n_episodes,
745
+ mean_reward=round(mean_r, 4),
746
+ best_reward=round(best_r, 4),
747
+ n_solved=n_solved,
748
+ )
749
+ results.append(result)
750
+ print(f" {task.task_id:40s} mean={mean_r:.4f} best={best_r:.4f} "
751
+ f"solved={n_solved}/{n_episodes}")
752
+
753
+ # Write evaluation report
754
+ report = {
755
+ "model": model_path,
756
+ "n_episodes": n_episodes,
757
+ "tasks": [r.__dict__ for r in results],
758
+ }
759
+ os.makedirs("outputs/evals", exist_ok=True)
760
+ report_path = f"outputs/evals/eval_{int(time.time())}.json"
761
+ with open(report_path, "w") as f:
762
+ json.dump(report, f, indent=2)
763
+ print(f"\nEval report saved: {report_path}")
764
+
765
+ return results
766
+
767
+
768
+ # =============================================================================
769
+ # ENTRY POINT
770
+ # =============================================================================
771
+
772
+ if __name__ == "__main__":
773
+ import argparse
774
+
775
+ parser = argparse.ArgumentParser(description="GRPO training for SQL Debug environment")
776
+ parser.add_argument("--mode", choices=["train", "eval", "both"], default="train")
777
+ parser.add_argument("--n-repeats", type=int, default=50, help="Dataset repeats per task")
778
+ parser.add_argument("--n-episodes", type=int, default=10, help="Eval episodes per task")
779
+ parser.add_argument("--output-dir", default="./sql-debug-qwen-grpo")
780
+ args = parser.parse_args()
781
+
782
+ if args.mode in ("train", "both"):
783
+ trainer = train(n_repeats=args.n_repeats)
784
+ save_and_push(trainer, output_dir=args.output_dir)
785
+
786
+ if args.mode in ("eval", "both"):
787
+ evaluate(model_path=args.output_dir, n_episodes=args.n_episodes)
train_rl.md ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RL Training for SQL Debug β€” GRPO with Qwen2.5-Coder-7B-Instruct
2
+
3
+ > **Full training script:** [`train_grpo.py`](train_grpo.py)
4
+ > **HF Space deployment:** [`deploy_hf_space.md`](deploy_hf_space.md)
5
+
6
+ ---
7
+
8
+ ## Why GRPO, Not DDPG
9
+
10
+ | | DDPG | GRPO |
11
+ |---|---|---|
12
+ | Action space | Continuous R^n | Discrete tokens βœ… |
13
+ | Value network | Required | Not needed βœ… |
14
+ | Gradient signal | Bellman + actor-critic | Group relative ranking βœ… |
15
+ | Works for SQL? | ❌ | βœ… |
16
+
17
+ DDPG is for robot control / trading. SQL token generation is discrete β€” **always use GRPO or PPO**.
18
+
19
+ ---
20
+
21
+ ## What `train_grpo.py` Contains
22
+
23
+ | Section | Description |
24
+ |---|---|
25
+ | 1. Environment | Local DuckDB env or HTTP client pointing at HF Space |
26
+ | 2. Model & Tokenizer | `Qwen/Qwen2.5-Coder-7B-Instruct`, left-padding |
27
+ | 3. System Prompt | Rules, Response Format (`<think>` + ```sql```), Strategy, Goal |
28
+ | 4. Helpers | `build_user_message()`, `extract_sql_from_response()`, `format_messages()` |
29
+ | 5. Rollout | `rollout_func()` β€” plays multi-turn episode, returns padded tensors |
30
+ | 6. Reward Fns | `reward_correctness`, `reward_format`, `reward_no_repetition` |
31
+ | 7. Dataset | 3 tasks Γ— N repeats β†’ HF `Dataset` with `prompt` + `task_id` columns |
32
+ | 8. GRPOConfig | A100-tuned: `num_generations=4`, `bf16=True`, `max_completion_length=1024` |
33
+ | 9. Trainer | `GRPOTrainer` with `reward_weights=[0.7, 0.2, 0.1]` |
34
+ | 10. Save & Push | `trainer.save_model()` + `push_to_hub()` when `PUSH_TO_HUB=true` |
35
+ | 11. Evaluation | Greedy decode, 10 episodes/task, JSON report in `outputs/evals/` |
36
+
37
+ ---
38
+
39
+ ## Quick Start
40
+
41
+ ```powershell
42
+ # Install
43
+ pip install trl>=0.12.0 transformers>=4.45.0 torch>=2.3.0 duckdb pandas pydantic
44
+
45
+ # Start local server (terminal 1)
46
+ uvicorn server.app:app --host 0.0.0.0 --port 7860
47
+
48
+ # Train (terminal 2)
49
+ python train_grpo.py --mode train --n-repeats 50
50
+
51
+ # Evaluate trained model
52
+ python train_grpo.py --mode eval --output-dir ./sql-debug-qwen-grpo
53
+
54
+ # Train + eval in one command
55
+ python train_grpo.py --mode both
56
+ ```
57
+
58
+ ---
59
+
60
+ ## System Prompt Structure
61
+
62
+ ```
63
+ RULES β€” what the agent must/must not do
64
+ RESPONSE FORMAT β€” <think>...</think> then ```sql...```
65
+ STRATEGY β€” task-specific hints (syntax / JOIN type / timezone)
66
+ GOAL β€” produce output matching the ground truth exactly
67
+ ```
68
+
69
+ The `<think>` block is critical β€” it teaches chain-of-thought diagnosis before emitting the fix.
70
+
71
+ ---
72
+
73
+ ## Reward Weights
74
+
75
+ ```python
76
+ reward_weights = [0.7, 0.2, 0.1]
77
+ # 0.7 Γ— reward_correctness (dense 0.0–1.0 from grader)
78
+ # 0.2 Γ— reward_format (<think> block + ```sql``` present)
79
+ # 0.1 Γ— reward_no_repetition (penalty for trivial empty output)
80
+ ```
81
+
82
+ ---
83
+
84
+ ## Expected Outcomes After Training
85
+
86
+ | Task | Before (GPT-4o-mini baseline) | After GRPO (estimated) |
87
+ |---|---|---|
88
+ | task1_syntax_fix | ~0.85 | ~0.95 |
89
+ | task2_join_aggregation | ~0.55 | ~0.75 |
90
+ | task3_etl_timezone | ~0.25 | ~0.50 |
91
+
92
+ Use curriculum (train on Task 1+2 first, then add Task 3) for better Hard task improvement.