Dishaaa25 commited on
Commit
6e7ed91
·
1 Parent(s): 39fe61e

add logs and fix train error

Browse files
scripts/test_space_api.py CHANGED
@@ -32,6 +32,10 @@ def main() -> None:
32
  assert "completed_steps" in train_status.json()
33
  assert "remaining_steps" in train_status.json()
34
  assert "phase" in train_status.json()
 
 
 
 
35
 
36
  reset = client.post("/reset", json={"difficulty": "easy", "problem_id": "sum_even_numbers"})
37
  assert reset.status_code == 200
 
32
  assert "completed_steps" in train_status.json()
33
  assert "remaining_steps" in train_status.json()
34
  assert "phase" in train_status.json()
35
+ assert "run_manifest_path" in train_status.json()
36
+ assert "events_path" in train_status.json()
37
+ assert "latest_checkpoint_path" in train_status.json()
38
+ assert "logs_deleted_from_space" in train_status.json()
39
 
40
  reset = client.post("/reset", json={"difficulty": "easy", "problem_id": "sum_even_numbers"})
41
  assert reset.status_code == 200
scripts/test_trace_logging.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import sys
5
+ from pathlib import Path
6
+ from tempfile import TemporaryDirectory
7
+
8
+ ROOT = Path(__file__).resolve().parents[1]
9
+ if str(ROOT) not in sys.path:
10
+ sys.path.insert(0, str(ROOT))
11
+
12
+ from training.trace_logging import TraceArtifactLogger
13
+
14
+
15
+ def main() -> None:
16
+ with TemporaryDirectory() as tmpdir:
17
+ output_dir = Path(tmpdir)
18
+ logger = TraceArtifactLogger(
19
+ run_id="run-123",
20
+ output_dir=output_dir,
21
+ training_config={"max_steps": 6, "model_name": "demo-model"},
22
+ model_identifiers={"model_name": "demo-model", "generator_mode": "reward_aware"},
23
+ system_prompt="You are the Solver Agent.",
24
+ checkpoint_interval_steps=2,
25
+ )
26
+
27
+ manifest_path = output_dir / "logs" / "run_manifest.json"
28
+ assert manifest_path.exists()
29
+ manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
30
+ assert manifest["run_id"] == "run-123"
31
+ assert manifest["training_config"]["max_steps"] == 6
32
+
33
+ logger.log_event(
34
+ {
35
+ "phase": "train",
36
+ "step": 0,
37
+ "train_episode_index": 1,
38
+ "problem_id": "sum_even_numbers_1",
39
+ "problem_family": "sum_even_numbers",
40
+ "difficulty": "easy",
41
+ "teacher_prompt": "Problem: sum the even numbers",
42
+ "solver_completion": "print(sum(x for x in nums if x % 2 == 0))",
43
+ "extracted_code": "print(sum(x for x in nums if x % 2 == 0))",
44
+ "reward": 0.94,
45
+ "pass_rate": 1.0,
46
+ "visible_pass_rate": 1.0,
47
+ "execution_status": "completed",
48
+ "efficiency_score": 0.94,
49
+ "optimization_hints": ["Avoid materializing temporary containers."],
50
+ "feedback": "All hidden tests passed, but the solution can still be optimized further.",
51
+ }
52
+ )
53
+ logger.record_progress(
54
+ {
55
+ "phase": "train",
56
+ "completed_steps": 2,
57
+ "total_steps": 6,
58
+ "remaining_steps": 4,
59
+ "progress_ratio": 0.3333,
60
+ "current_epoch": 2.0,
61
+ "current_difficulty": "easy",
62
+ "curriculum_level": 1,
63
+ "train_episode_index": 1,
64
+ "last_problem_id": "sum_even_numbers_1",
65
+ "last_problem_family": "sum_even_numbers",
66
+ "last_execution_status": "completed",
67
+ }
68
+ )
69
+ artifact_paths = logger.artifact_paths()
70
+
71
+ events_path = Path(artifact_paths["events_path"])
72
+ latest_checkpoint_path = Path(artifact_paths["latest_checkpoint_path"])
73
+ assert events_path.exists()
74
+ assert latest_checkpoint_path.exists()
75
+
76
+ event_line = events_path.read_text(encoding="utf-8").strip().splitlines()[0]
77
+ event = json.loads(event_line)
78
+ assert event["problem_id"] == "sum_even_numbers_1"
79
+ assert event["teacher_prompt"] == "Problem: sum the even numbers"
80
+ assert "training_config" not in event
81
+
82
+ checkpoint = json.loads(latest_checkpoint_path.read_text(encoding="utf-8"))
83
+ assert checkpoint["step"] == 2
84
+ assert checkpoint["rolling_metrics"]["avg_reward"] == 0.94
85
+ assert "training_config" not in checkpoint
86
+
87
+ reward_curve = output_dir / "reward_curve.csv"
88
+ reward_curve.write_text("step,episode_reward\n0,0.94\n", encoding="utf-8")
89
+ summary_paths = logger.finalize(
90
+ reward_curve_csv=reward_curve,
91
+ final_metrics={"completed_steps": 6},
92
+ )
93
+ summary_path = Path(summary_paths)
94
+ assert summary_path.exists()
95
+ summary = json.loads(summary_path.read_text(encoding="utf-8"))
96
+ assert summary["final_metrics"]["completed_steps"] == 6
97
+
98
+ print("Trace logging smoke tests passed")
99
+
100
+
101
+ if __name__ == "__main__":
102
+ main()
server/runtime.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
 
3
  import json
4
  import os
 
5
  import threading
6
  import traceback
7
  from dataclasses import asdict, dataclass, field
@@ -69,6 +70,13 @@ class TrainingJobState:
69
  reward_curve_csv: str | None = None
70
  model_repo_id: str | None = None
71
  uploaded_revision: str | None = None
 
 
 
 
 
 
 
72
  phase: str = "idle"
73
  completed_steps: int = 0
74
  total_steps: int = 0
@@ -430,6 +438,13 @@ class SpaceTrainingManager:
430
  reward_curve_csv=payload.get("reward_curve_csv"),
431
  model_repo_id=payload.get("model_repo_id"),
432
  uploaded_revision=payload.get("uploaded_revision"),
 
 
 
 
 
 
 
433
  phase=payload.get("phase", "idle"),
