md896 commited on
Commit
ceee0e3
·
1 Parent(s): ac3911c

Fix HF Jobs bootstrap (pin transformers/trl, drop torchao stack); add reward and trainer JSONL logging; stabilize launch_job.

Browse files
Files changed (2) hide show
  1. launch_job.py +86 -11
  2. ultimate_sota_training.py +339 -100
launch_job.py CHANGED
@@ -1,13 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from huggingface_hub import HfApi
2
- api = HfApi()
3
- try:
4
- job = api.create_compute_job(
5
- namespace="md896",
6
- flavor="a10g-small",
7
- image="pytorch/pytorch:2.11.0-cuda12.8-cudnn9-devel",
8
- command=["bash", "-c", "set -euxo pipefail; apt-get update; apt-get install -y git; git clone https://huggingface.co/spaces/md896/sql-debug-env; cd sql-debug-env; python -u ultimate_sota_training.py"],
9
- secrets=["HF_TOKEN"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  )
11
- print("JOB_ID:", job.job_id)
12
- except Exception as e:
13
- print("FAILED:", str(e))
 
1
+ """
2
+ Submit ultimate_sota_training.py to Hugging Face GPU Jobs (HfApi.run_job).
3
+
4
+ The Job command must be a single robust shell line (semicolon-separated). Hugging Face
5
+ has been observed to flatten multiline `bash -lc` payloads, which breaks `set` and can
6
+ leave the job stuck or failing silently.
7
+
8
+ Requires: huggingface_hub, `huggingface-cli login`.
9
+
10
+ Secrets: if SKIP_HUB_PUSH is not 1, the job requests Hub secret name HF_TOKEN mapped into
11
+ the container as env HF_TOKEN (Settings → Access Tokens / Job secrets).
12
+
13
+ Environment (optional):
14
+ HF_JOB_NAMESPACE default: whoami
15
+ HF_JOB_FLAVOR default: l4x1 (often faster than T4 for this workload; override with t4-small to save $)
16
+ HF_JOB_IMAGE default: pytorch CUDA 12.4 devel
17
+ HF_JOB_TIMEOUT default: 8h
18
+ TRAIN_REPO_GIT_URL, OPENENV_BASE_URL
19
+ TRAIN_MAX_STEPS default: 80 (faster run; raise for stronger fit)
20
+ ROWS_PER_TASK default: 32
21
+ GRPO_NUM_GENERATIONS default: 6
22
+ SKIP_HUB_PUSH default: 0
23
+ """
24
+ from __future__ import annotations
25
+
26
+ import os
27
+ import shlex
28
+
29
  from huggingface_hub import HfApi
30
+
31
+ _DEFAULT_REPO = "https://huggingface.co/spaces/md896/sql-debug-env"
32
+ _REPO_URL = os.environ.get("TRAIN_REPO_GIT_URL", _DEFAULT_REPO)
33
+ _OPENENV = os.environ.get("OPENENV_BASE_URL", "https://md896-sql-debug-env.hf.space")
34
+ _MAX_STEPS = os.environ.get("TRAIN_MAX_STEPS", "80")
35
+ _ROWS = os.environ.get("ROWS_PER_TASK", "32")
36
+ _NUM_GEN = os.environ.get("GRPO_NUM_GENERATIONS", "6")
37
+ _SKIP_PUSH = os.environ.get("SKIP_HUB_PUSH", "0")
38
+ _TIMEOUT = os.environ.get("HF_JOB_TIMEOUT", "8h")
39
+ # l4x1: newer GPU, good for Unsloth; use HF_JOB_FLAVOR=t4-small if queue or cost is better for you
40
+ _FLAVOR = os.environ.get("HF_JOB_FLAVOR", "l4x1")
41
+ _IMAGE = os.environ.get(
42
+ "HF_JOB_IMAGE",
43
+ "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel",
44
+ )
45
+ _NAMESPACE = os.environ.get("HF_JOB_NAMESPACE")
46
+
47
+ _SECRETS = None
48
+ if _SKIP_PUSH.strip().lower() not in ("1", "true", "yes"):
49
+ _SECRETS = {"HF_TOKEN": "HF_TOKEN"}
50
+
51
+ # One line only — survives UI/API newline flattening.
52
+ _bash = (
53
+ "set -euxo pipefail; "
54
+ "export DEBIAN_FRONTEND=noninteractive; "
55
+ "apt-get update -qq && apt-get install -y -qq git ca-certificates; "
56
+ "export PIP_BREAK_SYSTEM_PACKAGES=1; "
57
+ f"rm -rf train-repo; git clone {shlex.quote(_REPO_URL)} train-repo; "
58
+ "cd train-repo; "
59
+ "python -u ultimate_sota_training.py"
60
+ )
61
+
62
+ _job_env = {
63
+ "OPENENV_BASE_URL": _OPENENV,
64
+ "TRAIN_MAX_STEPS": _MAX_STEPS,
65
+ "ROWS_PER_TASK": _ROWS,
66
+ "GRPO_NUM_GENERATIONS": _NUM_GEN,
67
+ "SKIP_HUB_PUSH": _SKIP_PUSH,
68
+ }
69
+
70
+ if __name__ == "__main__":
71
+ api = HfApi()
72
+ ns = _NAMESPACE or api.whoami()["name"]
73
+ job = api.run_job(
74
+ image=_IMAGE,
75
+ command=["bash", "-lc", _bash],
76
+ flavor=_FLAVOR,
77
+ namespace=ns,
78
+ timeout=_TIMEOUT,
79
+ secrets=_SECRETS,
80
+ env=_job_env,
81
+ )
82
+ print("JOB_ID:", job.id)
83
+ print("JOB_URL:", job.url)
84
+ print("FLAVOR:", _FLAVOR, "| TRAIN_MAX_STEPS:", _MAX_STEPS, "| ROWS_PER_TASK:", _ROWS)
85
+ print(
86
+ "Note: SCHEDULING is Hugging Face queue time, not your script. "
87
+ "Cancel stuck jobs and retry, or try HF_JOB_FLAVOR=t4-small / t4-medium."
88
  )
 
 
 
ultimate_sota_training.py CHANGED
@@ -1,10 +1,27 @@
1
  """
2
- 🏆 Unsloth + OpenEnv GRPO training script
3
-
4
- Goal: produce *real* training evidence (reward curves + logs) and optionally push LoRA
5
- weights to the Hub.
6
-
7
- This script is designed to run inside Hugging Face Jobs/Spaces containers where:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  - system Python may be externally managed (PEP-668) → uses --break-system-packages
9
  - preinstalled CUDA/PyTorch stacks can conflict with optional vision packages
10
 
@@ -16,17 +33,23 @@ Key stability choices:
16
 
17
  from __future__ import annotations
18
 
 
19
  import json
 
20
  import os
21
  import random
22
  import re
23
  import subprocess
24
  import sys
25
  import time
 
26
  from dataclasses import dataclass
27
  from pathlib import Path
28
  from typing import Any, Dict, List, Optional
29
 
 
 
 
30
 
31
  def _run(cmd: List[str], *, check: bool = True) -> subprocess.CompletedProcess:
32
  return subprocess.run(cmd, check=check)
@@ -41,6 +64,7 @@ def bootstrap_deps() -> None:
41
  Best-effort dependency bootstrap for ephemeral HF containers.
42
 
43
  Set SKIP_BOOTSTRAP=1 to disable.
 
44
  """
45
  if os.environ.get("SKIP_BOOTSTRAP") == "1":
46
  return
@@ -53,31 +77,39 @@ def bootstrap_deps() -> None:
53
  # (PEP-668). Prefer an explicit opt-out for all pip ops in ephemeral jobs.
54
  os.environ.setdefault("PIP_BREAK_SYSTEM_PACKAGES", "1")
55
 
56
- print("📦 Bootstrapping dependencies...")
57
 
58
  # Text-only run: torchvision/torchaudio are not required and are a common source
59
  # of crashes when torch versions shift in container images.
60
  _pip(["uninstall", "--break-system-packages", "-y", "torchvision", "torchaudio"], check=False)
61
 
62
- # Keep these scoped; avoid blanket -U to reduce resolver churn.
 
63
  _pip(
64
  [
65
  "install",
66
  "--break-system-packages",
67
  "httpx>=0.27.0",
68
  "datasets>=3.4.1,<4.4.0",
69
- "trl>=0.18.2,<=0.22.2",
70
- "mergekit",
71
- "llm-blender",
72
- "weave",
73
- "wandb",
74
  "matplotlib",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  ]
76
  )
77
 
78
- # Unsloth (and its dependency set) can be fast-moving; install from git.
79
- # Build isolation/resolution can sometimes change torch; removing torchvision
80
- # above keeps transformers imports stable for text-only workloads.
81
  _pip(
82
  [
83
  "install",
@@ -86,10 +118,32 @@ def bootstrap_deps() -> None:
86
  ]
87
  )
88
 
89
- # Some dependency resolution paths can reintroduce torchvision. Remove it
90
- # again right before importing transformers/trl.
 
 
 
 
 
 
 
 
 
 
 
91
  _pip(["uninstall", "--break-system-packages", "-y", "torchvision", "torchaudio"], check=False)
92
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  bootstrap_deps()
95
 
@@ -126,71 +180,146 @@ import transformers.utils.hub
126
  if not hasattr(transformers.utils.hub, "TRANSFORMERS_CACHE"):
127
  transformers.utils.hub.TRANSFORMERS_CACHE = "/tmp"
128
 
 
129
  from trl import GRPOConfig, GRPOTrainer
130
  from unsloth import FastLanguageModel
131
 
132
- # --- 1. CONFIGURATION ---
133
- # Using your permanent Hugging Face Space!
134
- BRIDGE_URL = "https://md896-sql-debug-env.hf.space"
135
- BYPASS_HEADERS = {} # No longer needed for HF Spaces!
 
 
136
 
137
- # Using the massive 7B Coder model, but squeezing it into memory using Unsloth 4-bit!
138
- MODEL_NAME = "unsloth/Qwen2.5-Coder-7B-Instruct"
139
 
140
- # --- 2. THE XML FORMATTING PROMPT ---
141
- SYSTEM_PROMPT = """You are an elite SQL Database Administrator fixing a critical fan trap (Cartesian Explosion).
142
- You MUST output your reasoning process inside <think> tags.
 
 
 
 
 
 
 
143
  After you have finished thinking, you MUST output the exact fixed SQL query inside <sql> tags.
144
  Do not output any markdown blocks like ```sql.
145
 
146
  Example:
147
- <think>
148
- I need to aggregate the totals first using a CTE to avoid a Cartesian explosion.
149
- </think>
150
  <sql>
151
- WITH OrderTotals AS ( ... ) SELECT ...
 
152
  </sql>"""
153
 
154
- def make_real_dataset():
155
- print(f"🔗 Connecting to Environment at {BRIDGE_URL}...")
156
- tasks = ["hard_finance_explosion"]
157
- rows = []
158
-
159
- with httpx.Client(base_url=BRIDGE_URL, headers=BYPASS_HEADERS, timeout=30.0) as client:
160
- for t_id in tasks:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  resp = client.post("/reset", json={"task_id": t_id})
 
162
  obs = resp.json()["observation"]
163
-
164
  prompt = (
165
- f"{SYSTEM_PROMPT}\n\n"
166
  f"Task: {obs['task_description']}\n"
167
  f"Broken Query: {obs['original_query']}\n\n"
168
- "Provide your <think> and <sql> output:"
169
  )
170
- # Generate 40 identical starting states for the model to explore
171
- for _ in range(40):
172
  rows.append({"prompt": prompt, "task_id": t_id})
173
-
174
  if not rows:
175
- raise RuntimeError("Failed to connect to environment!")
 
176
  return Dataset.from_list(rows)
177
 
178
- # --- 3. MULTI-REWARD SHAPING (The Secret Weapon) ---
 
 
 
179
 
180
  def extract_xml_tag(text, tag):
181
  pattern = f"<{tag}>(.*?)</{tag}>"
182
  match = re.search(pattern, text, re.DOTALL)
183
  return match.group(1).strip() if match else None
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  def format_reward_func(completions, **kwargs):
186
- """Reward 1: Did the model use <think> and <sql> tags? (+0.1)"""
 
187
  rewards = []
188
  for comp in completions:
189
- has_think = extract_xml_tag(comp, "think") is not None
190
  has_sql = extract_xml_tag(comp, "sql") is not None
191
  rewards.append(0.1 if (has_think and has_sql) else 0.0)
 
192
  return rewards
193
 
 
194
  def syntax_reward_func(completions, **kwargs):
195
  """Reward 2: Does the SQL look like valid code? (+0.2)"""
196
  rewards = []
@@ -200,29 +329,85 @@ def syntax_reward_func(completions, **kwargs):
200
  rewards.append(0.2)
201
  else:
202
  rewards.append(0.0)
 
203
  return rewards
204
 
 
205
  def execution_reward_func(completions, task_id, **kwargs):
206
- """Reward 3: The Ultimate Sandbox Test (+1.0)"""
207
- rewards = []
208
- with httpx.Client(base_url=BRIDGE_URL, headers=BYPASS_HEADERS, timeout=30.0) as client:
 
 
209
  for query, t_id in zip(completions, task_id):
210
  sql = extract_xml_tag(query, "sql")
211
  if not sql:
212
- rewards.append(0.0)
213
  continue
214
-
 
215
  try:
216
- client.post("/reset", json={"task_id": t_id})
217
- resp = client.post("/step", json={"action": {"action_type": "submit_query", "query": sql}})
218
- reward = resp.json().get("reward", 0.0)
 
 
 
 
 
 
219
  except Exception:
220
  reward = 0.0
221
-
222
- reward += random.uniform(-1e-6, 1e-6)
223
  rewards.append(reward)
 
224
  return rewards
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  # --- 3b. ARTIFACTS / PLOTS (REAL, FROM LOGS) ---
227
 
228
  @dataclass(frozen=True)
@@ -329,17 +514,42 @@ def plot_reward_curve(reward_series: List[tuple[float, float]], paths: ArtifactP
329
  _ensure_dir(paths.root)
330
  plt.tight_layout()
331
  plt.savefig(paths.reward_curve_png, dpi=200)
332
- print(f"Saved {paths.reward_curve_png}")
 
333
 
 
 
 
 
 
 
 
 
 
334
 
335
- # --- 4. THE UNSLOTH + DEEPSEEK-R1 TRAINING LOOP ---
 
336
  def run_sota_train():
337
- print(f"🚀 Starting Unsloth GRPO on {MODEL_NAME}...")
338
-
339
- # LOAD WITH UNSLOTH 4-BIT QUANTIZATION (2X FASTER, 70% LESS MEMORY)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  model, tokenizer = FastLanguageModel.from_pretrained(
341
  model_name=MODEL_NAME,
342
- max_seq_length=1024,
343
  load_in_4bit=True,
344
  )
345
 
@@ -357,10 +567,7 @@ def run_sota_train():
357
 
358
  def quick_exec_eval(max_items: int = 8) -> float:
359
  """
360
- Quick before/after check:
361
- - sample a few prompts
362
- - generate <think>/<sql>
363
- - score via live execution reward
364
  """
365
  subset = train_dataset.select(range(min(max_items, len(train_dataset))))
366
  prompts = subset["prompt"]
@@ -382,39 +589,60 @@ def run_sota_train():
382
  rewards = execution_reward_func(completions, task_ids)
383
  return float(sum(rewards) / max(len(rewards), 1))
384
 
385
- print("📏 Quick baseline eval (pre-train)...")
386
  baseline_avg_reward = quick_exec_eval()
387
 
388
- training_args = GRPOConfig(
389
- output_dir="./sota_results",
390
- learning_rate=5e-6,
391
- per_device_train_batch_size=1,
392
- gradient_accumulation_steps=2,
393
- num_generations=8,
394
- max_completion_length=400, # Lots of room for <think> and <sql> CTEs
395
- temperature=0.9, # Forces creative exploration
396
- num_train_epochs=1,
397
- max_steps=30,
398
- logging_steps=1,
399
- report_to="none"
 
 
 
 
 
 
400
  )
 
 
 
 
 
 
401
 
402
  trainer = GRPOTrainer(
403
  model=model,
404
- reward_funcs=[format_reward_func, syntax_reward_func, execution_reward_func],
 
 
 
 
 
405
  args=training_args,
406
  train_dataset=train_dataset,
407
  processing_class=tokenizer,
 
 
 
 
408
  )
409
 
410
- print("🧠 SOTA Sandbox Active. Let the RL begin...")
411
  trainer.train()
412
 
413
- print("📏 Quick eval (post-train)...")
414
  post_avg_reward = quick_exec_eval()
415
 
416
  # --- Save artifacts (real logs/plots) ---
417
- artifacts = ArtifactPaths(root=Path("./sota_results/artifacts"))
418
  log_history = getattr(trainer.state, "log_history", []) or []
419
  save_log_history(log_history, artifacts)
420
  reward_series = extract_reward_series(log_history)
@@ -427,9 +655,16 @@ def run_sota_train():
427
  metrics = {}
428
  metrics.update(
429
  {
 
 
 
430
  "baseline_avg_reward": baseline_avg_reward,
431
  "post_avg_reward": post_avg_reward,
432
  "delta_avg_reward": post_avg_reward - baseline_avg_reward,
 
 
 
 
433
  }
434
  )
435
  metrics_path.write_text(json.dumps(metrics, indent=2), encoding="utf-8")
@@ -447,22 +682,26 @@ def run_sota_train():
447
  out_path = artifacts.root / "before_after_avg_reward.png"
448
  plt.tight_layout()
449
  plt.savefig(out_path, dpi=200)
450
- print(f"Saved {out_path}")
451
  except Exception as e:
452
- print(f"⚠️ Could not generate before/after plot: {e}")
453
-
454
- print("\n💾 Saving and (optionally) pushing LoRA weights...")
455
- model.save_pretrained("./sota_sql_agent_unsloth")
456
-
457
- # CRITICAL: Since you are running on HF Jobs, the server deletes everything when it finishes.
458
- # We MUST push the weights to your account so you can actually use them!
459
- try:
460
- model.push_to_hub("md896/sota-sql-agent-7b", token=os.environ.get("HF_TOKEN"))
461
- print(" Successfully pushed to https://huggingface.co/md896/sota-sql-agent-7b")
462
- except Exception as e:
463
- print(f"⚠️ Could not push to hub. Make sure HF_TOKEN is set. Error: {e}")
464
-
465
- print("\n📊 Training artifacts saved under ./sota_results/artifacts")
 
 
 
 
466
 
467
  if __name__ == "__main__":
468
  run_sota_train()
 
1
  """
2
+ Unsloth + OpenEnv GRPO training (production-oriented).
3
+
4
+ Produces real training artifacts (trainer log_history, metrics JSON, reward plots) and
5
+ optional Hub push of LoRA weights. Every execution reward calls your live Space (or
6
+ local server) at OPENENV_BASE_URL — not a mock.
7
+
8
+ Environment (control cost vs quality on HF Jobs / local GPU):
9
+ OPENENV_BASE_URL — OpenEnv HTTP root (default: Space URL from openenv.yaml)
10
+ OPENENV_TASK_IDS — Comma list; if unset, uses GET /tasks from the server
11
+ ROWS_PER_TASK — GRPO rows per task_id (default: 48)
12
+ OPENENV_REQUEST_TIMEOUT_SEC — HTTP timeout for reset/step (default: 120)
13
+ REASONING_XML_TAG — XML tag name for chain-of-thought (default: think)
14
+ TRAIN_MAX_STEPS — GRPO optimizer steps (default: 200; was 30 for smoke)
15
+ TRAIN_NUM_EPOCHS, TRAIN_LR, GRPO_NUM_GENERATIONS, GRPO_MAX_COMPLETION_LEN
16
+ PER_DEVICE_TRAIN_BS, GRAD_ACCUM
17
+ TRL_REPORT_TO — none | wandb | tensorboard (auto: wandb if key else tensorboard)
18
+ BOOTSTRAP_*_VERSION — pin transformers / accelerate / trl for HF Jobs (see bootstrap_deps)
19
+ Artifacts: artifacts/reward_components.jsonl, artifacts/trainer_on_log.jsonl, tensorboard/
20
+ HF_HUB_REPO_ID — push target (default md896/sota-sql-agent-7b)
21
+ SKIP_HUB_PUSH=1 — do not push after train
22
+ HF_TOKEN / HUGGING_FACE_HUB_TOKEN — Hub auth for push
23
+
24
+ Designed for Hugging Face Jobs / Spaces where:
25
  - system Python may be externally managed (PEP-668) → uses --break-system-packages
26
  - preinstalled CUDA/PyTorch stacks can conflict with optional vision packages
27
 
 
33
 
34
  from __future__ import annotations
35
 
36
+ import contextvars
37
  import json
38
+ import math
39
  import os
40
  import random
41
  import re
42
  import subprocess
43
  import sys
44
  import time
45
+ import uuid
46
  from dataclasses import dataclass
47
  from pathlib import Path
48
  from typing import Any, Dict, List, Optional
49
 
50
+ # Set by TrainerCallback so reward funcs can tag JSONL rows with the real global_step.
51
+ CURRENT_GRPO_STEP: contextvars.ContextVar[int] = contextvars.ContextVar("CURRENT_GRPO_STEP", default=-1)
52
+
53
 
54
  def _run(cmd: List[str], *, check: bool = True) -> subprocess.CompletedProcess:
55
  return subprocess.run(cmd, check=check)
 
64
  Best-effort dependency bootstrap for ephemeral HF containers.
65
 
66
  Set SKIP_BOOTSTRAP=1 to disable.
67
+ Pins: BOOTSTRAP_TRANSFORMERS_VERSION, BOOTSTRAP_ACCELERATE_VERSION, BOOTSTRAP_TRL_VERSION.
68
  """
69
  if os.environ.get("SKIP_BOOTSTRAP") == "1":
70
  return
 
77
  # (PEP-668). Prefer an explicit opt-out for all pip ops in ephemeral jobs.
78
  os.environ.setdefault("PIP_BREAK_SYSTEM_PACKAGES", "1")
79
 
80
+ print("Bootstrapping dependencies...")
81
 
82
  # Text-only run: torchvision/torchaudio are not required and are a common source
83
  # of crashes when torch versions shift in container images.
84
  _pip(["uninstall", "--break-system-packages", "-y", "torchvision", "torchaudio"], check=False)
85
 
86
+ _pip(["uninstall", "-y", "torchao"], check=False)
87
+
88
  _pip(
89
  [
90
  "install",
91
  "--break-system-packages",
92
  "httpx>=0.27.0",
93
  "datasets>=3.4.1,<4.4.0",
 
 
 
 
 
94
  "matplotlib",
95
+ "tensorboard",
96
+ "wandb",
97
+ ]
98
+ )
99
+
100
+ _tf = os.environ.get("BOOTSTRAP_TRANSFORMERS_VERSION", "4.48.3")
101
+ _acc = os.environ.get("BOOTSTRAP_ACCELERATE_VERSION", "0.34.2")
102
+ _trl = os.environ.get("BOOTSTRAP_TRL_VERSION", "0.18.2")
103
+ _pip(
104
+ [
105
+ "install",
106
+ "--break-system-packages",
107
+ f"transformers=={_tf}",
108
+ f"accelerate=={_acc}",
109
+ f"trl=={_trl}",
110
  ]
111
  )
112
 
 
 
 
113
  _pip(
114
  [
115
  "install",
 
118
  ]
119
  )
120
 
121
+ _pip(
122
+ [
123
+ "install",
124
+ "--break-system-packages",
125
+ "--force-reinstall",
126
+ "--no-deps",
127
+ f"transformers=={_tf}",
128
+ f"accelerate=={_acc}",
129
+ ]
130
+ )
131
+ _pip(["install", "--break-system-packages", "--no-deps", f"trl=={_trl}"])
132
+
133
+ _pip(["uninstall", "-y", "torchao"], check=False)
134
  _pip(["uninstall", "--break-system-packages", "-y", "torchvision", "torchaudio"], check=False)
135
 
136
+ try:
137
+ import accelerate # noqa: F401
138
+ import transformers # noqa: F401
139
+ from trl import GRPOConfig as _BootstrapGRPOConfig # noqa: F401
140
+
141
+ _ = _BootstrapGRPOConfig
142
+ except Exception as e:
143
+ raise RuntimeError(
144
+ "Post-bootstrap import check failed. Adjust BOOTSTRAP_*_VERSION or SKIP_BOOTSTRAP=1."
145
+ ) from e
146
+
147
 
148
  bootstrap_deps()
149
 
 
180
  if not hasattr(transformers.utils.hub, "TRANSFORMERS_CACHE"):
181
  transformers.utils.hub.TRANSFORMERS_CACHE = "/tmp"
182
 
183
+ from transformers import TrainerCallback
184
  from trl import GRPOConfig, GRPOTrainer
185
  from unsloth import FastLanguageModel
186
 
187
+ # --- 1. CONFIGURATION (env-first; defaults match openenv.yaml) ---
188
+ _DEFAULT_OPENENV_BASE = "https://md896-sql-debug-env.hf.space"
189
+ BYPASS_HEADERS: Dict[str, str] = {}
190
+
191
+ MODEL_NAME = os.environ.get("TRAIN_MODEL_NAME", "unsloth/Qwen2.5-Coder-7B-Instruct")
192
+
193
 
194
+ def get_bridge_url() -> str:
195
+ return os.environ.get("OPENENV_BASE_URL", _DEFAULT_OPENENV_BASE).rstrip("/")
196
 
197
+
198
+ def get_request_timeout() -> float:
199
+ return float(os.environ.get("OPENENV_REQUEST_TIMEOUT_SEC", "120"))
200
+
201
+
202
+ def build_system_prompt() -> str:
203
+ """Single prompt template for every task (easy → expert); tag name is configurable."""
204
+ tag = os.environ.get("REASONING_XML_TAG", "think")
205
+ return f"""You are an elite SQL engineer. You fix broken SQLite analytics queries using the task description and the broken query.
206
+ You MUST output your reasoning process inside <{tag}> tags.
207
  After you have finished thinking, you MUST output the exact fixed SQL query inside <sql> tags.
208
  Do not output any markdown blocks like ```sql.
209
 
210
  Example:
211
+ <{tag}>
212
+ I will check joins, filters, and aggregation, then write a corrected SELECT or WITH query.
213
+ </{tag}>
214
  <sql>
215
+ WITH OrderTotals AS (SELECT order_id, SUM(amount) AS total FROM line_items GROUP BY order_id)
216
+ SELECT o.id, ot.total FROM orders o JOIN OrderTotals ot ON o.id = ot.order_id;
217
  </sql>"""
218
 
219
+
220
+ def _fetch_task_ids(client: httpx.Client) -> List[str]:
221
+ raw = os.environ.get("OPENENV_TASK_IDS", "").strip()
222
+ if raw:
223
+ return [x.strip() for x in raw.split(",") if x.strip()]
224
+ r = client.get("/tasks", timeout=get_request_timeout())
225
+ r.raise_for_status()
226
+ body = r.json()
227
+ tasks = body.get("tasks") or []
228
+ ids = [t["task_id"] for t in tasks if t.get("task_id")]
229
+ if not ids:
230
+ raise RuntimeError("/tasks returned no task_id entries")
231
+ return ids
232
+
233
+
234
+ def make_real_dataset() -> Dataset:
235
+ bridge = get_bridge_url()
236
+ timeout = get_request_timeout()
237
+ rows_per_task = max(1, int(os.environ.get("ROWS_PER_TASK", "48")))
238
+ system = build_system_prompt()
239
+
240
+ print(f"Connecting to OpenEnv at {bridge} (timeout={timeout}s)...")
241
+ rows: List[Dict[str, Any]] = []
242
+
243
+ with httpx.Client(base_url=bridge, headers=BYPASS_HEADERS, timeout=timeout) as client:
244
+ h = client.get("/health", timeout=min(30.0, timeout))
245
+ h.raise_for_status()
246
+ print(f"OpenEnv health: {h.json()}")
247
+
248
+ task_ids = _fetch_task_ids(client)
249
+ print(f"Training task_ids ({len(task_ids)}): {task_ids}")
250
+
251
+ for t_id in task_ids:
252
  resp = client.post("/reset", json={"task_id": t_id})
253
+ resp.raise_for_status()
254
  obs = resp.json()["observation"]
255
+
256
  prompt = (
257
+ f"{system}\n\n"
258
  f"Task: {obs['task_description']}\n"
259
  f"Broken Query: {obs['original_query']}\n\n"
260
+ f"Provide your <{os.environ.get('REASONING_XML_TAG', 'think')}> and <sql> output:"
261
  )
262
+ for _ in range(rows_per_task):
 
263
  rows.append({"prompt": prompt, "task_id": t_id})
264
+
265
  if not rows:
266
+ raise RuntimeError("Failed to build dataset (no rows).")
267
+ print(f"Dataset: {len(rows)} prompts ({rows_per_task} per task).")
268
  return Dataset.from_list(rows)
269
 
270
+ # --- 3. MULTI-REWARD SHAPING + JSONL logging (per-component batch stats) ---
271
+
272
+ _REWARD_COMPONENTS_JSONL: Optional[Path] = None
273
+
274
 
275
  def extract_xml_tag(text, tag):
276
  pattern = f"<{tag}>(.*?)</{tag}>"
277
  match = re.search(pattern, text, re.DOTALL)
278
  return match.group(1).strip() if match else None
279
 
280
+
281
+ def _reward_batch_stats(values: List[float]) -> Dict[str, float]:
282
+ if not values:
283
+ return {"mean": 0.0, "std": 0.0, "min": 0.0, "max": 0.0}
284
+ n = len(values)
285
+ mean = sum(values) / n
286
+ var = sum((x - mean) ** 2 for x in values) / max(n - 1, 1)
287
+ return {"mean": mean, "std": math.sqrt(var), "min": min(values), "max": max(values)}
288
+
289
+
290
+ def _append_jsonl(path: Path, row: Dict[str, Any]) -> None:
291
+ path.parent.mkdir(parents=True, exist_ok=True)
292
+ with path.open("a", encoding="utf-8") as f:
293
+ f.write(json.dumps(row, ensure_ascii=False, default=str) + "\n")
294
+
295
+
296
+ def _log_reward_component(name: str, values: List[float]) -> None:
297
+ if _REWARD_COMPONENTS_JSONL is None:
298
+ return
299
+ _append_jsonl(
300
+ _REWARD_COMPONENTS_JSONL,
301
+ {
302
+ "time_epoch_s": time.time(),
303
+ "global_step": CURRENT_GRPO_STEP.get(),
304
+ "reward_component": name,
305
+ "n": len(values),
306
+ **_reward_batch_stats(values),
307
+ },
308
+ )
309
+
310
+
311
  def format_reward_func(completions, **kwargs):
312
+ """Reward 1: CoT + sql XML tags (+0.1). Tag name follows REASONING_XML_TAG."""
313
+ tag = os.environ.get("REASONING_XML_TAG", "think")
314
  rewards = []
315
  for comp in completions:
316
+ has_think = extract_xml_tag(comp, tag) is not None
317
  has_sql = extract_xml_tag(comp, "sql") is not None
318
  rewards.append(0.1 if (has_think and has_sql) else 0.0)
319
+ _log_reward_component("format_xml", rewards)
320
  return rewards
321
 
322
+
323
  def syntax_reward_func(completions, **kwargs):
324
  """Reward 2: Does the SQL look like valid code? (+0.2)"""
325
  rewards = []
 
329
  rewards.append(0.2)
330
  else:
331
  rewards.append(0.0)
332
+ _log_reward_component("syntax_select_with", rewards)
333
  return rewards
334
 
335
+
336
  def execution_reward_func(completions, task_id, **kwargs):
337
+ """Reward 3: live OpenEnv submit_query against the real Space/API (not a stub)."""
338
+ rewards: List[float] = []
339
+ base = get_bridge_url()
340
+ timeout = get_request_timeout()
341
+ with httpx.Client(base_url=base, headers=BYPASS_HEADERS, timeout=timeout) as client:
342
  for query, t_id in zip(completions, task_id):
343
  sql = extract_xml_tag(query, "sql")
344
  if not sql:
345
+ rewards.append(0.0)
346
  continue
347
+
348
+ session_headers = {"X-Session-Id": str(uuid.uuid4())}
349
  try:
350
+ r0 = client.post("/reset", json={"task_id": t_id}, headers=session_headers)
351
+ r0.raise_for_status()
352
+ resp = client.post(
353
+ "/step",
354
+ json={"action": {"action_type": "submit_query", "query": sql}},
355
+ headers=session_headers,
356
+ )
357
+ resp.raise_for_status()
358
+ reward = float(resp.json().get("reward", 0.0))
359
  except Exception:
360
  reward = 0.0
361
+
362
+ reward += random.uniform(-1e-6, 1e-6)
363
  rewards.append(reward)
364
+ _log_reward_component("openenv_execution", rewards)
365
  return rewards
366
 
367
+
368
+ def length_shape_reward_func(completions, **kwargs):
369
+ """Reward 4: soft preference for shorter completions (bounded; does not replace execution reward)."""
370
+ cap = float(os.environ.get("COMPLETION_SOFT_CHAR_CAP", "3500"))
371
+ bonus_max = float(os.environ.get("LENGTH_BONUS_MAX", "0.05"))
372
+ rewards: List[float] = []
373
+ for comp in completions:
374
+ L = len(comp) if comp else 0
375
+ if L <= 0:
376
+ rewards.append(0.0)
377
+ else:
378
+ rewards.append(bonus_max * max(0.0, 1.0 - min(L, cap) / cap))
379
+ _log_reward_component("length_shape", rewards)
380
+ return rewards
381
+
382
+
383
+ class GrpoStepContextCallback(TrainerCallback):
384
+ """Expose true global_step to reward funcs for JSONL alignment."""
385
+
386
+ def on_step_begin(self, args, state, control, **kwargs):
387
+ CURRENT_GRPO_STEP.set(int(state.global_step))
388
+
389
+
390
+ class JsonlOnLogCallback(TrainerCallback):
391
+ """Mirror every trainer `logs` dict to JSONL (loss, learning_rate, reward keys, etc.)."""
392
+
393
+ def __init__(self, path: Path):
394
+ self.path = path
395
+ self.path.parent.mkdir(parents=True, exist_ok=True)
396
+ self._fp = path.open("w", encoding="utf-8")
397
+
398
+ def on_log(self, args, state, control, logs=None, **kwargs):
399
+ if not logs:
400
+ return
401
+ row: Dict[str, Any] = {"global_step": int(state.global_step), **dict(logs)}
402
+ self._fp.write(json.dumps(row, ensure_ascii=False, default=str) + "\n")
403
+ self._fp.flush()
404
+
405
+ def on_train_end(self, args, state, control, **kwargs):
406
+ try:
407
+ self._fp.close()
408
+ except Exception:
409
+ pass
410
+
411
  # --- 3b. ARTIFACTS / PLOTS (REAL, FROM LOGS) ---
412
 
413
  @dataclass(frozen=True)
 
514
  _ensure_dir(paths.root)
515
  plt.tight_layout()
516
  plt.savefig(paths.reward_curve_png, dpi=200)
517
+ print(f"Saved {paths.reward_curve_png}")
518
+
519
 
520
+ def _resolve_report_to() -> str:
521
+ raw = os.environ.get("TRL_REPORT_TO", "").strip().lower()
522
+ if raw in ("", "auto"):
523
+ if os.environ.get("WANDB_API_KEY"):
524
+ return "wandb"
525
+ return "tensorboard"
526
+ if raw in ("false", "no", "off", "none"):
527
+ return "none"
528
+ return raw
529
 
530
+
531
+ # --- 4. Unsloth GRPO training loop (live OpenEnv rewards) ---
532
  def run_sota_train():
533
+ global _REWARD_COMPONENTS_JSONL
534
+
535
+ max_steps = int(os.environ.get("TRAIN_MAX_STEPS", "200"))
536
+ out_dir = os.environ.get("OUTPUT_DIR", "./sota_results")
537
+ artifacts_early = Path(out_dir) / "artifacts"
538
+ _ensure_dir(artifacts_early)
539
+ _REWARD_COMPONENTS_JSONL = artifacts_early / "reward_components.jsonl"
540
+ _REWARD_COMPONENTS_JSONL.write_text("", encoding="utf-8")
541
+
542
+ print(f"Starting Unsloth GRPO on {MODEL_NAME}...")
543
+ print(
544
+ f"OpenEnv={get_bridge_url()} | max_steps={max_steps} | "
545
+ f"rows_per_task={os.environ.get('ROWS_PER_TASK', '48')} | "
546
+ f"report_to={_resolve_report_to()}"
547
+ )
548
+
549
+ max_seq = int(os.environ.get("MAX_SEQ_LENGTH", "1024"))
550
  model, tokenizer = FastLanguageModel.from_pretrained(
551
  model_name=MODEL_NAME,
552
+ max_seq_length=max_seq,
553
  load_in_4bit=True,
554
  )
555
 
 
567
 
568
  def quick_exec_eval(max_items: int = 8) -> float:
569
  """
570
+ Quick before/after check: sample prompts, generate CoT + sql, score via live OpenEnv.
 
 
 
571
  """
572
  subset = train_dataset.select(range(min(max_items, len(train_dataset))))
573
  prompts = subset["prompt"]
 
589
  rewards = execution_reward_func(completions, task_ids)
590
  return float(sum(rewards) / max(len(rewards), 1))
591
 
592
+ print("Quick baseline eval (pre-train)...")
593
  baseline_avg_reward = quick_exec_eval()
594
 
595
+ report_to = _resolve_report_to()
596
+ tb_dir = Path(out_dir) / "tensorboard"
597
+ if report_to == "tensorboard":
598
+ _ensure_dir(tb_dir)
599
+
600
+ _cfg: Dict[str, Any] = dict(
601
+ output_dir=out_dir,
602
+ learning_rate=float(os.environ.get("TRAIN_LR", "5e-6")),
603
+ per_device_train_batch_size=int(os.environ.get("PER_DEVICE_TRAIN_BS", "1")),
604
+ gradient_accumulation_steps=int(os.environ.get("GRAD_ACCUM", "2")),
605
+ num_generations=int(os.environ.get("GRPO_NUM_GENERATIONS", "8")),
606
+ max_completion_length=int(os.environ.get("GRPO_MAX_COMPLETION_LEN", "512")),
607
+ temperature=float(os.environ.get("GRPO_TEMPERATURE", "0.9")),
608
+ num_train_epochs=int(os.environ.get("TRAIN_NUM_EPOCHS", "1")),
609
+ max_steps=max_steps,
610
+ logging_steps=int(os.environ.get("LOGGING_STEPS", "1")),
611
+ logging_first_step=True,
612
+ report_to=report_to,
613
  )
614
+ if report_to == "tensorboard":
615
+ _cfg["logging_dir"] = str(tb_dir)
616
+ training_args = GRPOConfig(**_cfg)
617
+
618
+ trainer_logs_path = artifacts_early / "trainer_on_log.jsonl"
619
+ trainer_logs_path.write_text("", encoding="utf-8")
620
 
621
  trainer = GRPOTrainer(
622
  model=model,
623
+ reward_funcs=[
624
+ format_reward_func,
625
+ syntax_reward_func,
626
+ execution_reward_func,
627
+ length_shape_reward_func,
628
+ ],
629
  args=training_args,
630
  train_dataset=train_dataset,
631
  processing_class=tokenizer,
632
+ callbacks=[
633
+ GrpoStepContextCallback(),
634
+ JsonlOnLogCallback(trainer_logs_path),
635
+ ],
636
  )
637
 
638
+ print("Training with live execution rewards against OpenEnv...")
639
  trainer.train()
640
 
641
+ print("Quick eval (post-train)...")
642
  post_avg_reward = quick_exec_eval()
643
 
644
  # --- Save artifacts (real logs/plots) ---
645
+ artifacts = ArtifactPaths(root=Path(out_dir) / "artifacts")
646
  log_history = getattr(trainer.state, "log_history", []) or []
647
  save_log_history(log_history, artifacts)
648
  reward_series = extract_reward_series(log_history)
 
655
  metrics = {}
656
  metrics.update(
657
  {
658
+ "openenv_base_url": get_bridge_url(),
659
+ "train_max_steps": max_steps,
660
+ "model_name": MODEL_NAME,
661
  "baseline_avg_reward": baseline_avg_reward,
662
  "post_avg_reward": post_avg_reward,
663
  "delta_avg_reward": post_avg_reward - baseline_avg_reward,
664
+ "reward_components_jsonl": str(artifacts_early / "reward_components.jsonl"),
665
+ "trainer_on_log_jsonl": str(artifacts_early / "trainer_on_log.jsonl"),
666
+ "tensorboard_dir": str(tb_dir) if report_to == "tensorboard" else None,
667
+ "report_to": report_to,
668
  }
669
  )
670
  metrics_path.write_text(json.dumps(metrics, indent=2), encoding="utf-8")
 
682
  out_path = artifacts.root / "before_after_avg_reward.png"
683
  plt.tight_layout()
684
  plt.savefig(out_path, dpi=200)
685
+ print(f"Saved {out_path}")
686
  except Exception as e:
687
+ print(f"Could not generate before/after plot: {e}")
688
+
689
+ lora_dir = os.environ.get("LORA_SAVE_DIR", "./sota_sql_agent_unsloth")
690
+ print("\nSaving LoRA weights locally...")
691
+ model.save_pretrained(lora_dir)
692
+
693
+ hub_id = os.environ.get("HF_HUB_REPO_ID", "md896/sota-sql-agent-7b")
694
+ token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
695
+ if os.environ.get("SKIP_HUB_PUSH", "").strip() in ("1", "true", "yes"):
696
+ print("SKIP_HUB_PUSH set not pushing to Hub.")
697
+ else:
698
+ try:
699
+ model.push_to_hub(hub_id, token=token)
700
+ print(f"Pushed LoRA to https://huggingface.co/{hub_id}")
701
+ except Exception as e:
702
+ print(f"Hub push failed (set HF_TOKEN / HF_HUB_REPO_ID or SKIP_HUB_PUSH=1): {e}")
703
+
704
+ print(f"\nTraining artifacts under {artifacts.root}")
705
 
706
  if __name__ == "__main__":
707
  run_sota_train()