Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gitignore +9 -0
- app.py +3 -1
- train_grpo.py +787 -0
- train_rl.md +92 -0
.gitignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
outputs/
|
| 5 |
+
*.log
|
| 6 |
+
.DS_Store
|
| 7 |
+
node_modules/
|
| 8 |
+
dist/
|
| 9 |
+
build/
|
app.py
CHANGED
|
@@ -293,7 +293,9 @@ def _run_chaos_pipeline(con):
|
|
| 293 |
)
|
| 294 |
|
| 295 |
@app.post("/reset", tags=["Environment"])
|
| 296 |
-
def reset_episode(req: ResetRequest):
|
|
|
|
|
|
|
| 297 |
task_id = req.task_id if req.task_id in TASKS else "task_1_easy"
|
| 298 |
task = TASKS[task_id]
|
| 299 |
|
|
|
|
| 293 |
)
|
| 294 |
|
| 295 |
@app.post("/reset", tags=["Environment"])
|
| 296 |
+
def reset_episode(req: ResetRequest = None):
|
| 297 |
+
if req is None:
|
| 298 |
+
req = ResetRequest()
|
| 299 |
task_id = req.task_id if req.task_id in TASKS else "task_1_easy"
|
| 300 |
task = TASKS[task_id]
|
| 301 |
|
train_grpo.py
ADDED
|
@@ -0,0 +1,787 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
train_grpo.py β Full GRPO training pipeline for SQL Debug & Data Pipeline Repair
|
| 3 |
+
using Qwen/Qwen2.5-Coder-7B-Instruct + TRL GRPOTrainer.
|
| 4 |
+
|
| 5 |
+
Follows the Module 5 pattern from https://github.com/huggingface/openenv-course
|
| 6 |
+
|
| 7 |
+
Pipeline:
|
| 8 |
+
1. Init environment (local server or HF Space URL)
|
| 9 |
+
2. Init model & tokenizer (Qwen2.5-Coder-7B-Instruct)
|
| 10 |
+
3. Define system prompt (rules, response format, strategy, goal)
|
| 11 |
+
4. Helper functions (prompt builder, SQL extractor)
|
| 12 |
+
5. Rollout function (plays one full episode against the environment)
|
| 13 |
+
6. Reward functions (wraps our grader decomposition into TRL callbacks)
|
| 14 |
+
7. Create dataset (prompts for all 3 tasks Γ N variants)
|
| 15 |
+
8. Configure GRPO (GRPOConfig)
|
| 16 |
+
9. Create GRPOTrainer and train
|
| 17 |
+
10. Save & push to Hub
|
| 18 |
+
11. Evaluation loop
|
| 19 |
+
|
| 20 |
+
Usage:
|
| 21 |
+
# Local environment (start server first)
|
| 22 |
+
uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 23 |
+
|
| 24 |
+
# Training (single GPU A100/H100 recommended)
|
| 25 |
+
python train_grpo.py
|
| 26 |
+
|
| 27 |
+
# With HF Space
|
| 28 |
+
ENV_URL=https://your-username-sql-debug-env.hf.space python train_grpo.py
|
| 29 |
+
|
| 30 |
+
Requirements:
|
| 31 |
+
pip install trl>=0.12.0 transformers>=4.45.0 torch>=2.3.0
|
| 32 |
+
pip install duckdb pandas pydantic requests vllm # for local env
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
from __future__ import annotations
|
| 36 |
+
|
| 37 |
+
import json
|
| 38 |
+
import os
|
| 39 |
+
import re
|
| 40 |
+
import sys
|
| 41 |
+
import time
|
| 42 |
+
from dataclasses import dataclass
|
| 43 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 44 |
+
|
| 45 |
+
import torch
|
| 46 |
+
from datasets import Dataset
|
| 47 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 48 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 49 |
+
|
| 50 |
+
# ββ Make local env importable βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 51 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 52 |
+
|
| 53 |
+
from client import SQLDebugEnv
|
| 54 |
+
from models import SQLDebugAction, SQLDebugObservation
|
| 55 |
+
from server.data import TASKS
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# =============================================================================
|
| 59 |
+
# 1. ENVIRONMENT SETUP
|
| 60 |
+
# =============================================================================
|
| 61 |
+
|
| 62 |
+
# Point to your deployed HF Space or local server
|
| 63 |
+
ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
|
| 64 |
+
|
| 65 |
+
# For training we use the local Python environment directly (no HTTP round-trip)
|
| 66 |
+
# This is faster and avoids network latency during rollouts.
|
| 67 |
+
# Switch to SQLDebugEnv(ENV_URL) if you want to use the HTTP server.
|
| 68 |
+
USE_LOCAL_ENV = os.environ.get("USE_LOCAL_ENV", "true").lower() == "true"
|
| 69 |
+
|
| 70 |
+
if USE_LOCAL_ENV:
|
| 71 |
+
from server.environment import SQLDebugEnvironment
|
| 72 |
+
_SHARED_ENV = SQLDebugEnvironment() # single instance, reset() per episode
|
| 73 |
+
else:
|
| 74 |
+
# HTTP client β point at your HF Space
|
| 75 |
+
_HTTP_CLIENT = SQLDebugEnv(base_url=ENV_URL)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_env():
|
| 79 |
+
"""Return the environment handle (local or HTTP)."""
|
| 80 |
+
if USE_LOCAL_ENV:
|
| 81 |
+
return _SHARED_ENV
|
| 82 |
+
return _HTTP_CLIENT
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# =============================================================================
|
| 86 |
+
# 2. MODEL & TOKENIZER
|
| 87 |
+
# =============================================================================
|
| 88 |
+
|
| 89 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-Coder-7B-Instruct")
|
| 90 |
+
HF_REPO_ID = os.environ.get("HF_REPO_ID", "sai1912/sql-debug-qwen-grpo")
|
| 91 |
+
|
| 92 |
+
print(f"Loading tokenizer: {MODEL_NAME}")
|
| 93 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
| 94 |
+
if tokenizer.pad_token is None:
|
| 95 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 96 |
+
tokenizer.padding_side = "left" # Required for decoder-only models in GRPO
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# =============================================================================
|
| 100 |
+
# 3. SYSTEM PROMPT β Rules, Response Format, Strategy, Goal
|
| 101 |
+
# =============================================================================
|
| 102 |
+
|
| 103 |
+
SYSTEM_PROMPT = """\
|
| 104 |
+
You are an expert SQL debugger and data engineer. Your goal is to diagnose \
|
| 105 |
+
and fix broken SQL queries and ETL pipelines.
|
| 106 |
+
|
| 107 |
+
RULES:
|
| 108 |
+
- Read the broken SQL or pipeline code carefully
|
| 109 |
+
- Study the schema β table names, column names, and types matter
|
| 110 |
+
- Look for: syntax errors, wrong aliases, wrong JOIN types, type casting bugs
|
| 111 |
+
- Your fix must produce exactly the correct output described in the task
|
| 112 |
+
- Never use DROP TABLE, DELETE, or TRUNCATE on real data tables
|
| 113 |
+
- Do not repeat the same query if it was already rejected
|
| 114 |
+
|
| 115 |
+
RESPONSE FORMAT:
|
| 116 |
+
Always respond with EXACTLY this structure (no deviation):
|
| 117 |
+
|
| 118 |
+
<think>
|
| 119 |
+
[Your step-by-step diagnosis of the bug. Be explicit about what is wrong and why.]
|
| 120 |
+
</think>
|
| 121 |
+
|
| 122 |
+
```sql
|
| 123 |
+
[Your complete corrected SQL query here]
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
EXPLANATION (Task 3 only):
|
| 127 |
+
[One sentence naming the root cause step and why it causes wrong results]
|
| 128 |
+
|
| 129 |
+
STRATEGY:
|
| 130 |
+
- Task 1 (easy): Look for syntax errors (missing commas) and wrong table aliases
|
| 131 |
+
- Task 2 (medium): Check JOIN types β INNER JOIN silently drops NULL-keyed rows
|
| 132 |
+
- Task 3 (hard): Trace the timezone handling β CAST(ts AS DATE) strips offset
|
| 133 |
+
|
| 134 |
+
GOAL:
|
| 135 |
+
Return a corrected SQL query (Tasks 1/2) or corrected Python pipeline \
|
| 136 |
+
code (Task 3) that produces output matching the ground truth exactly.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
# Task-specific addendum injected into user messages
|
| 140 |
+
TASK_HINTS = {
|
| 141 |
+
"task1_syntax_fix": (
|
| 142 |
+
"Hint: Check each line of the SELECT clause carefully. "
|
| 143 |
+
"Also verify every table alias used in JOIN conditions matches the FROM clause aliases."
|
| 144 |
+
),
|
| 145 |
+
"task2_join_aggregation": (
|
| 146 |
+
"Hint: Consider what happens when a JOIN key is NULL. "
|
| 147 |
+
"INNER JOIN silently drops those rows β is that correct for this aggregation?"
|
| 148 |
+
),
|
| 149 |
+
"task3_etl_timezone": (
|
| 150 |
+
"Hint: The timestamps include timezone offsets like '+05:30'. "
|
| 151 |
+
"What does CAST(ts AS DATE) do to that offset? "
|
| 152 |
+
"Which DuckDB type preserves timezone information?"
|
| 153 |
+
),
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# =============================================================================
|
| 158 |
+
# 4. HELPER FUNCTIONS
|
| 159 |
+
# =============================================================================
|
| 160 |
+
|
| 161 |
+
def build_user_message(obs: SQLDebugObservation) -> str:
|
| 162 |
+
"""
|
| 163 |
+
Format an observation into a user-turn message.
|
| 164 |
+
Mirrors baseline.py but adds structured context for RL training.
|
| 165 |
+
"""
|
| 166 |
+
# Schema block
|
| 167 |
+
schema_lines = []
|
| 168 |
+
for table, cols in obs.schema_info.items():
|
| 169 |
+
col_defs = ", ".join(f"{c['column']} {c['type']}" for c in cols)
|
| 170 |
+
schema_lines.append(f" {table}({col_defs})")
|
| 171 |
+
schema_str = "\n".join(schema_lines)
|
| 172 |
+
|
| 173 |
+
# Code block
|
| 174 |
+
if obs.task_id == "task3_etl_timezone":
|
| 175 |
+
code_block = (
|
| 176 |
+
f"## Broken ETL Pipeline (Python/DuckDB)\n\n"
|
| 177 |
+
f"```python\n{obs.pipeline_code}\n```"
|
| 178 |
+
)
|
| 179 |
+
if obs.intermediate_outputs:
|
| 180 |
+
wrong_output = json.dumps(obs.intermediate_outputs[-1]["rows"][:3], indent=2, default=str)
|
| 181 |
+
code_block += (
|
| 182 |
+
f"\n\n## Step 4 Wrong Output (first 3 rows)\n\n"
|
| 183 |
+
f"```json\n{wrong_output}\n```"
|
| 184 |
+
)
|
| 185 |
+
response_instruction = (
|
| 186 |
+
"Return the COMPLETE corrected Python pipeline code in a "
|
| 187 |
+
"```python ... ``` block. Set EXPLANATION to name the buggy step."
|
| 188 |
+
)
|
| 189 |
+
else:
|
| 190 |
+
code_block = f"## Broken SQL Query\n\n```sql\n{obs.broken_sql}\n```"
|
| 191 |
+
response_instruction = "Return the corrected SQL inside a ```sql ... ``` block."
|
| 192 |
+
|
| 193 |
+
# Previous attempts
|
| 194 |
+
history = ""
|
| 195 |
+
if obs.previous_attempts:
|
| 196 |
+
lines = ["\n## Previous Attempts (learn from these)\n"]
|
| 197 |
+
for a in obs.previous_attempts:
|
| 198 |
+
verdict = "CORRECT" if a.reward >= 1.0 else f"reward={a.reward:.2f}"
|
| 199 |
+
preview = a.fixed_sql[:150].replace("\n", " ")
|
| 200 |
+
lines.append(f" Attempt {a.step} [{verdict}]: {preview}...")
|
| 201 |
+
history = "\n".join(lines)
|
| 202 |
+
|
| 203 |
+
hint = TASK_HINTS.get(obs.task_id, "")
|
| 204 |
+
|
| 205 |
+
return (
|
| 206 |
+
f"## Task ({obs.difficulty.upper()}): {obs.task_id}\n\n"
|
| 207 |
+
f"{obs.task_description}\n\n"
|
| 208 |
+
f"## Database Schema\n\n{schema_str}\n\n"
|
| 209 |
+
f"{code_block}"
|
| 210 |
+
f"{history}\n\n"
|
| 211 |
+
f"## Instruction\n{response_instruction}\n\n"
|
| 212 |
+
f"{hint}"
|
| 213 |
+
).strip()
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def extract_sql_from_response(text: str, is_pipeline: bool = False) -> str:
|
| 217 |
+
"""
|
| 218 |
+
Extract the SQL or Python code block from a model response.
|
| 219 |
+
Falls back to the raw text if no code block found.
|
| 220 |
+
"""
|
| 221 |
+
lang = "python" if is_pipeline else "sql"
|
| 222 |
+
patterns = [
|
| 223 |
+
rf"```{lang}\s*\n(.*?)```",
|
| 224 |
+
r"```\s*\n(.*?)```",
|
| 225 |
+
r"```(.*?)```",
|
| 226 |
+
]
|
| 227 |
+
for pattern in patterns:
|
| 228 |
+
m = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
|
| 229 |
+
if m:
|
| 230 |
+
return m.group(1).strip()
|
| 231 |
+
return text.strip()
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def extract_explanation(text: str) -> Optional[str]:
|
| 235 |
+
"""Extract EXPLANATION section (Task 3 root-cause scoring)."""
|
| 236 |
+
m = re.search(r"EXPLANATION[:\s]+(.*?)(?:```|$)", text, re.DOTALL | re.IGNORECASE)
|
| 237 |
+
if m:
|
| 238 |
+
return m.group(1).strip()[:300]
|
| 239 |
+
# Also check the think block for step identification
|
| 240 |
+
think_m = re.search(r"<think>(.*?)</think>", text, re.DOTALL | re.IGNORECASE)
|
| 241 |
+
if think_m:
|
| 242 |
+
return think_m.group(1).strip()[:300]
|
| 243 |
+
return None
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def format_messages(obs: SQLDebugObservation) -> List[Dict[str, str]]:
|
| 247 |
+
"""Build the chat message list for the model."""
|
| 248 |
+
return [
|
| 249 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 250 |
+
{"role": "user", "content": build_user_message(obs)},
|
| 251 |
+
]
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# =============================================================================
|
| 255 |
+
# 5. ROLLOUT FUNCTION
|
| 256 |
+
# =============================================================================
|
| 257 |
+
|
| 258 |
+
def generate_rollout_completions(trainer: GRPOTrainer, batch_messages: List[List[Dict]]) -> List[Dict]:
|
| 259 |
+
"""
|
| 260 |
+
Generate completions using the current policy model via TRL's built-in
|
| 261 |
+
generate_completions utility (vLLM-backed when use_vllm=True).
|
| 262 |
+
|
| 263 |
+
Returns a list of dicts with keys: 'text', 'prompt_ids', 'completion_ids', 'logprobs'.
|
| 264 |
+
"""
|
| 265 |
+
# Tokenize prompts
|
| 266 |
+
texts = [
|
| 267 |
+
tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
|
| 268 |
+
for msgs in batch_messages
|
| 269 |
+
]
|
| 270 |
+
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True,
|
| 271 |
+
max_length=2048).to(trainer.model.device)
|
| 272 |
+
|
| 273 |
+
with torch.no_grad():
|
| 274 |
+
output_ids = trainer.model.generate(
|
| 275 |
+
**inputs,
|
| 276 |
+
max_new_tokens=1024,
|
| 277 |
+
temperature=0.8,
|
| 278 |
+
top_p=0.95,
|
| 279 |
+
do_sample=True,
|
| 280 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
results = []
|
| 284 |
+
for i, (prompt_ids, out_ids) in enumerate(zip(inputs["input_ids"], output_ids)):
|
| 285 |
+
prompt_len = prompt_ids.shape[0]
|
| 286 |
+
completion_ids = out_ids[prompt_len:]
|
| 287 |
+
text = tokenizer.decode(completion_ids, skip_special_tokens=True)
|
| 288 |
+
results.append({
|
| 289 |
+
"text": text,
|
| 290 |
+
"prompt_ids": prompt_ids,
|
| 291 |
+
"completion_ids": completion_ids,
|
| 292 |
+
"logprobs": None, # TRL computes logprobs internally
|
| 293 |
+
})
|
| 294 |
+
return results
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def rollout_func(
|
| 298 |
+
trainer: GRPOTrainer,
|
| 299 |
+
batch: Dict[str, Any],
|
| 300 |
+
tokenizer: AutoTokenizer,
|
| 301 |
+
) -> Dict[str, Any]:
|
| 302 |
+
"""
|
| 303 |
+
TRL rollout function. Called by GRPOTrainer during training.
|
| 304 |
+
|
| 305 |
+
Plays one full episode per row in the batch:
|
| 306 |
+
1. reset() the environment for the task
|
| 307 |
+
2. Generate a fix with the current policy
|
| 308 |
+
3. step() the environment
|
| 309 |
+
4. Repeat up to max_turns (multi-turn RL)
|
| 310 |
+
|
| 311 |
+
Returns a batch-format dict that TRL expects.
|
| 312 |
+
"""
|
| 313 |
+
env = get_env()
|
| 314 |
+
max_turns = 3 # 3 attempts per training episode (saves compute)
|
| 315 |
+
|
| 316 |
+
all_prompt_ids = []
|
| 317 |
+
all_completion_ids = []
|
| 318 |
+
all_rewards = []
|
| 319 |
+
all_task_rewards = [] # grade component (no penalties)
|
| 320 |
+
|
| 321 |
+
task_ids: List[str] = batch["task_id"]
|
| 322 |
+
|
| 323 |
+
for task_id in task_ids:
|
| 324 |
+
# ββ Episode start ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 325 |
+
if USE_LOCAL_ENV:
|
| 326 |
+
obs = env.reset(seed=42, task_id=task_id)
|
| 327 |
+
else:
|
| 328 |
+
obs = env.reset(task_id=task_id)
|
| 329 |
+
|
| 330 |
+
episode_prompt_ids = []
|
| 331 |
+
episode_completion_ids = []
|
| 332 |
+
episode_rewards = []
|
| 333 |
+
is_pipeline = (task_id == "task3_etl_timezone")
|
| 334 |
+
done = False
|
| 335 |
+
|
| 336 |
+
for turn in range(max_turns):
|
| 337 |
+
if done:
|
| 338 |
+
break
|
| 339 |
+
|
| 340 |
+
messages = format_messages(obs)
|
| 341 |
+
completions = generate_rollout_completions(trainer, [messages])
|
| 342 |
+
completion = completions[0]
|
| 343 |
+
|
| 344 |
+
fixed_sql = extract_sql_from_response(completion["text"], is_pipeline=is_pipeline)
|
| 345 |
+
explanation = extract_explanation(completion["text"])
|
| 346 |
+
|
| 347 |
+
action = SQLDebugAction(fixed_sql=fixed_sql, explanation=explanation)
|
| 348 |
+
|
| 349 |
+
if USE_LOCAL_ENV:
|
| 350 |
+
obs, reward, done, info = env.step(action)
|
| 351 |
+
else:
|
| 352 |
+
obs, reward, done, info = env.step(action)
|
| 353 |
+
|
| 354 |
+
episode_prompt_ids.append(
|
| 355 |
+
tokenizer(
|
| 356 |
+
tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True),
|
| 357 |
+
return_tensors="pt",
|
| 358 |
+
)["input_ids"][0]
|
| 359 |
+
)
|
| 360 |
+
episode_completion_ids.append(completion["completion_ids"])
|
| 361 |
+
episode_rewards.append(reward)
|
| 362 |
+
|
| 363 |
+
# Use the best reward in the episode as the final signal
|
| 364 |
+
best_reward = max(episode_rewards) if episode_rewards else 0.0
|
| 365 |
+
all_rewards.extend([best_reward] * len(episode_rewards))
|
| 366 |
+
all_prompt_ids.extend(episode_prompt_ids)
|
| 367 |
+
all_completion_ids.extend(episode_completion_ids)
|
| 368 |
+
|
| 369 |
+
# Pad sequences to same length
|
| 370 |
+
max_prompt_len = max(t.shape[0] for t in all_prompt_ids)
|
| 371 |
+
max_comp_len = max(t.shape[0] for t in all_completion_ids)
|
| 372 |
+
|
| 373 |
+
padded_prompts = torch.stack([
|
| 374 |
+
torch.nn.functional.pad(t, (max_prompt_len - t.shape[0], 0), value=tokenizer.pad_token_id)
|
| 375 |
+
for t in all_prompt_ids
|
| 376 |
+
])
|
| 377 |
+
padded_completions = torch.stack([
|
| 378 |
+
torch.nn.functional.pad(t, (0, max_comp_len - t.shape[0]), value=tokenizer.pad_token_id)
|
| 379 |
+
for t in all_completion_ids
|
| 380 |
+
])
|
| 381 |
+
|
| 382 |
+
return {
|
| 383 |
+
"prompt_ids": padded_prompts,
|
| 384 |
+
"completion_ids": padded_completions,
|
| 385 |
+
"rewards": torch.tensor(all_rewards, dtype=torch.float32),
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
# =============================================================================
|
| 390 |
+
# 6. REWARD FUNCTIONS (TRL-style callbacks)
|
| 391 |
+
# =============================================================================
|
| 392 |
+
# TRL's GRPOTrainer can accept multiple reward_funcs. Each receives
|
| 393 |
+
# (completions, prompts, **kwargs) and returns a list of floats.
|
| 394 |
+
# We use our grader decomposition to provide multi-signal training.
|
| 395 |
+
|
| 396 |
+
def _run_grader(completion_text: str, task_id: str, is_pipeline: bool) -> Dict[str, float]:
|
| 397 |
+
"""Run the environment grader and return breakdown dict."""
|
| 398 |
+
import duckdb as _duckdb
|
| 399 |
+
from server.data import TASK_MAP
|
| 400 |
+
from server.graders import grade_task1, grade_task2, grade_task3
|
| 401 |
+
|
| 402 |
+
task = TASK_MAP[task_id]
|
| 403 |
+
con = _duckdb.connect(":memory:")
|
| 404 |
+
con.execute(task.schema_ddl)
|
| 405 |
+
con.execute(task.seed_sql)
|
| 406 |
+
gt_df = con.execute(task.ground_truth_query).df()
|
| 407 |
+
|
| 408 |
+
fixed = extract_sql_from_response(completion_text, is_pipeline=is_pipeline)
|
| 409 |
+
explanation = extract_explanation(completion_text)
|
| 410 |
+
|
| 411 |
+
try:
|
| 412 |
+
if task_id == "task1_syntax_fix":
|
| 413 |
+
score, breakdown = grade_task1(fixed, gt_df, con)
|
| 414 |
+
elif task_id == "task2_join_aggregation":
|
| 415 |
+
score, breakdown = grade_task2(fixed, gt_df, con)
|
| 416 |
+
elif task_id == "task3_etl_timezone":
|
| 417 |
+
con.execute(task.schema_ddl)
|
| 418 |
+
con.execute(task.seed_sql)
|
| 419 |
+
score, breakdown = grade_task3(fixed, gt_df, con, explanation)
|
| 420 |
+
else:
|
| 421 |
+
score, breakdown = 0.0, {}
|
| 422 |
+
except Exception:
|
| 423 |
+
score, breakdown = 0.0, {}
|
| 424 |
+
finally:
|
| 425 |
+
con.close()
|
| 426 |
+
|
| 427 |
+
return {"score": score, **breakdown}
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def reward_correctness(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
|
| 431 |
+
"""
|
| 432 |
+
Primary reward: overall grader score (0.0β1.0).
|
| 433 |
+
This is the dense, decomposed score from our grader.
|
| 434 |
+
"""
|
| 435 |
+
task_ids: List[str] = kwargs.get("task_id", ["task1_syntax_fix"] * len(completions))
|
| 436 |
+
rewards = []
|
| 437 |
+
for text, task_id in zip(completions, task_ids):
|
| 438 |
+
is_pipeline = (task_id == "task3_etl_timezone")
|
| 439 |
+
result = _run_grader(text, task_id, is_pipeline)
|
| 440 |
+
rewards.append(result["score"])
|
| 441 |
+
return rewards
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def reward_parses(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
|
| 445 |
+
"""
|
| 446 |
+
Shaping reward: did the SQL parse? (+0.1 bonus).
|
| 447 |
+
Encourages the model to produce syntactically valid SQL even when
|
| 448 |
+
semantics are wrong β important early in training.
|
| 449 |
+
"""
|
| 450 |
+
task_ids: List[str] = kwargs.get("task_id", ["task1_syntax_fix"] * len(completions))
|
| 451 |
+
rewards = []
|
| 452 |
+
for text, task_id in zip(completions, task_ids):
|
| 453 |
+
is_pipeline = (task_id == "task3_etl_timezone")
|
| 454 |
+
result = _run_grader(text, task_id, is_pipeline)
|
| 455 |
+
rewards.append(result.get("parses", 0.0))
|
| 456 |
+
return rewards
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def reward_format(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
|
| 460 |
+
"""
|
| 461 |
+
Format reward: did the model use a ```sql ... ``` block?
|
| 462 |
+
This teaches the model the required response format.
|
| 463 |
+
"""
|
| 464 |
+
rewards = []
|
| 465 |
+
task_ids: List[str] = kwargs.get("task_id", ["task1_syntax_fix"] * len(completions))
|
| 466 |
+
for text, task_id in zip(completions, task_ids):
|
| 467 |
+
lang = "python" if task_id == "task3_etl_timezone" else "sql"
|
| 468 |
+
has_block = bool(re.search(rf"```{lang}", text, re.IGNORECASE))
|
| 469 |
+
has_think = bool(re.search(r"<think>.*?</think>", text, re.DOTALL))
|
| 470 |
+
score = (0.5 if has_block else 0.0) + (0.5 if has_think else 0.0)
|
| 471 |
+
rewards.append(score)
|
| 472 |
+
return rewards
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def reward_no_repetition(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
|
| 476 |
+
"""
|
| 477 |
+
Penalise repetitive/trivial outputs (empty or < 10 chars of code).
|
| 478 |
+
"""
|
| 479 |
+
rewards = []
|
| 480 |
+
task_ids: List[str] = kwargs.get("task_id", ["task1_syntax_fix"] * len(completions))
|
| 481 |
+
for text, task_id in zip(completions, task_ids):
|
| 482 |
+
is_pipeline = (task_id == "task3_etl_timezone")
|
| 483 |
+
code = extract_sql_from_response(text, is_pipeline=is_pipeline)
|
| 484 |
+
penalty = -0.3 if len(code) < 10 else 0.0
|
| 485 |
+
rewards.append(penalty)
|
| 486 |
+
return rewards
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
# =============================================================================
|
| 490 |
+
# 7. CREATE DATASET
|
| 491 |
+
# =============================================================================
|
| 492 |
+
|
| 493 |
+
def create_training_dataset(n_repeats: int = 50) -> Dataset:
|
| 494 |
+
"""
|
| 495 |
+
Build a training dataset from the 3 tasks.
|
| 496 |
+
Each task is repeated n_repeats times so the model sees diverse episodes.
|
| 497 |
+
The 'prompt' column is a pre-tokenised chat string; 'task_id' is metadata
|
| 498 |
+
passed through to reward functions via kwargs.
|
| 499 |
+
"""
|
| 500 |
+
env = get_env()
|
| 501 |
+
rows = []
|
| 502 |
+
|
| 503 |
+
for task in TASKS:
|
| 504 |
+
obs = env.reset(seed=42, task_id=task.task_id) if USE_LOCAL_ENV else env.reset(task_id=task.task_id)
|
| 505 |
+
messages = format_messages(obs)
|
| 506 |
+
prompt_text = tokenizer.apply_chat_template(
|
| 507 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
for i in range(n_repeats):
|
| 511 |
+
rows.append({
|
| 512 |
+
"prompt": prompt_text,
|
| 513 |
+
"task_id": task.task_id,
|
| 514 |
+
"difficulty": task.difficulty,
|
| 515 |
+
# Seed varies so GRPO sees slightly different phrasings across epochs
|
| 516 |
+
"seed": 42 + i,
|
| 517 |
+
})
|
| 518 |
+
|
| 519 |
+
dataset = Dataset.from_list(rows)
|
| 520 |
+
print(f"Dataset created: {len(dataset)} rows "
|
| 521 |
+
f"({n_repeats} Γ {len(TASKS)} tasks)")
|
| 522 |
+
return dataset
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
# =============================================================================
|
| 526 |
+
# 8. CONFIGURE GRPO TRAINING
|
| 527 |
+
# =============================================================================
|
| 528 |
+
|
| 529 |
+
def get_grpo_config(output_dir: str = "./sql-debug-qwen-grpo") -> GRPOConfig:
|
| 530 |
+
"""
|
| 531 |
+
Return a GRPOConfig tuned for Qwen2.5-Coder-7B on a single A100/H100 40GB.
|
| 532 |
+
Reduce per_device_train_batch_size and num_generations for smaller GPUs.
|
| 533 |
+
"""
|
| 534 |
+
return GRPOConfig(
|
| 535 |
+
# ββ Output ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 536 |
+
output_dir=output_dir,
|
| 537 |
+
run_name="sql-debug-grpo-qwen25coder7b",
|
| 538 |
+
|
| 539 |
+
# ββ Training schedule βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 540 |
+
num_train_epochs=3,
|
| 541 |
+
learning_rate=5e-6,
|
| 542 |
+
lr_scheduler_type="cosine",
|
| 543 |
+
warmup_ratio=0.05,
|
| 544 |
+
|
| 545 |
+
# ββ Batch & memory ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 546 |
+
per_device_train_batch_size=1,
|
| 547 |
+
gradient_accumulation_steps=8, # effective batch = 8
|
| 548 |
+
gradient_checkpointing=True,
|
| 549 |
+
bf16=True,
|
| 550 |
+
|
| 551 |
+
# ββ GRPO-specific ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 552 |
+
num_generations=4, # G: candidates per prompt to compare
|
| 553 |
+
max_prompt_length=2048,
|
| 554 |
+
max_completion_length=1024, # SQL fixes can be verbose
|
| 555 |
+
|
| 556 |
+
# ββ vLLM for fast generation (requires vllm package) βββββββββββββββββ
|
| 557 |
+
# Set use_vllm=False if not using vLLM (much slower but works on any GPU)
|
| 558 |
+
use_vllm=False, # set True on A100+ with vllm installed
|
| 559 |
+
# vllm_mode="colocate",
|
| 560 |
+
# vllm_gpu_memory_utilization=0.2,
|
| 561 |
+
|
| 562 |
+
# ββ Logging ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 563 |
+
logging_steps=5,
|
| 564 |
+
save_steps=50,
|
| 565 |
+
eval_steps=50,
|
| 566 |
+
report_to="none", # set "wandb" or "tensorboard" as needed
|
| 567 |
+
|
| 568 |
+
# ββ Hub βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 569 |
+
push_to_hub=False, # set True to auto-push checkpoints
|
| 570 |
+
hub_model_id=HF_REPO_ID,
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
# =============================================================================
|
| 575 |
+
# 9. CREATE TRAINER & TRAIN
|
| 576 |
+
# =============================================================================
|
| 577 |
+
|
| 578 |
+
def build_trainer(
|
| 579 |
+
dataset: Dataset,
|
| 580 |
+
grpo_config: GRPOConfig,
|
| 581 |
+
) -> GRPOTrainer:
|
| 582 |
+
"""
|
| 583 |
+
Instantiate GRPOTrainer with:
|
| 584 |
+
- The base model (Qwen2.5-Coder-7B-Instruct)
|
| 585 |
+
- 3 reward functions (correctness, format, no-repetition)
|
| 586 |
+
- The rollout function that drives environment interaction
|
| 587 |
+
- The training dataset
|
| 588 |
+
"""
|
| 589 |
+
trainer = GRPOTrainer(
|
| 590 |
+
model=MODEL_NAME,
|
| 591 |
+
# Multiple reward functions β TRL sums them with equal weight by default.
|
| 592 |
+
# You can pass reward_weights=[0.7, 0.2, 0.1] to control contribution.
|
| 593 |
+
reward_funcs=[
|
| 594 |
+
reward_correctness, # primary: correctness score 0.0β1.0
|
| 595 |
+
reward_format, # shaping: forces <think> + ```sql``` format
|
| 596 |
+
reward_no_repetition, # penalty: discourages trivial empty outputs
|
| 597 |
+
],
|
| 598 |
+
reward_weights=[0.7, 0.2, 0.1],
|
| 599 |
+
args=grpo_config,
|
| 600 |
+
train_dataset=dataset,
|
| 601 |
+
processing_class=tokenizer,
|
| 602 |
+
# rollout_func: commented out here because TRL β₯0.12 uses reward_funcs
|
| 603 |
+
# directly for non-interactive tasks. Use rollout_func for multi-turn.
|
| 604 |
+
# rollout_func=rollout_func, # uncomment for multi-turn RL
|
| 605 |
+
)
|
| 606 |
+
return trainer
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def train(n_repeats: int = 50):
|
| 610 |
+
"""Main training entry point."""
|
| 611 |
+
print("=" * 60)
|
| 612 |
+
print(f"Model: {MODEL_NAME}")
|
| 613 |
+
print(f"Env URL: {ENV_URL if not USE_LOCAL_ENV else 'local'}")
|
| 614 |
+
print(f"Tasks: {[t.task_id for t in TASKS]}")
|
| 615 |
+
print("=" * 60)
|
| 616 |
+
|
| 617 |
+
dataset = create_training_dataset(n_repeats=n_repeats)
|
| 618 |
+
grpo_config = get_grpo_config()
|
| 619 |
+
trainer = build_trainer(dataset, grpo_config)
|
| 620 |
+
|
| 621 |
+
print("\nStarting GRPO trainingβ¦")
|
| 622 |
+
trainer.train()
|
| 623 |
+
|
| 624 |
+
return trainer
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
# =============================================================================
|
| 628 |
+
# 10. SAVE & PUSH TO HUB
|
| 629 |
+
# =============================================================================
|
| 630 |
+
|
| 631 |
+
def save_and_push(trainer: GRPOTrainer, output_dir: str = "./sql-debug-qwen-grpo"):
|
| 632 |
+
"""Save the trained model locally and optionally push to the Hub."""
|
| 633 |
+
print(f"\nSaving model to {output_dir}β¦")
|
| 634 |
+
trainer.save_model(output_dir)
|
| 635 |
+
tokenizer.save_pretrained(output_dir)
|
| 636 |
+
|
| 637 |
+
push = os.environ.get("PUSH_TO_HUB", "false").lower() == "true"
|
| 638 |
+
if push:
|
| 639 |
+
print(f"Pushing to Hub: {HF_REPO_ID}")
|
| 640 |
+
trainer.push_to_hub(
|
| 641 |
+
repo_id=HF_REPO_ID,
|
| 642 |
+
commit_message="GRPO-trained SQL debug model",
|
| 643 |
+
)
|
| 644 |
+
print(f"Model available at: https://huggingface.co/{HF_REPO_ID}")
|
| 645 |
+
else:
|
| 646 |
+
print(f"Set PUSH_TO_HUB=true to push to {HF_REPO_ID}")
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
# =============================================================================
|
| 650 |
+
# 11. EVALUATION
|
| 651 |
+
# =============================================================================
|
| 652 |
+
|
| 653 |
+
@dataclass
|
| 654 |
+
class EvalResult:
|
| 655 |
+
task_id: str
|
| 656 |
+
difficulty: str
|
| 657 |
+
n_episodes: int
|
| 658 |
+
mean_reward: float
|
| 659 |
+
best_reward: float
|
| 660 |
+
n_solved: int # episodes with reward >= 1.0
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def evaluate(
|
| 664 |
+
model_path: str = "./sql-debug-qwen-grpo",
|
| 665 |
+
n_episodes: int = 10,
|
| 666 |
+
max_steps: int = 5,
|
| 667 |
+
) -> List[EvalResult]:
|
| 668 |
+
"""
|
| 669 |
+
Evaluate the trained model against all 3 tasks.
|
| 670 |
+
Loads the fine-tuned model and runs n_episodes per task.
|
| 671 |
+
"""
|
| 672 |
+
print(f"\n{'='*60}\nEVALUATION β {model_path}\n{'='*60}")
|
| 673 |
+
|
| 674 |
+
eval_model = AutoModelForCausalLM.from_pretrained(
|
| 675 |
+
model_path,
|
| 676 |
+
torch_dtype=torch.bfloat16,
|
| 677 |
+
device_map="auto",
|
| 678 |
+
trust_remote_code=True,
|
| 679 |
+
)
|
| 680 |
+
eval_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 681 |
+
eval_model.eval()
|
| 682 |
+
|
| 683 |
+
env = get_env()
|
| 684 |
+
results: List[EvalResult] = []
|
| 685 |
+
|
| 686 |
+
for task in TASKS:
|
| 687 |
+
episode_rewards = []
|
| 688 |
+
n_solved = 0
|
| 689 |
+
|
| 690 |
+
for ep in range(n_episodes):
|
| 691 |
+
seed = 1000 + ep # different seeds from training
|
| 692 |
+
obs = env.reset(seed=seed, task_id=task.task_id) if USE_LOCAL_ENV \
|
| 693 |
+
else env.reset(task_id=task.task_id)
|
| 694 |
+
|
| 695 |
+
best_reward = 0.0
|
| 696 |
+
done = False
|
| 697 |
+
is_pipeline = (task.task_id == "task3_etl_timezone")
|
| 698 |
+
|
| 699 |
+
for step in range(max_steps):
|
| 700 |
+
if done:
|
| 701 |
+
break
|
| 702 |
+
|
| 703 |
+
messages = format_messages(obs)
|
| 704 |
+
prompt_text = eval_tokenizer.apply_chat_template(
|
| 705 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 706 |
+
)
|
| 707 |
+
inputs = eval_tokenizer(
|
| 708 |
+
prompt_text, return_tensors="pt", truncation=True, max_length=2048
|
| 709 |
+
).to(eval_model.device)
|
| 710 |
+
|
| 711 |
+
with torch.no_grad():
|
| 712 |
+
output_ids = eval_model.generate(
|
| 713 |
+
**inputs,
|
| 714 |
+
max_new_tokens=1024,
|
| 715 |
+
temperature=0.0, # greedy for eval
|
| 716 |
+
do_sample=False,
|
| 717 |
+
pad_token_id=eval_tokenizer.eos_token_id,
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
prompt_len = inputs["input_ids"].shape[1]
|
| 721 |
+
completion = eval_tokenizer.decode(
|
| 722 |
+
output_ids[0][prompt_len:], skip_special_tokens=True
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
fixed_sql = extract_sql_from_response(completion, is_pipeline=is_pipeline)
|
| 726 |
+
explanation = extract_explanation(completion)
|
| 727 |
+
action = SQLDebugAction(fixed_sql=fixed_sql, explanation=explanation)
|
| 728 |
+
|
| 729 |
+
obs, reward, done, info = env.step(action) if USE_LOCAL_ENV \
|
| 730 |
+
else env.step(action)
|
| 731 |
+
|
| 732 |
+
best_reward = max(best_reward, reward)
|
| 733 |
+
|
| 734 |
+
episode_rewards.append(best_reward)
|
| 735 |
+
if best_reward >= 1.0:
|
| 736 |
+
n_solved += 1
|
| 737 |
+
|
| 738 |
+
mean_r = sum(episode_rewards) / len(episode_rewards)
|
| 739 |
+
best_r = max(episode_rewards)
|
| 740 |
+
|
| 741 |
+
result = EvalResult(
|
| 742 |
+
task_id=task.task_id,
|
| 743 |
+
difficulty=task.difficulty,
|
| 744 |
+
n_episodes=n_episodes,
|
| 745 |
+
mean_reward=round(mean_r, 4),
|
| 746 |
+
best_reward=round(best_r, 4),
|
| 747 |
+
n_solved=n_solved,
|
| 748 |
+
)
|
| 749 |
+
results.append(result)
|
| 750 |
+
print(f" {task.task_id:40s} mean={mean_r:.4f} best={best_r:.4f} "
|
| 751 |
+
f"solved={n_solved}/{n_episodes}")
|
| 752 |
+
|
| 753 |
+
# Write evaluation report
|
| 754 |
+
report = {
|
| 755 |
+
"model": model_path,
|
| 756 |
+
"n_episodes": n_episodes,
|
| 757 |
+
"tasks": [r.__dict__ for r in results],
|
| 758 |
+
}
|
| 759 |
+
os.makedirs("outputs/evals", exist_ok=True)
|
| 760 |
+
report_path = f"outputs/evals/eval_{int(time.time())}.json"
|
| 761 |
+
with open(report_path, "w") as f:
|
| 762 |
+
json.dump(report, f, indent=2)
|
| 763 |
+
print(f"\nEval report saved: {report_path}")
|
| 764 |
+
|
| 765 |
+
return results
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
# =============================================================================
|
| 769 |
+
# ENTRY POINT
|
| 770 |
+
# =============================================================================
|
| 771 |
+
|
| 772 |
+
if __name__ == "__main__":
|
| 773 |
+
import argparse
|
| 774 |
+
|
| 775 |
+
parser = argparse.ArgumentParser(description="GRPO training for SQL Debug environment")
|
| 776 |
+
parser.add_argument("--mode", choices=["train", "eval", "both"], default="train")
|
| 777 |
+
parser.add_argument("--n-repeats", type=int, default=50, help="Dataset repeats per task")
|
| 778 |
+
parser.add_argument("--n-episodes", type=int, default=10, help="Eval episodes per task")
|
| 779 |
+
parser.add_argument("--output-dir", default="./sql-debug-qwen-grpo")
|
| 780 |
+
args = parser.parse_args()
|
| 781 |
+
|
| 782 |
+
if args.mode in ("train", "both"):
|
| 783 |
+
trainer = train(n_repeats=args.n_repeats)
|
| 784 |
+
save_and_push(trainer, output_dir=args.output_dir)
|
| 785 |
+
|
| 786 |
+
if args.mode in ("eval", "both"):
|
| 787 |
+
evaluate(model_path=args.output_dir, n_episodes=args.n_episodes)
|
train_rl.md
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RL Training for SQL Debug β GRPO with Qwen2.5-Coder-7B-Instruct
|
| 2 |
+
|
| 3 |
+
> **Full training script:** [`train_grpo.py`](train_grpo.py)
|
| 4 |
+
> **HF Space deployment:** [`deploy_hf_space.md`](deploy_hf_space.md)
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## Why GRPO, Not DDPG
|
| 9 |
+
|
| 10 |
+
| | DDPG | GRPO |
|
| 11 |
+
|---|---|---|
|
| 12 |
+
| Action space | Continuous R^n | Discrete tokens β
|
|
| 13 |
+
| Value network | Required | Not needed β
|
|
| 14 |
+
| Gradient signal | Bellman + actor-critic | Group relative ranking β
|
|
| 15 |
+
| Works for SQL? | β | β
|
|
| 16 |
+
|
| 17 |
+
DDPG is for robot control / trading. SQL token generation is discrete β **always use GRPO or PPO**.
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
## What `train_grpo.py` Contains
|
| 22 |
+
|
| 23 |
+
| Section | Description |
|
| 24 |
+
|---|---|
|
| 25 |
+
| 1. Environment | Local DuckDB env or HTTP client pointing at HF Space |
|
| 26 |
+
| 2. Model & Tokenizer | `Qwen/Qwen2.5-Coder-7B-Instruct`, left-padding |
|
| 27 |
+
| 3. System Prompt | Rules, Response Format (`<think>` + ```sql```), Strategy, Goal |
|
| 28 |
+
| 4. Helpers | `build_user_message()`, `extract_sql_from_response()`, `format_messages()` |
|
| 29 |
+
| 5. Rollout | `rollout_func()` β plays multi-turn episode, returns padded tensors |
|
| 30 |
+
| 6. Reward Fns | `reward_correctness`, `reward_format`, `reward_no_repetition` |
|
| 31 |
+
| 7. Dataset | 3 tasks Γ N repeats β HF `Dataset` with `prompt` + `task_id` columns |
|
| 32 |
+
| 8. GRPOConfig | A100-tuned: `num_generations=4`, `bf16=True`, `max_completion_length=1024` |
|
| 33 |
+
| 9. Trainer | `GRPOTrainer` with `reward_weights=[0.7, 0.2, 0.1]` |
|
| 34 |
+
| 10. Save & Push | `trainer.save_model()` + `push_to_hub()` when `PUSH_TO_HUB=true` |
|
| 35 |
+
| 11. Evaluation | Greedy decode, 10 episodes/task, JSON report in `outputs/evals/` |
|
| 36 |
+
|
| 37 |
+
---
|
| 38 |
+
|
| 39 |
+
## Quick Start
|
| 40 |
+
|
| 41 |
+
```powershell
|
| 42 |
+
# Install
|
| 43 |
+
pip install trl>=0.12.0 transformers>=4.45.0 torch>=2.3.0 duckdb pandas pydantic
|
| 44 |
+
|
| 45 |
+
# Start local server (terminal 1)
|
| 46 |
+
uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 47 |
+
|
| 48 |
+
# Train (terminal 2)
|
| 49 |
+
python train_grpo.py --mode train --n-repeats 50
|
| 50 |
+
|
| 51 |
+
# Evaluate trained model
|
| 52 |
+
python train_grpo.py --mode eval --output-dir ./sql-debug-qwen-grpo
|
| 53 |
+
|
| 54 |
+
# Train + eval in one command
|
| 55 |
+
python train_grpo.py --mode both
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
---
|
| 59 |
+
|
| 60 |
+
## System Prompt Structure
|
| 61 |
+
|
| 62 |
+
```
|
| 63 |
+
RULES β what the agent must/must not do
|
| 64 |
+
RESPONSE FORMAT β <think>...</think> then ```sql...```
|
| 65 |
+
STRATEGY β task-specific hints (syntax / JOIN type / timezone)
|
| 66 |
+
GOAL β produce output matching the ground truth exactly
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
The `<think>` block is critical β it teaches chain-of-thought diagnosis before emitting the fix.
|
| 70 |
+
|
| 71 |
+
---
|
| 72 |
+
|
| 73 |
+
## Reward Weights
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
reward_weights = [0.7, 0.2, 0.1]
|
| 77 |
+
# 0.7 Γ reward_correctness (dense 0.0β1.0 from grader)
|
| 78 |
+
# 0.2 Γ reward_format (<think> block + ```sql``` present)
|
| 79 |
+
# 0.1 Γ reward_no_repetition (penalty for trivial empty output)
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
---
|
| 83 |
+
|
| 84 |
+
## Expected Outcomes After Training
|
| 85 |
+
|
| 86 |
+
| Task | Before (GPT-4o-mini baseline) | After GRPO (estimated) |
|
| 87 |
+
|---|---|---|
|
| 88 |
+
| task1_syntax_fix | ~0.85 | ~0.95 |
|
| 89 |
+
| task2_join_aggregation | ~0.55 | ~0.75 |
|
| 90 |
+
| task3_etl_timezone | ~0.25 | ~0.50 |
|
| 91 |
+
|
| 92 |
+
Use curriculum (train on Task 1+2 first, then add Task 3) for better Hard task improvement.
|