434
  completed_steps=int(payload.get("completed_steps", 0) or 0),
435
  total_steps=int(payload.get("total_steps", 0) or 0),
@@ -508,6 +523,7 @@ class SpaceTrainingManager:
508
  else:
509
  output_dir = self.output_root / requested_output_dir / run_id
510
  config.output_dir = str(output_dir)
 
511
 
512
  self._job = TrainingJobState(
513
  status="running",
@@ -519,6 +535,12 @@ class SpaceTrainingManager:
519
  reward_curve_csv=None,
520
  model_repo_id=os.getenv("HF_MODEL_REPO_ID"),
521
  uploaded_revision=None,
 
 
 
 
 
 
522
  phase="queued",
523
  completed_steps=0,
524
  total_steps=int(config.max_steps),
@@ -562,11 +584,22 @@ class SpaceTrainingManager:
562
  )
563
  return getattr(commit_info, "oid", None) or getattr(commit_info, "commit_hash", None) or "unknown"
564
 
 
 
 
 
 
 
 
 
 
565
  def _run_training_job(self, run_id: str, config: TrainingConfig) -> None:
 
566
  try:
567
- summary = run_training(config, progress_callback=self._update_progress)
568
  artifact_path = summary["output_dir"]
569
  uploaded_revision = self._upload_artifacts(artifact_path, run_id)
 
570
  self.model_registry.load_latest_from_hub()
571
 
572
  with self._lock:
@@ -576,6 +609,13 @@ class SpaceTrainingManager:
576
  self._job.reward_curve_csv = summary.get("reward_curve_csv")
577
  self._job.model_repo_id = os.getenv("HF_MODEL_REPO_ID")
578
  self._job.uploaded_revision = uploaded_revision
 
 
 
 
 
 
 
579
  self._job.phase = "completed"
580
  self._job.completed_steps = int(summary.get("completed_steps", config.max_steps))
581
  self._job.total_steps = int(config.max_steps)
@@ -589,9 +629,18 @@ class SpaceTrainingManager:
589
  self._job.traceback = None
590
  self._persist_status()
591
  except Exception as exc:
 
592
  with self._lock:
593
  self._job.status = "failed"
594
  self._job.finished_at = _utc_now()
 
 
 
 
 
 
 
 
595
  self._job.error = str(exc)
596
  self._job.traceback = traceback.format_exc()
597
  self._persist_status()
 
2
 
3
  import json
4
  import os
5
+ import shutil
6
  import threading
7
  import traceback
8
  from dataclasses import asdict, dataclass, field
 
70
  reward_curve_csv: str | None = None
71
  model_repo_id: str | None = None
72
  uploaded_revision: str | None = None
73
+ logs_dir: str | None = None
74
+ run_manifest_path: str | None = None
75
+ events_path: str | None = None
76
+ latest_checkpoint_path: str | None = None
77
+ run_summary_path: str | None = None
78
+ checkpoint_paths: list[str] = field(default_factory=list)
79
+ logs_deleted_from_space: bool = False
80
  phase: str = "idle"
81
  completed_steps: int = 0
82
  total_steps: int = 0
 
438
  reward_curve_csv=payload.get("reward_curve_csv"),
439
  model_repo_id=payload.get("model_repo_id"),
440
  uploaded_revision=payload.get("uploaded_revision"),
441
+ logs_dir=payload.get("logs_dir"),
442
+ run_manifest_path=payload.get("run_manifest_path"),
443
+ events_path=payload.get("events_path"),
444
+ latest_checkpoint_path=payload.get("latest_checkpoint_path"),
445
+ run_summary_path=payload.get("run_summary_path"),
446
+ checkpoint_paths=payload.get("checkpoint_paths", []),
447
+ logs_deleted_from_space=bool(payload.get("logs_deleted_from_space", False)),
448
  phase=payload.get("phase", "idle"),
449
  completed_steps=int(payload.get("completed_steps", 0) or 0),
450
  total_steps=int(payload.get("total_steps", 0) or 0),
 
523
  else:
524
  output_dir = self.output_root / requested_output_dir / run_id
525
  config.output_dir = str(output_dir)
526
+ logs_dir = output_dir / "logs"
527
 
528
  self._job = TrainingJobState(
529
  status="running",
 
535
  reward_curve_csv=None,
536
  model_repo_id=os.getenv("HF_MODEL_REPO_ID"),
537
  uploaded_revision=None,
538
+ logs_dir=str(logs_dir),
539
+ run_manifest_path=str(logs_dir / "run_manifest.json"),
540
+ events_path=str(logs_dir / "events.jsonl"),
541
+ latest_checkpoint_path=str(logs_dir / "latest_checkpoint.json"),
542
+ run_summary_path=str(logs_dir / "run_summary.json"),
543
+ checkpoint_paths=[],
544
  phase="queued",
545
  completed_steps=0,
546
  total_steps=int(config.max_steps),
 
584
  )
585
  return getattr(commit_info, "oid", None) or getattr(commit_info, "commit_hash", None) or "unknown"
586
 
587
+ def _cleanup_local_logs(self, log_dir: str | None) -> bool:
588
+ if not log_dir:
589
+ return False
590
+ folder_path = Path(log_dir)
591
+ if not folder_path.exists():
592
+ return False
593
+ shutil.rmtree(folder_path, ignore_errors=True)
594
+ return not folder_path.exists()
595
+
596
  def _run_training_job(self, run_id: str, config: TrainingConfig) -> None:
597
+ summary: dict[str, Any] | None = None
598
  try:
599
+ summary = run_training(config, run_id=run_id, progress_callback=self._update_progress)
600
  artifact_path = summary["output_dir"]
601
  uploaded_revision = self._upload_artifacts(artifact_path, run_id)
602
+ logs_deleted = self._cleanup_local_logs(summary.get("logs_dir"))
603
  self.model_registry.load_latest_from_hub()
604
 
605
  with self._lock:
 
609
  self._job.reward_curve_csv = summary.get("reward_curve_csv")
610
  self._job.model_repo_id = os.getenv("HF_MODEL_REPO_ID")
611
  self._job.uploaded_revision = uploaded_revision
