Spaces:
Running
Running
add logs and fix train error
Browse files- scripts/test_space_api.py +4 -0
- scripts/test_trace_logging.py +102 -0
- server/runtime.py +50 -1
- test.py +2 -0
- training/trace_logging.py +175 -0
- training/train_grpo.py +193 -66
scripts/test_space_api.py
CHANGED
|
@@ -32,6 +32,10 @@ def main() -> None:
|
|
| 32 |
assert "completed_steps" in train_status.json()
|
| 33 |
assert "remaining_steps" in train_status.json()
|
| 34 |
assert "phase" in train_status.json()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
reset = client.post("/reset", json={"difficulty": "easy", "problem_id": "sum_even_numbers"})
|
| 37 |
assert reset.status_code == 200
|
|
|
|
| 32 |
assert "completed_steps" in train_status.json()
|
| 33 |
assert "remaining_steps" in train_status.json()
|
| 34 |
assert "phase" in train_status.json()
|
| 35 |
+
assert "run_manifest_path" in train_status.json()
|
| 36 |
+
assert "events_path" in train_status.json()
|
| 37 |
+
assert "latest_checkpoint_path" in train_status.json()
|
| 38 |
+
assert "logs_deleted_from_space" in train_status.json()
|
| 39 |
|
| 40 |
reset = client.post("/reset", json={"difficulty": "easy", "problem_id": "sum_even_numbers"})
|
| 41 |
assert reset.status_code == 200
|
scripts/test_trace_logging.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from tempfile import TemporaryDirectory
|
| 7 |
+
|
| 8 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 9 |
+
if str(ROOT) not in sys.path:
|
| 10 |
+
sys.path.insert(0, str(ROOT))
|
| 11 |
+
|
| 12 |
+
from training.trace_logging import TraceArtifactLogger
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def main() -> None:
|
| 16 |
+
with TemporaryDirectory() as tmpdir:
|
| 17 |
+
output_dir = Path(tmpdir)
|
| 18 |
+
logger = TraceArtifactLogger(
|
| 19 |
+
run_id="run-123",
|
| 20 |
+
output_dir=output_dir,
|
| 21 |
+
training_config={"max_steps": 6, "model_name": "demo-model"},
|
| 22 |
+
model_identifiers={"model_name": "demo-model", "generator_mode": "reward_aware"},
|
| 23 |
+
system_prompt="You are the Solver Agent.",
|
| 24 |
+
checkpoint_interval_steps=2,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
manifest_path = output_dir / "logs" / "run_manifest.json"
|
| 28 |
+
assert manifest_path.exists()
|
| 29 |
+
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
|
| 30 |
+
assert manifest["run_id"] == "run-123"
|
| 31 |
+
assert manifest["training_config"]["max_steps"] == 6
|
| 32 |
+
|
| 33 |
+
logger.log_event(
|
| 34 |
+
{
|
| 35 |
+
"phase": "train",
|
| 36 |
+
"step": 0,
|
| 37 |
+
"train_episode_index": 1,
|
| 38 |
+
"problem_id": "sum_even_numbers_1",
|
| 39 |
+
"problem_family": "sum_even_numbers",
|
| 40 |
+
"difficulty": "easy",
|
| 41 |
+
"teacher_prompt": "Problem: sum the even numbers",
|
| 42 |
+
"solver_completion": "print(sum(x for x in nums if x % 2 == 0))",
|
| 43 |
+
"extracted_code": "print(sum(x for x in nums if x % 2 == 0))",
|
| 44 |
+
"reward": 0.94,
|
| 45 |
+
"pass_rate": 1.0,
|
| 46 |
+
"visible_pass_rate": 1.0,
|
| 47 |
+
"execution_status": "completed",
|
| 48 |
+
"efficiency_score": 0.94,
|
| 49 |
+
"optimization_hints": ["Avoid materializing temporary containers."],
|
| 50 |
+
"feedback": "All hidden tests passed, but the solution can still be optimized further.",
|
| 51 |
+
}
|
| 52 |
+
)
|
| 53 |
+
logger.record_progress(
|
| 54 |
+
{
|
| 55 |
+
"phase": "train",
|
| 56 |
+
"completed_steps": 2,
|
| 57 |
+
"total_steps": 6,
|
| 58 |
+
"remaining_steps": 4,
|
| 59 |
+
"progress_ratio": 0.3333,
|
| 60 |
+
"current_epoch": 2.0,
|
| 61 |
+
"current_difficulty": "easy",
|
| 62 |
+
"curriculum_level": 1,
|
| 63 |
+
"train_episode_index": 1,
|
| 64 |
+
"last_problem_id": "sum_even_numbers_1",
|
| 65 |
+
"last_problem_family": "sum_even_numbers",
|
| 66 |
+
"last_execution_status": "completed",
|
| 67 |
+
}
|
| 68 |
+
)
|
| 69 |
+
artifact_paths = logger.artifact_paths()
|
| 70 |
+
|
| 71 |
+
events_path = Path(artifact_paths["events_path"])
|
| 72 |
+
latest_checkpoint_path = Path(artifact_paths["latest_checkpoint_path"])
|
| 73 |
+
assert events_path.exists()
|
| 74 |
+
assert latest_checkpoint_path.exists()
|
| 75 |
+
|
| 76 |
+
event_line = events_path.read_text(encoding="utf-8").strip().splitlines()[0]
|
| 77 |
+
event = json.loads(event_line)
|
| 78 |
+
assert event["problem_id"] == "sum_even_numbers_1"
|
| 79 |
+
assert event["teacher_prompt"] == "Problem: sum the even numbers"
|
| 80 |
+
assert "training_config" not in event
|
| 81 |
+
|
| 82 |
+
checkpoint = json.loads(latest_checkpoint_path.read_text(encoding="utf-8"))
|
| 83 |
+
assert checkpoint["step"] == 2
|
| 84 |
+
assert checkpoint["rolling_metrics"]["avg_reward"] == 0.94
|
| 85 |
+
assert "training_config" not in checkpoint
|
| 86 |
+
|
| 87 |
+
reward_curve = output_dir / "reward_curve.csv"
|
| 88 |
+
reward_curve.write_text("step,episode_reward\n0,0.94\n", encoding="utf-8")
|
| 89 |
+
summary_paths = logger.finalize(
|
| 90 |
+
reward_curve_csv=reward_curve,
|
| 91 |
+
final_metrics={"completed_steps": 6},
|
| 92 |
+
)
|
| 93 |
+
summary_path = Path(summary_paths)
|
| 94 |
+
assert summary_path.exists()
|
| 95 |
+
summary = json.loads(summary_path.read_text(encoding="utf-8"))
|
| 96 |
+
assert summary["final_metrics"]["completed_steps"] == 6
|
| 97 |
+
|
| 98 |
+
print("Trace logging smoke tests passed")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
main()
|
server/runtime.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
import json
|
| 4 |
import os
|
|
|
|
| 5 |
import threading
|
| 6 |
import traceback
|
| 7 |
from dataclasses import asdict, dataclass, field
|
|
@@ -69,6 +70,13 @@ class TrainingJobState:
|
|
| 69 |
reward_curve_csv: str | None = None
|
| 70 |
model_repo_id: str | None = None
|
| 71 |
uploaded_revision: str | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
phase: str = "idle"
|
| 73 |
completed_steps: int = 0
|
| 74 |
total_steps: int = 0
|
|
@@ -430,6 +438,13 @@ class SpaceTrainingManager:
|
|
| 430 |
reward_curve_csv=payload.get("reward_curve_csv"),
|
| 431 |
model_repo_id=payload.get("model_repo_id"),
|
| 432 |
uploaded_revision=payload.get("uploaded_revision"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
phase=payload.get("phase", "idle"),
|
| 434 |
completed_steps=int(payload.get("completed_steps", 0) or 0),
|
| 435 |
total_steps=int(payload.get("total_steps", 0) or 0),
|
|
@@ -508,6 +523,7 @@ class SpaceTrainingManager:
|
|
| 508 |
else:
|
| 509 |
output_dir = self.output_root / requested_output_dir / run_id
|
| 510 |
config.output_dir = str(output_dir)
|
|
|
|
| 511 |
|
| 512 |
self._job = TrainingJobState(
|
| 513 |
status="running",
|
|
@@ -519,6 +535,12 @@ class SpaceTrainingManager:
|
|
| 519 |
reward_curve_csv=None,
|
| 520 |
model_repo_id=os.getenv("HF_MODEL_REPO_ID"),
|
| 521 |
uploaded_revision=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
phase="queued",
|
| 523 |
completed_steps=0,
|
| 524 |
total_steps=int(config.max_steps),
|
|
@@ -562,11 +584,22 @@ class SpaceTrainingManager:
|
|
| 562 |
)
|
| 563 |
return getattr(commit_info, "oid", None) or getattr(commit_info, "commit_hash", None) or "unknown"
|
| 564 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
def _run_training_job(self, run_id: str, config: TrainingConfig) -> None:
|
|
|
|
| 566 |
try:
|
| 567 |
-
summary = run_training(config, progress_callback=self._update_progress)
|
| 568 |
artifact_path = summary["output_dir"]
|
| 569 |
uploaded_revision = self._upload_artifacts(artifact_path, run_id)
|
|
|
|
| 570 |
self.model_registry.load_latest_from_hub()
|
| 571 |
|
| 572 |
with self._lock:
|
|
@@ -576,6 +609,13 @@ class SpaceTrainingManager:
|
|
| 576 |
self._job.reward_curve_csv = summary.get("reward_curve_csv")
|
| 577 |
self._job.model_repo_id = os.getenv("HF_MODEL_REPO_ID")
|
| 578 |
self._job.uploaded_revision = uploaded_revision
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
self._job.phase = "completed"
|
| 580 |
self._job.completed_steps = int(summary.get("completed_steps", config.max_steps))
|
| 581 |
self._job.total_steps = int(config.max_steps)
|
|
@@ -589,9 +629,18 @@ class SpaceTrainingManager:
|
|
| 589 |
self._job.traceback = None
|
| 590 |
self._persist_status()
|
| 591 |
except Exception as exc:
|
|
|
|
| 592 |
with self._lock:
|
| 593 |
self._job.status = "failed"
|
| 594 |
self._job.finished_at = _utc_now()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
self._job.error = str(exc)
|
| 596 |
self._job.traceback = traceback.format_exc()
|
| 597 |
self._persist_status()
|
|
|
|
| 2 |
|
| 3 |
import json
|
| 4 |
import os
|
| 5 |
+
import shutil
|
| 6 |
import threading
|
| 7 |
import traceback
|
| 8 |
from dataclasses import asdict, dataclass, field
|
|
|
|
| 70 |
reward_curve_csv: str | None = None
|
| 71 |
model_repo_id: str | None = None
|
| 72 |
uploaded_revision: str | None = None
|
| 73 |
+
logs_dir: str | None = None
|
| 74 |
+
run_manifest_path: str | None = None
|
| 75 |
+
events_path: str | None = None
|
| 76 |
+
latest_checkpoint_path: str | None = None
|
| 77 |
+
run_summary_path: str | None = None
|
| 78 |
+
checkpoint_paths: list[str] = field(default_factory=list)
|
| 79 |
+
logs_deleted_from_space: bool = False
|
| 80 |
phase: str = "idle"
|
| 81 |
completed_steps: int = 0
|
| 82 |
total_steps: int = 0
|
|
|
|
| 438 |
reward_curve_csv=payload.get("reward_curve_csv"),
|
| 439 |
model_repo_id=payload.get("model_repo_id"),
|
| 440 |
uploaded_revision=payload.get("uploaded_revision"),
|
| 441 |
+
logs_dir=payload.get("logs_dir"),
|
| 442 |
+
run_manifest_path=payload.get("run_manifest_path"),
|
| 443 |
+
events_path=payload.get("events_path"),
|
| 444 |
+
latest_checkpoint_path=payload.get("latest_checkpoint_path"),
|
| 445 |
+
run_summary_path=payload.get("run_summary_path"),
|
| 446 |
+
checkpoint_paths=payload.get("checkpoint_paths", []),
|
| 447 |
+
logs_deleted_from_space=bool(payload.get("logs_deleted_from_space", False)),
|
| 448 |
phase=payload.get("phase", "idle"),
|
| 449 |
completed_steps=int(payload.get("completed_steps", 0) or 0),
|
| 450 |
total_steps=int(payload.get("total_steps", 0) or 0),
|
|
|
|
| 523 |
else:
|
| 524 |
output_dir = self.output_root / requested_output_dir / run_id
|
| 525 |
config.output_dir = str(output_dir)
|
| 526 |
+
logs_dir = output_dir / "logs"
|
| 527 |
|
| 528 |
self._job = TrainingJobState(
|
| 529 |
status="running",
|
|
|
|
| 535 |
reward_curve_csv=None,
|
| 536 |
model_repo_id=os.getenv("HF_MODEL_REPO_ID"),
|
| 537 |
uploaded_revision=None,
|
| 538 |
+
logs_dir=str(logs_dir),
|
| 539 |
+
run_manifest_path=str(logs_dir / "run_manifest.json"),
|
| 540 |
+
events_path=str(logs_dir / "events.jsonl"),
|
| 541 |
+
latest_checkpoint_path=str(logs_dir / "latest_checkpoint.json"),
|
| 542 |
+
run_summary_path=str(logs_dir / "run_summary.json"),
|
| 543 |
+
checkpoint_paths=[],
|
| 544 |
phase="queued",
|
| 545 |
completed_steps=0,
|
| 546 |
total_steps=int(config.max_steps),
|
|
|
|
| 584 |
)
|
| 585 |
return getattr(commit_info, "oid", None) or getattr(commit_info, "commit_hash", None) or "unknown"
|
| 586 |
|
| 587 |
+
def _cleanup_local_logs(self, log_dir: str | None) -> bool:
|
| 588 |
+
if not log_dir:
|
| 589 |
+
return False
|
| 590 |
+
folder_path = Path(log_dir)
|
| 591 |
+
if not folder_path.exists():
|
| 592 |
+
return False
|
| 593 |
+
shutil.rmtree(folder_path, ignore_errors=True)
|
| 594 |
+
return not folder_path.exists()
|
| 595 |
+
|
| 596 |
def _run_training_job(self, run_id: str, config: TrainingConfig) -> None:
|
| 597 |
+
summary: dict[str, Any] | None = None
|
| 598 |
try:
|
| 599 |
+
summary = run_training(config, run_id=run_id, progress_callback=self._update_progress)
|
| 600 |
artifact_path = summary["output_dir"]
|
| 601 |
uploaded_revision = self._upload_artifacts(artifact_path, run_id)
|
| 602 |
+
logs_deleted = self._cleanup_local_logs(summary.get("logs_dir"))
|
| 603 |
self.model_registry.load_latest_from_hub()
|
| 604 |
|
| 605 |
with self._lock:
|
|
|
|
| 609 |
self._job.reward_curve_csv = summary.get("reward_curve_csv")
|
| 610 |
self._job.model_repo_id = os.getenv("HF_MODEL_REPO_ID")
|
| 611 |
self._job.uploaded_revision = uploaded_revision
|
| 612 |
+
self._job.logs_dir = None if logs_deleted else summary.get("logs_dir")
|
| 613 |
+
self._job.run_manifest_path = None if logs_deleted else summary.get("run_manifest_path")
|
| 614 |
+
self._job.events_path = None if logs_deleted else summary.get("events_path")
|
| 615 |
+
self._job.latest_checkpoint_path = None if logs_deleted else summary.get("latest_checkpoint_path")
|
| 616 |
+
self._job.run_summary_path = None if logs_deleted else summary.get("run_summary_path")
|
| 617 |
+
self._job.checkpoint_paths = [] if logs_deleted else summary.get("checkpoint_paths", [])
|
| 618 |
+
self._job.logs_deleted_from_space = logs_deleted
|
| 619 |
self._job.phase = "completed"
|
| 620 |
self._job.completed_steps = int(summary.get("completed_steps", config.max_steps))
|
| 621 |
self._job.total_steps = int(config.max_steps)
|
|
|
|
| 629 |
self._job.traceback = None
|
| 630 |
self._persist_status()
|
| 631 |
except Exception as exc:
|
| 632 |
+
logs_deleted = self._cleanup_local_logs(summary.get("logs_dir") if summary else self._job.logs_dir)
|
| 633 |
with self._lock:
|
| 634 |
self._job.status = "failed"
|
| 635 |
self._job.finished_at = _utc_now()
|
| 636 |
+
if logs_deleted:
|
| 637 |
+
self._job.logs_dir = None
|
| 638 |
+
self._job.run_manifest_path = None
|
| 639 |
+
self._job.events_path = None
|
| 640 |
+
self._job.latest_checkpoint_path = None
|
| 641 |
+
self._job.run_summary_path = None
|
| 642 |
+
self._job.checkpoint_paths = []
|
| 643 |
+
self._job.logs_deleted_from_space = logs_deleted
|
| 644 |
self._job.error = str(exc)
|
| 645 |
self._job.traceback = traceback.format_exc()
|
| 646 |
self._persist_status()
|
test.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
from scripts.test_env import main as run_env_smoke
|
| 4 |
from scripts.test_space_api import main as run_space_api_smoke
|
|
|
|
| 5 |
from scripts.test_verifier import test_cases
|
| 6 |
from verifier.verifier import verify
|
| 7 |
|
|
@@ -9,6 +10,7 @@ from verifier.verifier import verify
|
|
| 9 |
def main() -> None:
|
| 10 |
run_env_smoke()
|
| 11 |
run_space_api_smoke()
|
|
|
|
| 12 |
|
| 13 |
reward, info = verify(
|
| 14 |
"n=int(input())\nnums=list(map(int,input().split()))\nprint(sum(x for x in nums if x % 2 == 0))",
|
|
|
|
| 2 |
|
| 3 |
from scripts.test_env import main as run_env_smoke
|
| 4 |
from scripts.test_space_api import main as run_space_api_smoke
|
| 5 |
+
from scripts.test_trace_logging import main as run_trace_logging_smoke
|
| 6 |
from scripts.test_verifier import test_cases
|
| 7 |
from verifier.verifier import verify
|
| 8 |
|
|
|
|
| 10 |
def main() -> None:
|
| 11 |
run_env_smoke()
|
| 12 |
run_space_api_smoke()
|
| 13 |
+
run_trace_logging_smoke()
|
| 14 |
|
| 15 |
reward, info = verify(
|
| 16 |
"n=int(input())\nnums=list(map(int,input().split()))\nprint(sum(x for x in nums if x % 2 == 0))",
|
training/trace_logging.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import shutil
|
| 5 |
+
from collections import deque
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from datetime import datetime, timezone
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _utc_now_iso() -> str:
|
| 13 |
+
return datetime.now(timezone.utc).isoformat()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _json_safe(value: Any) -> Any:
|
| 17 |
+
if isinstance(value, Path):
|
| 18 |
+
return str(value)
|
| 19 |
+
if isinstance(value, dict):
|
| 20 |
+
return {str(key): _json_safe(item) for key, item in value.items()}
|
| 21 |
+
if isinstance(value, list):
|
| 22 |
+
return [_json_safe(item) for item in value]
|
| 23 |
+
return value
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class TraceArtifactLogger:
|
| 28 |
+
run_id: str
|
| 29 |
+
output_dir: Path
|
| 30 |
+
training_config: dict[str, Any]
|
| 31 |
+
model_identifiers: dict[str, Any]
|
| 32 |
+
system_prompt: str
|
| 33 |
+
checkpoint_interval_steps: int = 10
|
| 34 |
+
schema_version: str = "1.0"
|
| 35 |
+
logs_dir: Path = field(init=False)
|
| 36 |
+
manifest_path: Path = field(init=False)
|
| 37 |
+
events_path: Path = field(init=False)
|
| 38 |
+
latest_checkpoint_path: Path = field(init=False)
|
| 39 |
+
run_summary_path: Path = field(init=False)
|
| 40 |
+
checkpoint_paths: list[Path] = field(default_factory=list, init=False)
|
| 41 |
+
_last_checkpoint_step: int = field(default=0, init=False, repr=False)
|
| 42 |
+
_latest_event: dict[str, Any] = field(default_factory=dict, init=False, repr=False)
|
| 43 |
+
_recent_rewards: deque[float] = field(default_factory=lambda: deque(maxlen=25), init=False, repr=False)
|
| 44 |
+
_recent_pass_rates: deque[float] = field(default_factory=lambda: deque(maxlen=25), init=False, repr=False)
|
| 45 |
+
_recent_efficiency_scores: deque[float] = field(default_factory=lambda: deque(maxlen=25), init=False, repr=False)
|
| 46 |
+
_latest_progress: dict[str, Any] = field(default_factory=dict, init=False, repr=False)
|
| 47 |
+
|
| 48 |
+
def __post_init__(self) -> None:
|
| 49 |
+
self.logs_dir = self.output_dir / "logs"
|
| 50 |
+
self.logs_dir.mkdir(parents=True, exist_ok=True)
|
| 51 |
+
self.manifest_path = self.logs_dir / "run_manifest.json"
|
| 52 |
+
self.events_path = self.logs_dir / "events.jsonl"
|
| 53 |
+
self.latest_checkpoint_path = self.logs_dir / "latest_checkpoint.json"
|
| 54 |
+
self.run_summary_path = self.logs_dir / "run_summary.json"
|
| 55 |
+
manifest = {
|
| 56 |
+
"run_id": self.run_id,
|
| 57 |
+
"schema_version": self.schema_version,
|
| 58 |
+
"started_at": _utc_now_iso(),
|
| 59 |
+
"training_config": self.training_config,
|
| 60 |
+
"model_identifiers": self.model_identifiers,
|
| 61 |
+
"system_prompt": self.system_prompt,
|
| 62 |
+
"checkpoint_interval_steps": int(max(self.checkpoint_interval_steps, 1)),
|
| 63 |
+
}
|
| 64 |
+
self.manifest_path.write_text(json.dumps(_json_safe(manifest), indent=2), encoding="utf-8")
|
| 65 |
+
|
| 66 |
+
def artifact_paths(self) -> dict[str, Any]:
|
| 67 |
+
return {
|
| 68 |
+
"logs_dir": str(self.logs_dir),
|
| 69 |
+
"run_manifest_path": str(self.manifest_path),
|
| 70 |
+
"events_path": str(self.events_path),
|
| 71 |
+
"latest_checkpoint_path": str(self.latest_checkpoint_path),
|
| 72 |
+
"checkpoint_paths": [str(path) for path in self.checkpoint_paths],
|
| 73 |
+
"run_summary_path": str(self.run_summary_path),
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
def log_event(self, event: dict[str, Any]) -> None:
|
| 77 |
+
dynamic_event = {
|
| 78 |
+
"run_id": self.run_id,
|
| 79 |
+
"timestamp": _utc_now_iso(),
|
| 80 |
+
"phase": event.get("phase"),
|
| 81 |
+
"step": event.get("step"),
|
| 82 |
+
"train_episode_index": event.get("train_episode_index"),
|
| 83 |
+
"problem_id": event.get("problem_id"),
|
| 84 |
+
"problem_family": event.get("problem_family"),
|
| 85 |
+
"difficulty": event.get("difficulty"),
|
| 86 |
+
"teacher_prompt": event.get("teacher_prompt"),
|
| 87 |
+
"solver_completion": event.get("solver_completion"),
|
| 88 |
+
"extracted_code": event.get("extracted_code"),
|
| 89 |
+
"reward": event.get("reward"),
|
| 90 |
+
"pass_rate": event.get("pass_rate"),
|
| 91 |
+
"visible_pass_rate": event.get("visible_pass_rate"),
|
| 92 |
+
"execution_status": event.get("execution_status"),
|
| 93 |
+
"efficiency_score": event.get("efficiency_score"),
|
| 94 |
+
"optimization_hints": event.get("optimization_hints", []),
|
| 95 |
+
"feedback": event.get("feedback"),
|
| 96 |
+
}
|
| 97 |
+
with self.events_path.open("a", encoding="utf-8") as handle:
|
| 98 |
+
handle.write(json.dumps(_json_safe(dynamic_event)) + "\n")
|
| 99 |
+
|
| 100 |
+
self._latest_event = dynamic_event
|
| 101 |
+
if dynamic_event.get("reward") is not None:
|
| 102 |
+
self._recent_rewards.append(float(dynamic_event["reward"]))
|
| 103 |
+
if dynamic_event.get("pass_rate") is not None:
|
| 104 |
+
self._recent_pass_rates.append(float(dynamic_event["pass_rate"]))
|
| 105 |
+
if dynamic_event.get("efficiency_score") is not None:
|
| 106 |
+
self._recent_efficiency_scores.append(float(dynamic_event["efficiency_score"]))
|
| 107 |
+
|
| 108 |
+
def record_progress(self, progress: dict[str, Any]) -> None:
|
| 109 |
+
self._latest_progress.update({key: value for key, value in progress.items() if value is not None})
|
| 110 |
+
completed_steps = int(self._latest_progress.get("completed_steps", 0) or 0)
|
| 111 |
+
interval = int(max(self.checkpoint_interval_steps, 1))
|
| 112 |
+
if completed_steps > 0 and completed_steps % interval == 0 and completed_steps != self._last_checkpoint_step:
|
| 113 |
+
self._write_checkpoint(completed_steps)
|
| 114 |
+
|
| 115 |
+
def finalize(self, *, reward_curve_csv: Path | None = None, final_metrics: dict[str, Any] | None = None) -> Path:
|
| 116 |
+
copied_reward_curve = None
|
| 117 |
+
if reward_curve_csv is not None and reward_curve_csv.exists():
|
| 118 |
+
copied_reward_curve = self.logs_dir / "reward_curve.csv"
|
| 119 |
+
if reward_curve_csv.resolve() != copied_reward_curve.resolve():
|
| 120 |
+
shutil.copy2(reward_curve_csv, copied_reward_curve)
|
| 121 |
+
else:
|
| 122 |
+
copied_reward_curve = reward_curve_csv
|
| 123 |
+
|
| 124 |
+
summary = {
|
| 125 |
+
"run_id": self.run_id,
|
| 126 |
+
"finished_at": _utc_now_iso(),
|
| 127 |
+
"artifact_paths": self.artifact_paths(),
|
| 128 |
+
"reward_curve_csv": str(copied_reward_curve) if copied_reward_curve else None,
|
| 129 |
+
"latest_progress": self._latest_progress,
|
| 130 |
+
"latest_event": self._latest_event,
|
| 131 |
+
"rolling_metrics": self._rolling_metrics(),
|
| 132 |
+
"final_metrics": final_metrics or {},
|
| 133 |
+
}
|
| 134 |
+
self.run_summary_path.write_text(json.dumps(_json_safe(summary), indent=2), encoding="utf-8")
|
| 135 |
+
return self.run_summary_path
|
| 136 |
+
|
| 137 |
+
def _write_checkpoint(self, step: int) -> None:
|
| 138 |
+
checkpoint_payload = {
|
| 139 |
+
"run_id": self.run_id,
|
| 140 |
+
"timestamp": _utc_now_iso(),
|
| 141 |
+
"step": int(step),
|
| 142 |
+
"phase": self._latest_progress.get("phase"),
|
| 143 |
+
"total_steps": self._latest_progress.get("total_steps"),
|
| 144 |
+
"remaining_steps": self._latest_progress.get("remaining_steps"),
|
| 145 |
+
"progress_ratio": self._latest_progress.get("progress_ratio"),
|
| 146 |
+
"current_epoch": self._latest_progress.get("current_epoch"),
|
| 147 |
+
"current_difficulty": self._latest_progress.get("current_difficulty"),
|
| 148 |
+
"curriculum_level": self._latest_progress.get("curriculum_level"),
|
| 149 |
+
"train_episode_index": self._latest_progress.get("train_episode_index"),
|
| 150 |
+
"last_problem_id": self._latest_progress.get("last_problem_id"),
|
| 151 |
+
"last_problem_family": self._latest_progress.get("last_problem_family"),
|
| 152 |
+
"last_execution_status": self._latest_progress.get("last_execution_status"),
|
| 153 |
+
"rolling_metrics": self._rolling_metrics(),
|
| 154 |
+
"artifact_paths": {
|
| 155 |
+
"events_path": str(self.events_path),
|
| 156 |
+
"latest_checkpoint_path": str(self.latest_checkpoint_path),
|
| 157 |
+
},
|
| 158 |
+
}
|
| 159 |
+
checkpoint_path = self.logs_dir / f"checkpoint_step_{step:05d}.json"
|
| 160 |
+
checkpoint_path.write_text(json.dumps(_json_safe(checkpoint_payload), indent=2), encoding="utf-8")
|
| 161 |
+
self.latest_checkpoint_path.write_text(json.dumps(_json_safe(checkpoint_payload), indent=2), encoding="utf-8")
|
| 162 |
+
self.checkpoint_paths.append(checkpoint_path)
|
| 163 |
+
self._last_checkpoint_step = step
|
| 164 |
+
|
| 165 |
+
def _rolling_metrics(self) -> dict[str, Any]:
|
| 166 |
+
def _average(values: deque[float]) -> float | None:
|
| 167 |
+
if not values:
|
| 168 |
+
return None
|
| 169 |
+
return round(sum(values) / len(values), 4)
|
| 170 |
+
|
| 171 |
+
return {
|
| 172 |
+
"avg_reward": _average(self._recent_rewards),
|
| 173 |
+
"avg_pass_rate": _average(self._recent_pass_rates),
|
| 174 |
+
"avg_efficiency_score": _average(self._recent_efficiency_scores),
|
| 175 |
+
}
|
training/train_grpo.py
CHANGED
|
@@ -11,6 +11,7 @@ from typing import Any, Callable
|
|
| 11 |
from env.adapt_env import AdaptEnvironment, MAX_STEPS_PER_EPISODE
|
| 12 |
from env.generator import DIFFICULTY_LABELS, GeneratorAgent
|
| 13 |
from models import AdaptAction
|
|
|
|
| 14 |
|
| 15 |
SYSTEM_PROMPT = """You are the Solver Agent for ADAPT.
|
| 16 |
Write only runnable Python code.
|
|
@@ -44,6 +45,8 @@ class TrainingConfig:
|
|
| 44 |
wandb_run_name: str | None = None
|
| 45 |
generator_mode: str = "reward_aware"
|
| 46 |
non_deterministic_generator: bool = False
|
|
|
|
|
|
|
| 47 |
|
| 48 |
def to_dict(self) -> dict[str, Any]:
|
| 49 |
return asdict(self)
|
|
@@ -60,6 +63,7 @@ TRAINING_PRESETS: dict[str, dict[str, Any]] = {
|
|
| 60 |
"baseline_eval": False,
|
| 61 |
"disable_wandb": True,
|
| 62 |
"output_dir": "outputs_smoke",
|
|
|
|
| 63 |
},
|
| 64 |
"default": {},
|
| 65 |
}
|
|
@@ -156,6 +160,8 @@ def namespace_to_config(args: argparse.Namespace) -> TrainingConfig:
|
|
| 156 |
wandb_run_name=args.wandb_run_name,
|
| 157 |
generator_mode=args.generator_mode,
|
| 158 |
non_deterministic_generator=args.non_deterministic_generator,
|
|
|
|
|
|
|
| 159 |
)
|
| 160 |
|
| 161 |
|
|
@@ -329,12 +335,27 @@ class TrainingLogger:
|
|
| 329 |
use_wandb: bool = True
|
| 330 |
wandb_project: str = "adapt-dsa-tutor"
|
| 331 |
wandb_run_name: str | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
rows: list[dict[str, Any]] = field(default_factory=list)
|
| 333 |
global_step: int = 0
|
| 334 |
_wandb_run: Any = field(default=None, init=False, repr=False)
|
|
|
|
| 335 |
|
| 336 |
def __post_init__(self) -> None:
|
| 337 |
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
if not self.use_wandb:
|
| 339 |
return
|
| 340 |
try:
|
|
@@ -381,10 +402,37 @@ class TrainingLogger:
|
|
| 381 |
if extra:
|
| 382 |
row.update(extra)
|
| 383 |
self.rows.append(row)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
if self._wandb_run is not None:
|
| 385 |
self._wandb_run.log(row, step=self.global_step)
|
| 386 |
self.global_step += 1
|
| 387 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
def write_csv(self) -> Path:
|
| 389 |
output_path = self.output_dir / "reward_curve.csv"
|
| 390 |
fieldnames: list[str] = []
|
|
@@ -402,11 +450,37 @@ class TrainingLogger:
|
|
| 402 |
if self._wandb_run is not None:
|
| 403 |
self._wandb_run.finish()
|
| 404 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
|
| 406 |
def build_dataset(size: int, controller: GeneratorController, curriculum: CurriculumManager) -> GeneratorRolloutDataset:
|
| 407 |
return GeneratorRolloutDataset(size=size, controller=controller, curriculum=curriculum)
|
| 408 |
|
| 409 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
def build_reward_func(
|
| 411 |
curriculum: CurriculumManager,
|
| 412 |
controller: GeneratorController,
|
|
@@ -453,6 +527,13 @@ def build_reward_func(
|
|
| 453 |
extra={
|
| 454 |
"generator_reward": round(float(observation.generator_reward_signal), 4),
|
| 455 |
"problem_id": problem["problem_id"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
},
|
| 457 |
)
|
| 458 |
if progress_callback is not None:
|
|
@@ -552,6 +633,9 @@ def run_policy_evaluation(
|
|
| 552 |
session_id=env.session_id,
|
| 553 |
generator_mode=generator_mode,
|
| 554 |
)
|
|
|
|
|
|
|
|
|
|
| 555 |
|
| 556 |
for _ in range(MAX_STEPS_PER_EPISODE):
|
| 557 |
prompt = build_solver_prompt(observation.model_dump())
|
|
@@ -561,10 +645,13 @@ def run_policy_evaluation(
|
|
| 561 |
prompt=prompt,
|
| 562 |
max_new_tokens=max_new_tokens,
|
| 563 |
)
|
|
|
|
|
|
|
|
|
|
| 564 |
observation = env.step(
|
| 565 |
AdaptAction(
|
| 566 |
session_id=env.session_id,
|
| 567 |
-
code=
|
| 568 |
)
|
| 569 |
)
|
| 570 |
if observation.done:
|
|
@@ -591,6 +678,13 @@ def run_policy_evaluation(
|
|
| 591 |
extra={
|
| 592 |
"generator_reward": round(float(observation.generator_reward_signal), 4),
|
| 593 |
"problem_id": problem["problem_id"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
},
|
| 595 |
)
|
| 596 |
|
|
@@ -615,11 +709,17 @@ def print_evaluation_summary(baseline: dict[str, Any], trained: dict[str, Any])
|
|
| 615 |
def run_training(
|
| 616 |
config: TrainingConfig | argparse.Namespace,
|
| 617 |
*,
|
|
|
|
| 618 |
progress_callback: Callable[[dict[str, Any]], None] | None = None,
|
| 619 |
) -> dict[str, Any]:
|
| 620 |
if isinstance(config, argparse.Namespace):
|
| 621 |
config = namespace_to_config(config)
|
| 622 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 623 |
try:
|
| 624 |
from trl import GRPOConfig, GRPOTrainer
|
| 625 |
from unsloth import FastLanguageModel, PatchFastRL
|
|
@@ -634,13 +734,20 @@ def run_training(
|
|
| 634 |
|
| 635 |
PatchFastRL("GRPO", FastLanguageModel)
|
| 636 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 637 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 638 |
model_name=config.model_name,
|
| 639 |
max_seq_length=config.max_seq_length,
|
|
|
|
| 640 |
load_in_4bit=not config.disable_4bit,
|
| 641 |
)
|
| 642 |
if tokenizer.pad_token is None:
|
| 643 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
| 644 |
|
| 645 |
model = FastLanguageModel.get_peft_model(
|
| 646 |
model,
|
|
@@ -661,22 +768,36 @@ def run_training(
|
|
| 661 |
use_wandb=not config.disable_wandb,
|
| 662 |
wandb_project=config.wandb_project,
|
| 663 |
wandb_run_name=config.wandb_run_name,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
)
|
| 665 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 666 |
baseline_summary = {"easy": 0.0, "medium": 0.0, "hard": 0.0, "overall": 0.0}
|
| 667 |
trained_summary = {"easy": 0.0, "medium": 0.0, "hard": 0.0, "overall": 0.0}
|
| 668 |
|
| 669 |
if config.baseline_eval:
|
| 670 |
FastLanguageModel.for_inference(model)
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
)
|
| 680 |
baseline_summary = run_policy_evaluation(
|
| 681 |
model=model,
|
| 682 |
tokenizer=tokenizer,
|
|
@@ -688,14 +809,13 @@ def run_training(
|
|
| 688 |
max_new_tokens=config.eval_max_new_tokens,
|
| 689 |
)
|
| 690 |
print(f"[baseline_eval] {json.dumps(baseline_summary)}")
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
)
|
| 699 |
|
| 700 |
training_args = GRPOConfig(
|
| 701 |
output_dir=str(output_dir),
|
|
@@ -707,49 +827,47 @@ def run_training(
|
|
| 707 |
max_completion_length=config.max_completion_length,
|
| 708 |
max_steps=config.max_steps,
|
| 709 |
logging_steps=1,
|
| 710 |
-
bf16=
|
|
|
|
| 711 |
report_to=[],
|
| 712 |
)
|
| 713 |
|
| 714 |
class ProgressCallback(TrainerCallback):
|
| 715 |
def on_train_begin(self, args, state, control, **kwargs): # type: ignore[override]
|
| 716 |
del args, control, kwargs
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
)
|
| 727 |
|
| 728 |
def on_step_end(self, args, state, control, **kwargs): # type: ignore[override]
|
| 729 |
del args, control, kwargs
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
)
|
| 740 |
|
| 741 |
def on_train_end(self, args, state, control, **kwargs): # type: ignore[override]
|
| 742 |
del args, control, kwargs
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
)
|
| 753 |
|
| 754 |
trainer = GRPOTrainer(
|
| 755 |
model=model,
|
|
@@ -765,15 +883,14 @@ def run_training(
|
|
| 765 |
|
| 766 |
if config.baseline_eval:
|
| 767 |
FastLanguageModel.for_inference(model)
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
)
|
| 777 |
trained_summary = run_policy_evaluation(
|
| 778 |
model=model,
|
| 779 |
tokenizer=tokenizer,
|
|
@@ -786,16 +903,23 @@ def run_training(
|
|
| 786 |
)
|
| 787 |
print(f"[trained_eval] {json.dumps(trained_summary)}")
|
| 788 |
print_evaluation_summary(baseline_summary, trained_summary)
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
)
|
| 797 |
|
| 798 |
csv_path = logger.write_csv()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 799 |
logger.close()
|
| 800 |
print(f"[artifacts] reward curve CSV written to {csv_path}")
|
| 801 |
|
|
@@ -803,6 +927,7 @@ def run_training(
|
|
| 803 |
"config": config.to_dict(),
|
| 804 |
"output_dir": str(output_dir.resolve()),
|
| 805 |
"reward_curve_csv": str(csv_path.resolve()),
|
|
|
|
| 806 |
"baseline_summary": baseline_summary,
|
| 807 |
"trained_summary": trained_summary,
|
| 808 |
"completed_steps": int(config.max_steps),
|
|
@@ -832,6 +957,8 @@ def build_parser() -> argparse.ArgumentParser:
|
|
| 832 |
parser.add_argument("--disable-wandb", action="store_true")
|
| 833 |
parser.add_argument("--wandb-project", default="adapt-dsa-tutor")
|
| 834 |
parser.add_argument("--wandb-run-name", default=None)
|
|
|
|
|
|
|
| 835 |
parser.add_argument(
|
| 836 |
"--generator-mode",
|
| 837 |
choices=["heuristic", "reward_aware"],
|
|
|
|
| 11 |
from env.adapt_env import AdaptEnvironment, MAX_STEPS_PER_EPISODE
|
| 12 |
from env.generator import DIFFICULTY_LABELS, GeneratorAgent
|
| 13 |
from models import AdaptAction
|
| 14 |
+
from training.trace_logging import TraceArtifactLogger
|
| 15 |
|
| 16 |
SYSTEM_PROMPT = """You are the Solver Agent for ADAPT.
|
| 17 |
Write only runnable Python code.
|
|
|
|
| 45 |
wandb_run_name: str | None = None
|
| 46 |
generator_mode: str = "reward_aware"
|
| 47 |
non_deterministic_generator: bool = False
|
| 48 |
+
trace_logging_enabled: bool = True
|
| 49 |
+
checkpoint_log_interval_steps: int = 10
|
| 50 |
|
| 51 |
def to_dict(self) -> dict[str, Any]:
|
| 52 |
return asdict(self)
|
|
|
|
| 63 |
"baseline_eval": False,
|
| 64 |
"disable_wandb": True,
|
| 65 |
"output_dir": "outputs_smoke",
|
| 66 |
+
"checkpoint_log_interval_steps": 2,
|
| 67 |
},
|
| 68 |
"default": {},
|
| 69 |
}
|
|
|
|
| 160 |
wandb_run_name=args.wandb_run_name,
|
| 161 |
generator_mode=args.generator_mode,
|
| 162 |
non_deterministic_generator=args.non_deterministic_generator,
|
| 163 |
+
trace_logging_enabled=args.trace_logging_enabled,
|
| 164 |
+
checkpoint_log_interval_steps=args.checkpoint_log_interval_steps,
|
| 165 |
)
|
| 166 |
|
| 167 |
|
|
|
|
| 335 |
use_wandb: bool = True
|
| 336 |
wandb_project: str = "adapt-dsa-tutor"
|
| 337 |
wandb_run_name: str | None = None
|
| 338 |
+
run_id: str | None = None
|
| 339 |
+
training_config: dict[str, Any] = field(default_factory=dict)
|
| 340 |
+
model_identifiers: dict[str, Any] = field(default_factory=dict)
|
| 341 |
+
trace_logging_enabled: bool = True
|
| 342 |
+
checkpoint_log_interval_steps: int = 10
|
| 343 |
rows: list[dict[str, Any]] = field(default_factory=list)
|
| 344 |
global_step: int = 0
|
| 345 |
_wandb_run: Any = field(default=None, init=False, repr=False)
|
| 346 |
+
_trace_logger: TraceArtifactLogger | None = field(default=None, init=False, repr=False)
|
| 347 |
|
| 348 |
def __post_init__(self) -> None:
|
| 349 |
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 350 |
+
if self.trace_logging_enabled and self.run_id:
|
| 351 |
+
self._trace_logger = TraceArtifactLogger(
|
| 352 |
+
run_id=self.run_id,
|
| 353 |
+
output_dir=self.output_dir,
|
| 354 |
+
training_config=dict(self.training_config),
|
| 355 |
+
model_identifiers=dict(self.model_identifiers),
|
| 356 |
+
system_prompt=SYSTEM_PROMPT,
|
| 357 |
+
checkpoint_interval_steps=int(max(self.checkpoint_log_interval_steps, 1)),
|
| 358 |
+
)
|
| 359 |
if not self.use_wandb:
|
| 360 |
return
|
| 361 |
try:
|
|
|
|
| 402 |
if extra:
|
| 403 |
row.update(extra)
|
| 404 |
self.rows.append(row)
|
| 405 |
+
if self._trace_logger is not None:
|
| 406 |
+
self._trace_logger.log_event(
|
| 407 |
+
{
|
| 408 |
+
"phase": phase,
|
| 409 |
+
"step": self.global_step,
|
| 410 |
+
"train_episode_index": extra.get("train_episode_index") if extra else None,
|
| 411 |
+
"problem_id": row.get("problem_id"),
|
| 412 |
+
"problem_family": row.get("problem_family"),
|
| 413 |
+
"difficulty": row.get("difficulty_tier"),
|
| 414 |
+
"teacher_prompt": row.get("teacher_prompt"),
|
| 415 |
+
"solver_completion": row.get("solver_completion"),
|
| 416 |
+
"extracted_code": row.get("extracted_code"),
|
| 417 |
+
"reward": row.get("episode_reward"),
|
| 418 |
+
"pass_rate": row.get("pass_rate"),
|
| 419 |
+
"visible_pass_rate": row.get("visible_pass_rate"),
|
| 420 |
+
"execution_status": row.get("execution_status"),
|
| 421 |
+
"efficiency_score": row.get("efficiency_score"),
|
| 422 |
+
"optimization_hints": row.get("optimization_hints", []),
|
| 423 |
+
"feedback": row.get("feedback"),
|
| 424 |
+
}
|
| 425 |
+
)
|
| 426 |
if self._wandb_run is not None:
|
| 427 |
self._wandb_run.log(row, step=self.global_step)
|
| 428 |
self.global_step += 1
|
| 429 |
|
| 430 |
+
def record_progress(self, updates: dict[str, Any]) -> dict[str, Any]:
|
| 431 |
+
if self._trace_logger is None:
|
| 432 |
+
return {}
|
| 433 |
+
self._trace_logger.record_progress(updates)
|
| 434 |
+
return self._trace_logger.artifact_paths()
|
| 435 |
+
|
| 436 |
def write_csv(self) -> Path:
|
| 437 |
output_path = self.output_dir / "reward_curve.csv"
|
| 438 |
fieldnames: list[str] = []
|
|
|
|
| 450 |
if self._wandb_run is not None:
|
| 451 |
self._wandb_run.finish()
|
| 452 |
|
| 453 |
+
def finalize_trace_artifacts(
|
| 454 |
+
self,
|
| 455 |
+
*,
|
| 456 |
+
reward_curve_csv: Path | None = None,
|
| 457 |
+
final_metrics: dict[str, Any] | None = None,
|
| 458 |
+
) -> dict[str, Any]:
|
| 459 |
+
if self._trace_logger is None:
|
| 460 |
+
return {}
|
| 461 |
+
self._trace_logger.finalize(reward_curve_csv=reward_curve_csv, final_metrics=final_metrics)
|
| 462 |
+
return self._trace_logger.artifact_paths()
|
| 463 |
+
|
| 464 |
|
| 465 |
def build_dataset(size: int, controller: GeneratorController, curriculum: CurriculumManager) -> GeneratorRolloutDataset:
|
| 466 |
return GeneratorRolloutDataset(size=size, controller=controller, curriculum=curriculum)
|
| 467 |
|
| 468 |
|
| 469 |
+
def extract_optimization_hints(feedback: str) -> list[str]:
|
| 470 |
+
lines = [line.strip() for line in feedback.splitlines()]
|
| 471 |
+
hints: list[str] = []
|
| 472 |
+
capture = False
|
| 473 |
+
for line in lines:
|
| 474 |
+
if line == "Optimization hints:":
|
| 475 |
+
capture = True
|
| 476 |
+
continue
|
| 477 |
+
if capture and line.startswith("- "):
|
| 478 |
+
hints.append(line[2:])
|
| 479 |
+
elif capture and line:
|
| 480 |
+
break
|
| 481 |
+
return hints
|
| 482 |
+
|
| 483 |
+
|
| 484 |
def build_reward_func(
|
| 485 |
curriculum: CurriculumManager,
|
| 486 |
controller: GeneratorController,
|
|
|
|
| 527 |
extra={
|
| 528 |
"generator_reward": round(float(observation.generator_reward_signal), 4),
|
| 529 |
"problem_id": problem["problem_id"],
|
| 530 |
+
"teacher_prompt": prompt,
|
| 531 |
+
"solver_completion": completion,
|
| 532 |
+
"extracted_code": extract_code(completion),
|
| 533 |
+
"feedback": observation.feedback,
|
| 534 |
+
"efficiency_score": observation.reward_components.get("efficiency_score"),
|
| 535 |
+
"optimization_hints": extract_optimization_hints(observation.feedback),
|
| 536 |
+
"train_episode_index": int(controller.history["episode_index"]),
|
| 537 |
},
|
| 538 |
)
|
| 539 |
if progress_callback is not None:
|
|
|
|
| 633 |
session_id=env.session_id,
|
| 634 |
generator_mode=generator_mode,
|
| 635 |
)
|
| 636 |
+
last_prompt = ""
|
| 637 |
+
last_completion = ""
|
| 638 |
+
last_code = ""
|
| 639 |
|
| 640 |
for _ in range(MAX_STEPS_PER_EPISODE):
|
| 641 |
prompt = build_solver_prompt(observation.model_dump())
|
|
|
|
| 645 |
prompt=prompt,
|
| 646 |
max_new_tokens=max_new_tokens,
|
| 647 |
)
|
| 648 |
+
last_prompt = prompt
|
| 649 |
+
last_completion = completion
|
| 650 |
+
last_code = extract_code(completion)
|
| 651 |
observation = env.step(
|
| 652 |
AdaptAction(
|
| 653 |
session_id=env.session_id,
|
| 654 |
+
code=last_code,
|
| 655 |
)
|
| 656 |
)
|
| 657 |
if observation.done:
|
|
|
|
| 678 |
extra={
|
| 679 |
"generator_reward": round(float(observation.generator_reward_signal), 4),
|
| 680 |
"problem_id": problem["problem_id"],
|
| 681 |
+
"teacher_prompt": last_prompt,
|
| 682 |
+
"solver_completion": last_completion,
|
| 683 |
+
"extracted_code": last_code,
|
| 684 |
+
"feedback": observation.feedback,
|
| 685 |
+
"efficiency_score": observation.reward_components.get("efficiency_score"),
|
| 686 |
+
"optimization_hints": extract_optimization_hints(observation.feedback),
|
| 687 |
+
"train_episode_index": int(controller.history["episode_index"]),
|
| 688 |
},
|
| 689 |
)
|
| 690 |
|
|
|
|
| 709 |
def run_training(
|
| 710 |
config: TrainingConfig | argparse.Namespace,
|
| 711 |
*,
|
| 712 |
+
run_id: str | None = None,
|
| 713 |
progress_callback: Callable[[dict[str, Any]], None] | None = None,
|
| 714 |
) -> dict[str, Any]:
|
| 715 |
if isinstance(config, argparse.Namespace):
|
| 716 |
config = namespace_to_config(config)
|
| 717 |
|
| 718 |
+
try:
|
| 719 |
+
import torch
|
| 720 |
+
except ImportError as exc:
|
| 721 |
+
raise RuntimeError("Training requires `torch` to be installed.") from exc
|
| 722 |
+
|
| 723 |
try:
|
| 724 |
from trl import GRPOConfig, GRPOTrainer
|
| 725 |
from unsloth import FastLanguageModel, PatchFastRL
|
|
|
|
| 734 |
|
| 735 |
PatchFastRL("GRPO", FastLanguageModel)
|
| 736 |
|
| 737 |
+
use_cuda = torch.cuda.is_available()
|
| 738 |
+
use_bf16 = bool(config.bf16) or (use_cuda and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported())
|
| 739 |
+
model_dtype = torch.bfloat16 if use_bf16 else (torch.float16 if use_cuda else torch.float32)
|
| 740 |
+
|
| 741 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 742 |
model_name=config.model_name,
|
| 743 |
max_seq_length=config.max_seq_length,
|
| 744 |
+
dtype=model_dtype,
|
| 745 |
load_in_4bit=not config.disable_4bit,
|
| 746 |
)
|
| 747 |
if tokenizer.pad_token is None:
|
| 748 |
tokenizer.pad_token = tokenizer.eos_token
|
| 749 |
+
if hasattr(model, "config"):
|
| 750 |
+
model.config.torch_dtype = model_dtype
|
| 751 |
|
| 752 |
model = FastLanguageModel.get_peft_model(
|
| 753 |
model,
|
|
|
|
| 768 |
use_wandb=not config.disable_wandb,
|
| 769 |
wandb_project=config.wandb_project,
|
| 770 |
wandb_run_name=config.wandb_run_name,
|
| 771 |
+
run_id=run_id,
|
| 772 |
+
training_config=config.to_dict(),
|
| 773 |
+
model_identifiers={
|
| 774 |
+
"model_name": config.model_name,
|
| 775 |
+
"generator_mode": config.generator_mode,
|
| 776 |
+
},
|
| 777 |
+
trace_logging_enabled=config.trace_logging_enabled,
|
| 778 |
+
checkpoint_log_interval_steps=config.checkpoint_log_interval_steps,
|
| 779 |
)
|
| 780 |
|
| 781 |
+
def emit_progress(update: dict[str, Any]) -> None:
|
| 782 |
+
artifact_paths = logger.record_progress(update)
|
| 783 |
+
if progress_callback is not None:
|
| 784 |
+
payload = dict(update)
|
| 785 |
+
payload.update(artifact_paths)
|
| 786 |
+
progress_callback(payload)
|
| 787 |
+
|
| 788 |
baseline_summary = {"easy": 0.0, "medium": 0.0, "hard": 0.0, "overall": 0.0}
|
| 789 |
trained_summary = {"easy": 0.0, "medium": 0.0, "hard": 0.0, "overall": 0.0}
|
| 790 |
|
| 791 |
if config.baseline_eval:
|
| 792 |
FastLanguageModel.for_inference(model)
|
| 793 |
+
emit_progress(
|
| 794 |
+
{
|
| 795 |
+
"phase": "baseline_eval",
|
| 796 |
+
"status": "running",
|
| 797 |
+
"completed_steps": 0,
|
| 798 |
+
"total_steps": int(config.max_steps),
|
| 799 |
+
}
|
| 800 |
+
)
|
|
|
|
| 801 |
baseline_summary = run_policy_evaluation(
|
| 802 |
model=model,
|
| 803 |
tokenizer=tokenizer,
|
|
|
|
| 809 |
max_new_tokens=config.eval_max_new_tokens,
|
| 810 |
)
|
| 811 |
print(f"[baseline_eval] {json.dumps(baseline_summary)}")
|
| 812 |
+
emit_progress(
|
| 813 |
+
{
|
| 814 |
+
"phase": "baseline_eval",
|
| 815 |
+
"status": "completed",
|
| 816 |
+
"baseline_summary": baseline_summary,
|
| 817 |
+
}
|
| 818 |
+
)
|
|
|
|
| 819 |
|
| 820 |
training_args = GRPOConfig(
|
| 821 |
output_dir=str(output_dir),
|
|
|
|
| 827 |
max_completion_length=config.max_completion_length,
|
| 828 |
max_steps=config.max_steps,
|
| 829 |
logging_steps=1,
|
| 830 |
+
bf16=use_bf16,
|
| 831 |
+
fp16=use_cuda and not use_bf16,
|
| 832 |
report_to=[],
|
| 833 |
)
|
| 834 |
|
| 835 |
class ProgressCallback(TrainerCallback):
|
| 836 |
def on_train_begin(self, args, state, control, **kwargs): # type: ignore[override]
|
| 837 |
del args, control, kwargs
|
| 838 |
+
emit_progress(
|
| 839 |
+
{
|
| 840 |
+
"phase": "train",
|
| 841 |
+
"status": "running",
|
| 842 |
+
"completed_steps": int(getattr(state, "global_step", 0)),
|
| 843 |
+
"total_steps": int(config.max_steps),
|
| 844 |
+
"current_epoch": float(getattr(state, "epoch", 0.0) or 0.0),
|
| 845 |
+
}
|
| 846 |
+
)
|
|
|
|
| 847 |
|
| 848 |
def on_step_end(self, args, state, control, **kwargs): # type: ignore[override]
|
| 849 |
del args, control, kwargs
|
| 850 |
+
emit_progress(
|
| 851 |
+
{
|
| 852 |
+
"phase": "train",
|
| 853 |
+
"status": "running",
|
| 854 |
+
"completed_steps": int(getattr(state, "global_step", 0)),
|
| 855 |
+
"total_steps": int(config.max_steps),
|
| 856 |
+
"current_epoch": float(getattr(state, "epoch", 0.0) or 0.0),
|
| 857 |
+
}
|
| 858 |
+
)
|
|
|
|
| 859 |
|
| 860 |
def on_train_end(self, args, state, control, **kwargs): # type: ignore[override]
|
| 861 |
del args, control, kwargs
|
| 862 |
+
emit_progress(
|
| 863 |
+
{
|
| 864 |
+
"phase": "train",
|
| 865 |
+
"status": "completed",
|
| 866 |
+
"completed_steps": int(getattr(state, "global_step", 0)),
|
| 867 |
+
"total_steps": int(config.max_steps),
|
| 868 |
+
"current_epoch": float(getattr(state, "epoch", 0.0) or 0.0),
|
| 869 |
+
}
|
| 870 |
+
)
|
|
|
|
| 871 |
|
| 872 |
trainer = GRPOTrainer(
|
| 873 |
model=model,
|
|
|
|
| 883 |
|
| 884 |
if config.baseline_eval:
|
| 885 |
FastLanguageModel.for_inference(model)
|
| 886 |
+
emit_progress(
|
| 887 |
+
{
|
| 888 |
+
"phase": "trained_eval",
|
| 889 |
+
"status": "running",
|
| 890 |
+
"completed_steps": int(config.max_steps),
|
| 891 |
+
"total_steps": int(config.max_steps),
|
| 892 |
+
}
|
| 893 |
+
)
|
|
|
|
| 894 |
trained_summary = run_policy_evaluation(
|
| 895 |
model=model,
|
| 896 |
tokenizer=tokenizer,
|
|
|
|
| 903 |
)
|
| 904 |
print(f"[trained_eval] {json.dumps(trained_summary)}")
|
| 905 |
print_evaluation_summary(baseline_summary, trained_summary)
|
| 906 |
+
emit_progress(
|
| 907 |
+
{
|
| 908 |
+
"phase": "trained_eval",
|
| 909 |
+
"status": "completed",
|
| 910 |
+
"trained_summary": trained_summary,
|
| 911 |
+
}
|
| 912 |
+
)
|
|
|
|
| 913 |
|
| 914 |
csv_path = logger.write_csv()
|
| 915 |
+
trace_artifact_paths = logger.finalize_trace_artifacts(
|
| 916 |
+
reward_curve_csv=csv_path,
|
| 917 |
+
final_metrics={
|
| 918 |
+
"baseline_summary": baseline_summary,
|
| 919 |
+
"trained_summary": trained_summary,
|
| 920 |
+
"completed_steps": int(config.max_steps),
|
| 921 |
+
},
|
| 922 |
+
)
|
| 923 |
logger.close()
|
| 924 |
print(f"[artifacts] reward curve CSV written to {csv_path}")
|
| 925 |
|
|
|
|
| 927 |
"config": config.to_dict(),
|
| 928 |
"output_dir": str(output_dir.resolve()),
|
| 929 |
"reward_curve_csv": str(csv_path.resolve()),
|
| 930 |
+
**trace_artifact_paths,
|
| 931 |
"baseline_summary": baseline_summary,
|
| 932 |
"trained_summary": trained_summary,
|
| 933 |
"completed_steps": int(config.max_steps),
|
|
|
|
| 957 |
parser.add_argument("--disable-wandb", action="store_true")
|
| 958 |
parser.add_argument("--wandb-project", default="adapt-dsa-tutor")
|
| 959 |
parser.add_argument("--wandb-run-name", default=None)
|
| 960 |
+
parser.add_argument("--trace-logging-enabled", action=argparse.BooleanOptionalAction, default=True)
|
| 961 |
+
parser.add_argument("--checkpoint-log-interval-steps", type=int, default=10)
|
| 962 |
parser.add_argument(
|
| 963 |
"--generator-mode",
|
| 964 |
choices=["heuristic", "reward_aware"],
|