md896 commited on
Commit
ee30276
Β·
1 Parent(s): ceee0e3

Fix HF Job bootstrap: transformers>=4.51 for trl 0.18, datasets<4; simplify to colab-style OpenEnv SQL reward.

Browse files
Files changed (1) hide show
  1. ultimate_sota_training.py +39 -201
ultimate_sota_training.py CHANGED
@@ -10,13 +10,10 @@ Environment (control cost vs quality on HF Jobs / local GPU):
10
  OPENENV_TASK_IDS β€” Comma list; if unset, uses GET /tasks from the server
11
  ROWS_PER_TASK β€” GRPO rows per task_id (default: 48)
12
  OPENENV_REQUEST_TIMEOUT_SEC β€” HTTP timeout for reset/step (default: 120)
13
- REASONING_XML_TAG β€” XML tag name for chain-of-thought (default: think)
14
- TRAIN_MAX_STEPS β€” GRPO optimizer steps (default: 200; was 30 for smoke)
15
- TRAIN_NUM_EPOCHS, TRAIN_LR, GRPO_NUM_GENERATIONS, GRPO_MAX_COMPLETION_LEN
16
- PER_DEVICE_TRAIN_BS, GRAD_ACCUM
17
- TRL_REPORT_TO β€” none | wandb | tensorboard (auto: wandb if key else tensorboard)
18
- BOOTSTRAP_*_VERSION β€” pin transformers / accelerate / trl for HF Jobs (see bootstrap_deps)
19
- Artifacts: artifacts/reward_components.jsonl, artifacts/trainer_on_log.jsonl, tensorboard/
20
  HF_HUB_REPO_ID β€” push target (default md896/sota-sql-agent-7b)
21
  SKIP_HUB_PUSH=1 β€” do not push after train
22
  HF_TOKEN / HUGGING_FACE_HUB_TOKEN β€” Hub auth for push
@@ -33,12 +30,9 @@ Key stability choices:
33
 
34
  from __future__ import annotations
35
 
36
- import contextvars
37
  import json
38
- import math
39
  import os
40
  import random
41
- import re
42
  import subprocess
43
  import sys
44
  import time
@@ -47,10 +41,6 @@ from dataclasses import dataclass
47
  from pathlib import Path
48
  from typing import Any, Dict, List, Optional
49
 
50
- # Set by TrainerCallback so reward funcs can tag JSONL rows with the real global_step.
51
- CURRENT_GRPO_STEP: contextvars.ContextVar[int] = contextvars.ContextVar("CURRENT_GRPO_STEP", default=-1)
52
-
53
-
54
  def _run(cmd: List[str], *, check: bool = True) -> subprocess.CompletedProcess:
55
  return subprocess.run(cmd, check=check)
56
 
@@ -82,34 +72,29 @@ def bootstrap_deps() -> None:
82
  # Text-only run: torchvision/torchaudio are not required and are a common source
83
  # of crashes when torch versions shift in container images.
84
  _pip(["uninstall", "--break-system-packages", "-y", "torchvision", "torchaudio"], check=False)
85
-
86
  _pip(["uninstall", "-y", "torchao"], check=False)
87
 
 
 
 
 
88
  _pip(
89
  [
90
  "install",
91
  "--break-system-packages",
92
  "httpx>=0.27.0",
93
- "datasets>=3.4.1,<4.4.0",
94
  "matplotlib",
95
  "tensorboard",
96
- "wandb",
97
- ]
98
- )
99
-
100
- _tf = os.environ.get("BOOTSTRAP_TRANSFORMERS_VERSION", "4.48.3")
101
- _acc = os.environ.get("BOOTSTRAP_ACCELERATE_VERSION", "0.34.2")
102
- _trl = os.environ.get("BOOTSTRAP_TRL_VERSION", "0.18.2")
103
- _pip(
104
- [
105
- "install",
106
- "--break-system-packages",
107
  f"transformers=={_tf}",
108
  f"accelerate=={_acc}",
109
  f"trl=={_trl}",
110
  ]
111
  )