612
+ self._job.logs_dir = None if logs_deleted else summary.get("logs_dir")
613
+ self._job.run_manifest_path = None if logs_deleted else summary.get("run_manifest_path")
614
+ self._job.events_path = None if logs_deleted else summary.get("events_path")
615
+ self._job.latest_checkpoint_path = None if logs_deleted else summary.get("latest_checkpoint_path")
616
+ self._job.run_summary_path = None if logs_deleted else summary.get("run_summary_path")
617
+ self._job.checkpoint_paths = [] if logs_deleted else summary.get("checkpoint_paths", [])
618
+ self._job.logs_deleted_from_space = logs_deleted
619
  self._job.phase = "completed"
620
  self._job.completed_steps = int(summary.get("completed_steps", config.max_steps))
621
  self._job.total_steps = int(config.max_steps)
 
629
  self._job.traceback = None
630
  self._persist_status()
631
  except Exception as exc:
632
+ logs_deleted = self._cleanup_local_logs(summary.get("logs_dir") if summary else self._job.logs_dir)
633
  with self._lock:
634
  self._job.status = "failed"
635
  self._job.finished_at = _utc_now()
636
+ if logs_deleted:
637
+ self._job.logs_dir = None
638
+ self._job.run_manifest_path = None
639
+ self._job.events_path = None
640
+ self._job.latest_checkpoint_path = None
641
+ self._job.run_summary_path = None
642
+ self._job.checkpoint_paths = []
643
+ self._job.logs_deleted_from_space = logs_deleted
644
  self._job.error = str(exc)
645
  self._job.traceback = traceback.format_exc()
646
  self._persist_status()
test.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
 
3
  from scripts.test_env import main as run_env_smoke
4
  from scripts.test_space_api import main as run_space_api_smoke
 
5
  from scripts.test_verifier import test_cases
6
  from verifier.verifier import verify
7
 
@@ -9,6 +10,7 @@ from verifier.verifier import verify
9
  def main() -> None:
10
  run_env_smoke()
11
  run_space_api_smoke()
 
12
 
13
  reward, info = verify(
14
  "n=int(input())\nnums=list(map(int,input().split()))\nprint(sum(x for x in nums if x % 2 == 0))",
 
2
 
3
  from scripts.test_env import main as run_env_smoke
4
  from scripts.test_space_api import main as run_space_api_smoke
5
+ from scripts.test_trace_logging import main as run_trace_logging_smoke
6
  from scripts.test_verifier import test_cases
7
  from verifier.verifier import verify
8
 
 
10
  def main() -> None:
11
  run_env_smoke()
12
  run_space_api_smoke()
13
+ run_trace_logging_smoke()
14
 
15
  reward, info = verify(
16
  "n=int(input())\nnums=list(map(int,input().split()))\nprint(sum(x for x in nums if x % 2 == 0))",
training/trace_logging.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import shutil
5
+ from collections import deque
6
+ from dataclasses import dataclass, field
7
+ from datetime import datetime, timezone
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+
12
+ def _utc_now_iso() -> str:
13
+ return datetime.now(timezone.utc).isoformat()
14
+
15
+
16
+ def _json_safe(value: Any) -> Any:
17
+ if isinstance(value, Path):
18
+ return str(value)
19
+ if isinstance(value, dict):
20
+ return {str(key): _json_safe(item) for key, item in value.items()}
21
+ if isinstance(value, list):
22
+ return [_json_safe(item) for item in value]
23
+ return value
24
+
25
+
26
+ @dataclass
27
+ class TraceArtifactLogger:
28
+ run_id: str
29
+ output_dir: Path
30
+ training_config: dict[str, Any]
31
+ model_identifiers: dict[str, Any]
32
+ system_prompt: str
33
+ checkpoint_interval_steps: int = 10
34
+ schema_version: str = "1.0"
35
+ logs_dir: Path = field(init=False)
36
+ manifest_path: Path = field(init=False)
37
+ events_path: Path = field(init=False)
38
+ latest_checkpoint_path: Path = field(init=False)
39
+ run_summary_path: Path = field(init=False)
40
+ checkpoint_paths: list[Path] = field(default_factory=list, init=False)
41
+ _last_checkpoint_step: int = field(default=0, init=False, repr=False)
42
+ _latest_event: dict[str, Any] = field(default_factory=dict, init=False, repr=False)
43
+ _recent_rewards: deque[float] = field(default_factory=lambda: deque(maxlen=25), init=False, repr=False)
44
+ _recent_pass_rates: deque[float] = field(default_factory=lambda: deque(maxlen=25), init=False, repr=False)
45
+ _recent_efficiency_scores: deque[float] = field(default_factory=lambda: deque(maxlen=25), init=False, repr=False)
46
+ _latest_progress: dict[str, Any] = field(default_factory=dict, init=False, repr=False)
47
+
48
+ def __post_init__(self) -> None:
49
+ self.logs_dir = self.output_dir / "logs"
50
+ self.logs_dir.mkdir(parents=True, exist_ok=True)
51
+ self.manifest_path = self.logs_dir / "run_manifest.json"
52
+ self.events_path = self.logs_dir / "events.jsonl"
53
+ self.latest_checkpoint_path = self.logs_dir / "latest_checkpoint.json"
54
+ self.run_summary_path = self.logs_dir / "run_summary.json"
55
+ manifest = {
56
+ "run_id": self.run_id,
57
+ "schema_version": self.schema_version,
58
+ "started_at": _utc_now_iso(),
59
+ "training_config": self.training_config,
60
+ "model_identifiers": self.model_identifiers,
61
+ "system_prompt": self.system_prompt,
62
+ "checkpoint_interval_steps": int(max(self.checkpoint_interval_steps, 1)),
63
+ }
64
+ self.manifest_path.write_text(json.dumps(_json_safe(manifest), indent=2), encoding="utf-8")
65
+
66
+ def artifact_paths(self) -> dict[str, Any]:
67
+ return {
68
+ "logs_dir": str(self.logs_dir),
69
+ "run_manifest_path": str(self.manifest_path),
70
+ "events_path": str(self.events_path),
71
+ "latest_checkpoint_path": str(self.latest_checkpoint_path),
72
+ "checkpoint_paths": [str(path) for path in self.checkpoint_paths],
73
+ "run_summary_path": str(self.run_summary_path),
74
+ }
75
+
76
+ def log_event(self, event: dict[str, Any]) -> None:
77
+ dynamic_event = {
78
+ "run_id": self.run_id,
79
+ "timestamp": _utc_now_iso(),
80
+ "phase": event.get("phase"),
81
+ "step": event.get("step"),
82
+ "train_episode_index": event.get("train_episode_index"),
83
+ "problem_id": event.get("problem_id"),
84
+ "problem_family": event.get("problem_family"),
85
+ "difficulty": event.get("difficulty"),
86
+ "teacher_prompt": event.get("teacher_prompt"),
87
+ "solver_completion": event.get("solver_completion"),
88
+ "extracted_code": event.get("extracted_code"),
89
+ "reward": event.get("reward"),
90
+ "pass_rate": event.get("pass_rate"),
91
+ "visible_pass_rate": event.get("visible_pass_rate"),
92
+ "execution_status": event.get("execution_status"),
93
+ "efficiency_score": event.get("efficiency_score"),
94
+ "optimization_hints": event.get("optimization_hints", []),
95
+ "feedback": event.get("feedback"),
96
+ }
97
+ with self.events_path.open("a", encoding="utf-8") as handle:
98
+ handle.write(json.dumps(_json_safe(dynamic_event)) + "\n")
99
+
100
+ self._latest_event = dynamic_event
101
+ if dynamic_event.get("reward") is not None:
102
+ self._recent_rewards.append(float(dynamic_event["reward"]))
103
+ if dynamic_event.get("pass_rate") is not None:
104
+ self._recent_pass_rates.append(float(dynamic_event["pass_rate"]))
105
+ if dynamic_event.get("efficiency_score") is not None:
106
+ self._recent_efficiency_scores.append(float(dynamic_event["efficiency_score"]))
107
+
108
+ def record_progress(self, progress: dict[str, Any]) -> None:
109
+ self._latest_progress.update({key: value for key, value in progress.items() if value is not None})
110
+ completed_steps = int(self._latest_progress.get("completed_steps", 0) or 0)
111
+ interval = int(max(self.checkpoint_interval_steps, 1))
112
+ if completed_steps > 0 and completed_steps % interval == 0 and completed_steps != self._last_checkpoint_step:
113
+ self._write_checkpoint(completed_steps)
114
+
115
+ def finalize(self, *, reward_curve_csv: Path | None = None, final_metrics: dict[str, Any] | None = None) -> Path:
116
+ copied_reward_curve = None
117
+ if reward_curve_csv is not None and reward_curve_csv.exists():
118
+ copied_reward_curve = self.logs_dir / "reward_curve.csv"
119
+ if reward_curve_csv.resolve() != copied_reward_curve.resolve():
120
+ shutil.copy2(reward_curve_csv, copied_reward_curve)
121
+ else:
122
+ copied_reward_curve = reward_curve_csv
123
+
124
+ summary = {
125
+ "run_id": self.run_id,
126
+ "finished_at": _utc_now_iso(),
127
+ "artifact_paths": self.artifact_paths(),
128
+ "reward_curve_csv": str(copied_reward_curve) if copied_reward_curve else None,
129
+ "latest_progress": self._latest_progress,
130
+ "latest_event": self._latest_event,
131
+ "rolling_metrics": self._rolling_metrics(),
132
+ "final_metrics": final_metrics or {},
133
+ }
134
+ self.run_summary_path.write_text(json.dumps(_json_safe(summary), indent=2), encoding="utf-8")
135
+ return self.run_summary_path
136
+
137
+ def _write_checkpoint(self, step: int) -> None:
138
+ checkpoint_payload = {
139
+ "run_id": self.run_id,
140
+ "timestamp": _utc_now_iso(),
141
+ "step": int(step),
142
+ "phase": self._latest_progress.get("phase"),
143
+ "total_steps": self._latest_progress.get("total_steps"),
144
+ "remaining_steps": self._latest_progress.get("remaining_steps"),
145
+ "progress_ratio": self._latest_progress.get("progress_ratio"),
146
+ "current_epoch": self._latest_progress.get("current_epoch"),
147
+ "current_difficulty": self._latest_progress.get("current_difficulty"),
148
+ "curriculum_level": self._latest_progress.get("curriculum_level"),
149
+ "train_episode_index": self._latest_progress.get("train_episode_index"),
150
+ "last_problem_id": self._latest_progress.get("last_problem_id"),
151
+ "last_problem_family": self._latest_progress.get("last_problem_family"),
152
+ "last_execution_status": self._latest_progress.get("last_execution_status"),
153
+ "rolling_metrics": self._rolling_metrics(),
154
+ "artifact_paths": {
155
+ "events_path": str(self.events_path),
156
+ "latest_checkpoint_path": str(self.latest_checkpoint_path),
157
+ },
158
+ }
159
+ checkpoint_path = self.logs_dir / f"checkpoint_step_{step:05d}.json"
160
+ checkpoint_path.write_text(json.dumps(_json_safe(checkpoint_payload), indent=2), encoding="utf-8")
161
+ self.latest_checkpoint_path.write_text(json.dumps(_json_safe(checkpoint_payload), indent=2), encoding="utf-8")
162
+ self.checkpoint_paths.append(checkpoint_path)
163
+ self._last_checkpoint_step = step
164
+
165
+ def _rolling_metrics(self) -> dict[str, Any]:
166
+ def _average(values: deque[float]) -> float | None:
167
+ if not values:
168
+ return None
169
+ return round(sum(values) / len(values), 4)
170
+
171
+ return {
172
+ "avg_reward": _average(self._recent_rewards),
173
+ "avg_pass_rate": _average(self._recent_pass_rates),
174
+ "avg_efficiency_score": _average(self._recent_efficiency_scores),
175
+ }
training/train_grpo.py CHANGED
@@ -11,6 +11,7 @@ from typing import Any, Callable
11
  from env.adapt_env import AdaptEnvironment, MAX_STEPS_PER_EPISODE
12
  from env.generator import DIFFICULTY_LABELS, GeneratorAgent
13
  from models import AdaptAction
 
14
 
15
  SYSTEM_PROMPT = """You are the Solver Agent for ADAPT.
16
  Write only runnable Python code.
@@ -44,6 +45,8 @@ class TrainingConfig:
44
  wandb_run_name: str | None = None
45
  generator_mode: str = "reward_aware"
46
  non_deterministic_generator: bool = False
 
 
47
 
48
  def to_dict(self) -> dict[str, Any]:
49
  return asdict(self)
@@ -60,6 +63,7 @@ TRAINING_PRESETS: dict[str, dict[str, Any]] = {
60
  "baseline_eval": False,
61
  "disable_wandb": True,
62
  "output_dir": "outputs_smoke",
 
63
  },
64
  "default": {},
65
  }
@@ -156,6 +160,8 @@ def namespace_to_config(args: argparse.Namespace) -> TrainingConfig:
156
  wandb_run_name=args.wandb_run_name,
157
  generator_mode=args.generator_mode,
158
  non_deterministic_generator=args.non_deterministic_generator,
 
 
159
  )
160
 
161
 
@@ -329,12 +335,27 @@ class TrainingLogger:
329
  use_wandb: bool = True
330
  wandb_project: str = "adapt-dsa-tutor"
331
  wandb_run_name: str | None = None
 
 
 
 
 
332
  rows: list[dict[str, Any]] = field(default_factory=list)
333
  global_step: int = 0
334
  _wandb_run: Any = field(default=None, init=False, repr=False)
 
335
 
336
  def __post_init__(self) -> None:
337
  self.output_dir.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
338
  if not self.use_wandb:
339
  return
340
  try:
@@ -381,10 +402,37 @@ class TrainingLogger:
381
  if extra:
382
  row.update(extra)
383
  self.rows.append(row)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  if self._wandb_run is not None:
385
  self._wandb_run.log(row, step=self.global_step)
386
  self.global_step += 1
387
 
 
 
 
 
 
 
388
  def write_csv(self) -> Path:
389
  output_path = self.output_dir / "reward_curve.csv"
390
  fieldnames: list[str] = []
@@ -402,11 +450,37 @@ class TrainingLogger:
402
  if self._wandb_run is not None:
403
  self._wandb_run.finish()
404
 
 
 
 
 
 
 
 
 
 
 
 
405
 
406
  def build_dataset(size: int, controller: GeneratorController, curriculum: CurriculumManager) -> GeneratorRolloutDataset:
407
  return GeneratorRolloutDataset(size=size, controller=controller, curriculum=curriculum)
408
 
409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  def build_reward_func(
411
  curriculum: CurriculumManager,
412
  controller: GeneratorController,
@@ -453,6 +527,13 @@ def build_reward_func(
453
  extra={
454
  "generator_reward": round(float(observation.generator_reward_signal), 4),
455
  "problem_id": problem["problem_id"],
 
 
 
 
 
 
 
456
  },
457
  )
458
  if progress_callback is not None:
@@ -552,6 +633,9 @@ def run_policy_evaluation(
552
  session_id=env.session_id,
553
  generator_mode=generator_mode,
554
  )
 
 
 
555
 
556
  for _ in range(MAX_STEPS_PER_EPISODE):
557
  prompt = build_solver_prompt(observation.model_dump())
@@ -561,10 +645,13 @@ def run_policy_evaluation(
561
  prompt=prompt,
562
  max_new_tokens=max_new_tokens,
563
  )
 
 
 
564
  observation = env.step(
565
  AdaptAction(
566
  session_id=env.session_id,
567
- code=extract_code(completion),
568
  )
569
  )
570
  if observation.done:
@@ -591,6 +678,13 @@ def run_policy_evaluation(
591
  extra={
592
  "generator_reward": round(float(observation.generator_reward_signal), 4),
593
  "problem_id": problem["problem_id"],
 
 
 
 
 
 
 
594
  },
595
  )
596
 
@@ -615,11 +709,17 @@ def print_evaluation_summary(baseline: dict[str, Any], trained: dict[str, Any])
615
  def run_training(
616
  config: TrainingConfig | argparse.Namespace,
617
  *,
 
618
  progress_callback: Callable[[dict[str, Any]], None] | None = None,
619
  ) -> dict[str, Any]:
620
  if isinstance(config, argparse.Namespace):
621
  config = namespace_to_config(config)
622
 
 
 
 
 
 
623
  try:
624
  from trl import GRPOConfig, GRPOTrainer
625
  from unsloth import FastLanguageModel, PatchFastRL
@@ -634,13 +734,20 @@ def run_training(
634
 
635
  PatchFastRL("GRPO", FastLanguageModel)
636
 
 
 
 
 
637
  model, tokenizer = FastLanguageModel.from_pretrained(
638
  model_name=config.model_name,
639
  max_seq_length=config.max_seq_length,
 
640
  load_in_4bit=not config.disable_4bit,
641
  )
642
  if tokenizer.pad_token is None:
643
  tokenizer.pad_token = tokenizer.eos_token
 
 
644
 
645
  model = FastLanguageModel.get_peft_model(
646
  model,
@@ -661,22 +768,36 @@ def run_training(
661
  use_wandb=not config.disable_wandb,
662
  wandb_project=config.wandb_project,
663
  wandb_run_name=config.wandb_run_name,
 
 
 
 
 
 
 
 
664
  )
665
 
 
 
 
 
 
 
 
666
  baseline_summary = {"easy": 0.0, "medium": 0.0, "hard": 0.0, "overall": 0.0}
667
  trained_summary = {"easy": 0.0, "medium": 0.0, "hard": 0.0, "overall": 0.0}
668
 
669
  if config.baseline_eval:
670
  FastLanguageModel.for_inference(model)
671
- if progress_callback is not None:
672
- progress_callback(
673
- {
674
- "phase": "baseline_eval",
675
- "status": "running",
676
- "completed_steps": 0,
677
- "total_steps": int(config.max_steps),
678
- }
679
- )
680
  baseline_summary = run_policy_evaluation(
681
  model=model,
682
  tokenizer=tokenizer,
@@ -688,14 +809,13 @@ def run_training(
688
  max_new_tokens=config.eval_max_new_tokens,
689
  )
690
  print(f"[baseline_eval] {json.dumps(baseline_summary)}")
691
- if progress_callback is not None:
692
- progress_callback(
693
- {
694
- "phase": "baseline_eval",
695
- "status": "completed",
696
- "baseline_summary": baseline_summary,
697
- }
698
- )
699
 
700
  training_args = GRPOConfig(
701
  output_dir=str(output_dir),
@@ -707,49 +827,47 @@ def run_training(
707
  max_completion_length=config.max_completion_length,
708
  max_steps=config.max_steps,
709
  logging_steps=1,
710
- bf16=config.bf16,
 
711
  report_to=[],
712
  )
713
 
714
  class ProgressCallback(TrainerCallback):
715
  def on_train_begin(self, args, state, control, **kwargs): # type: ignore[override]
716
  del args, control, kwargs
717
- if progress_callback is not None:
718
- progress_callback(
719
- {
720
- "phase": "train",
721
- "status": "running",
722
- "completed_steps": int(getattr(state, "global_step", 0)),
723
- "total_steps": int(config.max_steps),
724
- "current_epoch": float(getattr(state, "epoch", 0.0) or 0.0),
725
- }
726
- )
727
 
728
  def on_step_end(self, args, state, control, **kwargs): # type: ignore[override]
729
  del args, control, kwargs
730
- if progress_callback is not None:
731
- progress_callback(
732
- {
733
- "phase": "train",
734
- "status": "running",
735
- "completed_steps": int(getattr(state, "global_step", 0)),
736
- "total_steps": int(config.max_steps),
737
- "current_epoch": float(getattr(state, "epoch", 0.0) or 0.0),
738
- }
739
- )
740
 
741
  def on_train_end(self, args, state, control, **kwargs): # type: ignore[override]
742
  del args, control, kwargs
743
- if progress_callback is not None:
744
- progress_callback(
745
- {
746
- "phase": "train",
747
- "status": "completed",
748
- "completed_steps": int(getattr(state, "global_step", 0)),
749
- "total_steps": int(config.max_steps),
750
- "current_epoch": float(getattr(state, "epoch", 0.0) or 0.0),
751
- }
752
- )
753
 
754
  trainer = GRPOTrainer(
755
  model=model,
@@ -765,15 +883,14 @@ def run_training(
765
 
766
  if config.baseline_eval:
767
  FastLanguageModel.for_inference(model)
768
- if progress_callback is not None:
769
- progress_callback(
770
- {
771
- "phase": "trained_eval",
772
- "status": "running",
773
- "completed_steps": int(config.max_steps),
774
- "total_steps": int(config.max_steps),
775
- }
776
- )
777
  trained_summary = run_policy_evaluation(
778
  model=model,
779
  tokenizer=tokenizer,
@@ -786,16 +903,23 @@ def run_training(
786
  )
787
  print(f"[trained_eval] {json.dumps(trained_summary)}")
788
  print_evaluation_summary(baseline_summary, trained_summary)
789
- if progress_callback is not None:
790
- progress_callback(
791
- {
792
- "phase": "trained_eval",
793
- "status": "completed",
794
- "trained_summary": trained_summary,
795
- }
796
- )
797
 
798
  csv_path = logger.write_csv()
 
 
 
 
 
 
 
 
799
  logger.close()
800
  print(f"[artifacts] reward curve CSV written to {csv_path}")
801
 
@@ -803,6 +927,7 @@ def run_training(
803
  "config": config.to_dict(),
804
  "output_dir": str(output_dir.resolve()),
805
  "reward_curve_csv": str(csv_path.resolve()),
 
806
  "baseline_summary": baseline_summary,
807
  "trained_summary": trained_summary,
808
  "completed_steps": int(config.max_steps),
@@ -832,6 +957,8 @@ def build_parser() -> argparse.ArgumentParser:
832
  parser.add_argument("--disable-wandb", action="store_true")
833
  parser.add_argument("--wandb-project", default="adapt-dsa-tutor")
834
  parser.add_argument("--wandb-run-name", default=None)
 
 
835
  parser.add_argument(
836
  "--generator-mode",
837
  choices=["heuristic", "reward_aware"],
 
11
  from env.adapt_env import AdaptEnvironment, MAX_STEPS_PER_EPISODE
12
  from env.generator import DIFFICULTY_LABELS, GeneratorAgent
13
  from models import AdaptAction
14
+ from training.trace_logging import TraceArtifactLogger
15
 
16
  SYSTEM_PROMPT = """You are the Solver Agent for ADAPT.
17
  Write only runnable Python code.
 
45
  wandb_run_name: str | None = None
46
  generator_mode: str = "reward_aware"
47
  non_deterministic_generator: bool = False
48
+ trace_logging_enabled: bool = True
49
+ checkpoint_log_interval_steps: int = 10
50
 
51
  def to_dict(self) -> dict[str, Any]:
52
  return asdict(self)
 
63
  "baseline_eval": False,
64
  "disable_wandb": True,
65
  "output_dir": "outputs_smoke",
66
+ "checkpoint_log_interval_steps": 2,
67
  },
68
  "default": {},
69
  }
 