112
 
 
 
 
113
  _pip(
114
  [
115
  "install",
@@ -180,7 +165,6 @@ import transformers.utils.hub
180
  if not hasattr(transformers.utils.hub, "TRANSFORMERS_CACHE"):
181
  transformers.utils.hub.TRANSFORMERS_CACHE = "/tmp"
182
 
183
- from transformers import TrainerCallback
184
  from trl import GRPOConfig, GRPOTrainer
185
  from unsloth import FastLanguageModel
186
 
@@ -199,24 +183,6 @@ def get_request_timeout() -> float:
199
  return float(os.environ.get("OPENENV_REQUEST_TIMEOUT_SEC", "120"))
200
 
201
 
202
- def build_system_prompt() -> str:
203
- """Single prompt template for every task (easy β†’ expert); tag name is configurable."""
204
- tag = os.environ.get("REASONING_XML_TAG", "think")
205
- return f"""You are an elite SQL engineer. You fix broken SQLite analytics queries using the task description and the broken query.
206
- You MUST output your reasoning process inside <{tag}> tags.
207
- After you have finished thinking, you MUST output the exact fixed SQL query inside <sql> tags.
208
- Do not output any markdown blocks like ```sql.
209
-
210
- Example:
211
- <{tag}>
212
- I will check joins, filters, and aggregation, then write a corrected SELECT or WITH query.
213
- </{tag}>
214
- <sql>
215
- WITH OrderTotals AS (SELECT order_id, SUM(amount) AS total FROM line_items GROUP BY order_id)
216
- SELECT o.id, ot.total FROM orders o JOIN OrderTotals ot ON o.id = ot.order_id;
217
- </sql>"""
218
-
219
-
220
  def _fetch_task_ids(client: httpx.Client) -> List[str]:
221
  raw = os.environ.get("OPENENV_TASK_IDS", "").strip()
222
  if raw:
@@ -232,10 +198,11 @@ def _fetch_task_ids(client: httpx.Client) -> List[str]:
232
 
233
 
234
  def make_real_dataset() -> Dataset:
 
235
  bridge = get_bridge_url()
236
  timeout = get_request_timeout()
237
  rows_per_task = max(1, int(os.environ.get("ROWS_PER_TASK", "48")))
238
- system = build_system_prompt()
239
 
240
  print(f"Connecting to OpenEnv at {bridge} (timeout={timeout}s)...")
241
  rows: List[Dict[str, Any]] = []
@@ -254,10 +221,10 @@ def make_real_dataset() -> Dataset:
254
  obs = resp.json()["observation"]
255
 
256
  prompt = (
257
- f"{system}\n\n"
258
  f"Task: {obs['task_description']}\n"
259
- f"Broken Query: {obs['original_query']}\n\n"
260
- f"Provide your <{os.environ.get('REASONING_XML_TAG', 'think')}> and <sql> output:"
261
  )
262
  for _ in range(rows_per_task):
263
  rows.append({"prompt": prompt, "task_id": t_id})
@@ -267,147 +234,42 @@ def make_real_dataset() -> Dataset:
267
  print(f"Dataset: {len(rows)} prompts ({rows_per_task} per task).")
268
  return Dataset.from_list(rows)
269
 
270
- # --- 3. MULTI-REWARD SHAPING + JSONL logging (per-component batch stats) ---
271
-
272
- _REWARD_COMPONENTS_JSONL: Optional[Path] = None
273
-
274
-
275
- def extract_xml_tag(text, tag):
276
- pattern = f"<{tag}>(.*?)</{tag}>"
277
- match = re.search(pattern, text, re.DOTALL)
278
- return match.group(1).strip() if match else None
279
 
 
280
 
281
- def _reward_batch_stats(values: List[float]) -> Dict[str, float]:
282
- if not values:
283
- return {"mean": 0.0, "std": 0.0, "min": 0.0, "max": 0.0}
284
- n = len(values)
285
- mean = sum(values) / n
286
- var = sum((x - mean) ** 2 for x in values) / max(n - 1, 1)
287
- return {"mean": mean, "std": math.sqrt(var), "min": min(values), "max": max(values)}
288
 
289
-
290
- def _append_jsonl(path: Path, row: Dict[str, Any]) -> None:
291
- path.parent.mkdir(parents=True, exist_ok=True)
292
- with path.open("a", encoding="utf-8") as f:
293
- f.write(json.dumps(row, ensure_ascii=False, default=str) + "\n")
294
-
295
-
296
- def _log_reward_component(name: str, values: List[float]) -> None:
297
- if _REWARD_COMPONENTS_JSONL is None:
298
- return
299
- _append_jsonl(
300
- _REWARD_COMPONENTS_JSONL,
301
- {
302
- "time_epoch_s": time.time(),
303
- "global_step": CURRENT_GRPO_STEP.get(),
304
- "reward_component": name,
305
- "n": len(values),
306
- **_reward_batch_stats(values),
307
- },
308
- )
309
-
310
-
311
- def format_reward_func(completions, **kwargs):
312
- """Reward 1: CoT + sql XML tags (+0.1). Tag name follows REASONING_XML_TAG."""
313
- tag = os.environ.get("REASONING_XML_TAG", "think")
314
- rewards = []
315
- for comp in completions:
316
- has_think = extract_xml_tag(comp, tag) is not None
317
- has_sql = extract_xml_tag(comp, "sql") is not None
318
- rewards.append(0.1 if (has_think and has_sql) else 0.0)
319
- _log_reward_component("format_xml", rewards)
320
- return rewards
321
-
322
-
323
- def syntax_reward_func(completions, **kwargs):
324
- """Reward 2: Does the SQL look like valid code? (+0.2)"""
325
- rewards = []
326
- for comp in completions:
327
- sql = extract_xml_tag(comp, "sql")
328
- if sql and (sql.upper().startswith("SELECT") or sql.upper().startswith("WITH")):
329
- rewards.append(0.2)
330
- else:
331
- rewards.append(0.0)
332
- _log_reward_component("syntax_select_with", rewards)
333
- return rewards
334
-
335
-
336
- def execution_reward_func(completions, task_id, **kwargs):
337
- """Reward 3: live OpenEnv submit_query against the real Space/API (not a stub)."""
338
- rewards: List[float] = []
339
  base = get_bridge_url()
340
  timeout = get_request_timeout()
 
 
341
  with httpx.Client(base_url=base, headers=BYPASS_HEADERS, timeout=timeout) as client:
342
- for query, t_id in zip(completions, task_id):
343
- sql = extract_xml_tag(query, "sql")
 
 
 
344
  if not sql:
345
  rewards.append(0.0)
346
  continue
347
-
348
- session_headers = {"X-Session-Id": str(uuid.uuid4())}
349
  try:
350
- r0 = client.post("/reset", json={"task_id": t_id}, headers=session_headers)
351
- r0.raise_for_status()
352
  resp = client.post(
353
  "/step",
354
  json={"action": {"action_type": "submit_query", "query": sql}},
355
- headers=session_headers,
356
  )
357
  resp.raise_for_status()
358
- reward = float(resp.json().get("reward", 0.0))
359
  except Exception:
360
- reward = 0.0
361
-
362
- reward += random.uniform(-1e-6, 1e-6)
363
- rewards.append(reward)
364
- _log_reward_component("openenv_execution", rewards)
365
- return rewards
366
-
367
-
368
- def length_shape_reward_func(completions, **kwargs):
369
- """Reward 4: soft preference for shorter completions (bounded; does not replace execution reward)."""
370
- cap = float(os.environ.get("COMPLETION_SOFT_CHAR_CAP", "3500"))
371
- bonus_max = float(os.environ.get("LENGTH_BONUS_MAX", "0.05"))
372
- rewards: List[float] = []
373
- for comp in completions:
374
- L = len(comp) if comp else 0
375
- if L <= 0:
376
- rewards.append(0.0)
377
- else:
378
- rewards.append(bonus_max * max(0.0, 1.0 - min(L, cap) / cap))
379
- _log_reward_component("length_shape", rewards)
380
  return rewards
381
 
382
 
383
- class GrpoStepContextCallback(TrainerCallback):
384
- """Expose true global_step to reward funcs for JSONL alignment."""
385
-
386
- def on_step_begin(self, args, state, control, **kwargs):
387
- CURRENT_GRPO_STEP.set(int(state.global_step))
388
-
389
-
390
- class JsonlOnLogCallback(TrainerCallback):
391
- """Mirror every trainer `logs` dict to JSONL (loss, learning_rate, reward keys, etc.)."""
392
-
393
- def __init__(self, path: Path):
394
- self.path = path
395
- self.path.parent.mkdir(parents=True, exist_ok=True)
396
- self._fp = path.open("w", encoding="utf-8")
397
-
398
- def on_log(self, args, state, control, logs=None, **kwargs):
399
- if not logs:
400
- return
401
- row: Dict[str, Any] = {"global_step": int(state.global_step), **dict(logs)}
402
- self._fp.write(json.dumps(row, ensure_ascii=False, default=str) + "\n")
403
- self._fp.flush()
404
-
405
- def on_train_end(self, args, state, control, **kwargs):
406
- try:
407
- self._fp.close()
408
- except Exception:
409
- pass
410
-
411
  # --- 3b. ARTIFACTS / PLOTS (REAL, FROM LOGS) ---
412
 
413
  @dataclass(frozen=True)
@@ -520,9 +382,7 @@ def plot_reward_curve(reward_series: List[tuple[float, float]], paths: ArtifactP
520
  def _resolve_report_to() -> str:
521
  raw = os.environ.get("TRL_REPORT_TO", "").strip().lower()
522
  if raw in ("", "auto"):
523
- if os.environ.get("WANDB_API_KEY"):
524
- return "wandb"
525
- return "tensorboard"
526
  if raw in ("false", "no", "off", "none"):
527
  return "none"
528
  return raw
@@ -530,14 +390,8 @@ def _resolve_report_to() -> str:
530
 
531
  # --- 4. Unsloth GRPO training loop (live OpenEnv rewards) ---
532
  def run_sota_train():
533
- global _REWARD_COMPONENTS_JSONL
534
-
535
  max_steps = int(os.environ.get("TRAIN_MAX_STEPS", "200"))
536
  out_dir = os.environ.get("OUTPUT_DIR", "./sota_results")
537
- artifacts_early = Path(out_dir) / "artifacts"
538
- _ensure_dir(artifacts_early)
539
- _REWARD_COMPONENTS_JSONL = artifacts_early / "reward_components.jsonl"
540
- _REWARD_COMPONENTS_JSONL.write_text("", encoding="utf-8")
541
 
542
  print(f"Starting Unsloth GRPO on {MODEL_NAME}...")
543
  print(
@@ -566,9 +420,7 @@ def run_sota_train():
566
  train_dataset = make_real_dataset()
567
 
568
  def quick_exec_eval(max_items: int = 8) -> float:
569
- """
570
- Quick before/after check: sample prompts, generate CoT + sql, score via live OpenEnv.
571
- """
572
  subset = train_dataset.select(range(min(max_items, len(train_dataset))))
573
  prompts = subset["prompt"]
574
  task_ids = subset["task_id"]
@@ -586,7 +438,7 @@ def run_sota_train():
586
  )
587
  completions.append(tokenizer.decode(out[0], skip_special_tokens=True))
588
 
589
- rewards = execution_reward_func(completions, task_ids)
590
  return float(sum(rewards) / max(len(rewards), 1))
591
 
592
  print("Quick baseline eval (pre-train)...")
@@ -603,7 +455,7 @@ def run_sota_train():
603
  per_device_train_batch_size=int(os.environ.get("PER_DEVICE_TRAIN_BS", "1")),
604
  gradient_accumulation_steps=int(os.environ.get("GRAD_ACCUM", "2")),
605
  num_generations=int(os.environ.get("GRPO_NUM_GENERATIONS", "8")),
606
- max_completion_length=int(os.environ.get("GRPO_MAX_COMPLETION_LEN", "512")),
607
  temperature=float(os.environ.get("GRPO_TEMPERATURE", "0.9")),
608
  num_train_epochs=int(os.environ.get("TRAIN_NUM_EPOCHS", "1")),
609
  max_steps=max_steps,
@@ -615,24 +467,12 @@ def run_sota_train():
615
  _cfg["logging_dir"] = str(tb_dir)
616
  training_args = GRPOConfig(**_cfg)
617
 
618
- trainer_logs_path = artifacts_early / "trainer_on_log.jsonl"
619
- trainer_logs_path.write_text("", encoding="utf-8")
620
-
621
  trainer = GRPOTrainer(
622
  model=model,
623
- reward_funcs=[
624
- format_reward_func,
625
- syntax_reward_func,
626
- execution_reward_func,
627
- length_shape_reward_func,
628
- ],
629
  args=training_args,
630
  train_dataset=train_dataset,
631
  processing_class=tokenizer,
632
- callbacks=[
633
- GrpoStepContextCallback(),
634
- JsonlOnLogCallback(trainer_logs_path),
635
- ],
636
  )
637
 
638
  print("Training with live execution rewards against OpenEnv...")
@@ -661,8 +501,6 @@ def run_sota_train():
661
  "baseline_avg_reward": baseline_avg_reward,
662
  "post_avg_reward": post_avg_reward,
663
  "delta_avg_reward": post_avg_reward - baseline_avg_reward,
664
- "reward_components_jsonl": str(artifacts_early / "reward_components.jsonl"),
665
- "trainer_on_log_jsonl": str(artifacts_early / "trainer_on_log.jsonl"),
666
  "tensorboard_dir": str(tb_dir) if report_to == "tensorboard" else None,
667
  "report_to": report_to,
668
  }
 
10
  OPENENV_TASK_IDS β€” Comma list; if unset, uses GET /tasks from the server
11
  ROWS_PER_TASK β€” GRPO rows per task_id (default: 48)
12
  OPENENV_REQUEST_TIMEOUT_SEC β€” HTTP timeout for reset/step (default: 120)
13
+ TRAIN_MAX_STEPS β€” GRPO steps (default 200)
14
+ TRL_REPORT_TO β€” none | wandb | tensorboard (auto: wandb if key else none)
15
+ BOOTSTRAP_*_VERSION β€” pin transformers / accelerate / trl (defaults satisfy trl>=4.50)
16
+ Artifacts: artifacts/train_log_history.jsonl, metrics, plots
 
 
 
17
  HF_HUB_REPO_ID β€” push target (default md896/sota-sql-agent-7b)
18
  SKIP_HUB_PUSH=1 β€” do not push after train
19
  HF_TOKEN / HUGGING_FACE_HUB_TOKEN β€” Hub auth for push
 
30
 
31
  from __future__ import annotations
32
 
 
33
  import json
 
34
  import os
35
  import random
 
36
  import subprocess
37
  import sys
38
  import time
 
41
  from pathlib import Path
42
  from typing import Any, Dict, List, Optional
43
 
 
 
 
 
44
  def _run(cmd: List[str], *, check: bool = True) -> subprocess.CompletedProcess:
45
  return subprocess.run(cmd, check=check)
46
 
 
72
  # Text-only run: torchvision/torchaudio are not required and are a common source
73
  # of crashes when torch versions shift in container images.
74
  _pip(["uninstall", "--break-system-packages", "-y", "torchvision", "torchaudio"], check=False)
 
75
  _pip(["uninstall", "-y", "torchao"], check=False)
76
 
77
+ # trl 0.18.x needs transformers>=4.50. datasets 4.x pulls huggingface-hub 1.x which breaks 4.5x.
78
+ _tf = os.environ.get("BOOTSTRAP_TRANSFORMERS_VERSION", "4.51.3")
79
+ _acc = os.environ.get("BOOTSTRAP_ACCELERATE_VERSION", "0.34.2")
80
+ _trl = os.environ.get("BOOTSTRAP_TRL_VERSION", "0.18.2")
81
  _pip(
82
  [
83
  "install",
84
  "--break-system-packages",
85
  "httpx>=0.27.0",
86
+ "datasets>=3.2.0,<4.0.0",
87
  "matplotlib",
88
  "tensorboard",
 
 
 
 
 
 
 
 
 
 
 
89
  f"transformers=={_tf}",
90
  f"accelerate=={_acc}",
91
  f"trl=={_trl}",
92
  ]
93
  )
94
 
95
+ if os.environ.get("WANDB_API_KEY"):
96
+ _pip(["install", "--break-system-packages", "wandb"], check=False)
97
+
98
  _pip(
99
  [
100
  "install",
 
165
  if not hasattr(transformers.utils.hub, "TRANSFORMERS_CACHE"):
166
  transformers.utils.hub.TRANSFORMERS_CACHE = "/tmp"
167
 
 
168
  from trl import GRPOConfig, GRPOTrainer
169
  from unsloth import FastLanguageModel
170
 
 
183
  return float(os.environ.get("OPENENV_REQUEST_TIMEOUT_SEC", "120"))
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  def _fetch_task_ids(client: httpx.Client) -> List[str]:
187
  raw = os.environ.get("OPENENV_TASK_IDS", "").strip()
188
  if raw:
 
198
 
199
 
200
  def make_real_dataset() -> Dataset:
201
+ """Plain prompts + live /tasks (same spirit as colab_real_world.py, HF Space instead of loca.lt)."""
202
  bridge = get_bridge_url()
203
  timeout = get_request_timeout()
204
  rows_per_task = max(1, int(os.environ.get("ROWS_PER_TASK", "48")))
205
+ marker = os.environ.get("COMPLETION_SQL_MARKER", "Fixed SQL:")
206
 
207
  print(f"Connecting to OpenEnv at {bridge} (timeout={timeout}s)...")
208
  rows: List[Dict[str, Any]] = []
 
221
  obs = resp.json()["observation"]
222
 
223
  prompt = (
224
+ "Fix the following SQL query and provide only the fixed SQL.\n"
225
  f"Task: {obs['task_description']}\n"
226
+ f"Broken Query: {obs['original_query']}\n"
227
+ f"{marker}"
228
  )
229
  for _ in range(rows_per_task):
230
  rows.append({"prompt": prompt, "task_id": t_id})
 
234
  print(f"Dataset: {len(rows)} prompts ({rows_per_task} per task).")
235
  return Dataset.from_list(rows)
236
 
 
 
 
 
 
 
 
 
 
237
 
238
+ # --- 3. One live OpenEnv reward (colab_real_world style) ---
239
 
 
 
 
 
 
 
 
240
 
241
+ def openenv_sql_reward_func(completions, task_id, **kwargs):
242
+ """Score completions by executing extracted SQL against the real OpenEnv HTTP API."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  base = get_bridge_url()
244
  timeout = get_request_timeout()
245
+ marker = os.environ.get("COMPLETION_SQL_MARKER", "Fixed SQL:")
246
+ rewards: List[float] = []
247
  with httpx.Client(base_url=base, headers=BYPASS_HEADERS, timeout=timeout) as client:
248
+ for completion, t_id in zip(completions, task_id):
249
+ if marker in completion:
250
+ sql = completion.split(marker, 1)[-1].strip()
251
+ else:
252
+ sql = completion.strip()
253
  if not sql:
254
  rewards.append(0.0)
255
  continue
256
+ hdr = {"X-Session-Id": str(uuid.uuid4())}
 
257
  try:
258
+ client.post("/reset", json={"task_id": t_id}, headers=hdr).raise_for_status()
 
259
  resp = client.post(
260
  "/step",
261
  json={"action": {"action_type": "submit_query", "query": sql}},
262
+ headers=hdr,
263
  )
264
  resp.raise_for_status()
265
+ r = float(resp.json().get("reward", 0.0))
266
  except Exception:
267
+ r = 0.0
268
+ r += random.uniform(-1e-6, 1e-6)
269
+ rewards.append(r)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  return rewards
271
 
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  # --- 3b. ARTIFACTS / PLOTS (REAL, FROM LOGS) ---
274
 
275
  @dataclass(frozen=True)
 
382
  def _resolve_report_to() -> str:
383
  raw = os.environ.get("TRL_REPORT_TO", "").strip().lower()
384
  if raw in ("", "auto"):
385
+ return "wandb" if os.environ.get("WANDB_API_KEY") else "none"
 
 
386
  if raw in ("false", "no", "off", "none"):
387
  return "none"
388
  return raw
 
390
 
391
  # --- 4. Unsloth GRPO training loop (live OpenEnv rewards) ---
392
  def run_sota_train():
 
 
393
  max_steps = int(os.environ.get("TRAIN_MAX_STEPS", "200"))
394
  out_dir = os.environ.get("OUTPUT_DIR", "./sota_results")
 
 
 
 
395
 
396
  print(f"Starting Unsloth GRPO on {MODEL_NAME}...")
397
  print(
 
420
  train_dataset = make_real_dataset()
421
 
422
  def quick_exec_eval(max_items: int = 8) -> float:
423
+ """Sample prompts, generate completions, score with the same OpenEnv SQL reward."""
 
 
424
  subset = train_dataset.select(range(min(max_items, len(train_dataset))))
425
  prompts = subset["prompt"]
426
  task_ids = subset["task_id"]
 
438
  )
439
  completions.append(tokenizer.decode(out[0], skip_special_tokens=True))
440
 
441
+ rewards = openenv_sql_reward_func(completions, task_ids)
442
  return float(sum(rewards) / max(len(rewards), 1))
443
 
444
  print("Quick baseline eval (pre-train)...")
 
455
  per_device_train_batch_size=int(os.environ.get("PER_DEVICE_TRAIN_BS", "1")),
456
  gradient_accumulation_steps=int(os.environ.get("GRAD_ACCUM", "2")),
457
  num_generations=int(os.environ.get("GRPO_NUM_GENERATIONS", "8")),
458
+ max_completion_length=int(os.environ.get("GRPO_MAX_COMPLETION_LEN", "256")),
459
  temperature=float(os.environ.get("GRPO_TEMPERATURE", "0.9")),
460
  num_train_epochs=int(os.environ.get("TRAIN_NUM_EPOCHS", "1")),
461
  max_steps=max_steps,
 
467
  _cfg["logging_dir"] = str(tb_dir)
468
  training_args = GRPOConfig(**_cfg)
469
 
 
 
 
470
  trainer = GRPOTrainer(
471
  model=model,
472
+ reward_funcs=[openenv_sql_reward_func],
 
 
 
 
 
473
  args=training_args,
474
  train_dataset=train_dataset,
475
  processing_class=tokenizer,
 
 
 
 
476
  )
477
 
478
  print("Training with live execution rewards against OpenEnv...")
 
501
  "baseline_avg_reward": baseline_avg_reward,
502
  "post_avg_reward": post_avg_reward,
503
  "delta_avg_reward": post_avg_reward - baseline_avg_reward,
 
 
504
  "tensorboard_dir": str(tb_dir) if report_to == "tensorboard" else None,
505
  "report_to": report_to,
506
  }