160
  wandb_run_name=args.wandb_run_name,
161
  generator_mode=args.generator_mode,
162
  non_deterministic_generator=args.non_deterministic_generator,
163
+ trace_logging_enabled=args.trace_logging_enabled,
164
+ checkpoint_log_interval_steps=args.checkpoint_log_interval_steps,
165
  )
166
 
167
 
 
335
  use_wandb: bool = True
336
  wandb_project: str = "adapt-dsa-tutor"
337
  wandb_run_name: str | None = None
338
+ run_id: str | None = None
339
+ training_config: dict[str, Any] = field(default_factory=dict)
340
+ model_identifiers: dict[str, Any] = field(default_factory=dict)
341
+ trace_logging_enabled: bool = True
342
+ checkpoint_log_interval_steps: int = 10
343
  rows: list[dict[str, Any]] = field(default_factory=list)
344
  global_step: int = 0
345
  _wandb_run: Any = field(default=None, init=False, repr=False)
346
+ _trace_logger: TraceArtifactLogger | None = field(default=None, init=False, repr=False)
347
 
348
  def __post_init__(self) -> None:
349
  self.output_dir.mkdir(parents=True, exist_ok=True)
350
+ if self.trace_logging_enabled and self.run_id:
351
+ self._trace_logger = TraceArtifactLogger(
352
+ run_id=self.run_id,
353
+ output_dir=self.output_dir,
354
+ training_config=dict(self.training_config),
355
+ model_identifiers=dict(self.model_identifiers),
356
+ system_prompt=SYSTEM_PROMPT,
357
+ checkpoint_interval_steps=int(max(self.checkpoint_log_interval_steps, 1)),
358
+ )
359
  if not self.use_wandb:
360
  return
361
  try:
 
402
  if extra:
403
  row.update(extra)
404
  self.rows.append(row)
405
+ if self._trace_logger is not None:
406
+ self._trace_logger.log_event(
407
+ {
408
+ "phase": phase,
409
+ "step": self.global_step,
410
+ "train_episode_index": extra.get("train_episode_index") if extra else None,
411
+ "problem_id": row.get("problem_id"),
412
+ "problem_family": row.get("problem_family"),
413
+ "difficulty": row.get("difficulty_tier"),
414
+ "teacher_prompt": row.get("teacher_prompt"),
415
+ "solver_completion": row.get("solver_completion"),
416
+ "extracted_code": row.get("extracted_code"),
417
+ "reward": row.get("episode_reward"),
418
+ "pass_rate": row.get("pass_rate"),
419
+ "visible_pass_rate": row.get("visible_pass_rate"),
420
+ "execution_status": row.get("execution_status"),
421
+ "efficiency_score": row.get("efficiency_score"),
422
+ "optimization_hints": row.get("optimization_hints", []),
423
+ "feedback": row.get("feedback"),
424
+ }
425
+ )
426
  if self._wandb_run is not None:
427
  self._wandb_run.log(row, step=self.global_step)
428
  self.global_step += 1
429
 
430
+ def record_progress(self, updates: dict[str, Any]) -> dict[str, Any]:
431
+ if self._trace_logger is None:
432
+ return {}
433
+ self._trace_logger.record_progress(updates)
434
+ return self._trace_logger.artifact_paths()
435
+
436
  def write_csv(self) -> Path:
437
  output_path = self.output_dir / "reward_curve.csv"
438
  fieldnames: list[str] = []
 
450
  if self._wandb_run is not None:
451
  self._wandb_run.finish()
452
 
453
+ def finalize_trace_artifacts(
454
+ self,
455
+ *,
456
+ reward_curve_csv: Path | None = None,
457
+ final_metrics: dict[str, Any] | None = None,
458
+ ) -> dict[str, Any]:
459
+ if self._trace_logger is None:
460
+ return {}
461
+ self._trace_logger.finalize(reward_curve_csv=reward_curve_csv, final_metrics=final_metrics)
462
+ return self._trace_logger.artifact_paths()
463
+
464
 
465
  def build_dataset(size: int, controller: GeneratorController, curriculum: CurriculumManager) -> GeneratorRolloutDataset:
466
  return GeneratorRolloutDataset(size=size, controller=controller, curriculum=curriculum)
467
 
468
 
469
+ def extract_optimization_hints(feedback: str) -> list[str]:
470
+ lines = [line.strip() for line in feedback.splitlines()]
471
+ hints: list[str] = []
472
+ capture = False
473
+ for line in lines:
474
+ if line == "Optimization hints:":
475
+ capture = True
476
+ continue
477
+ if capture and line.startswith("- "):
478
+ hints.append(line[2:])
479
+ elif capture and line:
480
+ break
481
+ return hints
482
+
483
+
484
  def build_reward_func(
485
  curriculum: CurriculumManager,
486
  controller: GeneratorController,
 
527
  extra={
528
  "generator_reward": round(float(observation.generator_reward_signal), 4),
529
  "problem_id": problem["problem_id"],
530
+ "teacher_prompt": prompt,
531
+ "solver_completion": completion,
532
+ "extracted_code": extract_code(completion),
533
+ "feedback": observation.feedback,
534
+ "efficiency_score": observation.reward_components.get("efficiency_score"),
535
+ "optimization_hints": extract_optimization_hints(observation.feedback),
536
+ "train_episode_index": int(controller.history["episode_index"]),
537
  },
538
  )
539
  if progress_callback is not None:
 
633
  session_id=env.session_id,
634
  generator_mode=generator_mode,
635
  )
636
+ last_prompt = ""
637
+ last_completion = ""
638
+ last_code = ""
639
 
640
  for _ in range(MAX_STEPS_PER_EPISODE):
641
  prompt = build_solver_prompt(observation.model_dump())
 
645
  prompt=prompt,
646
  max_new_tokens=max_new_tokens,
647
  )
648
+ last_prompt = prompt
649
+ last_completion = completion
650
+ last_code = extract_code(completion)
651
  observation = env.step(
652
  AdaptAction(
653
  session_id=env.session_id,
654
+ code=last_code,
655
  )
656
  )
657
  if observation.done:
 
678
  extra={
679
  "generator_reward": round(float(observation.generator_reward_signal), 4),
680
  "problem_id": problem["problem_id"],
681
+ "teacher_prompt": last_prompt,
682
+ "solver_completion": last_completion,
683
+ "extracted_code": last_code,
684
+ "feedback": observation.feedback,
685
+ "efficiency_score": observation.reward_components.get("efficiency_score"),
686
+ "optimization_hints": extract_optimization_hints(observation.feedback),
687
+ "train_episode_index": int(controller.history["episode_index"]),
688
  },
689
  )
690
 
 
709
  def run_training(
710
  config: TrainingConfig | argparse.Namespace,
711
  *,
712
+ run_id: str | None = None,
713
  progress_callback: Callable[[dict[str, Any]], None] | None = None,
714
  ) -> dict[str, Any]:
715
  if isinstance(config, argparse.Namespace):
716
  config = namespace_to_config(config)
717
 
718
+ try:
719
+ import torch
720
+ except ImportError as exc:
721
+ raise RuntimeError("Training requires `torch` to be installed.") from exc
722
+
723
  try:
724
  from trl import GRPOConfig, GRPOTrainer
725
  from unsloth import FastLanguageModel, PatchFastRL
 
734
 
735
  PatchFastRL("GRPO", FastLanguageModel)
736
 
737
+ use_cuda = torch.cuda.is_available()
738
+ use_bf16 = bool(config.bf16) or (use_cuda and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported())
739
+ model_dtype = torch.bfloat16 if use_bf16 else (torch.float16 if use_cuda else torch.float32)
740
+
741
  model, tokenizer = FastLanguageModel.from_pretrained(
742
  model_name=config.model_name,
743
  max_seq_length=config.max_seq_length,
744
+ dtype=model_dtype,
745
  load_in_4bit=not config.disable_4bit,
746
  )
747
  if tokenizer.pad_token is None:
748
  tokenizer.pad_token = tokenizer.eos_token
749
+ if hasattr(model, "config"):
750
+ model.config.torch_dtype = model_dtype
751
 
752
  model = FastLanguageModel.get_peft_model(
753
  model,
 
768
  use_wandb=not config.disable_wandb,
769
  wandb_project=config.wandb_project,
770
  wandb_run_name=config.wandb_run_name,
771
+ run_id=run_id,
772
+ training_config=config.to_dict(),
773
+ model_identifiers={
774
+ "model_name": config.model_name,
775
+ "generator_mode": config.generator_mode,
776
+ },
777
+ trace_logging_enabled=config.trace_logging_enabled,
778
+ checkpoint_log_interval_steps=config.checkpoint_log_interval_steps,
779
  )
780
 
781
+ def emit_progress(update: dict[str, Any]) -> None:
782
+ artifact_paths = logger.record_progress(update)
783
+ if progress_callback is not None:
784
+ payload = dict(update)
785
+ payload.update(artifact_paths)
786
+ progress_callback(payload)
787
+
788
  baseline_summary = {"easy": 0.0, "medium": 0.0, "hard": 0.0, "overall": 0.0}
789
  trained_summary = {"easy": 0.0, "medium": 0.0, "hard": 0.0, "overall": 0.0}
790
 
791
  if config.baseline_eval:
792
  FastLanguageModel.for_inference(model)
793
+ emit_progress(
794
+ {
795
+ "phase": "baseline_eval",
796
+ "status": "running",
797
+ "completed_steps": 0,
798
+ "total_steps": int(config.max_steps),
799
+ }
800
+ )
 
801
  baseline_summary = run_policy_evaluation(
802
  model=model,
803
  tokenizer=tokenizer,
 
809
  max_new_tokens=config.eval_max_new_tokens,
810
  )
811
  print(f"[baseline_eval] {json.dumps(baseline_summary)}")
812
+ emit_progress(
813
+ {
814
+ "phase": "baseline_eval",
815
+ "status": "completed",
816
+ "baseline_summary": baseline_summary,
817
+ }
818
+ )
 
819
 
820
  training_args = GRPOConfig(
821
  output_dir=str(output_dir),
 
827
  max_completion_length=config.max_completion_length,
828
  max_steps=config.max_steps,
829
  logging_steps=1,
830
+ bf16=use_bf16,
831
+ fp16=use_cuda and not use_bf16,
832
  report_to=[],
833
  )
834
 
835
  class ProgressCallback(TrainerCallback):
836
  def on_train_begin(self, args, state, control, **kwargs): # type: ignore[override]
837
  del args, control, kwargs
838
+ emit_progress(
839
+ {
840
+ "phase": "train",
841
+ "status": "running",
842
+ "completed_steps": int(getattr(state, "global_step", 0)),
843
+ "total_steps": int(config.max_steps),
844
+ "current_epoch": float(getattr(state, "epoch", 0.0) or 0.0),
845
+ }
846
+ )
 
847
 
848
  def on_step_end(self, args, state, control, **kwargs): # type: ignore[override]
849
  del args, control, kwargs
850
+ emit_progress(
851
+ {
852
+ "phase": "train",
853
+ "status": "running",
854
+ "completed_steps": int(getattr(state, "global_step", 0)),
855
+ "total_steps": int(config.max_steps),
856
+ "current_epoch": float(getattr(state, "epoch", 0.0) or 0.0),
857
+ }
858
+ )
 
859
 
860
  def on_train_end(self, args, state, control, **kwargs): # type: ignore[override]
861
  del args, control, kwargs
862
+ emit_progress(
863
+ {
864
+ "phase": "train",
865
+ "status": "completed",
866
+ "completed_steps": int(getattr(state, "global_step", 0)),
867
+ "total_steps": int(config.max_steps),
868
+ "current_epoch": float(getattr(state, "epoch", 0.0) or 0.0),
869
+ }
870
+ )
 
871
 
872
  trainer = GRPOTrainer(
873
  model=model,
 
883
 
884
  if config.baseline_eval:
885
  FastLanguageModel.for_inference(model)
886
+ emit_progress(
887
+ {
888
+ "phase": "trained_eval",
889
+ "status": "running",
890
+ "completed_steps": int(config.max_steps),
891
+ "total_steps": int(config.max_steps),
892
+ }
893
+ )
 
894
  trained_summary = run_policy_evaluation(
895
  model=model,
896
  tokenizer=tokenizer,
 
903
  )
904
  print(f"[trained_eval] {json.dumps(trained_summary)}")
905
  print_evaluation_summary(baseline_summary, trained_summary)
906
+ emit_progress(
907
+ {
908
+ "phase": "trained_eval",
909
+ "status": "completed",
910
+ "trained_summary": trained_summary,
911
+ }
912
+ )
 
913
 
914
  csv_path = logger.write_csv()
915
+ trace_artifact_paths = logger.finalize_trace_artifacts(
916
+ reward_curve_csv=csv_path,
917
+ final_metrics={
918
+ "baseline_summary": baseline_summary,
919
+ "trained_summary": trained_summary,
920
+ "completed_steps": int(config.max_steps),
921
+ },
922
+ )
923
  logger.close()
924
  print(f"[artifacts] reward curve CSV written to {csv_path}")
925
 
 
927
  "config": config.to_dict(),
928
  "output_dir": str(output_dir.resolve()),
929
  "reward_curve_csv": str(csv_path.resolve()),
930
+ **trace_artifact_paths,
931
  "baseline_summary": baseline_summary,
932
  "trained_summary": trained_summary,
933
  "completed_steps": int(config.max_steps),
 
957
  parser.add_argument("--disable-wandb", action="store_true")
958
  parser.add_argument("--wandb-project", default="adapt-dsa-tutor")
959
  parser.add_argument("--wandb-run-name", default=None)
960
+ parser.add_argument("--trace-logging-enabled", action=argparse.BooleanOptionalAction, default=True)
961
+ parser.add_argument("--checkpoint-log-interval-steps", type=int, default=10)
962
  parser.add_argument(
963
  "--generator-mode",
964
  choices=["heuristic", "reward_aware"],