Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- README.md +49 -0
- graders.py +35 -6
- models.py +7 -0
- notebooks/ghostexec_unsloth_grpo_hf_api.ipynb +337 -64
- openenv_ghostexec.egg-info/PKG-INFO +5 -0
- openenv_ghostexec.egg-info/SOURCES.txt +2 -2
- openenv_ghostexec.egg-info/requires.txt +6 -0
- outputs/logs/episode_rewards.jsonl +0 -0
- pyproject.toml +6 -0
- scripts/eval_reward_ablation.py +64 -0
- scripts/plot_training_report.py +170 -0
- scripts/train_sft_then_grpo.py +641 -0
- server/ghostexec_environment.py +74 -0
- server/reward.py +166 -1
- tests/test_phase4.py +25 -0
- uv.lock +0 -0
README.md
CHANGED
|
@@ -215,6 +215,55 @@ set GHOSTEXEC_WS_BASE_URL=http://127.0.0.1:8000
|
|
| 215 |
uv run pytest tests/test_complete_integration.py::test_ghostexec_env_client_against_live_url_if_set -q
|
| 216 |
```
|
| 217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
---
|
| 219 |
|
| 220 |
## Hugging Face Spaces
|
|
|
|
| 215 |
uv run pytest tests/test_complete_integration.py::test_ghostexec_env_client_against_live_url_if_set -q
|
| 216 |
```
|
| 217 |
|
| 218 |
+
Post-training plot pack (loss + reward + components + baseline bar):
|
| 219 |
+
|
| 220 |
+
```bash
|
| 221 |
+
uv run python scripts/plot_training_report.py \
|
| 222 |
+
--trainer-history outputs/trainer_state.json \
|
| 223 |
+
--reward-csv outputs/reward_log.csv \
|
| 224 |
+
--baselines-json outputs/compliance_manifest.json \
|
| 225 |
+
--out-dir outputs/plots
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
The script writes:
|
| 229 |
+
- `outputs/plots/loss_curve.png`
|
| 230 |
+
- `outputs/plots/reward_curve.png`
|
| 231 |
+
- `outputs/plots/components_curve.png`
|
| 232 |
+
- `outputs/plots/baseline_comparison.png`
|
| 233 |
+
|
| 234 |
+
SFT before GRPO (with partial live-env usage during SFT data generation and GRPO rewards):
|
| 235 |
+
|
| 236 |
+
```bash
|
| 237 |
+
uv run python scripts/train_sft_then_grpo.py \
|
| 238 |
+
--model-preset small_iter_fast \
|
| 239 |
+
--training-preset hackathon_turbo \
|
| 240 |
+
--env-url http://127.0.0.1:8000 \
|
| 241 |
+
--generate-sft-from-env \
|
| 242 |
+
--sft-samples 120 \
|
| 243 |
+
--max-sft-steps 60 \
|
| 244 |
+
--max-grpo-steps 120 \
|
| 245 |
+
--env-reward-scale 1.0 \
|
| 246 |
+
--local-reward-scale 0.35 \
|
| 247 |
+
--complexity-curriculum easy_to_full \
|
| 248 |
+
--curriculum-ramp-ratio 0.60
|
| 249 |
+
```
|
| 250 |
+
|
| 251 |
+
This performs:
|
| 252 |
+
- SFT warm-start on JSONL (`prompt` + `completion`) generated from live `/reset` briefings.
|
| 253 |
+
- GRPO continuation from the SFT adapter.
|
| 254 |
+
- Mixed reward shaping where env-derived reward remains active and local shaping can be down-weighted/up-weighted via scales.
|
| 255 |
+
- Optional complexity curriculum (`easy_to_full`) that starts with stronger scaffold/local signals and anneals to env-dominant reward later.
|
| 256 |
+
- Stability-first optimization defaults (cosine schedule + warmup + grad clipping + higher GRPO KL beta) and optional guardrails:
|
| 257 |
+
- `--reward-ema-decay 0..1` smooths the *env* reward channel (defaults come from `--training-preset`).
|
| 258 |
+
- omit `--no-stability-tripwire` to enable early stopping when logs show repeated “env reward down + loss up” (GRPO) or repeated loss blow-ups (SFT).
|
| 259 |
+
|
| 260 |
+
Recommended model strategy for hackathon iteration speed:
|
| 261 |
+
- Start with `--model-preset small_iter_fast` (`unsloth/Qwen2.5-3B-Instruct`) + QLoRA.
|
| 262 |
+
- Run many short SFT->GRPO loops, improve reward signals, then scale model size only after curves stabilize.
|
| 263 |
+
- Use larger presets only when memory + runtime are consistently stable.
|
| 264 |
+
- Use `--training-preset hackathon_turbo` to apply stable aggressive defaults for iterative win-rate.
|
| 265 |
+
- Script prints SFT/GRPO LoRA delta checks; if deltas are near zero it stops, so you never mistake a no-op run for real finetuning.
|
| 266 |
+
|
| 267 |
---
|
| 268 |
|
| 269 |
## Hugging Face Spaces
|
graders.py
CHANGED
|
@@ -7,28 +7,57 @@ rewards in `server/reward.py`. The hackathon validator reads `openenv.yaml`
|
|
| 7 |
"""
|
| 8 |
from __future__ import annotations
|
| 9 |
|
| 10 |
-
|
|
|
|
| 11 |
|
| 12 |
STRICT_MIN = 0.01
|
| 13 |
STRICT_MAX = 0.99
|
| 14 |
|
| 15 |
|
| 16 |
def _bounded(value: float) -> float:
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
def _as_reward_list(trajectory: dict | None) -> List[float]:
|
| 21 |
payload = trajectory or {}
|
|
|
|
|
|
|
| 22 |
rewards = payload.get("rewards")
|
| 23 |
if isinstance(rewards, list) and rewards:
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
if "score" in payload:
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
reward = payload.get("reward")
|
| 28 |
if isinstance(reward, dict) and "total" in reward:
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
if reward is not None:
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
return []
|
| 33 |
|
| 34 |
|
|
|
|
| 7 |
"""
|
| 8 |
from __future__ import annotations
|
| 9 |
|
| 10 |
+
import math
|
| 11 |
+
from typing import List
|
| 12 |
|
| 13 |
STRICT_MIN = 0.01
|
| 14 |
STRICT_MAX = 0.99
|
| 15 |
|
| 16 |
|
| 17 |
def _bounded(value: float) -> float:
|
| 18 |
+
try:
|
| 19 |
+
v = round(float(value), 4)
|
| 20 |
+
except (TypeError, ValueError):
|
| 21 |
+
return 0.5
|
| 22 |
+
if not math.isfinite(v):
|
| 23 |
+
return 0.5
|
| 24 |
+
return min(max(v, STRICT_MIN), STRICT_MAX)
|
| 25 |
|
| 26 |
|
| 27 |
def _as_reward_list(trajectory: dict | None) -> List[float]:
|
| 28 |
payload = trajectory or {}
|
| 29 |
+
if not isinstance(payload, dict):
|
| 30 |
+
return []
|
| 31 |
rewards = payload.get("rewards")
|
| 32 |
if isinstance(rewards, list) and rewards:
|
| 33 |
+
out: List[float] = []
|
| 34 |
+
for r in rewards:
|
| 35 |
+
try:
|
| 36 |
+
rv = float(r)
|
| 37 |
+
except (TypeError, ValueError):
|
| 38 |
+
continue
|
| 39 |
+
if math.isfinite(rv):
|
| 40 |
+
out.append(rv)
|
| 41 |
+
return out
|
| 42 |
if "score" in payload:
|
| 43 |
+
try:
|
| 44 |
+
v = float(payload["score"])
|
| 45 |
+
return [v] if math.isfinite(v) else []
|
| 46 |
+
except (TypeError, ValueError):
|
| 47 |
+
return []
|
| 48 |
reward = payload.get("reward")
|
| 49 |
if isinstance(reward, dict) and "total" in reward:
|
| 50 |
+
try:
|
| 51 |
+
v = float(reward["total"])
|
| 52 |
+
return [v] if math.isfinite(v) else []
|
| 53 |
+
except (TypeError, ValueError):
|
| 54 |
+
return []
|
| 55 |
if reward is not None:
|
| 56 |
+
try:
|
| 57 |
+
v = float(reward)
|
| 58 |
+
return [v] if math.isfinite(v) else []
|
| 59 |
+
except (TypeError, ValueError):
|
| 60 |
+
return []
|
| 61 |
return []
|
| 62 |
|
| 63 |
|
models.py
CHANGED
|
@@ -195,6 +195,13 @@ class RewardBreakdown(BaseModel):
|
|
| 195 |
conflict: float = 0.0
|
| 196 |
relationship: float = 0.0
|
| 197 |
task: float = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
weighted_base: float = 0.0
|
| 199 |
output_scale: float = 1.0
|
| 200 |
invalid_step_adjustment: float = 0.0
|
|
|
|
| 195 |
conflict: float = 0.0
|
| 196 |
relationship: float = 0.0
|
| 197 |
task: float = 0.0
|
| 198 |
+
shaping_synergy: float = 0.0
|
| 199 |
+
shaping_tradeoff: float = 0.0
|
| 200 |
+
shaping_potential: float = 0.0
|
| 201 |
+
shaping_scaffold: float = 0.0
|
| 202 |
+
shaping_quality: float = 0.0
|
| 203 |
+
shaping_total: float = 0.0
|
| 204 |
+
shaping_to_base_ratio: float = 0.0
|
| 205 |
weighted_base: float = 0.0
|
| 206 |
output_scale: float = 1.0
|
| 207 |
invalid_step_adjustment: float = 0.0
|
notebooks/ghostexec_unsloth_grpo_hf_api.ipynb
CHANGED
|
@@ -4,15 +4,15 @@
|
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"metadata": {},
|
| 6 |
"source": [
|
| 7 |
-
"# Ghostexec — Unsloth + TRL GRPO against the deployed HF Space API\n",
|
| 8 |
"\n",
|
| 9 |
-
"Post-train `unsloth/Llama-3.2-3B-Instruct` with GRPO where
|
| 10 |
"\n",
|
| 11 |
"- Live endpoint: `https://modelbuilderhq-ghostexec.hf.space`\n",
|
| 12 |
-
"- Algorithm: TRL `0.22.2` `GRPOTrainer` (no vLLM — HF `generate()` path)\n",
|
| 13 |
-
"- Base: `unsloth/
|
| 14 |
-
"- Curriculum:
|
| 15 |
-
"- Rewards:
|
| 16 |
"\n",
|
| 17 |
"### Help Guide phase map (notebook sections mirror `[Participant Help Guide] §18`)\n",
|
| 18 |
"| Phase | Where |\n",
|
|
@@ -21,12 +21,13 @@
|
|
| 21 |
"| 2 Build the environment | section 2 (already deployed; health check here) |\n",
|
| 22 |
"| 3 Build rewards | section 3 |\n",
|
| 23 |
"| 4 Deploy | section 4 (confirm) |\n",
|
| 24 |
-
"| 5 Train small | section 5 (Stage B) |\n",
|
| 25 |
"| 6 Inspect for hacking | section 6 |\n",
|
| 26 |
"| 7 Add curriculum | section 7 (Stages C + D) |\n",
|
| 27 |
"| 8 Train bigger | section 8 (knobs, not action) |\n",
|
| 28 |
"| 9 Save and demo | section 9 |"
|
| 29 |
-
]
|
|
|
|
| 30 |
},
|
| 31 |
{
|
| 32 |
"cell_type": "markdown",
|
|
@@ -90,7 +91,8 @@
|
|
| 90 |
"from typing import Any\n",
|
| 91 |
"\n",
|
| 92 |
"GHOSTEXEC_ENV_URL = os.environ.get(\"GHOSTEXEC_ENV_URL\", \"https://modelbuilderhq-ghostexec.hf.space\")\n",
|
| 93 |
-
"
|
|
|
|
| 94 |
"RUN_NAME = os.environ.get(\"RUN_NAME\", \"ghostexec-unsloth-grpo\")\n",
|
| 95 |
"HUB_REPO_ID = os.environ.get(\"HUB_REPO_ID\", \"\")\n",
|
| 96 |
"OUT = pathlib.Path(\"/content/ghostexec_out\") if os.path.exists(\"/content\") else pathlib.Path(\"./ghostexec_out\")\n",
|
|
@@ -175,8 +177,13 @@
|
|
| 175 |
"source": [
|
| 176 |
"### 2.2 Verifier sanity check (Help Guide §8)\n",
|
| 177 |
"\n",
|
| 178 |
-
"
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
},
|
| 181 |
{
|
| 182 |
"cell_type": "code",
|
|
@@ -188,33 +195,75 @@
|
|
| 188 |
"]\n",
|
| 189 |
"\n",
|
| 190 |
"def _smoke_action(action_type: str) -> dict:\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
" return {\n",
|
| 192 |
-
"
|
| 193 |
-
" \"email_id\":
|
| 194 |
-
" \"message_body\":
|
| 195 |
-
" \"meeting_id\":
|
| 196 |
-
" \"new_time\":
|
| 197 |
-
" \"reason\":
|
| 198 |
-
" \"task_id\":
|
| 199 |
-
" \"contact_name\": \"
|
| 200 |
-
" \"message\": \"\",\n",
|
| 201 |
" }\n",
|
| 202 |
"\n",
|
| 203 |
"rewards_by_action: dict[str, float] = {}\n",
|
|
|
|
|
|
|
| 204 |
"for at in LEGAL_ACTION_TYPES:\n",
|
| 205 |
" env.reset()\n",
|
| 206 |
-
" r,
|
| 207 |
" rewards_by_action[at] = round(r, 4)\n",
|
| 208 |
-
"
|
|
|
|
|
|
|
|
|
|
| 209 |
"\n",
|
| 210 |
"uniq = set(rewards_by_action.values())\n",
|
| 211 |
"assert len(uniq) > 1, \"Verifier is constant across actions — env can't teach anything.\"\n",
|
| 212 |
-
"
|
| 213 |
-
"
|
| 214 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
],
|
| 216 |
"execution_count": null,
|
| 217 |
-
"outputs": []
|
|
|
|
| 218 |
},
|
| 219 |
{
|
| 220 |
"cell_type": "markdown",
|
|
@@ -264,6 +313,39 @@
|
|
| 264 |
" try: return parse_action_strict(text)\n",
|
| 265 |
" except Exception: return {\"action_type\": \"do_nothing\"}\n",
|
| 266 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
"assert parse_action_strict('```json\\n{\"action_type\":\"archive_email\",\"email_id\":\"email_01\"}\\n```')[\"action_type\"] == \"archive_email\"\n",
|
| 268 |
"assert parse_action(\"garbage\")[\"action_type\"] == \"do_nothing\"\n",
|
| 269 |
"print(\"parser OK\")"
|
|
@@ -280,43 +362,115 @@
|
|
| 280 |
" return c[0].get(\"content\", \"\")\n",
|
| 281 |
" return c if isinstance(c, str) else str(c)\n",
|
| 282 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
"def env_reward(completions, prompts=None, **_) -> list[float]:\n",
|
| 284 |
" out: list[float] = []\n",
|
| 285 |
" for c in completions:\n",
|
| 286 |
" text = _completion_text(c)\n",
|
| 287 |
-
" action = parse_action(text)\n",
|
| 288 |
" try:\n",
|
| 289 |
" env.reset()\n",
|
| 290 |
" r, _ = env.step(action)\n",
|
| 291 |
" except Exception:\n",
|
| 292 |
" r = -1.0\n",
|
| 293 |
-
" out.append(float(r))\n",
|
| 294 |
" return out\n",
|
| 295 |
"\n",
|
|
|
|
| 296 |
"def format_reward(completions, **_) -> list[float]:\n",
|
| 297 |
" out: list[float] = []\n",
|
| 298 |
" for c in completions:\n",
|
| 299 |
" text = _completion_text(c)\n",
|
| 300 |
" try:\n",
|
| 301 |
-
" parse_action_strict(text)
|
|
|
|
| 302 |
" except Exception:\n",
|
| 303 |
-
" out.append(-0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
" return out\n",
|
| 305 |
"\n",
|
|
|
|
| 306 |
"def anti_idle_reward(completions, **_) -> list[float]:\n",
|
| 307 |
" out: list[float] = []\n",
|
| 308 |
" for c in completions:\n",
|
| 309 |
" text = _completion_text(c)\n",
|
| 310 |
" act = parse_action(text)\n",
|
| 311 |
-
" out.append(-0.
|
| 312 |
" return out\n",
|
| 313 |
"\n",
|
|
|
|
| 314 |
"_dummy = '{\"action_type\":\"archive_email\",\"email_id\":\"email_01\"}'\n",
|
|
|
|
| 315 |
"print(\"format :\", format_reward([_dummy]))\n",
|
|
|
|
| 316 |
"print(\"anti_idle:\", anti_idle_reward([_dummy]))"
|
| 317 |
],
|
| 318 |
"execution_count": null,
|
| 319 |
-
"outputs": []
|
|
|
|
| 320 |
},
|
| 321 |
{
|
| 322 |
"cell_type": "code",
|
|
@@ -325,24 +479,32 @@
|
|
| 325 |
"from transformers import TrainerCallback\n",
|
| 326 |
"\n",
|
| 327 |
"class HackingTripwire(TrainerCallback):\n",
|
| 328 |
-
" \"\"\"Stop training on mode collapse or reward-format divergence
|
| 329 |
-
" def __init__(self, min_unique_ratio: float = 0.2):\n",
|
| 330 |
" self.min_unique_ratio = min_unique_ratio\n",
|
|
|
|
| 331 |
"\n",
|
| 332 |
" def on_log(self, args, state, control, logs=None, **kw):\n",
|
| 333 |
" logs = logs or {}\n",
|
| 334 |
" uniq = logs.get(\"completions/unique_ratio\") or logs.get(\"completions/mean_unique\")\n",
|
| 335 |
" env_r = logs.get(\"rewards/env_reward/mean\")\n",
|
| 336 |
" fmt_r = logs.get(\"rewards/format_reward/mean\")\n",
|
|
|
|
| 337 |
" if uniq is not None and uniq < self.min_unique_ratio:\n",
|
| 338 |
" print(f\"[TRIPWIRE] unique_ratio={uniq:.2f} < {self.min_unique_ratio} — stopping.\")\n",
|
| 339 |
" control.should_training_stop = True\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
" if env_r is not None and fmt_r is not None and env_r > 0.8 and fmt_r < 0.0:\n",
|
| 341 |
" print(f\"[TRIPWIRE] env_r={env_r:.2f} but fmt_r={fmt_r:.2f} — possible hack. stopping.\")\n",
|
| 342 |
" control.should_training_stop = True"
|
| 343 |
],
|
| 344 |
"execution_count": null,
|
| 345 |
-
"outputs": []
|
|
|
|
| 346 |
},
|
| 347 |
{
|
| 348 |
"cell_type": "markdown",
|
|
@@ -357,10 +519,11 @@
|
|
| 357 |
"cell_type": "markdown",
|
| 358 |
"metadata": {},
|
| 359 |
"source": [
|
| 360 |
-
"## Phase 5 — Train small\n",
|
| 361 |
"\n",
|
| 362 |
-
"Load `unsloth/Llama-3.2-3B-Instruct` in 4-bit with Unsloth, attach LoRA,
|
| 363 |
-
]
|
|
|
|
| 364 |
},
|
| 365 |
{
|
| 366 |
"cell_type": "code",
|
|
@@ -408,11 +571,17 @@
|
|
| 408 |
" \"Legal action_type values: reply_email, archive_email, reschedule_meeting, cancel_meeting, \"\n",
|
| 409 |
" \"complete_task, delegate_task, send_message, do_nothing.\\n\\n\"\n",
|
| 410 |
" \"Output ONLY a compact JSON object with these keys (no prose, no code fences):\\n\"\n",
|
| 411 |
-
" \"{\\\"action_type\\\":
|
| 412 |
" \"\\\"meeting_id\\\": \\\"\\\", \\\"new_time\\\": \\\"\\\", \\\"reason\\\": \\\"\\\", \\\"task_id\\\": \\\"\\\", \"\n",
|
| 413 |
" \"\\\"contact_name\\\": \\\"\\\", \\\"message\\\": \\\"\\\"}.\\n\\n\"\n",
|
| 414 |
-
" \"
|
| 415 |
-
" \"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
")\n",
|
| 417 |
"\n",
|
| 418 |
"def build_prompt(briefing: str) -> list[dict]:\n",
|
|
@@ -425,7 +594,8 @@
|
|
| 425 |
" return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)"
|
| 426 |
],
|
| 427 |
"execution_count": null,
|
| 428 |
-
"outputs": []
|
|
|
|
| 429 |
},
|
| 430 |
{
|
| 431 |
"cell_type": "code",
|
|
@@ -492,9 +662,12 @@
|
|
| 492 |
" pad_token_id=tokenizer.pad_token_id,\n",
|
| 493 |
" )\n",
|
| 494 |
" completion = tokenizer.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
|
| 495 |
-
" action = parse_action(completion)\n",
|
| 496 |
" env.reset()\n",
|
| 497 |
-
"
|
|
|
|
|
|
|
|
|
|
| 498 |
" rs.append(r)\n",
|
| 499 |
" FastLanguageModel.for_training(model)\n",
|
| 500 |
" return rs\n",
|
|
@@ -510,26 +683,71 @@
|
|
| 510 |
"baselines = {\"random\": random_rewards, \"frozen\": frozen_rewards}"
|
| 511 |
],
|
| 512 |
"execution_count": null,
|
| 513 |
-
"outputs": []
|
|
|
|
| 514 |
},
|
| 515 |
{
|
| 516 |
"cell_type": "markdown",
|
| 517 |
"metadata": {},
|
| 518 |
"source": [
|
| 519 |
-
"### 5.2 Stage B — first GRPO stage (
|
| 520 |
"\n",
|
| 521 |
-
"
|
| 522 |
-
|
|
|
|
|
|
|
|
|
|
| 523 |
},
|
| 524 |
{
|
| 525 |
"cell_type": "code",
|
| 526 |
"metadata": {},
|
| 527 |
"source": [
|
| 528 |
-
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 529 |
"\n",
|
| 530 |
-
"reward_funcs = [env_reward, format_reward, anti_idle_reward]\n",
|
| 531 |
"stage_logs: dict[str, list[dict]] = {}\n",
|
| 532 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
"def grpo_config(name: str, *, temperature: float, num_generations: int, max_steps: int, lr: float) -> GRPOConfig:\n",
|
| 534 |
" return GRPOConfig(\n",
|
| 535 |
" output_dir=str(OUT / f\"stage_{name}\"),\n",
|
|
@@ -537,20 +755,38 @@
|
|
| 537 |
" gradient_accumulation_steps=4,\n",
|
| 538 |
" num_generations=num_generations,\n",
|
| 539 |
" max_prompt_length=1920,\n",
|
| 540 |
-
" max_completion_length=
|
| 541 |
" temperature=temperature,\n",
|
| 542 |
" learning_rate=lr,\n",
|
| 543 |
" beta=0.04,\n",
|
| 544 |
" max_steps=max_steps,\n",
|
| 545 |
" logging_steps=1,\n",
|
| 546 |
-
" bf16=
|
|
|
|
| 547 |
" report_to=\"none\",\n",
|
| 548 |
" save_strategy=\"no\",\n",
|
| 549 |
" remove_unused_columns=False,\n",
|
| 550 |
" log_completions=True,\n",
|
| 551 |
" )\n",
|
| 552 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
"def run_stage(name: str, **kw) -> None:\n",
|
|
|
|
| 554 |
" print(f\"\\n=== Stage {name} → {kw} ===\")\n",
|
| 555 |
" trainer = GRPOTrainer(\n",
|
| 556 |
" model=policy,\n",
|
|
@@ -567,10 +803,12 @@
|
|
| 567 |
" tokenizer.save_pretrained(adapter_dir)\n",
|
| 568 |
" print(f\"stage {name} adapter → {adapter_dir}\")\n",
|
| 569 |
"\n",
|
| 570 |
-
"
|
|
|
|
| 571 |
],
|
| 572 |
"execution_count": null,
|
| 573 |
-
"outputs": []
|
|
|
|
| 574 |
},
|
| 575 |
{
|
| 576 |
"cell_type": "markdown",
|
|
@@ -610,8 +848,15 @@
|
|
| 610 |
"source": [
|
| 611 |
"## Phase 7 — Add curriculum\n",
|
| 612 |
"\n",
|
| 613 |
-
"The deployed Space scenario is fixed, so
|
| 614 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 615 |
},
|
| 616 |
{
|
| 617 |
"cell_type": "code",
|
|
@@ -682,20 +927,38 @@
|
|
| 682 |
"plt.show()\n",
|
| 683 |
"\n",
|
| 684 |
"rows = []\n",
|
|
|
|
| 685 |
"step_counter = 0\n",
|
| 686 |
"for name, log in stage_logs.items():\n",
|
| 687 |
" for entry in log:\n",
|
| 688 |
" r = entry.get(\"rewards/env_reward/mean\", entry.get(\"reward\"))\n",
|
| 689 |
-
" if
|
|
|
|
|
|
|
|
|
|
| 690 |
" step_counter += 1\n",
|
| 691 |
" rows.append({\n",
|
| 692 |
-
" \"stage\": name,
|
|
|
|
|
|
|
| 693 |
" \"fmt\": entry.get(\"rewards/format_reward/mean\", 0.0),\n",
|
|
|
|
| 694 |
" \"idle\": entry.get(\"rewards/anti_idle_reward/mean\", 0.0),\n",
|
| 695 |
" })\n",
|
|
|
|
| 696 |
"df = pd.DataFrame(rows)\n",
|
| 697 |
"df.to_csv(OUT / \"reward_log.csv\", index=False)\n",
|
| 698 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 699 |
"if not df.empty:\n",
|
| 700 |
" plt.figure(figsize=(8, 4))\n",
|
| 701 |
" for name, sub in df.groupby(\"stage\"):\n",
|
|
@@ -708,6 +971,7 @@
|
|
| 708 |
" plt.figure(figsize=(8, 4))\n",
|
| 709 |
" plt.plot(df[\"global_step\"], df[\"env\"], label=\"env_reward\")\n",
|
| 710 |
" plt.plot(df[\"global_step\"], df[\"fmt\"], label=\"format_reward\")\n",
|
|
|
|
| 711 |
" plt.plot(df[\"global_step\"], df[\"idle\"], label=\"anti_idle_reward\")\n",
|
| 712 |
" plt.xlabel(\"global step\"); plt.ylabel(\"mean component reward\")\n",
|
| 713 |
" plt.title(\"Reward components — hacking-watch\")\n",
|
|
@@ -717,7 +981,8 @@
|
|
| 717 |
" print(\"No numeric reward log found — skipping curve plots.\")"
|
| 718 |
],
|
| 719 |
"execution_count": null,
|
| 720 |
-
"outputs": []
|
|
|
|
| 721 |
},
|
| 722 |
{
|
| 723 |
"cell_type": "code",
|
|
@@ -753,16 +1018,23 @@
|
|
| 753 |
" \"env_url\": GHOSTEXEC_ENV_URL,\n",
|
| 754 |
" \"model\": MODEL_ID,\n",
|
| 755 |
" \"run\": RUN_NAME,\n",
|
| 756 |
-
" \"stack\": {\"unsloth\": True, \"trl\": \"0.22.2\"},\n",
|
| 757 |
" \"rewards\": {\n",
|
| 758 |
" \"random_mean\": summary[\"random\"],\n",
|
| 759 |
" \"frozen_mean\": summary[\"frozen\"],\n",
|
| 760 |
" \"trained_mean\": summary[\"trained\"],\n",
|
| 761 |
" \"improvement_vs_frozen\": summary[\"trained\"] - summary[\"frozen\"],\n",
|
| 762 |
" },\n",
|
| 763 |
-
" \"stages\": list(stage_logs.keys()),\n",
|
| 764 |
-
" \"reward_fns\": [\"env_reward\", \"format_reward\", \"anti_idle_reward\"],\n",
|
| 765 |
-
" \"curriculum\": \
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 766 |
" \"tripwire\": \"HackingTripwire (unique_ratio<0.2 or env↑/fmt↓)\",\n",
|
| 767 |
" \"adapter_path\": str(final_adapter),\n",
|
| 768 |
" \"mean_space_latency_ms\": round(sum(env.latency_ms) / max(len(env.latency_ms), 1), 1),\n",
|
|
@@ -773,7 +1045,8 @@
|
|
| 773 |
"print(\"\\nmanifest →\", OUT / \"manifest.json\")"
|
| 774 |
],
|
| 775 |
"execution_count": null,
|
| 776 |
-
"outputs": []
|
|
|
|
| 777 |
}
|
| 778 |
],
|
| 779 |
"metadata": {
|
|
|
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"metadata": {},
|
| 6 |
"source": [
|
| 7 |
+
"# Ghostexec — Unsloth + TRL SFT -> GRPO against the deployed HF Space API\n",
|
| 8 |
"\n",
|
| 9 |
+
"Post-train `unsloth/Llama-3.2-3B-Instruct` with **SFT warmup first** and then GRPO, where rewards are fetched over HTTP from the **live** Ghostexec OpenEnv Space.\n",
|
| 10 |
"\n",
|
| 11 |
"- Live endpoint: `https://modelbuilderhq-ghostexec.hf.space`\n",
|
| 12 |
+
"- Algorithm: TRL `0.22.2` `SFTTrainer` -> `GRPOTrainer` (no vLLM — HF `generate()` path)\n",
|
| 13 |
+
"- Base (recommended for fast winning iterations): `unsloth/Qwen2.5-3B-Instruct` (4-bit) + LoRA r=16 + bf16\n",
|
| 14 |
+
"- Curriculum: **easy -> full** annealing (strong local scaffold early, env-dominant later)\n",
|
| 15 |
+
"- Rewards: four **independent** functions — `env_reward` (live Space) / `format_reward` / `semantic_action_reward` / `anti_idle_reward`\n",
|
| 16 |
"\n",
|
| 17 |
"### Help Guide phase map (notebook sections mirror `[Participant Help Guide] §18`)\n",
|
| 18 |
"| Phase | Where |\n",
|
|
|
|
| 21 |
"| 2 Build the environment | section 2 (already deployed; health check here) |\n",
|
| 22 |
"| 3 Build rewards | section 3 |\n",
|
| 23 |
"| 4 Deploy | section 4 (confirm) |\n",
|
| 24 |
+
"| 5 Train small | section 5 (SFT + Stage B) |\n",
|
| 25 |
"| 6 Inspect for hacking | section 6 |\n",
|
| 26 |
"| 7 Add curriculum | section 7 (Stages C + D) |\n",
|
| 27 |
"| 8 Train bigger | section 8 (knobs, not action) |\n",
|
| 28 |
"| 9 Save and demo | section 9 |"
|
| 29 |
+
],
|
| 30 |
+
"id": "33566e3d"
|
| 31 |
},
|
| 32 |
{
|
| 33 |
"cell_type": "markdown",
|
|
|
|
| 91 |
"from typing import Any\n",
|
| 92 |
"\n",
|
| 93 |
"GHOSTEXEC_ENV_URL = os.environ.get(\"GHOSTEXEC_ENV_URL\", \"https://modelbuilderhq-ghostexec.hf.space\")\n",
|
| 94 |
+
"# Small-model-first default for rapid iteration and higher success probability.\n",
|
| 95 |
+
"MODEL_ID = os.environ.get(\"MODEL_ID\", \"unsloth/Qwen2.5-3B-Instruct\")\n",
|
| 96 |
"RUN_NAME = os.environ.get(\"RUN_NAME\", \"ghostexec-unsloth-grpo\")\n",
|
| 97 |
"HUB_REPO_ID = os.environ.get(\"HUB_REPO_ID\", \"\")\n",
|
| 98 |
"OUT = pathlib.Path(\"/content/ghostexec_out\") if os.path.exists(\"/content\") else pathlib.Path(\"./ghostexec_out\")\n",
|
|
|
|
| 177 |
"source": [
|
| 178 |
"### 2.2 Verifier sanity check (Help Guide §8)\n",
|
| 179 |
"\n",
|
| 180 |
+
"**Colab / stale cells:** If the traceback mentions **`do_nothing is not the worst/floor`** on **line ~28**, you are running **old cached notebook code** (that assert was removed). Use **Runtime → Disconnect and delete runtime**, then **re-clone** the repo or **re-download** this notebook from GitHub and run from the top.\n",
|
| 181 |
+
"\n",
|
| 182 |
+
"**If every proactive action prints `-0.25` and only `do_nothing` is `-0.15`:** every non-idle smoke is an **invalid step** (wrong ids like `email_01`, or an outdated `_smoke_action`). This cell expects **real `phase2_core` ids** (`e01`, `e09`, `m02`, …) — see `_smoke_action` below.\n",
|
| 183 |
+
"\n",
|
| 184 |
+
"Fire every legal `action_type` once with **semantically valid** payloads (real ids from `scenarios/phase2_core.json`). Fake ids deserialize but fail validation (−0.25 invalid-step) and are not a fair probe. Also: **`do_nothing` is not guaranteed to be the lowest reward** — a valid but harmful action (e.g. cancelling an important meeting) can push the weighted score below the idle penalty. We instead assert **non-idle smokes are `step_ok=True`** and **`do_nothing` scores below a benign `reply_email` on `e01`**. If rewards are all identical, abort — GRPO cannot learn from a degenerate verifier."
|
| 185 |
+
],
|
| 186 |
+
"id": "b747bc4e"
|
| 187 |
},
|
| 188 |
{
|
| 189 |
"cell_type": "code",
|
|
|
|
| 195 |
"]\n",
|
| 196 |
"\n",
|
| 197 |
"def _smoke_action(action_type: str) -> dict:\n",
|
| 198 |
+
" # Real IDs from phase2_core scenario\n",
|
| 199 |
+
" base = {\"action_type\": action_type, \"message\": \"\"}\n",
|
| 200 |
+
"\n",
|
| 201 |
+
" if action_type == \"reply_email\":\n",
|
| 202 |
+
" return {**base, \"email_id\": \"e01\", \"message_body\": \"Acknowledged — on it now.\"}\n",
|
| 203 |
+
" if action_type == \"archive_email\":\n",
|
| 204 |
+
" return {**base, \"email_id\": \"e09\"}\n",
|
| 205 |
+
" if action_type == \"reschedule_meeting\":\n",
|
| 206 |
+
" return {\n",
|
| 207 |
+
" **base,\n",
|
| 208 |
+
" \"meeting_id\": \"m02\",\n",
|
| 209 |
+
" \"new_time\": \"2026-04-21T18:00:00\",\n",
|
| 210 |
+
" \"reason\": \"freeing the morning block\",\n",
|
| 211 |
+
" }\n",
|
| 212 |
+
" if action_type == \"cancel_meeting\":\n",
|
| 213 |
+
" return {**base, \"meeting_id\": \"m10\", \"reason\": \"smoke test cancel\"}\n",
|
| 214 |
+
" if action_type == \"complete_task\":\n",
|
| 215 |
+
" return {**base, \"task_id\": \"t07\"}\n",
|
| 216 |
+
" if action_type == \"delegate_task\":\n",
|
| 217 |
+
" return {**base, \"task_id\": \"t08\", \"contact_name\": \"Jordan Lee\"}\n",
|
| 218 |
+
" if action_type == \"send_message\":\n",
|
| 219 |
+
" return {\n",
|
| 220 |
+
" **base,\n",
|
| 221 |
+
" \"contact_name\": \"Jamie Liu\",\n",
|
| 222 |
+
" \"message_body\": \"Quick sync when you have a minute.\",\n",
|
| 223 |
+
" }\n",
|
| 224 |
+
"\n",
|
| 225 |
+
" # do_nothing\n",
|
| 226 |
" return {\n",
|
| 227 |
+
" **base,\n",
|
| 228 |
+
" \"email_id\": \"\",\n",
|
| 229 |
+
" \"message_body\": \"\",\n",
|
| 230 |
+
" \"meeting_id\": \"\",\n",
|
| 231 |
+
" \"new_time\": \"\",\n",
|
| 232 |
+
" \"reason\": \"\",\n",
|
| 233 |
+
" \"task_id\": \"\",\n",
|
| 234 |
+
" \"contact_name\": \"\",\n",
|
|
|
|
| 235 |
" }\n",
|
| 236 |
"\n",
|
| 237 |
"rewards_by_action: dict[str, float] = {}\n",
|
| 238 |
+
"step_ok_by_action: dict[str, bool | None] = {}\n",
|
| 239 |
+
"\n",
|
| 240 |
"for at in LEGAL_ACTION_TYPES:\n",
|
| 241 |
" env.reset()\n",
|
| 242 |
+
" r, raw = env.step(_smoke_action(at))\n",
|
| 243 |
" rewards_by_action[at] = round(r, 4)\n",
|
| 244 |
+
" obs = raw.get(\"observation\") or {}\n",
|
| 245 |
+
" step_ok_by_action[at] = (obs.get(\"metadata\") or {}).get(\"step_ok\")\n",
|
| 246 |
+
"\n",
|
| 247 |
+
"print(json.dumps({\"reward\": rewards_by_action, \"step_ok\": step_ok_by_action}, indent=2))\n",
|
| 248 |
"\n",
|
| 249 |
"uniq = set(rewards_by_action.values())\n",
|
| 250 |
"assert len(uniq) > 1, \"Verifier is constant across actions — env can't teach anything.\"\n",
|
| 251 |
+
"\n",
|
| 252 |
+
"# All non-idle smokes must be valid\n",
|
| 253 |
+
"for at in LEGAL_ACTION_TYPES:\n",
|
| 254 |
+
" if at == \"do_nothing\":\n",
|
| 255 |
+
" continue\n",
|
| 256 |
+
" assert step_ok_by_action.get(at) is True, f\"{at} smoke is invalid; check IDs.\"\n",
|
| 257 |
+
"\n",
|
| 258 |
+
"# Idle should be worse than benign good action\n",
|
| 259 |
+
"assert rewards_by_action[\"do_nothing\"] < rewards_by_action[\"reply_email\"] - 1e-6, \\\n",
|
| 260 |
+
" \"do_nothing should score below reply_email(e01).\"\n",
|
| 261 |
+
"\n",
|
| 262 |
+
"print(\"\\nverifier OK — rewards vary, smokes are valid, do_nothing < reply_email(e01).\")"
|
| 263 |
],
|
| 264 |
"execution_count": null,
|
| 265 |
+
"outputs": [],
|
| 266 |
+
"id": "5ed1a9bc"
|
| 267 |
},
|
| 268 |
{
|
| 269 |
"cell_type": "markdown",
|
|
|
|
| 313 |
" try: return parse_action_strict(text)\n",
|
| 314 |
" except Exception: return {\"action_type\": \"do_nothing\"}\n",
|
| 315 |
"\n",
|
| 316 |
+
"LEGAL_ACTION_TYPES = {\n",
|
| 317 |
+
" \"reply_email\", \"archive_email\", \"reschedule_meeting\", \"cancel_meeting\",\n",
|
| 318 |
+
" \"complete_task\", \"delegate_task\", \"send_message\", \"do_nothing\",\n",
|
| 319 |
+
"}\n",
|
| 320 |
+
"LEGAL_ACTION_KEYS = {\n",
|
| 321 |
+
" \"action_type\", \"email_id\", \"message_body\", \"meeting_id\",\n",
|
| 322 |
+
" \"new_time\", \"reason\", \"task_id\", \"contact_name\", \"message\",\n",
|
| 323 |
+
"}\n",
|
| 324 |
+
"\n",
|
| 325 |
+
"\n",
|
| 326 |
+
"def sanitize_action(raw: dict) -> dict:\n",
|
| 327 |
+
" \"\"\"Keep only legal Ghostexec fields and coerce malformed IDs/actions safely.\"\"\"\n",
|
| 328 |
+
" action = {k: v for k, v in (raw or {}).items() if k in LEGAL_ACTION_KEYS}\n",
|
| 329 |
+
"\n",
|
| 330 |
+
" at = str(action.get(\"action_type\", \"do_nothing\"))\n",
|
| 331 |
+
" if at not in LEGAL_ACTION_TYPES:\n",
|
| 332 |
+
" at = \"do_nothing\"\n",
|
| 333 |
+
" action[\"action_type\"] = at\n",
|
| 334 |
+
"\n",
|
| 335 |
+
" # Common model mistake: writes message text into `message` instead of `message_body`.\n",
|
| 336 |
+
" if at in {\"reply_email\", \"send_message\"}:\n",
|
| 337 |
+
" if not action.get(\"message_body\") and action.get(\"message\"):\n",
|
| 338 |
+
" action[\"message_body\"] = action[\"message\"]\n",
|
| 339 |
+
"\n",
|
| 340 |
+
" if \"email_id\" in action and not re.fullmatch(r\"e\\d{2}\", str(action[\"email_id\"])):\n",
|
| 341 |
+
" action[\"email_id\"] = \"\"\n",
|
| 342 |
+
" if \"meeting_id\" in action and not re.fullmatch(r\"m\\d{2}\", str(action[\"meeting_id\"])):\n",
|
| 343 |
+
" action[\"meeting_id\"] = \"\"\n",
|
| 344 |
+
" if \"task_id\" in action and not re.fullmatch(r\"t\\d{2}\", str(action[\"task_id\"])):\n",
|
| 345 |
+
" action[\"task_id\"] = \"\"\n",
|
| 346 |
+
"\n",
|
| 347 |
+
" return action\n",
|
| 348 |
+
"\n",
|
| 349 |
"assert parse_action_strict('```json\\n{\"action_type\":\"archive_email\",\"email_id\":\"email_01\"}\\n```')[\"action_type\"] == \"archive_email\"\n",
|
| 350 |
"assert parse_action(\"garbage\")[\"action_type\"] == \"do_nothing\"\n",
|
| 351 |
"print(\"parser OK\")"
|
|
|
|
| 362 |
" return c[0].get(\"content\", \"\")\n",
|
| 363 |
" return c if isinstance(c, str) else str(c)\n",
|
| 364 |
"\n",
|
| 365 |
+
"\n",
|
| 366 |
+
"def _prompt_to_text(p) -> str:\n",
|
| 367 |
+
" if isinstance(p, list) and p and isinstance(p[-1], dict):\n",
|
| 368 |
+
" return str(p[-1].get(\"content\", \"\"))\n",
|
| 369 |
+
" if isinstance(p, dict):\n",
|
| 370 |
+
" return str(p.get(\"content\", \"\"))\n",
|
| 371 |
+
" return str(p)\n",
|
| 372 |
+
"\n",
|
| 373 |
+
"\n",
|
| 374 |
+
"# Curriculum scalars are updated per stage: easy -> full.\n",
|
| 375 |
+
"CURRENT_ENV_SCALE = 0.85\n",
|
| 376 |
+
"CURRENT_LOCAL_SCALE = 0.60\n",
|
| 377 |
+
"\n",
|
| 378 |
+
"\n",
|
| 379 |
"def env_reward(completions, prompts=None, **_) -> list[float]:\n",
|
| 380 |
" out: list[float] = []\n",
|
| 381 |
" for c in completions:\n",
|
| 382 |
" text = _completion_text(c)\n",
|
| 383 |
+
" action = sanitize_action(parse_action(text))\n",
|
| 384 |
" try:\n",
|
| 385 |
" env.reset()\n",
|
| 386 |
" r, _ = env.step(action)\n",
|
| 387 |
" except Exception:\n",
|
| 388 |
" r = -1.0\n",
|
| 389 |
+
" out.append(float(r) * CURRENT_ENV_SCALE)\n",
|
| 390 |
" return out\n",
|
| 391 |
"\n",
|
| 392 |
+
"\n",
|
| 393 |
"def format_reward(completions, **_) -> list[float]:\n",
|
| 394 |
" out: list[float] = []\n",
|
| 395 |
" for c in completions:\n",
|
| 396 |
" text = _completion_text(c)\n",
|
| 397 |
" try:\n",
|
| 398 |
+
" parse_action_strict(text)\n",
|
| 399 |
+
" out.append(0.12 * CURRENT_LOCAL_SCALE)\n",
|
| 400 |
" except Exception:\n",
|
| 401 |
+
" out.append(-0.20 * CURRENT_LOCAL_SCALE)\n",
|
| 402 |
+
" return out\n",
|
| 403 |
+
"\n",
|
| 404 |
+
"\n",
|
| 405 |
+
"def semantic_action_reward(completions, prompts=None, **_) -> list[float]:\n",
|
| 406 |
+
" \"\"\"\n",
|
| 407 |
+
" Reward canonical, briefing-grounded action payloads before env call.\n",
|
| 408 |
+
" Scaled by CURRENT_LOCAL_SCALE for easy->full curriculum annealing.\n",
|
| 409 |
+
" \"\"\"\n",
|
| 410 |
+
" out: list[float] = []\n",
|
| 411 |
+
" for i, c in enumerate(completions):\n",
|
| 412 |
+
" text = _completion_text(c)\n",
|
| 413 |
+
" act = parse_action(text)\n",
|
| 414 |
+
" at = act.get(\"action_type\", \"do_nothing\")\n",
|
| 415 |
+
"\n",
|
| 416 |
+
" prompt_text = \"\"\n",
|
| 417 |
+
" if prompts is not None and i < len(prompts):\n",
|
| 418 |
+
" prompt_text = _prompt_to_text(prompts[i])\n",
|
| 419 |
+
"\n",
|
| 420 |
+
" def present(tok: str) -> bool:\n",
|
| 421 |
+
" return bool(tok) and re.search(rf\"\\b{re.escape(tok)}\\b\", prompt_text) is not None\n",
|
| 422 |
+
"\n",
|
| 423 |
+
" r = -0.30\n",
|
| 424 |
+
" if at == \"do_nothing\":\n",
|
| 425 |
+
" r = -0.05\n",
|
| 426 |
+
" elif at == \"reply_email\":\n",
|
| 427 |
+
" eid = act.get(\"email_id\", \"\")\n",
|
| 428 |
+
" mb = (act.get(\"message_body\", \"\") or \"\").strip()\n",
|
| 429 |
+
" r = 0.30 if present(eid) and bool(re.fullmatch(r\"e\\d{2}\", eid)) and mb else -0.30\n",
|
| 430 |
+
" elif at == \"archive_email\":\n",
|
| 431 |
+
" eid = act.get(\"email_id\", \"\")\n",
|
| 432 |
+
" r = 0.30 if present(eid) and bool(re.fullmatch(r\"e\\d{2}\", eid)) else -0.30\n",
|
| 433 |
+
" elif at == \"reschedule_meeting\":\n",
|
| 434 |
+
" mid = act.get(\"meeting_id\", \"\")\n",
|
| 435 |
+
" nt = (act.get(\"new_time\", \"\") or \"\").strip()\n",
|
| 436 |
+
" r = 0.30 if present(mid) and bool(re.fullmatch(r\"m\\d{2}\", mid)) and nt else -0.30\n",
|
| 437 |
+
" elif at == \"cancel_meeting\":\n",
|
| 438 |
+
" mid = act.get(\"meeting_id\", \"\")\n",
|
| 439 |
+
" r = 0.30 if present(mid) and bool(re.fullmatch(r\"m\\d{2}\", mid)) else -0.30\n",
|
| 440 |
+
" elif at == \"complete_task\":\n",
|
| 441 |
+
" tid = act.get(\"task_id\", \"\")\n",
|
| 442 |
+
" r = 0.30 if present(tid) and bool(re.fullmatch(r\"t\\d{2}\", tid)) else -0.30\n",
|
| 443 |
+
" elif at == \"delegate_task\":\n",
|
| 444 |
+
" tid = act.get(\"task_id\", \"\")\n",
|
| 445 |
+
" cn = (act.get(\"contact_name\", \"\") or \"\").strip()\n",
|
| 446 |
+
" r = 0.30 if present(tid) and bool(re.fullmatch(r\"t\\d{2}\", tid)) and (cn in prompt_text) else -0.30\n",
|
| 447 |
+
" elif at == \"send_message\":\n",
|
| 448 |
+
" cn = (act.get(\"contact_name\", \"\") or \"\").strip()\n",
|
| 449 |
+
" mb = (act.get(\"message_body\", \"\") or \"\").strip()\n",
|
| 450 |
+
" r = 0.30 if cn and (cn in prompt_text) and mb else -0.30\n",
|
| 451 |
+
"\n",
|
| 452 |
+
" out.append(float(r) * CURRENT_LOCAL_SCALE)\n",
|
| 453 |
" return out\n",
|
| 454 |
"\n",
|
| 455 |
+
"\n",
|
| 456 |
"def anti_idle_reward(completions, **_) -> list[float]:\n",
|
| 457 |
" out: list[float] = []\n",
|
| 458 |
" for c in completions:\n",
|
| 459 |
" text = _completion_text(c)\n",
|
| 460 |
" act = parse_action(text)\n",
|
| 461 |
+
" out.append((-0.28 if act.get(\"action_type\") == \"do_nothing\" else 0.03) * CURRENT_LOCAL_SCALE)\n",
|
| 462 |
" return out\n",
|
| 463 |
"\n",
|
| 464 |
+
"\n",
|
| 465 |
"_dummy = '{\"action_type\":\"archive_email\",\"email_id\":\"email_01\"}'\n",
|
| 466 |
+
"print(\"env :\", env_reward([_dummy]))\n",
|
| 467 |
"print(\"format :\", format_reward([_dummy]))\n",
|
| 468 |
+
"print(\"semantic :\", semantic_action_reward([_dummy], prompts=[\"... e01 e09 t07 m02 Jamie Liu ...\"]))\n",
|
| 469 |
"print(\"anti_idle:\", anti_idle_reward([_dummy]))"
|
| 470 |
],
|
| 471 |
"execution_count": null,
|
| 472 |
+
"outputs": [],
|
| 473 |
+
"id": "3bd66b49"
|
| 474 |
},
|
| 475 |
{
|
| 476 |
"cell_type": "code",
|
|
|
|
| 479 |
"from transformers import TrainerCallback\n",
|
| 480 |
"\n",
|
| 481 |
"class HackingTripwire(TrainerCallback):\n",
|
| 482 |
+
" \"\"\"Stop training on mode collapse, invalid-action collapse, or reward-format divergence.\"\"\"\n",
|
| 483 |
+
" def __init__(self, min_unique_ratio: float = 0.2, invalid_floor: float = -0.24):\n",
|
| 484 |
" self.min_unique_ratio = min_unique_ratio\n",
|
| 485 |
+
" self.invalid_floor = invalid_floor\n",
|
| 486 |
"\n",
|
| 487 |
" def on_log(self, args, state, control, logs=None, **kw):\n",
|
| 488 |
" logs = logs or {}\n",
|
| 489 |
" uniq = logs.get(\"completions/unique_ratio\") or logs.get(\"completions/mean_unique\")\n",
|
| 490 |
" env_r = logs.get(\"rewards/env_reward/mean\")\n",
|
| 491 |
" fmt_r = logs.get(\"rewards/format_reward/mean\")\n",
|
| 492 |
+
"\n",
|
| 493 |
" if uniq is not None and uniq < self.min_unique_ratio:\n",
|
| 494 |
" print(f\"[TRIPWIRE] unique_ratio={uniq:.2f} < {self.min_unique_ratio} — stopping.\")\n",
|
| 495 |
" control.should_training_stop = True\n",
|
| 496 |
+
"\n",
|
| 497 |
+
" if env_r is not None and env_r <= self.invalid_floor:\n",
|
| 498 |
+
" print(f\"[TRIPWIRE] env_reward mean stuck at {env_r:.2f} (invalid-action collapse). stopping.\")\n",
|
| 499 |
+
" control.should_training_stop = True\n",
|
| 500 |
+
"\n",
|
| 501 |
" if env_r is not None and fmt_r is not None and env_r > 0.8 and fmt_r < 0.0:\n",
|
| 502 |
" print(f\"[TRIPWIRE] env_r={env_r:.2f} but fmt_r={fmt_r:.2f} — possible hack. stopping.\")\n",
|
| 503 |
" control.should_training_stop = True"
|
| 504 |
],
|
| 505 |
"execution_count": null,
|
| 506 |
+
"outputs": [],
|
| 507 |
+
"id": "a6a37dad"
|
| 508 |
},
|
| 509 |
{
|
| 510 |
"cell_type": "markdown",
|
|
|
|
| 519 |
"cell_type": "markdown",
|
| 520 |
"metadata": {},
|
| 521 |
"source": [
|
| 522 |
+
"## Phase 5 — Train small (SFT warmup -> GRPO)\n",
|
| 523 |
"\n",
|
| 524 |
+
"Load `unsloth/Llama-3.2-3B-Instruct` in 4-bit with Unsloth, attach LoRA, run a **short SFT warmup first**, then run GRPO. vLLM is not used anywhere in this notebook — rollouts go through the standard HF `generate()` path inside `GRPOTrainer`."
|
| 525 |
+
],
|
| 526 |
+
"id": "428d6377"
|
| 527 |
},
|
| 528 |
{
|
| 529 |
"cell_type": "code",
|
|
|
|
| 571 |
" \"Legal action_type values: reply_email, archive_email, reschedule_meeting, cancel_meeting, \"\n",
|
| 572 |
" \"complete_task, delegate_task, send_message, do_nothing.\\n\\n\"\n",
|
| 573 |
" \"Output ONLY a compact JSON object with these keys (no prose, no code fences):\\n\"\n",
|
| 574 |
+
" \"{\\\"action_type\\\": \\\"\\\", \\\"email_id\\\": \\\"\\\", \\\"message_body\\\": \\\"\\\", \"\n",
|
| 575 |
" \"\\\"meeting_id\\\": \\\"\\\", \\\"new_time\\\": \\\"\\\", \\\"reason\\\": \\\"\\\", \\\"task_id\\\": \\\"\\\", \"\n",
|
| 576 |
" \"\\\"contact_name\\\": \\\"\\\", \\\"message\\\": \\\"\\\"}.\\n\\n\"\n",
|
| 577 |
+
" \"ID RULES:\\n\"\n",
|
| 578 |
+
" \"- email_id must be an email token from briefing like e01, e02, ...\\n\"\n",
|
| 579 |
+
" \"- meeting_id must be a meeting token like m01, m02, ...\\n\"\n",
|
| 580 |
+
" \"- task_id must be a task token like t01, t02, ...\\n\"\n",
|
| 581 |
+
" \"- contact_name must exactly match a contact shown in briefing.\\n\"\n",
|
| 582 |
+
" \"- Never use subject/body/description text as an ID.\\n\"\n",
|
| 583 |
+
" \"- If you cannot find a valid ID for your chosen action, output {\\\"action_type\\\":\\\"do_nothing\\\"}.\\n\\n\"\n",
|
| 584 |
+
" \"Prefer high-impact valid actions; avoid do_nothing when critical items are unresolved.\"\n",
|
| 585 |
")\n",
|
| 586 |
"\n",
|
| 587 |
"def build_prompt(briefing: str) -> list[dict]:\n",
|
|
|
|
| 594 |
" return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)"
|
| 595 |
],
|
| 596 |
"execution_count": null,
|
| 597 |
+
"outputs": [],
|
| 598 |
+
"id": "883dce70"
|
| 599 |
},
|
| 600 |
{
|
| 601 |
"cell_type": "code",
|
|
|
|
| 662 |
" pad_token_id=tokenizer.pad_token_id,\n",
|
| 663 |
" )\n",
|
| 664 |
" completion = tokenizer.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
|
| 665 |
+
" action = sanitize_action(parse_action(completion))\n",
|
| 666 |
" env.reset()\n",
|
| 667 |
+
" try:\n",
|
| 668 |
+
" r, _ = env.step(action)\n",
|
| 669 |
+
" except RuntimeError:\n",
|
| 670 |
+
" r, _ = env.step({\"action_type\": \"do_nothing\"})\n",
|
| 671 |
" rs.append(r)\n",
|
| 672 |
" FastLanguageModel.for_training(model)\n",
|
| 673 |
" return rs\n",
|
|
|
|
| 683 |
"baselines = {\"random\": random_rewards, \"frozen\": frozen_rewards}"
|
| 684 |
],
|
| 685 |
"execution_count": null,
|
| 686 |
+
"outputs": [],
|
| 687 |
+
"id": "9c2ff7d6"
|
| 688 |
},
|
| 689 |
{
|
| 690 |
"cell_type": "markdown",
|
| 691 |
"metadata": {},
|
| 692 |
"source": [
|
| 693 |
+
"### 5.2 Stage B — first GRPO stage (easy->full curriculum starts here)\n",
|
| 694 |
"\n",
|
| 695 |
+
"We run a short SFT warmup first, then GRPO Stage B with stronger local scaffold weights (`CURRENT_LOCAL_SCALE`) and slightly lower env scale (`CURRENT_ENV_SCALE`).\n",
|
| 696 |
+
"\n",
|
| 697 |
+
"As stages progress (B -> C -> D), the notebook anneals toward full env-dominant training."
|
| 698 |
+
],
|
| 699 |
+
"id": "018d2c7c"
|
| 700 |
},
|
| 701 |
{
|
| 702 |
"cell_type": "code",
|
| 703 |
"metadata": {},
|
| 704 |
"source": [
|
| 705 |
+
"from trl import GRPOConfig, GRPOTrainer, SFTConfig, SFTTrainer\n",
|
| 706 |
"\n",
|
| 707 |
+
"reward_funcs = [env_reward, format_reward, semantic_action_reward, anti_idle_reward]\n",
|
| 708 |
"stage_logs: dict[str, list[dict]] = {}\n",
|
| 709 |
"\n",
|
| 710 |
+
"# -------- SFT warmup --------\n",
|
| 711 |
+
"def _heuristic_action_for_sft(briefing: str) -> dict:\n",
|
| 712 |
+
" b = briefing.lower()\n",
|
| 713 |
+
" if \"e01\" in b:\n",
|
| 714 |
+
" return {\"action_type\": \"reply_email\", \"email_id\": \"e01\", \"message_body\": \"Acknowledged, sharing an update shortly.\"}\n",
|
| 715 |
+
" if \"m02\" in b:\n",
|
| 716 |
+
" return {\"action_type\": \"reschedule_meeting\", \"meeting_id\": \"m02\", \"new_time\": \"2026-04-21T18:00:00\", \"reason\": \"resolve overlap\"}\n",
|
| 717 |
+
" if \"t06\" in b:\n",
|
| 718 |
+
" return {\"action_type\": \"complete_task\", \"task_id\": \"t06\"}\n",
|
| 719 |
+
" return {\"action_type\": \"do_nothing\"}\n",
|
| 720 |
+
"\n",
|
| 721 |
+
"sft_rows = []\n",
|
| 722 |
+
"for b in briefings:\n",
|
| 723 |
+
" msgs = build_prompt(b)\n",
|
| 724 |
+
" prompt_txt = render_chat(msgs)\n",
|
| 725 |
+
" completion_txt = json.dumps(_heuristic_action_for_sft(b), ensure_ascii=True)\n",
|
| 726 |
+
" sft_rows.append({\"prompt_text\": prompt_txt, \"completion_text\": completion_txt})\n",
|
| 727 |
+
"\n",
|
| 728 |
+
"sft_ds = Dataset.from_list(sft_rows)\n",
|
| 729 |
+
"sft_cfg = SFTConfig(\n",
|
| 730 |
+
" output_dir=str(OUT / \"sft_warmup\"),\n",
|
| 731 |
+
" max_steps=30,\n",
|
| 732 |
+
" per_device_train_batch_size=1,\n",
|
| 733 |
+
" gradient_accumulation_steps=4,\n",
|
| 734 |
+
" learning_rate=2e-5,\n",
|
| 735 |
+
" logging_steps=5,\n",
|
| 736 |
+
" report_to=\"none\",\n",
|
| 737 |
+
")\n",
|
| 738 |
+
"sft_trainer = SFTTrainer(\n",
|
| 739 |
+
" model=policy,\n",
|
| 740 |
+
" processing_class=tokenizer,\n",
|
| 741 |
+
" train_dataset=sft_ds,\n",
|
| 742 |
+
" args=sft_cfg,\n",
|
| 743 |
+
" dataset_text_field=\"prompt_text\",\n",
|
| 744 |
+
" formatting_func=lambda ex: [f\"{p}{c}\" for p, c in zip(ex[\"prompt_text\"], ex[\"completion_text\"])],\n",
|
| 745 |
+
")\n",
|
| 746 |
+
"print(\"\\n=== SFT warmup ===\")\n",
|
| 747 |
+
"sft_trainer.train()\n",
|
| 748 |
+
"policy = sft_trainer.model\n",
|
| 749 |
+
"\n",
|
| 750 |
+
"\n",
|
| 751 |
"def grpo_config(name: str, *, temperature: float, num_generations: int, max_steps: int, lr: float) -> GRPOConfig:\n",
|
| 752 |
" return GRPOConfig(\n",
|
| 753 |
" output_dir=str(OUT / f\"stage_{name}\"),\n",
|
|
|
|
| 755 |
" gradient_accumulation_steps=4,\n",
|
| 756 |
" num_generations=num_generations,\n",
|
| 757 |
" max_prompt_length=1920,\n",
|
| 758 |
+
" max_completion_length=48,\n",
|
| 759 |
" temperature=temperature,\n",
|
| 760 |
" learning_rate=lr,\n",
|
| 761 |
" beta=0.04,\n",
|
| 762 |
" max_steps=max_steps,\n",
|
| 763 |
" logging_steps=1,\n",
|
| 764 |
+
" bf16=False,\n",
|
| 765 |
+
" fp16=True,\n",
|
| 766 |
" report_to=\"none\",\n",
|
| 767 |
" save_strategy=\"no\",\n",
|
| 768 |
" remove_unused_columns=False,\n",
|
| 769 |
" log_completions=True,\n",
|
| 770 |
" )\n",
|
| 771 |
"\n",
|
| 772 |
+
"\n",
|
| 773 |
+
"def set_curriculum_scales(stage_name: str) -> None:\n",
|
| 774 |
+
" global CURRENT_ENV_SCALE, CURRENT_LOCAL_SCALE\n",
|
| 775 |
+
" # easy -> full complexity curriculum\n",
|
| 776 |
+
" if stage_name == \"B\":\n",
|
| 777 |
+
" CURRENT_ENV_SCALE = 0.85\n",
|
| 778 |
+
" CURRENT_LOCAL_SCALE = 0.60\n",
|
| 779 |
+
" elif stage_name == \"C\":\n",
|
| 780 |
+
" CURRENT_ENV_SCALE = 0.95\n",
|
| 781 |
+
" CURRENT_LOCAL_SCALE = 0.40\n",
|
| 782 |
+
" else:\n",
|
| 783 |
+
" CURRENT_ENV_SCALE = 1.00\n",
|
| 784 |
+
" CURRENT_LOCAL_SCALE = 0.25\n",
|
| 785 |
+
" print(f\"curriculum[{stage_name}] env={CURRENT_ENV_SCALE:.2f} local={CURRENT_LOCAL_SCALE:.2f}\")\n",
|
| 786 |
+
"\n",
|
| 787 |
+
"\n",
|
| 788 |
"def run_stage(name: str, **kw) -> None:\n",
|
| 789 |
+
" set_curriculum_scales(name)\n",
|
| 790 |
" print(f\"\\n=== Stage {name} → {kw} ===\")\n",
|
| 791 |
" trainer = GRPOTrainer(\n",
|
| 792 |
" model=policy,\n",
|
|
|
|
| 803 |
" tokenizer.save_pretrained(adapter_dir)\n",
|
| 804 |
" print(f\"stage {name} adapter → {adapter_dir}\")\n",
|
| 805 |
"\n",
|
| 806 |
+
"\n",
|
| 807 |
+
"run_stage(\"B\", temperature=0.8, num_generations=2, max_steps=20, lr=5e-6)"
|
| 808 |
],
|
| 809 |
"execution_count": null,
|
| 810 |
+
"outputs": [],
|
| 811 |
+
"id": "10b073d0"
|
| 812 |
},
|
| 813 |
{
|
| 814 |
"cell_type": "markdown",
|
|
|
|
| 848 |
"source": [
|
| 849 |
"## Phase 7 — Add curriculum\n",
|
| 850 |
"\n",
|
| 851 |
+
"The deployed Space scenario is fixed, so curriculum is applied through both:\n",
|
| 852 |
+
"\n",
|
| 853 |
+
"1. **Exploration schedule** (temperature/lr across stages)\n",
|
| 854 |
+
"2. **Complexity curriculum (easy -> full)** via reward scales:\n",
|
| 855 |
+
" - Stage B: stronger local scaffold guidance\n",
|
| 856 |
+
" - Stage C: mixed guidance\n",
|
| 857 |
+
" - Stage D: env-dominant optimization"
|
| 858 |
+
],
|
| 859 |
+
"id": "524f6691"
|
| 860 |
},
|
| 861 |
{
|
| 862 |
"cell_type": "code",
|
|
|
|
| 927 |
"plt.show()\n",
|
| 928 |
"\n",
|
| 929 |
"rows = []\n",
|
| 930 |
+
"loss_rows = []\n",
|
| 931 |
"step_counter = 0\n",
|
| 932 |
"for name, log in stage_logs.items():\n",
|
| 933 |
" for entry in log:\n",
|
| 934 |
" r = entry.get(\"rewards/env_reward/mean\", entry.get(\"reward\"))\n",
|
| 935 |
+
" if \"loss\" in entry:\n",
|
| 936 |
+
" loss_rows.append({\"stage\": name, \"global_step\": step_counter + 1, \"loss\": entry[\"loss\"]})\n",
|
| 937 |
+
" if r is None:\n",
|
| 938 |
+
" continue\n",
|
| 939 |
" step_counter += 1\n",
|
| 940 |
" rows.append({\n",
|
| 941 |
+
" \"stage\": name,\n",
|
| 942 |
+
" \"global_step\": step_counter,\n",
|
| 943 |
+
" \"env\": r,\n",
|
| 944 |
" \"fmt\": entry.get(\"rewards/format_reward/mean\", 0.0),\n",
|
| 945 |
+
" \"semantic\": entry.get(\"rewards/semantic_action_reward/mean\", 0.0),\n",
|
| 946 |
" \"idle\": entry.get(\"rewards/anti_idle_reward/mean\", 0.0),\n",
|
| 947 |
" })\n",
|
| 948 |
+
"\n",
|
| 949 |
"df = pd.DataFrame(rows)\n",
|
| 950 |
"df.to_csv(OUT / \"reward_log.csv\", index=False)\n",
|
| 951 |
"\n",
|
| 952 |
+
"loss_df = pd.DataFrame(loss_rows)\n",
|
| 953 |
+
"if not loss_df.empty:\n",
|
| 954 |
+
" plt.figure(figsize=(8, 4))\n",
|
| 955 |
+
" for name, sub in loss_df.groupby(\"stage\"):\n",
|
| 956 |
+
" plt.plot(sub[\"global_step\"], sub[\"loss\"], label=f\"stage {name}\")\n",
|
| 957 |
+
" plt.xlabel(\"global step\"); plt.ylabel(\"loss\")\n",
|
| 958 |
+
" plt.title(\"Ghostexec SFT+GRPO — loss vs step\")\n",
|
| 959 |
+
" plt.legend(); plt.tight_layout()\n",
|
| 960 |
+
" plt.savefig(OUT / \"loss_curve.png\", dpi=150); plt.show()\n",
|
| 961 |
+
"\n",
|
| 962 |
"if not df.empty:\n",
|
| 963 |
" plt.figure(figsize=(8, 4))\n",
|
| 964 |
" for name, sub in df.groupby(\"stage\"):\n",
|
|
|
|
| 971 |
" plt.figure(figsize=(8, 4))\n",
|
| 972 |
" plt.plot(df[\"global_step\"], df[\"env\"], label=\"env_reward\")\n",
|
| 973 |
" plt.plot(df[\"global_step\"], df[\"fmt\"], label=\"format_reward\")\n",
|
| 974 |
+
" plt.plot(df[\"global_step\"], df[\"semantic\"], label=\"semantic_action_reward\")\n",
|
| 975 |
" plt.plot(df[\"global_step\"], df[\"idle\"], label=\"anti_idle_reward\")\n",
|
| 976 |
" plt.xlabel(\"global step\"); plt.ylabel(\"mean component reward\")\n",
|
| 977 |
" plt.title(\"Reward components — hacking-watch\")\n",
|
|
|
|
| 981 |
" print(\"No numeric reward log found — skipping curve plots.\")"
|
| 982 |
],
|
| 983 |
"execution_count": null,
|
| 984 |
+
"outputs": [],
|
| 985 |
+
"id": "5ccb3832"
|
| 986 |
},
|
| 987 |
{
|
| 988 |
"cell_type": "code",
|
|
|
|
| 1018 |
" \"env_url\": GHOSTEXEC_ENV_URL,\n",
|
| 1019 |
" \"model\": MODEL_ID,\n",
|
| 1020 |
" \"run\": RUN_NAME,\n",
|
| 1021 |
+
" \"stack\": {\"unsloth\": True, \"trl\": \"0.22.2\", \"pipeline\": \"SFT->GRPO\"},\n",
|
| 1022 |
" \"rewards\": {\n",
|
| 1023 |
" \"random_mean\": summary[\"random\"],\n",
|
| 1024 |
" \"frozen_mean\": summary[\"frozen\"],\n",
|
| 1025 |
" \"trained_mean\": summary[\"trained\"],\n",
|
| 1026 |
" \"improvement_vs_frozen\": summary[\"trained\"] - summary[\"frozen\"],\n",
|
| 1027 |
" },\n",
|
| 1028 |
+
" \"stages\": [\"SFT\"] + list(stage_logs.keys()),\n",
|
| 1029 |
+
" \"reward_fns\": [\"env_reward\", \"format_reward\", \"semantic_action_reward\", \"anti_idle_reward\"],\n",
|
| 1030 |
+
" \"curriculum\": {\n",
|
| 1031 |
+
" \"type\": \"easy_to_full\",\n",
|
| 1032 |
+
" \"stage_scales\": {\n",
|
| 1033 |
+
" \"B\": {\"env\": 0.85, \"local\": 0.60},\n",
|
| 1034 |
+
" \"C\": {\"env\": 0.95, \"local\": 0.40},\n",
|
| 1035 |
+
" \"D\": {\"env\": 1.00, \"local\": 0.25},\n",
|
| 1036 |
+
" },\n",
|
| 1037 |
+
" },\n",
|
| 1038 |
" \"tripwire\": \"HackingTripwire (unique_ratio<0.2 or env↑/fmt↓)\",\n",
|
| 1039 |
" \"adapter_path\": str(final_adapter),\n",
|
| 1040 |
" \"mean_space_latency_ms\": round(sum(env.latency_ms) / max(len(env.latency_ms), 1), 1),\n",
|
|
|
|
| 1045 |
"print(\"\\nmanifest →\", OUT / \"manifest.json\")"
|
| 1046 |
],
|
| 1047 |
"execution_count": null,
|
| 1048 |
+
"outputs": [],
|
| 1049 |
+
"id": "81fdfca3"
|
| 1050 |
}
|
| 1051 |
],
|
| 1052 |
"metadata": {
|
openenv_ghostexec.egg-info/PKG-INFO
CHANGED
|
@@ -13,3 +13,8 @@ Provides-Extra: constrained
|
|
| 13 |
Requires-Dist: lm-format-enforcer>=0.10; extra == "constrained"
|
| 14 |
Provides-Extra: constrained-outlines
|
| 15 |
Requires-Dist: outlines>=0.1; extra == "constrained-outlines"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
Requires-Dist: lm-format-enforcer>=0.10; extra == "constrained"
|
| 14 |
Provides-Extra: constrained-outlines
|
| 15 |
Requires-Dist: outlines>=0.1; extra == "constrained-outlines"
|
| 16 |
+
Provides-Extra: train
|
| 17 |
+
Requires-Dist: datasets>=2.20.0; extra == "train"
|
| 18 |
+
Requires-Dist: trl>=0.22.2; extra == "train"
|
| 19 |
+
Requires-Dist: transformers>=4.45.0; extra == "train"
|
| 20 |
+
Requires-Dist: accelerate>=0.34.0; extra == "train"
|
openenv_ghostexec.egg-info/SOURCES.txt
CHANGED
|
@@ -9,6 +9,7 @@ pyproject.toml
|
|
| 9 |
./client.py
|
| 10 |
./conftest.py
|
| 11 |
./graders.py
|
|
|
|
| 12 |
./models.py
|
| 13 |
./scenarios/dinner_disaster.json
|
| 14 |
./scenarios/monday_morning.json
|
|
@@ -41,5 +42,4 @@ tests/test_phase1.py
|
|
| 41 |
tests/test_phase2.py
|
| 42 |
tests/test_phase3.py
|
| 43 |
tests/test_phase4.py
|
| 44 |
-
tests/test_reward_dead_suite.py
|
| 45 |
-
tests/test_submission_plots_committed.py
|
|
|
|
| 9 |
./client.py
|
| 10 |
./conftest.py
|
| 11 |
./graders.py
|
| 12 |
+
./inference.py
|
| 13 |
./models.py
|
| 14 |
./scenarios/dinner_disaster.json
|
| 15 |
./scenarios/monday_morning.json
|
|
|
|
| 42 |
tests/test_phase2.py
|
| 43 |
tests/test_phase3.py
|
| 44 |
tests/test_phase4.py
|
| 45 |
+
tests/test_reward_dead_suite.py
|
|
|
openenv_ghostexec.egg-info/requires.txt
CHANGED
|
@@ -11,3 +11,9 @@ pytest>=8.0.0
|
|
| 11 |
pytest-cov>=4.0.0
|
| 12 |
pyyaml>=6.0.0
|
| 13 |
matplotlib>=3.8.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
pytest-cov>=4.0.0
|
| 12 |
pyyaml>=6.0.0
|
| 13 |
matplotlib>=3.8.0
|
| 14 |
+
|
| 15 |
+
[train]
|
| 16 |
+
datasets>=2.20.0
|
| 17 |
+
trl>=0.22.2
|
| 18 |
+
transformers>=4.45.0
|
| 19 |
+
accelerate>=0.34.0
|
outputs/logs/episode_rewards.jsonl
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pyproject.toml
CHANGED
|
@@ -42,6 +42,12 @@ constrained = [
|
|
| 42 |
constrained-outlines = [
|
| 43 |
"outlines>=0.1",
|
| 44 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
[project.scripts]
|
| 47 |
# Server entry point - enables running via: uv run --project . server
|
|
|
|
| 42 |
constrained-outlines = [
|
| 43 |
"outlines>=0.1",
|
| 44 |
]
|
| 45 |
+
train = [
|
| 46 |
+
"datasets>=2.20.0",
|
| 47 |
+
"trl>=0.22.2",
|
| 48 |
+
"transformers>=4.45.0",
|
| 49 |
+
"accelerate>=0.34.0",
|
| 50 |
+
]
|
| 51 |
|
| 52 |
[project.scripts]
|
| 53 |
# Server entry point - enables running via: uv run --project . server
|
scripts/eval_reward_ablation.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import statistics
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 9 |
+
sys.path.insert(0, str(ROOT.parent))
|
| 10 |
+
|
| 11 |
+
from ghostexec.models import GhostexecAction
|
| 12 |
+
from ghostexec.server.ghostexec_environment import GhostexecEnvironment
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _run_episode(mode: str, scenario: Path) -> float:
|
| 16 |
+
env = GhostexecEnvironment(scenario_path=scenario, reward_mode=mode)
|
| 17 |
+
env.reset()
|
| 18 |
+
actions = [
|
| 19 |
+
GhostexecAction(action_type="reschedule_meeting", meeting_id="m02", new_time="2026-04-21T18:00:00"),
|
| 20 |
+
GhostexecAction(action_type="reply_email", email_id="e01", message_body="Sharing revised numbers now."),
|
| 21 |
+
GhostexecAction(action_type="archive_email", email_id="e09"),
|
| 22 |
+
GhostexecAction(action_type="send_message", contact_name="Jordan Lee", message_body="Quick status sync."),
|
| 23 |
+
GhostexecAction(action_type="complete_task", task_id="t06"),
|
| 24 |
+
]
|
| 25 |
+
rewards = [float(env.step(a).reward or 0.0) for a in actions]
|
| 26 |
+
return statistics.fmean(rewards)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _run(mode: str, scenario: Path, episodes: int) -> dict[str, float]:
|
| 30 |
+
vals = [_run_episode(mode, scenario) for _ in range(episodes)]
|
| 31 |
+
return {
|
| 32 |
+
"mean": statistics.fmean(vals),
|
| 33 |
+
"std": statistics.pstdev(vals) if len(vals) > 1 else 0.0,
|
| 34 |
+
"min": min(vals),
|
| 35 |
+
"max": max(vals),
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def main() -> None:
|
| 40 |
+
parser = argparse.ArgumentParser(description="Reward-mode ablation for Ghostexec.")
|
| 41 |
+
parser.add_argument("--episodes", type=int, default=30)
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--scenario",
|
| 44 |
+
type=Path,
|
| 45 |
+
default=ROOT / "scenarios" / "phase2_core.json",
|
| 46 |
+
)
|
| 47 |
+
args = parser.parse_args()
|
| 48 |
+
|
| 49 |
+
modes = ("base", "full")
|
| 50 |
+
results = {m: _run(m, args.scenario, args.episodes) for m in modes}
|
| 51 |
+
print("Ghostexec reward ablation")
|
| 52 |
+
print(f"scenario={args.scenario} episodes={args.episodes}")
|
| 53 |
+
for m in modes:
|
| 54 |
+
r = results[m]
|
| 55 |
+
print(
|
| 56 |
+
f"{m:>5}: mean={r['mean']:.4f} std={r['std']:.4f} "
|
| 57 |
+
f"min={r['min']:.4f} max={r['max']:.4f}"
|
| 58 |
+
)
|
| 59 |
+
delta = results["full"]["mean"] - results["base"]["mean"]
|
| 60 |
+
print(f"delta(full-base)={delta:.4f}")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
main()
|
scripts/plot_training_report.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import pandas as pd
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _load_trainer_history(path: Path) -> list[dict[str, Any]]:
|
| 13 |
+
if not path.exists():
|
| 14 |
+
return []
|
| 15 |
+
data = json.loads(path.read_text(encoding="utf-8"))
|
| 16 |
+
if isinstance(data, dict) and isinstance(data.get("log_history"), list):
|
| 17 |
+
return [x for x in data["log_history"] if isinstance(x, dict)]
|
| 18 |
+
if isinstance(data, list):
|
| 19 |
+
return [x for x in data if isinstance(x, dict)]
|
| 20 |
+
return []
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _load_baselines(path: Path) -> dict[str, float]:
|
| 24 |
+
if not path.exists():
|
| 25 |
+
return {}
|
| 26 |
+
data = json.loads(path.read_text(encoding="utf-8"))
|
| 27 |
+
if isinstance(data, dict) and "rewards" in data and isinstance(data["rewards"], dict):
|
| 28 |
+
data = data["rewards"]
|
| 29 |
+
out: dict[str, float] = {}
|
| 30 |
+
for k in ("random", "frozen", "trained", "random_mean", "frozen_mean", "trained_mean"):
|
| 31 |
+
if k in data:
|
| 32 |
+
v = data[k]
|
| 33 |
+
name = k.replace("_mean", "")
|
| 34 |
+
out[name] = float(v)
|
| 35 |
+
return out
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _ensure_dir(path: Path) -> None:
|
| 39 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _plot_loss(history: list[dict[str, Any]], out_dir: Path) -> bool:
|
| 43 |
+
rows = []
|
| 44 |
+
for i, h in enumerate(history):
|
| 45 |
+
step = h.get("step", h.get("global_step", i))
|
| 46 |
+
if "loss" in h:
|
| 47 |
+
rows.append((float(step), float(h["loss"])))
|
| 48 |
+
if not rows:
|
| 49 |
+
return False
|
| 50 |
+
df = pd.DataFrame(rows, columns=["step", "loss"]).sort_values("step")
|
| 51 |
+
plt.figure(figsize=(9, 4.8))
|
| 52 |
+
plt.plot(df["step"], df["loss"], label="train_loss")
|
| 53 |
+
plt.xlabel("global step")
|
| 54 |
+
plt.ylabel("loss")
|
| 55 |
+
plt.title("Ghostexec training loss")
|
| 56 |
+
plt.grid(alpha=0.2)
|
| 57 |
+
plt.legend()
|
| 58 |
+
plt.tight_layout()
|
| 59 |
+
plt.savefig(out_dir / "loss_curve.png", dpi=150)
|
| 60 |
+
plt.close()
|
| 61 |
+
return True
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _plot_reward_components(reward_csv: Path, out_dir: Path) -> tuple[bool, bool]:
|
| 65 |
+
if not reward_csv.exists():
|
| 66 |
+
return False, False
|
| 67 |
+
df = pd.read_csv(reward_csv)
|
| 68 |
+
if "global_step" not in df.columns:
|
| 69 |
+
return False, False
|
| 70 |
+
|
| 71 |
+
made_reward_curve = False
|
| 72 |
+
for col in ("env", "reward", "mean_reward"):
|
| 73 |
+
if col in df.columns:
|
| 74 |
+
plt.figure(figsize=(9, 4.8))
|
| 75 |
+
plt.plot(df["global_step"], df[col], label=col)
|
| 76 |
+
plt.xlabel("global step")
|
| 77 |
+
plt.ylabel("reward")
|
| 78 |
+
plt.title("Ghostexec reward vs step")
|
| 79 |
+
plt.grid(alpha=0.2)
|
| 80 |
+
plt.legend()
|
| 81 |
+
plt.tight_layout()
|
| 82 |
+
plt.savefig(out_dir / "reward_curve.png", dpi=150)
|
| 83 |
+
plt.close()
|
| 84 |
+
made_reward_curve = True
|
| 85 |
+
break
|
| 86 |
+
|
| 87 |
+
component_cols = [c for c in ("env", "fmt", "semantic", "idle") if c in df.columns]
|
| 88 |
+
if len(component_cols) >= 2:
|
| 89 |
+
plt.figure(figsize=(9, 4.8))
|
| 90 |
+
for c in component_cols:
|
| 91 |
+
plt.plot(df["global_step"], df[c], label=c)
|
| 92 |
+
plt.xlabel("global step")
|
| 93 |
+
plt.ylabel("mean component reward")
|
| 94 |
+
plt.title("Reward components vs step")
|
| 95 |
+
plt.grid(alpha=0.2)
|
| 96 |
+
plt.legend()
|
| 97 |
+
plt.tight_layout()
|
| 98 |
+
plt.savefig(out_dir / "components_curve.png", dpi=150)
|
| 99 |
+
plt.close()
|
| 100 |
+
return made_reward_curve, True
|
| 101 |
+
return made_reward_curve, False
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _plot_baseline_bars(baselines: dict[str, float], out_dir: Path) -> bool:
|
| 105 |
+
needed = ("random", "frozen", "trained")
|
| 106 |
+
if not all(k in baselines for k in needed):
|
| 107 |
+
return False
|
| 108 |
+
names = list(needed)
|
| 109 |
+
vals = [baselines[n] for n in names]
|
| 110 |
+
colors = ["#888888", "#1f77b4", "#2ca02c"]
|
| 111 |
+
plt.figure(figsize=(8.2, 4.8))
|
| 112 |
+
plt.bar(names, vals, color=colors)
|
| 113 |
+
plt.ylabel("mean episode reward (higher is better)")
|
| 114 |
+
plt.title("Ghostexec: random vs frozen vs trained")
|
| 115 |
+
plt.tight_layout()
|
| 116 |
+
plt.savefig(out_dir / "baseline_comparison.png", dpi=150)
|
| 117 |
+
plt.close()
|
| 118 |
+
return True
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def main() -> None:
|
| 122 |
+
parser = argparse.ArgumentParser(description="Generate post-training Ghostexec plots.")
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--trainer-history",
|
| 125 |
+
type=Path,
|
| 126 |
+
default=Path("outputs/trainer_state.json"),
|
| 127 |
+
help="JSON with HF/Unsloth log history (trainer_state.json or list of logs).",
|
| 128 |
+
)
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--reward-csv",
|
| 131 |
+
type=Path,
|
| 132 |
+
default=Path("outputs/reward_log.csv"),
|
| 133 |
+
help="CSV containing global_step and reward columns.",
|
| 134 |
+
)
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--baselines-json",
|
| 137 |
+
type=Path,
|
| 138 |
+
default=Path("outputs/compliance_manifest.json"),
|
| 139 |
+
help="JSON containing random/frozen/trained means (or rewards object).",
|
| 140 |
+
)
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--out-dir",
|
| 143 |
+
type=Path,
|
| 144 |
+
default=Path("outputs/plots"),
|
| 145 |
+
help="Directory to save plot PNGs.",
|
| 146 |
+
)
|
| 147 |
+
args = parser.parse_args()
|
| 148 |
+
|
| 149 |
+
_ensure_dir(args.out_dir)
|
| 150 |
+
history = _load_trainer_history(args.trainer_history)
|
| 151 |
+
baselines = _load_baselines(args.baselines_json)
|
| 152 |
+
|
| 153 |
+
made_loss = _plot_loss(history, args.out_dir)
|
| 154 |
+
made_reward, made_components = _plot_reward_components(args.reward_csv, args.out_dir)
|
| 155 |
+
made_bars = _plot_baseline_bars(baselines, args.out_dir)
|
| 156 |
+
|
| 157 |
+
print("Generated plots:")
|
| 158 |
+
print(f"- loss_curve.png: {'yes' if made_loss else 'no (missing loss history)'}")
|
| 159 |
+
print(f"- reward_curve.png: {'yes' if made_reward else 'no (missing reward csv columns)'}")
|
| 160 |
+
print(
|
| 161 |
+
f"- components_curve.png: {'yes' if made_components else 'no (missing component columns)'}"
|
| 162 |
+
)
|
| 163 |
+
print(
|
| 164 |
+
f"- baseline_comparison.png: {'yes' if made_bars else 'no (missing random/frozen/trained means)'}"
|
| 165 |
+
)
|
| 166 |
+
print(f"Output directory: {args.out_dir.resolve()}")
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
if __name__ == "__main__":
|
| 170 |
+
main()
|
scripts/train_sft_then_grpo.py
ADDED
|
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import random
|
| 6 |
+
import re
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
import requests
|
| 11 |
+
from transformers import TrainerCallback
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
LEGAL_ACTION_TYPES = [
|
| 15 |
+
"reply_email",
|
| 16 |
+
"archive_email",
|
| 17 |
+
"reschedule_meeting",
|
| 18 |
+
"cancel_meeting",
|
| 19 |
+
"complete_task",
|
| 20 |
+
"delegate_task",
|
| 21 |
+
"send_message",
|
| 22 |
+
"do_nothing",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
MODEL_PRESETS: dict[str, str] = {
|
| 26 |
+
# Fast iteration winner preset: small, strong instruction following, QLoRA-friendly.
|
| 27 |
+
"small_iter_fast": "unsloth/Qwen2.5-3B-Instruct",
|
| 28 |
+
# Existing baseline used in this repo.
|
| 29 |
+
"balanced_3b": "unsloth/Llama-3.2-3B-Instruct",
|
| 30 |
+
# Larger option when compute budget is stable.
|
| 31 |
+
"bigger_4b": "unsloth/Qwen3-4B-Instruct-2507",
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
TRAINING_PRESETS: dict[str, dict[str, float | int | str]] = {
|
| 35 |
+
"hackathon_turbo": {
|
| 36 |
+
"max_sft_steps": 80,
|
| 37 |
+
"max_grpo_steps": 180,
|
| 38 |
+
"env_reward_scale": 1.00,
|
| 39 |
+
"local_reward_scale": 0.45,
|
| 40 |
+
"complexity_curriculum": "easy_to_full",
|
| 41 |
+
"curriculum_ramp_ratio": 0.65,
|
| 42 |
+
"sft_samples": 180,
|
| 43 |
+
# Optimizer / schedule knobs (stability-first for iterative winning runs)
|
| 44 |
+
"sft_lr": 1.2e-5,
|
| 45 |
+
"sft_grad_accum": 8,
|
| 46 |
+
"grpo_lr": 3.0e-6,
|
| 47 |
+
"grpo_grad_accum": 8,
|
| 48 |
+
"grpo_beta": 0.08,
|
| 49 |
+
"reward_ema_decay": 0.35,
|
| 50 |
+
},
|
| 51 |
+
# Quicker loop for smoke iterations on weaker hardware.
|
| 52 |
+
"quick_smoke": {
|
| 53 |
+
"max_sft_steps": 30,
|
| 54 |
+
"max_grpo_steps": 80,
|
| 55 |
+
"env_reward_scale": 0.95,
|
| 56 |
+
"local_reward_scale": 0.35,
|
| 57 |
+
"complexity_curriculum": "easy_to_full",
|
| 58 |
+
"curriculum_ramp_ratio": 0.50,
|
| 59 |
+
"sft_samples": 90,
|
| 60 |
+
"sft_lr": 1.5e-5,
|
| 61 |
+
"sft_grad_accum": 4,
|
| 62 |
+
"grpo_lr": 4.0e-6,
|
| 63 |
+
"grpo_grad_accum": 4,
|
| 64 |
+
"grpo_beta": 0.06,
|
| 65 |
+
"reward_ema_decay": 0.25,
|
| 66 |
+
},
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _as_float(x: object | None) -> float | None:
|
| 71 |
+
if x is None:
|
| 72 |
+
return None
|
| 73 |
+
try:
|
| 74 |
+
return float(x)
|
| 75 |
+
except Exception:
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class StabilityTripwire(TrainerCallback):
|
| 80 |
+
"""Stop training when logs show sustained reward collapse + loss blow-up."""
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
*,
|
| 85 |
+
min_step: int,
|
| 86 |
+
reward_key: str,
|
| 87 |
+
loss_key: str,
|
| 88 |
+
reward_drop: float,
|
| 89 |
+
loss_spike: float,
|
| 90 |
+
bad_streak: int,
|
| 91 |
+
) -> None:
|
| 92 |
+
self.min_step = min_step
|
| 93 |
+
self.reward_key = reward_key
|
| 94 |
+
self.loss_key = loss_key
|
| 95 |
+
self.reward_drop = reward_drop
|
| 96 |
+
self.loss_spike = loss_spike
|
| 97 |
+
self.bad_streak = bad_streak
|
| 98 |
+
self._best_reward: float | None = None
|
| 99 |
+
self._best_loss: float | None = None
|
| 100 |
+
self._streak = 0
|
| 101 |
+
|
| 102 |
+
def on_log(self, args, state, control, logs=None, **kw): # type: ignore[no-untyped-def]
|
| 103 |
+
logs = logs or {}
|
| 104 |
+
step = int(getattr(state, "global_step", 0) or 0)
|
| 105 |
+
if step < self.min_step:
|
| 106 |
+
return control
|
| 107 |
+
|
| 108 |
+
r = _as_float(logs.get(self.reward_key))
|
| 109 |
+
loss = _as_float(logs.get(self.loss_key))
|
| 110 |
+
|
| 111 |
+
reward_bad = False
|
| 112 |
+
loss_bad = False
|
| 113 |
+
|
| 114 |
+
if r is not None:
|
| 115 |
+
if self._best_reward is None or r > self._best_reward:
|
| 116 |
+
self._best_reward = r
|
| 117 |
+
elif self._best_reward is not None and self._best_reward - r >= self.reward_drop:
|
| 118 |
+
reward_bad = True
|
| 119 |
+
|
| 120 |
+
if loss is not None:
|
| 121 |
+
if self._best_loss is None or loss < self._best_loss:
|
| 122 |
+
self._best_loss = loss
|
| 123 |
+
elif self._best_loss is not None and loss - self._best_loss >= self.loss_spike:
|
| 124 |
+
loss_bad = True
|
| 125 |
+
|
| 126 |
+
bad = reward_bad and loss_bad and r is not None and loss is not None
|
| 127 |
+
|
| 128 |
+
if bad:
|
| 129 |
+
self._streak += 1
|
| 130 |
+
else:
|
| 131 |
+
self._streak = 0
|
| 132 |
+
|
| 133 |
+
if self._streak >= self.bad_streak:
|
| 134 |
+
print(
|
| 135 |
+
f"[STABILITY] stopping: sustained instability "
|
| 136 |
+
f"(best {self.reward_key}={self._best_reward}, best loss={self._best_loss}, streak={self._streak})."
|
| 137 |
+
)
|
| 138 |
+
control.should_training_stop = True
|
| 139 |
+
return control
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class LossSpikeTripwire(TrainerCallback):
|
| 143 |
+
"""SFT guardrail: stop if loss repeatedly blows up vs the best-so-far."""
|
| 144 |
+
|
| 145 |
+
def __init__(self, *, min_step: int, loss_key: str, loss_spike: float, bad_streak: int) -> None:
|
| 146 |
+
self.min_step = min_step
|
| 147 |
+
self.loss_key = loss_key
|
| 148 |
+
self.loss_spike = loss_spike
|
| 149 |
+
self.bad_streak = bad_streak
|
| 150 |
+
self._best_loss: float | None = None
|
| 151 |
+
self._streak = 0
|
| 152 |
+
|
| 153 |
+
def on_log(self, args, state, control, logs=None, **kw): # type: ignore[no-untyped-def]
|
| 154 |
+
logs = logs or {}
|
| 155 |
+
step = int(getattr(state, "global_step", 0) or 0)
|
| 156 |
+
if step < self.min_step:
|
| 157 |
+
return control
|
| 158 |
+
|
| 159 |
+
loss = _as_float(logs.get(self.loss_key))
|
| 160 |
+
if loss is None:
|
| 161 |
+
return control
|
| 162 |
+
|
| 163 |
+
if self._best_loss is None or loss < self._best_loss:
|
| 164 |
+
self._best_loss = loss
|
| 165 |
+
self._streak = 0
|
| 166 |
+
return control
|
| 167 |
+
|
| 168 |
+
if self._best_loss is not None and loss - self._best_loss >= self.loss_spike:
|
| 169 |
+
self._streak += 1
|
| 170 |
+
else:
|
| 171 |
+
self._streak = 0
|
| 172 |
+
|
| 173 |
+
if self._streak >= self.bad_streak:
|
| 174 |
+
print(f"[STABILITY] stopping SFT: repeated loss spikes (best={self._best_loss}, streak={self._streak}).")
|
| 175 |
+
control.should_training_stop = True
|
| 176 |
+
return control
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _extract_briefing(reset_payload: dict[str, Any]) -> str:
|
| 180 |
+
obs = reset_payload.get("observation", reset_payload)
|
| 181 |
+
if isinstance(obs, dict):
|
| 182 |
+
return str(obs.get("echoed_message", "")).strip()
|
| 183 |
+
return ""
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def _legal_action_heuristic(briefing: str) -> dict[str, Any]:
|
| 187 |
+
# Minimal heuristic used only for SFT warm-start data generation.
|
| 188 |
+
# Keeps the action schema valid and non-idle-biased.
|
| 189 |
+
lower = briefing.lower()
|
| 190 |
+
if "e01" in lower:
|
| 191 |
+
return {
|
| 192 |
+
"action_type": "reply_email",
|
| 193 |
+
"email_id": "e01",
|
| 194 |
+
"message_body": "Acknowledged. Sharing a concise update shortly.",
|
| 195 |
+
}
|
| 196 |
+
if "m02" in lower:
|
| 197 |
+
return {
|
| 198 |
+
"action_type": "reschedule_meeting",
|
| 199 |
+
"meeting_id": "m02",
|
| 200 |
+
"new_time": "2026-04-21T18:00:00",
|
| 201 |
+
"reason": "Resolve overlap with higher priority commitments.",
|
| 202 |
+
}
|
| 203 |
+
if "t06" in lower:
|
| 204 |
+
return {"action_type": "complete_task", "task_id": "t06"}
|
| 205 |
+
return {"action_type": random.choice(LEGAL_ACTION_TYPES)}
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def generate_sft_jsonl_from_env(
|
| 209 |
+
env_url: str,
|
| 210 |
+
out_jsonl: Path,
|
| 211 |
+
samples: int = 120,
|
| 212 |
+
task_id: str = "phase2_core",
|
| 213 |
+
) -> None:
|
| 214 |
+
out_jsonl.parent.mkdir(parents=True, exist_ok=True)
|
| 215 |
+
rows: list[dict[str, str]] = []
|
| 216 |
+
for _ in range(samples):
|
| 217 |
+
r = requests.post(f"{env_url.rstrip('/')}/reset", json={"task_id": task_id}, timeout=30)
|
| 218 |
+
r.raise_for_status()
|
| 219 |
+
payload = r.json()
|
| 220 |
+
briefing = _extract_briefing(payload)
|
| 221 |
+
if not briefing:
|
| 222 |
+
continue
|
| 223 |
+
action = _legal_action_heuristic(briefing)
|
| 224 |
+
prompt = (
|
| 225 |
+
"You are Ghostexec AI Chief-of-Staff.\n"
|
| 226 |
+
"Output one valid GhostexecAction JSON only.\n\n"
|
| 227 |
+
f"{briefing}"
|
| 228 |
+
)
|
| 229 |
+
rows.append({"prompt": prompt, "completion": json.dumps(action, ensure_ascii=True)})
|
| 230 |
+
with out_jsonl.open("w", encoding="utf-8") as fh:
|
| 231 |
+
for row in rows:
|
| 232 |
+
fh.write(json.dumps(row, ensure_ascii=True) + "\n")
|
| 233 |
+
print(f"Wrote {len(rows)} SFT rows to {out_jsonl}")
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def run_sft_then_grpo(
|
| 237 |
+
model_name: str,
|
| 238 |
+
env_url: str,
|
| 239 |
+
sft_jsonl: Path,
|
| 240 |
+
out_dir: Path,
|
| 241 |
+
env_reward_scale: float,
|
| 242 |
+
local_reward_scale: float,
|
| 243 |
+
max_sft_steps: int,
|
| 244 |
+
max_grpo_steps: int,
|
| 245 |
+
complexity_curriculum: str,
|
| 246 |
+
curriculum_ramp_ratio: float,
|
| 247 |
+
*,
|
| 248 |
+
sft_lr: float,
|
| 249 |
+
sft_grad_accum: int,
|
| 250 |
+
grpo_lr: float,
|
| 251 |
+
grpo_grad_accum: int,
|
| 252 |
+
grpo_beta: float,
|
| 253 |
+
reward_ema_decay: float,
|
| 254 |
+
stability_tripwire: bool,
|
| 255 |
+
) -> None:
|
| 256 |
+
try:
|
| 257 |
+
from datasets import load_dataset
|
| 258 |
+
from trl import GRPOConfig, GRPOTrainer, SFTConfig, SFTTrainer
|
| 259 |
+
from unsloth import FastLanguageModel
|
| 260 |
+
except Exception as exc: # pragma: no cover
|
| 261 |
+
raise RuntimeError(
|
| 262 |
+
"Missing training deps. Install unsloth, trl, datasets, transformers before running."
|
| 263 |
+
) from exc
|
| 264 |
+
|
| 265 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 266 |
+
|
| 267 |
+
def _trainable_lora_sum_abs(model) -> float:
|
| 268 |
+
total = 0.0
|
| 269 |
+
for n, p in model.named_parameters():
|
| 270 |
+
if not p.requires_grad:
|
| 271 |
+
continue
|
| 272 |
+
if "lora" not in n.lower():
|
| 273 |
+
continue
|
| 274 |
+
total += float(p.detach().abs().sum().item())
|
| 275 |
+
return total
|
| 276 |
+
|
| 277 |
+
policy, tokenizer = FastLanguageModel.from_pretrained(
|
| 278 |
+
model_name=model_name,
|
| 279 |
+
max_seq_length=2048,
|
| 280 |
+
dtype=None,
|
| 281 |
+
load_in_4bit=True,
|
| 282 |
+
)
|
| 283 |
+
policy = FastLanguageModel.get_peft_model(
|
| 284 |
+
policy,
|
| 285 |
+
r=16,
|
| 286 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| 287 |
+
lora_alpha=16,
|
| 288 |
+
lora_dropout=0.0,
|
| 289 |
+
bias="none",
|
| 290 |
+
use_gradient_checkpointing="unsloth",
|
| 291 |
+
random_state=3407,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
ds = load_dataset("json", data_files=str(sft_jsonl), split="train")
|
| 295 |
+
sft_cfg = SFTConfig(
|
| 296 |
+
output_dir=str(out_dir / "sft"),
|
| 297 |
+
max_steps=max_sft_steps,
|
| 298 |
+
per_device_train_batch_size=1,
|
| 299 |
+
gradient_accumulation_steps=sft_grad_accum,
|
| 300 |
+
learning_rate=sft_lr,
|
| 301 |
+
lr_scheduler_type="cosine",
|
| 302 |
+
warmup_ratio=0.06,
|
| 303 |
+
max_grad_norm=1.0,
|
| 304 |
+
adam_beta1=0.9,
|
| 305 |
+
adam_beta2=0.95,
|
| 306 |
+
logging_steps=5,
|
| 307 |
+
save_steps=max(10, max_sft_steps),
|
| 308 |
+
report_to=[],
|
| 309 |
+
)
|
| 310 |
+
sft_trainer = SFTTrainer(
|
| 311 |
+
model=policy,
|
| 312 |
+
tokenizer=tokenizer,
|
| 313 |
+
train_dataset=ds,
|
| 314 |
+
args=sft_cfg,
|
| 315 |
+
dataset_text_field="prompt",
|
| 316 |
+
formatting_func=lambda ex: [f"{p}\n\n{c}" for p, c in zip(ex["prompt"], ex["completion"])],
|
| 317 |
+
)
|
| 318 |
+
if stability_tripwire:
|
| 319 |
+
sft_trainer.add_callback(
|
| 320 |
+
LossSpikeTripwire(
|
| 321 |
+
min_step=max(10, max_sft_steps // 6),
|
| 322 |
+
loss_key="loss",
|
| 323 |
+
loss_spike=0.85,
|
| 324 |
+
bad_streak=4,
|
| 325 |
+
)
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
sft_before = _trainable_lora_sum_abs(policy)
|
| 329 |
+
sft_trainer.train()
|
| 330 |
+
sft_after = _trainable_lora_sum_abs(sft_trainer.model)
|
| 331 |
+
sft_delta = abs(sft_after - sft_before)
|
| 332 |
+
print(f"SFT LoRA delta(abs-sum): {sft_delta:.6f}")
|
| 333 |
+
if sft_delta <= 1e-6:
|
| 334 |
+
raise RuntimeError("SFT appears not to have updated LoRA weights (delta too small).")
|
| 335 |
+
sft_dir = out_dir / "sft_adapter"
|
| 336 |
+
sft_trainer.model.save_pretrained(sft_dir)
|
| 337 |
+
tokenizer.save_pretrained(sft_dir)
|
| 338 |
+
print(f"SFT complete. Adapter saved: {sft_dir}")
|
| 339 |
+
|
| 340 |
+
def _extract_json(text: str) -> dict[str, Any] | None:
|
| 341 |
+
m = re.search(r"\{.*\}", text, flags=re.S)
|
| 342 |
+
if not m:
|
| 343 |
+
return None
|
| 344 |
+
try:
|
| 345 |
+
obj = json.loads(m.group(0))
|
| 346 |
+
except Exception:
|
| 347 |
+
return None
|
| 348 |
+
return obj if isinstance(obj, dict) else None
|
| 349 |
+
|
| 350 |
+
def _env_step_reward_from_completion(text: str) -> float:
|
| 351 |
+
payload = _extract_json(text)
|
| 352 |
+
if payload is None:
|
| 353 |
+
return -0.25
|
| 354 |
+
payload.setdefault("action_type", "do_nothing")
|
| 355 |
+
try:
|
| 356 |
+
r = requests.post(f"{env_url.rstrip('/')}/reset", json={"task_id": "phase2_core"}, timeout=30)
|
| 357 |
+
r.raise_for_status()
|
| 358 |
+
s = requests.post(
|
| 359 |
+
f"{env_url.rstrip('/')}/step",
|
| 360 |
+
json={"action": payload},
|
| 361 |
+
timeout=30,
|
| 362 |
+
)
|
| 363 |
+
s.raise_for_status()
|
| 364 |
+
raw = s.json()
|
| 365 |
+
except Exception:
|
| 366 |
+
return 0.0
|
| 367 |
+
rew = raw.get("reward")
|
| 368 |
+
if rew is None and isinstance(raw.get("observation"), dict):
|
| 369 |
+
rew = raw["observation"].get("reward", 0.0)
|
| 370 |
+
try:
|
| 371 |
+
return float(rew)
|
| 372 |
+
except Exception:
|
| 373 |
+
return 0.0
|
| 374 |
+
|
| 375 |
+
progress = {"step": 0, "total": max(1, max_grpo_steps)}
|
| 376 |
+
reward_ema_state = {"env": None}
|
| 377 |
+
|
| 378 |
+
class _ProgressCallback(TrainerCallback):
|
| 379 |
+
def on_step_end(self, args, state, control, **kwargs): # type: ignore[override]
|
| 380 |
+
progress["step"] = int(getattr(state, "global_step", progress["step"]))
|
| 381 |
+
return control
|
| 382 |
+
|
| 383 |
+
def _progress_frac() -> float:
|
| 384 |
+
return min(1.0, progress["step"] / progress["total"])
|
| 385 |
+
|
| 386 |
+
def _curriculum_phase_weight() -> float:
|
| 387 |
+
frac = _progress_frac()
|
| 388 |
+
ramp = max(0.05, min(1.0, curriculum_ramp_ratio))
|
| 389 |
+
if complexity_curriculum == "off":
|
| 390 |
+
return 1.0
|
| 391 |
+
# easy_to_full: start with strong scaffold guidance, then smoothly
|
| 392 |
+
# transition to full env-dominant optimization.
|
| 393 |
+
if frac >= ramp:
|
| 394 |
+
return 0.0
|
| 395 |
+
return max(0.0, 1.0 - (frac / ramp))
|
| 396 |
+
|
| 397 |
+
def _annealed_local_scale() -> float:
|
| 398 |
+
frac = _progress_frac()
|
| 399 |
+
base = local_reward_scale * (1.20 - 0.70 * frac)
|
| 400 |
+
return base * (1.0 + 0.70 * _curriculum_phase_weight())
|
| 401 |
+
|
| 402 |
+
def _annealed_env_scale() -> float:
|
| 403 |
+
w = _curriculum_phase_weight()
|
| 404 |
+
# Slightly downweight env reward in early easy phase to reduce variance,
|
| 405 |
+
# then recover to full strength by the end of ramp.
|
| 406 |
+
return env_reward_scale * (1.0 - 0.30 * w)
|
| 407 |
+
|
| 408 |
+
def env_reward(completions, **_):
|
| 409 |
+
scale = _annealed_env_scale()
|
| 410 |
+
raw = [scale * _env_step_reward_from_completion(str(c)) for c in completions]
|
| 411 |
+
if reward_ema_decay <= 0.0:
|
| 412 |
+
return raw
|
| 413 |
+
batch_mean = sum(raw) / max(len(raw), 1)
|
| 414 |
+
prev = reward_ema_state["env"]
|
| 415 |
+
d = max(0.0, min(1.0, reward_ema_decay))
|
| 416 |
+
if prev is None:
|
| 417 |
+
smoothed_mean = batch_mean
|
| 418 |
+
else:
|
| 419 |
+
smoothed_mean = (1.0 - d) * prev + d * batch_mean
|
| 420 |
+
reward_ema_state["env"] = smoothed_mean
|
| 421 |
+
delta = smoothed_mean - batch_mean
|
| 422 |
+
return [r + delta for r in raw]
|
| 423 |
+
|
| 424 |
+
def format_reward(completions, **_):
|
| 425 |
+
scale = _annealed_local_scale()
|
| 426 |
+
outs: list[float] = []
|
| 427 |
+
for c in completions:
|
| 428 |
+
txt = str(c).strip()
|
| 429 |
+
obj = _extract_json(txt)
|
| 430 |
+
if obj is None:
|
| 431 |
+
outs.append(-0.20 * scale)
|
| 432 |
+
continue
|
| 433 |
+
if obj.get("action_type") not in LEGAL_ACTION_TYPES:
|
| 434 |
+
outs.append(-0.20 * scale)
|
| 435 |
+
continue
|
| 436 |
+
# Encourage concise, parseable schema-correct JSON.
|
| 437 |
+
length_pen = -0.04 * scale if len(txt) > 500 else 0.0
|
| 438 |
+
outs.append(0.12 * scale + length_pen)
|
| 439 |
+
return outs
|
| 440 |
+
|
| 441 |
+
def semantic_action_reward(completions, prompts=None, **_):
|
| 442 |
+
scale = _annealed_local_scale()
|
| 443 |
+
outs: list[float] = []
|
| 444 |
+
for i, c in enumerate(completions):
|
| 445 |
+
obj = _extract_json(str(c))
|
| 446 |
+
if obj is None:
|
| 447 |
+
outs.append(-0.10 * scale)
|
| 448 |
+
continue
|
| 449 |
+
at = str(obj.get("action_type", ""))
|
| 450 |
+
ptxt = str(prompts[i] if prompts and i < len(prompts) else "").lower()
|
| 451 |
+
bonus = 0.0
|
| 452 |
+
if "critical" in ptxt and at == "reply_email":
|
| 453 |
+
bonus += 0.08
|
| 454 |
+
if "clash" in ptxt and at in ("reschedule_meeting", "cancel_meeting"):
|
| 455 |
+
bonus += 0.08
|
| 456 |
+
if ("overdue" in ptxt or "due soon" in ptxt) and at in ("complete_task", "delegate_task"):
|
| 457 |
+
bonus += 0.08
|
| 458 |
+
outs.append(scale * bonus)
|
| 459 |
+
return outs
|
| 460 |
+
|
| 461 |
+
def anti_idle_reward(completions, **_):
|
| 462 |
+
scale = _annealed_local_scale()
|
| 463 |
+
outs = []
|
| 464 |
+
for c in completions:
|
| 465 |
+
txt = str(c).lower()
|
| 466 |
+
outs.append((-0.20 if "do_nothing" in txt else 0.02) * scale)
|
| 467 |
+
return outs
|
| 468 |
+
|
| 469 |
+
grpo_cfg = GRPOConfig(
|
| 470 |
+
output_dir=str(out_dir / "grpo"),
|
| 471 |
+
learning_rate=grpo_lr,
|
| 472 |
+
per_device_train_batch_size=1,
|
| 473 |
+
gradient_accumulation_steps=grpo_grad_accum,
|
| 474 |
+
max_steps=max_grpo_steps,
|
| 475 |
+
logging_steps=5,
|
| 476 |
+
num_generations=2,
|
| 477 |
+
beta=grpo_beta,
|
| 478 |
+
lr_scheduler_type="cosine",
|
| 479 |
+
warmup_ratio=0.06,
|
| 480 |
+
max_grad_norm=1.0,
|
| 481 |
+
adam_beta1=0.9,
|
| 482 |
+
adam_beta2=0.95,
|
| 483 |
+
report_to=[],
|
| 484 |
+
)
|
| 485 |
+
grpo_callbacks = [_ProgressCallback()]
|
| 486 |
+
if stability_tripwire:
|
| 487 |
+
grpo_callbacks.append(
|
| 488 |
+
StabilityTripwire(
|
| 489 |
+
min_step=max(15, max_grpo_steps // 8),
|
| 490 |
+
reward_key="rewards/env_reward/mean",
|
| 491 |
+
loss_key="loss",
|
| 492 |
+
reward_drop=0.12,
|
| 493 |
+
loss_spike=0.35,
|
| 494 |
+
bad_streak=3,
|
| 495 |
+
)
|
| 496 |
+
)
|
| 497 |
+
grpo_trainer = GRPOTrainer(
|
| 498 |
+
model=sft_trainer.model,
|
| 499 |
+
processing_class=tokenizer,
|
| 500 |
+
reward_funcs=[env_reward, format_reward, semantic_action_reward, anti_idle_reward],
|
| 501 |
+
train_dataset=ds,
|
| 502 |
+
args=grpo_cfg,
|
| 503 |
+
callbacks=grpo_callbacks,
|
| 504 |
+
)
|
| 505 |
+
grpo_before = _trainable_lora_sum_abs(sft_trainer.model)
|
| 506 |
+
grpo_trainer.train()
|
| 507 |
+
progress["step"] = progress["total"]
|
| 508 |
+
grpo_after = _trainable_lora_sum_abs(grpo_trainer.model)
|
| 509 |
+
grpo_delta = abs(grpo_after - grpo_before)
|
| 510 |
+
print(f"GRPO LoRA delta(abs-sum): {grpo_delta:.6f}")
|
| 511 |
+
if grpo_delta <= 1e-6:
|
| 512 |
+
raise RuntimeError("GRPO appears not to have updated LoRA weights (delta too small).")
|
| 513 |
+
final_dir = out_dir / "grpo_adapter"
|
| 514 |
+
grpo_trainer.model.save_pretrained(final_dir)
|
| 515 |
+
tokenizer.save_pretrained(final_dir)
|
| 516 |
+
print(f"GRPO complete. Adapter saved: {final_dir}")
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def main() -> None:
|
| 520 |
+
parser = argparse.ArgumentParser(description="Run SFT warmup before GRPO.")
|
| 521 |
+
parser.add_argument(
|
| 522 |
+
"--model-name",
|
| 523 |
+
default="",
|
| 524 |
+
help="Optional explicit model id. If omitted, --model-preset is used.",
|
| 525 |
+
)
|
| 526 |
+
parser.add_argument(
|
| 527 |
+
"--model-preset",
|
| 528 |
+
choices=sorted(MODEL_PRESETS.keys()),
|
| 529 |
+
default="small_iter_fast",
|
| 530 |
+
help="Recommended compute-aware preset. small_iter_fast is best for iteration speed.",
|
| 531 |
+
)
|
| 532 |
+
parser.add_argument(
|
| 533 |
+
"--training-preset",
|
| 534 |
+
choices=sorted(TRAINING_PRESETS.keys()),
|
| 535 |
+
default="hackathon_turbo",
|
| 536 |
+
help="Compute-aware run preset. hackathon_turbo is best default for iterative winning loops.",
|
| 537 |
+
)
|
| 538 |
+
parser.add_argument("--env-url", default="http://127.0.0.1:8000")
|
| 539 |
+
parser.add_argument("--sft-jsonl", type=Path, default=Path("outputs/sft_from_env.jsonl"))
|
| 540 |
+
parser.add_argument("--out-dir", type=Path, default=Path("outputs/train_runs/sft_then_grpo"))
|
| 541 |
+
parser.add_argument("--generate-sft-from-env", action="store_true")
|
| 542 |
+
parser.add_argument("--sft-samples", type=int, default=120)
|
| 543 |
+
parser.add_argument("--max-sft-steps", type=int, default=60)
|
| 544 |
+
parser.add_argument("--max-grpo-steps", type=int, default=120)
|
| 545 |
+
parser.add_argument("--env-reward-scale", type=float, default=1.0)
|
| 546 |
+
parser.add_argument("--local-reward-scale", type=float, default=0.35)
|
| 547 |
+
parser.add_argument(
|
| 548 |
+
"--complexity-curriculum",
|
| 549 |
+
choices=["off", "easy_to_full"],
|
| 550 |
+
default="easy_to_full",
|
| 551 |
+
help="Reward curriculum: easy_to_full starts with stronger local scaffold and anneals to env-dominant.",
|
| 552 |
+
)
|
| 553 |
+
parser.add_argument(
|
| 554 |
+
"--curriculum-ramp-ratio",
|
| 555 |
+
type=float,
|
| 556 |
+
default=0.60,
|
| 557 |
+
help="Fraction of GRPO steps used to ramp from easy scaffold to full env weighting.",
|
| 558 |
+
)
|
| 559 |
+
parser.add_argument(
|
| 560 |
+
"--no-stability-tripwire",
|
| 561 |
+
action="store_true",
|
| 562 |
+
help="Disable oscillation/collapse early-stop guardrails (not recommended).",
|
| 563 |
+
)
|
| 564 |
+
parser.add_argument(
|
| 565 |
+
"--reward-ema-decay",
|
| 566 |
+
type=float,
|
| 567 |
+
default=-1.0,
|
| 568 |
+
help="EMA decay in [0,1] for env reward smoothing; -1 uses training preset default.",
|
| 569 |
+
)
|
| 570 |
+
args = parser.parse_args()
|
| 571 |
+
model_name = args.model_name.strip() or MODEL_PRESETS[args.model_preset]
|
| 572 |
+
p = TRAINING_PRESETS[args.training_preset]
|
| 573 |
+
max_sft_steps = int(p["max_sft_steps"])
|
| 574 |
+
max_grpo_steps = int(p["max_grpo_steps"])
|
| 575 |
+
env_reward_scale = float(p["env_reward_scale"])
|
| 576 |
+
local_reward_scale = float(p["local_reward_scale"])
|
| 577 |
+
complexity_curriculum = str(p["complexity_curriculum"])
|
| 578 |
+
curriculum_ramp_ratio = float(p["curriculum_ramp_ratio"])
|
| 579 |
+
sft_samples = int(p["sft_samples"])
|
| 580 |
+
sft_lr = float(p["sft_lr"])
|
| 581 |
+
sft_grad_accum = int(p["sft_grad_accum"])
|
| 582 |
+
grpo_lr = float(p["grpo_lr"])
|
| 583 |
+
grpo_grad_accum = int(p["grpo_grad_accum"])
|
| 584 |
+
grpo_beta = float(p["grpo_beta"])
|
| 585 |
+
reward_ema_decay = float(p["reward_ema_decay"])
|
| 586 |
+
if args.max_sft_steps != 60:
|
| 587 |
+
max_sft_steps = args.max_sft_steps
|
| 588 |
+
if args.max_grpo_steps != 120:
|
| 589 |
+
max_grpo_steps = args.max_grpo_steps
|
| 590 |
+
if args.env_reward_scale != 1.0:
|
| 591 |
+
env_reward_scale = args.env_reward_scale
|
| 592 |
+
if args.local_reward_scale != 0.35:
|
| 593 |
+
local_reward_scale = args.local_reward_scale
|
| 594 |
+
if args.complexity_curriculum != "easy_to_full":
|
| 595 |
+
complexity_curriculum = args.complexity_curriculum
|
| 596 |
+
if args.curriculum_ramp_ratio != 0.60:
|
| 597 |
+
curriculum_ramp_ratio = args.curriculum_ramp_ratio
|
| 598 |
+
if args.sft_samples != 120:
|
| 599 |
+
sft_samples = args.sft_samples
|
| 600 |
+
if args.reward_ema_decay >= 0.0:
|
| 601 |
+
reward_ema_decay = float(args.reward_ema_decay)
|
| 602 |
+
stability_tripwire = not args.no_stability_tripwire
|
| 603 |
+
print(f"Model preset: {args.model_preset} -> {model_name}")
|
| 604 |
+
print(
|
| 605 |
+
"Training preset:"
|
| 606 |
+
f" {args.training_preset} -> sft={max_sft_steps}, grpo={max_grpo_steps},"
|
| 607 |
+
f" env_scale={env_reward_scale}, local_scale={local_reward_scale},"
|
| 608 |
+
f" curriculum={complexity_curriculum}, ramp={curriculum_ramp_ratio}"
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
if args.generate_sft_from_env or not args.sft_jsonl.exists():
|
| 612 |
+
generate_sft_jsonl_from_env(
|
| 613 |
+
env_url=args.env_url,
|
| 614 |
+
out_jsonl=args.sft_jsonl,
|
| 615 |
+
samples=sft_samples,
|
| 616 |
+
task_id="phase2_core",
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
run_sft_then_grpo(
|
| 620 |
+
model_name=model_name,
|
| 621 |
+
env_url=args.env_url,
|
| 622 |
+
sft_jsonl=args.sft_jsonl,
|
| 623 |
+
out_dir=args.out_dir,
|
| 624 |
+
env_reward_scale=env_reward_scale,
|
| 625 |
+
local_reward_scale=local_reward_scale,
|
| 626 |
+
max_sft_steps=max_sft_steps,
|
| 627 |
+
max_grpo_steps=max_grpo_steps,
|
| 628 |
+
complexity_curriculum=complexity_curriculum,
|
| 629 |
+
curriculum_ramp_ratio=curriculum_ramp_ratio,
|
| 630 |
+
sft_lr=sft_lr,
|
| 631 |
+
sft_grad_accum=sft_grad_accum,
|
| 632 |
+
grpo_lr=grpo_lr,
|
| 633 |
+
grpo_grad_accum=grpo_grad_accum,
|
| 634 |
+
grpo_beta=grpo_beta,
|
| 635 |
+
reward_ema_decay=reward_ema_decay,
|
| 636 |
+
stability_tripwire=stability_tripwire,
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
if __name__ == "__main__":
|
| 641 |
+
main()
|
server/ghostexec_environment.py
CHANGED
|
@@ -15,6 +15,7 @@ Rewards aggregate conflict / relationship / task scores and log each step to out
|
|
| 15 |
from __future__ import annotations
|
| 16 |
|
| 17 |
import json
|
|
|
|
| 18 |
from datetime import datetime, timedelta, timezone
|
| 19 |
from pathlib import Path
|
| 20 |
from typing import Any
|
|
@@ -71,6 +72,7 @@ _REL_DISPLAY: dict[str, str] = {
|
|
| 71 |
|
| 72 |
_INVALID_ACTION_REWARD = -0.25
|
| 73 |
_DEFAULT_STEP_REWARD = 0.0
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
def _default_scenario_path() -> Path:
|
|
@@ -104,6 +106,7 @@ class GhostexecEnvironment(Environment):
|
|
| 104 |
self,
|
| 105 |
scenario_path: str | Path | None = None,
|
| 106 |
schema_drift_events_path: str | Path | None = None,
|
|
|
|
| 107 |
) -> None:
|
| 108 |
self._scenario_path = Path(scenario_path) if scenario_path else _default_scenario_path()
|
| 109 |
self._drift_events_path = (
|
|
@@ -124,6 +127,9 @@ class GhostexecEnvironment(Environment):
|
|
| 124 |
self._last_step_error: str | None = None
|
| 125 |
self._last_step_detail: str = ""
|
| 126 |
self._last_reward_breakdown: RewardBreakdown | None = None
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
# --- lifecycle ---
|
| 129 |
|
|
@@ -176,6 +182,7 @@ class GhostexecEnvironment(Environment):
|
|
| 176 |
|
| 177 |
before = self.world.model_copy(deep=True)
|
| 178 |
action_ok = self._apply_action(action)
|
|
|
|
| 179 |
self._rebuild_conflict_list()
|
| 180 |
|
| 181 |
episode_done = False
|
|
@@ -191,6 +198,9 @@ class GhostexecEnvironment(Environment):
|
|
| 191 |
action_ok=action_ok,
|
| 192 |
episode_done=episode_done,
|
| 193 |
relationship_suppressed_for_email_to=frozenset(self._reply_relationship_suppressed),
|
|
|
|
|
|
|
|
|
|
| 194 |
)
|
| 195 |
self._last_reward_breakdown = breakdown
|
| 196 |
self._append_reward_log(breakdown, episode_done, action)
|
|
@@ -540,6 +550,62 @@ class GhostexecEnvironment(Environment):
|
|
| 540 |
self._world.action_log.append(f"error: {msg}")
|
| 541 |
return False
|
| 542 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
def _ensure_reward_log_dir(self) -> None:
|
| 544 |
self._reward_log_path.parent.mkdir(parents=True, exist_ok=True)
|
| 545 |
|
|
@@ -566,6 +632,13 @@ class GhostexecEnvironment(Environment):
|
|
| 566 |
"task": breakdown.task,
|
| 567 |
"weighted_base": breakdown.weighted_base,
|
| 568 |
"output_scale": breakdown.output_scale,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
"invalid_step_adjustment": breakdown.invalid_step_adjustment,
|
| 570 |
"episode_completion_bonus": breakdown.episode_completion_bonus,
|
| 571 |
"catastrophic_penalty": breakdown.catastrophic_penalty,
|
|
@@ -573,6 +646,7 @@ class GhostexecEnvironment(Environment):
|
|
| 573 |
"calendar_overlap_pairs": len(self.detect_meeting_conflicts()),
|
| 574 |
"critical_unreplied": crit_open,
|
| 575 |
"overdue_tasks": overdue_n,
|
|
|
|
| 576 |
}
|
| 577 |
with self._reward_log_path.open("a", encoding="utf-8") as fh:
|
| 578 |
fh.write(json.dumps(line) + "\n")
|
|
|
|
| 15 |
from __future__ import annotations
|
| 16 |
|
| 17 |
import json
|
| 18 |
+
import os
|
| 19 |
from datetime import datetime, timedelta, timezone
|
| 20 |
from pathlib import Path
|
| 21 |
from typing import Any
|
|
|
|
| 72 |
|
| 73 |
_INVALID_ACTION_REWARD = -0.25
|
| 74 |
_DEFAULT_STEP_REWARD = 0.0
|
| 75 |
+
_MOOD_ORDER: tuple[Mood, ...] = ("furious", "angry", "annoyed", "neutral", "happy")
|
| 76 |
|
| 77 |
|
| 78 |
def _default_scenario_path() -> Path:
|
|
|
|
| 106 |
self,
|
| 107 |
scenario_path: str | Path | None = None,
|
| 108 |
schema_drift_events_path: str | Path | None = None,
|
| 109 |
+
reward_mode: str | None = None,
|
| 110 |
) -> None:
|
| 111 |
self._scenario_path = Path(scenario_path) if scenario_path else _default_scenario_path()
|
| 112 |
self._drift_events_path = (
|
|
|
|
| 127 |
self._last_step_error: str | None = None
|
| 128 |
self._last_step_detail: str = ""
|
| 129 |
self._last_reward_breakdown: RewardBreakdown | None = None
|
| 130 |
+
self._reward_mode = (reward_mode or os.getenv("GHOSTEXEC_REWARD_MODE", "full")).strip().lower()
|
| 131 |
+
if self._reward_mode not in {"full", "base", "shaping"}:
|
| 132 |
+
self._reward_mode = "full"
|
| 133 |
|
| 134 |
# --- lifecycle ---
|
| 135 |
|
|
|
|
| 182 |
|
| 183 |
before = self.world.model_copy(deep=True)
|
| 184 |
action_ok = self._apply_action(action)
|
| 185 |
+
self._apply_post_action_dynamics(action, action_ok=action_ok)
|
| 186 |
self._rebuild_conflict_list()
|
| 187 |
|
| 188 |
episode_done = False
|
|
|
|
| 198 |
action_ok=action_ok,
|
| 199 |
episode_done=episode_done,
|
| 200 |
relationship_suppressed_for_email_to=frozenset(self._reply_relationship_suppressed),
|
| 201 |
+
reward_mode=self._reward_mode,
|
| 202 |
+
step_index=self._state.step_count,
|
| 203 |
+
max_steps=self.world.max_episode_steps,
|
| 204 |
)
|
| 205 |
self._last_reward_breakdown = breakdown
|
| 206 |
self._append_reward_log(breakdown, episode_done, action)
|
|
|
|
| 550 |
self._world.action_log.append(f"error: {msg}")
|
| 551 |
return False
|
| 552 |
|
| 553 |
+
def _advance_clock(self, minutes: int) -> None:
|
| 554 |
+
now = _parse_dt(self.world.simulation_time)
|
| 555 |
+
new_now = (now + timedelta(minutes=minutes)).replace(tzinfo=None)
|
| 556 |
+
self.world.simulation_time = new_now.isoformat(timespec="seconds")
|
| 557 |
+
self._reapply_task_overdue_flags()
|
| 558 |
+
|
| 559 |
+
def _shift_contact_mood(self, name: str, delta: int) -> None:
|
| 560 |
+
if delta == 0:
|
| 561 |
+
return
|
| 562 |
+
c = self.get_contact(name)
|
| 563 |
+
if c is None:
|
| 564 |
+
return
|
| 565 |
+
idx = _MOOD_ORDER.index(c.mood)
|
| 566 |
+
next_idx = max(0, min(len(_MOOD_ORDER) - 1, idx + delta))
|
| 567 |
+
if next_idx != idx:
|
| 568 |
+
self.update_contact_mood(name, _MOOD_ORDER[next_idx])
|
| 569 |
+
|
| 570 |
+
def _apply_post_action_dynamics(self, action: GhostexecAction, *, action_ok: bool) -> None:
|
| 571 |
+
# Step-level world progression adds realistic pressure dynamics while
|
| 572 |
+
# remaining deterministic and learnable for policy optimization.
|
| 573 |
+
self._advance_clock(minutes=20)
|
| 574 |
+
now = _parse_dt(self.world.simulation_time)
|
| 575 |
+
|
| 576 |
+
if action_ok and action.action_type == "reply_email" and action.email_id:
|
| 577 |
+
em = next((e for e in self.world.emails if e.id == action.email_id), None)
|
| 578 |
+
if em:
|
| 579 |
+
self._shift_contact_mood(em.sender, +1)
|
| 580 |
+
|
| 581 |
+
if action_ok and action.action_type == "send_message" and action.contact_name:
|
| 582 |
+
self._shift_contact_mood(action.contact_name.strip(), +1)
|
| 583 |
+
|
| 584 |
+
if action_ok and action.action_type == "cancel_meeting" and action.meeting_id:
|
| 585 |
+
mtg = next((m for m in self.world.meetings if m.id == action.meeting_id), None)
|
| 586 |
+
if mtg:
|
| 587 |
+
for attendee in mtg.attendees:
|
| 588 |
+
self._shift_contact_mood(attendee, -1)
|
| 589 |
+
|
| 590 |
+
# Pressure escalation only on idle/invalid behavior to keep
|
| 591 |
+
# action-quality separation sharp for learning.
|
| 592 |
+
if (not action_ok) or action.action_type == "do_nothing":
|
| 593 |
+
critical_pending = [e for e in self.world.emails if e.priority == "critical" and not e.replied]
|
| 594 |
+
if critical_pending:
|
| 595 |
+
self._shift_contact_mood(critical_pending[0].sender, -1)
|
| 596 |
+
|
| 597 |
+
# Meetings that have already ended without cancellation and still overlap
|
| 598 |
+
# indicate unresolved calendar debt; this increases stress pressure.
|
| 599 |
+
unresolved_past_conflicts = 0
|
| 600 |
+
for row in self.detect_meeting_conflicts():
|
| 601 |
+
overlap_end = _parse_dt(row["overlap_end"])
|
| 602 |
+
if overlap_end <= now:
|
| 603 |
+
unresolved_past_conflicts += 1
|
| 604 |
+
if unresolved_past_conflicts > 0:
|
| 605 |
+
self.world.action_log.append(
|
| 606 |
+
f"pressure: {unresolved_past_conflicts} unresolved past overlap(s) increased stress pressure."
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
def _ensure_reward_log_dir(self) -> None:
|
| 610 |
self._reward_log_path.parent.mkdir(parents=True, exist_ok=True)
|
| 611 |
|
|
|
|
| 632 |
"task": breakdown.task,
|
| 633 |
"weighted_base": breakdown.weighted_base,
|
| 634 |
"output_scale": breakdown.output_scale,
|
| 635 |
+
"shaping_synergy": breakdown.shaping_synergy,
|
| 636 |
+
"shaping_tradeoff": breakdown.shaping_tradeoff,
|
| 637 |
+
"shaping_potential": breakdown.shaping_potential,
|
| 638 |
+
"shaping_scaffold": breakdown.shaping_scaffold,
|
| 639 |
+
"shaping_quality": breakdown.shaping_quality,
|
| 640 |
+
"shaping_total": breakdown.shaping_total,
|
| 641 |
+
"shaping_to_base_ratio": breakdown.shaping_to_base_ratio,
|
| 642 |
"invalid_step_adjustment": breakdown.invalid_step_adjustment,
|
| 643 |
"episode_completion_bonus": breakdown.episode_completion_bonus,
|
| 644 |
"catastrophic_penalty": breakdown.catastrophic_penalty,
|
|
|
|
| 646 |
"calendar_overlap_pairs": len(self.detect_meeting_conflicts()),
|
| 647 |
"critical_unreplied": crit_open,
|
| 648 |
"overdue_tasks": overdue_n,
|
| 649 |
+
"reward_mode": self._reward_mode,
|
| 650 |
}
|
| 651 |
with self._reward_log_path.open("a", encoding="utf-8") as fh:
|
| 652 |
fh.write(json.dumps(line) + "\n")
|
server/reward.py
CHANGED
|
@@ -12,6 +12,7 @@ scaling.
|
|
| 12 |
|
| 13 |
from __future__ import annotations
|
| 14 |
|
|
|
|
| 15 |
from datetime import datetime, timedelta, timezone
|
| 16 |
from typing import Any
|
| 17 |
|
|
@@ -41,6 +42,12 @@ _SEND_MESSAGE_VALID_MICRO_BONUS: float = 0.08
|
|
| 41 |
_COMPLETE_TASK_VALID_MICRO_BONUS: float = 0.06
|
| 42 |
_DELEGATE_TASK_VALID_MICRO_BONUS: float = 0.10
|
| 43 |
_DO_NOTHING_STRICT_PENALTY: float = -0.15
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
_REPLY_PRIORITY_MICRO_BONUS: dict[str, float] = {
|
| 45 |
"critical": 0.30,
|
| 46 |
"high": 0.15,
|
|
@@ -261,6 +268,95 @@ def catastrophic(world: WorldState) -> bool:
|
|
| 261 |
return vip_furious or critical_open > 3
|
| 262 |
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
def aggregate_scores(
|
| 265 |
conflict: float,
|
| 266 |
relationship: float,
|
|
@@ -269,11 +365,23 @@ def aggregate_scores(
|
|
| 269 |
conflict_raw: float,
|
| 270 |
critical_queue_bonus: float,
|
| 271 |
weighted_inner: float,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
action_ok: bool,
|
| 273 |
episode_done: bool,
|
| 274 |
world_after: WorldState,
|
| 275 |
) -> RewardBreakdown:
|
| 276 |
weighted = WEIGHTED_OUTPUT_SCALE * weighted_inner
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
inv = 0.0
|
| 278 |
if not action_ok:
|
| 279 |
inv = -0.25
|
|
@@ -291,6 +399,13 @@ def aggregate_scores(
|
|
| 291 |
conflict=conflict,
|
| 292 |
relationship=relationship,
|
| 293 |
task=task,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
weighted_base=weighted,
|
| 295 |
output_scale=WEIGHTED_OUTPUT_SCALE,
|
| 296 |
invalid_step_adjustment=inv,
|
|
@@ -322,6 +437,9 @@ def compute_step_reward(
|
|
| 322 |
action_ok: bool,
|
| 323 |
episode_done: bool,
|
| 324 |
relationship_suppressed_for_email_to: frozenset[str] | None = None,
|
|
|
|
|
|
|
|
|
|
| 325 |
) -> RewardBreakdown:
|
| 326 |
c_core = score_conflict_resolution(before, after, action, action_ok=action_ok)
|
| 327 |
crit_b = score_critical_queue_bonus(before, after)
|
|
@@ -335,7 +453,48 @@ def compute_step_reward(
|
|
| 335 |
relationship_suppressed_for_email_to=relationship_suppressed_for_email_to,
|
| 336 |
)
|
| 337 |
t = score_task_completion(before, after, action, action_ok=action_ok)
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
bd = aggregate_scores(
|
| 340 |
c,
|
| 341 |
r,
|
|
@@ -343,6 +502,12 @@ def compute_step_reward(
|
|
| 343 |
conflict_raw=c_raw,
|
| 344 |
critical_queue_bonus=crit_b,
|
| 345 |
weighted_inner=weighted_inner,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
action_ok=action_ok,
|
| 347 |
episode_done=episode_done,
|
| 348 |
world_after=after,
|
|
|
|
| 12 |
|
| 13 |
from __future__ import annotations
|
| 14 |
|
| 15 |
+
import math
|
| 16 |
from datetime import datetime, timedelta, timezone
|
| 17 |
from typing import Any
|
| 18 |
|
|
|
|
| 42 |
_COMPLETE_TASK_VALID_MICRO_BONUS: float = 0.06
|
| 43 |
_DELEGATE_TASK_VALID_MICRO_BONUS: float = 0.10
|
| 44 |
_DO_NOTHING_STRICT_PENALTY: float = -0.15
|
| 45 |
+
_SYNERGY_CAP: float = 0.40
|
| 46 |
+
_TRADEOFF_CAP: float = 0.30
|
| 47 |
+
_POTENTIAL_CAP: float = 0.25
|
| 48 |
+
_SCAFFOLD_CAP: float = 0.35
|
| 49 |
+
_SHAPING_TO_BASE_BUDGET: float = 1.25
|
| 50 |
+
_QUALITY_CAP: float = 0.28
|
| 51 |
_REPLY_PRIORITY_MICRO_BONUS: dict[str, float] = {
|
| 52 |
"critical": 0.30,
|
| 53 |
"high": 0.15,
|
|
|
|
| 268 |
return vip_furious or critical_open > 3
|
| 269 |
|
| 270 |
|
| 271 |
+
def _scaffold_learning_signal(
|
| 272 |
+
before: WorldState,
|
| 273 |
+
after: WorldState,
|
| 274 |
+
action: GhostexecAction,
|
| 275 |
+
*,
|
| 276 |
+
action_ok: bool,
|
| 277 |
+
step_index: int | None,
|
| 278 |
+
max_steps: int | None,
|
| 279 |
+
) -> float:
|
| 280 |
+
if not action_ok:
|
| 281 |
+
return 0.0
|
| 282 |
+
if action.action_type == "do_nothing":
|
| 283 |
+
return 0.0
|
| 284 |
+
s = 0.0
|
| 285 |
+
critical_before = critical_unreplied_count(before)
|
| 286 |
+
critical_after = critical_unreplied_count(after)
|
| 287 |
+
conflict_before = len(meeting_conflicts(before))
|
| 288 |
+
conflict_after = len(meeting_conflicts(after))
|
| 289 |
+
overdue_before = len(_overdue_tasks(before))
|
| 290 |
+
overdue_after = len(_overdue_tasks(after))
|
| 291 |
+
if action.action_type == "reply_email":
|
| 292 |
+
if critical_after < critical_before:
|
| 293 |
+
s += 0.16
|
| 294 |
+
elif critical_before > 0:
|
| 295 |
+
s += 0.05
|
| 296 |
+
if action.action_type in ("reschedule_meeting", "cancel_meeting"):
|
| 297 |
+
if conflict_after < conflict_before:
|
| 298 |
+
s += 0.15
|
| 299 |
+
elif conflict_before > 0:
|
| 300 |
+
s += 0.04
|
| 301 |
+
if action.action_type in ("complete_task", "delegate_task"):
|
| 302 |
+
if overdue_after < overdue_before:
|
| 303 |
+
s += 0.12
|
| 304 |
+
elif overdue_before > 0:
|
| 305 |
+
s += 0.03
|
| 306 |
+
# Early episode shaping slightly amplified for better exploration guidance.
|
| 307 |
+
if step_index is not None and max_steps and max_steps > 0:
|
| 308 |
+
frac = max(0.0, min(1.0, step_index / max_steps))
|
| 309 |
+
if frac <= 0.33:
|
| 310 |
+
s *= 1.20
|
| 311 |
+
elif frac >= 0.85:
|
| 312 |
+
s *= 0.90
|
| 313 |
+
|
| 314 |
+
return max(-_SCAFFOLD_CAP, min(_SCAFFOLD_CAP, s))
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def _state_potential(world: WorldState) -> float:
|
| 318 |
+
conflicts = len(meeting_conflicts(world))
|
| 319 |
+
critical_open = critical_unreplied_count(world)
|
| 320 |
+
overdue = len(_overdue_tasks(world))
|
| 321 |
+
stress = float(world.stress)
|
| 322 |
+
# Lower operational pressure => higher potential.
|
| 323 |
+
return -(
|
| 324 |
+
1.15 * critical_open
|
| 325 |
+
+ 0.90 * conflicts
|
| 326 |
+
+ 0.55 * overdue
|
| 327 |
+
+ 0.02 * stress
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def _budgeted_shaping_total(base_weighted_inner: float, shaping_total_inner: float) -> float:
|
| 332 |
+
# Keep shaping informative but bounded against the base objective to avoid exploit loops.
|
| 333 |
+
budget = _SHAPING_TO_BASE_BUDGET * (abs(base_weighted_inner) + 0.05)
|
| 334 |
+
return max(-budget, min(budget, shaping_total_inner))
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def _quality_separation_signal(
|
| 338 |
+
*,
|
| 339 |
+
c: float,
|
| 340 |
+
r: float,
|
| 341 |
+
t: float,
|
| 342 |
+
action: GhostexecAction,
|
| 343 |
+
action_ok: bool,
|
| 344 |
+
) -> float:
|
| 345 |
+
# Amplify distance between clearly good vs clearly bad valid actions.
|
| 346 |
+
if not action_ok or action.action_type == "do_nothing":
|
| 347 |
+
return 0.0
|
| 348 |
+
base = W_CONFLICT * c + W_REL * r + W_TASK * t
|
| 349 |
+
if base >= 0.90:
|
| 350 |
+
return _QUALITY_CAP
|
| 351 |
+
if base >= 0.35:
|
| 352 |
+
return 0.12
|
| 353 |
+
if base <= -0.90:
|
| 354 |
+
return -_QUALITY_CAP
|
| 355 |
+
if base <= -0.35:
|
| 356 |
+
return -0.12
|
| 357 |
+
return 0.0
|
| 358 |
+
|
| 359 |
+
|
| 360 |
def aggregate_scores(
|
| 361 |
conflict: float,
|
| 362 |
relationship: float,
|
|
|
|
| 365 |
conflict_raw: float,
|
| 366 |
critical_queue_bonus: float,
|
| 367 |
weighted_inner: float,
|
| 368 |
+
weighted_base_only: float,
|
| 369 |
+
shaping_synergy: float,
|
| 370 |
+
shaping_tradeoff: float,
|
| 371 |
+
shaping_potential: float,
|
| 372 |
+
shaping_scaffold: float,
|
| 373 |
+
shaping_quality: float,
|
| 374 |
action_ok: bool,
|
| 375 |
episode_done: bool,
|
| 376 |
world_after: WorldState,
|
| 377 |
) -> RewardBreakdown:
|
| 378 |
weighted = WEIGHTED_OUTPUT_SCALE * weighted_inner
|
| 379 |
+
weighted_base_only_scaled = WEIGHTED_OUTPUT_SCALE * weighted_base_only
|
| 380 |
+
shaping_total = WEIGHTED_OUTPUT_SCALE * (
|
| 381 |
+
shaping_synergy + shaping_tradeoff + shaping_potential + shaping_scaffold + shaping_quality
|
| 382 |
+
)
|
| 383 |
+
denom = abs(weighted_base_only_scaled) + 1e-6
|
| 384 |
+
shaping_ratio = min(10.0, abs(shaping_total) / denom)
|
| 385 |
inv = 0.0
|
| 386 |
if not action_ok:
|
| 387 |
inv = -0.25
|
|
|
|
| 399 |
conflict=conflict,
|
| 400 |
relationship=relationship,
|
| 401 |
task=task,
|
| 402 |
+
shaping_synergy=WEIGHTED_OUTPUT_SCALE * shaping_synergy,
|
| 403 |
+
shaping_tradeoff=WEIGHTED_OUTPUT_SCALE * shaping_tradeoff,
|
| 404 |
+
shaping_potential=WEIGHTED_OUTPUT_SCALE * shaping_potential,
|
| 405 |
+
shaping_scaffold=WEIGHTED_OUTPUT_SCALE * shaping_scaffold,
|
| 406 |
+
shaping_quality=WEIGHTED_OUTPUT_SCALE * shaping_quality,
|
| 407 |
+
shaping_total=shaping_total,
|
| 408 |
+
shaping_to_base_ratio=shaping_ratio,
|
| 409 |
weighted_base=weighted,
|
| 410 |
output_scale=WEIGHTED_OUTPUT_SCALE,
|
| 411 |
invalid_step_adjustment=inv,
|
|
|
|
| 437 |
action_ok: bool,
|
| 438 |
episode_done: bool,
|
| 439 |
relationship_suppressed_for_email_to: frozenset[str] | None = None,
|
| 440 |
+
reward_mode: str = "full",
|
| 441 |
+
step_index: int | None = None,
|
| 442 |
+
max_steps: int | None = None,
|
| 443 |
) -> RewardBreakdown:
|
| 444 |
c_core = score_conflict_resolution(before, after, action, action_ok=action_ok)
|
| 445 |
crit_b = score_critical_queue_bonus(before, after)
|
|
|
|
| 453 |
relationship_suppressed_for_email_to=relationship_suppressed_for_email_to,
|
| 454 |
)
|
| 455 |
t = score_task_completion(before, after, action, action_ok=action_ok)
|
| 456 |
+
weighted_base_only = W_CONFLICT * c + W_REL * r + W_TASK * t
|
| 457 |
+
weighted_inner = weighted_base_only
|
| 458 |
+
synergy = 0.0
|
| 459 |
+
tradeoff_penalty = 0.0
|
| 460 |
+
potential_progress = 0.0
|
| 461 |
+
scaffold_signal = 0.0
|
| 462 |
+
quality_signal = 0.0
|
| 463 |
+
if reward_mode in ("full", "shaping"):
|
| 464 |
+
# Bounded nonlinear shaping to speed learning without overpowering base channels.
|
| 465 |
+
if c > 0.0 and r > 0.0:
|
| 466 |
+
synergy += min(_SYNERGY_CAP, 0.18 * math.tanh(0.35 * c * r))
|
| 467 |
+
if t > 0.0 and (c > 0.0 or r > 0.0):
|
| 468 |
+
bridge = max(c, 0.0) + max(r, 0.0)
|
| 469 |
+
synergy += min(_SYNERGY_CAP, 0.10 * math.tanh(0.25 * t * bridge))
|
| 470 |
+
if c < -0.5 and r < -0.5:
|
| 471 |
+
tradeoff_penalty -= min(_TRADEOFF_CAP, 0.12 * math.tanh(0.25 * abs(c * r)))
|
| 472 |
+
if t < -0.5 and (c < 0.0 or r < 0.0):
|
| 473 |
+
debt = abs(t) * (abs(min(c, 0.0)) + abs(min(r, 0.0)))
|
| 474 |
+
tradeoff_penalty -= min(_TRADEOFF_CAP, 0.08 * math.tanh(0.18 * debt))
|
| 475 |
+
potential_progress = max(
|
| 476 |
+
-_POTENTIAL_CAP,
|
| 477 |
+
min(_POTENTIAL_CAP, _state_potential(after) - _state_potential(before)),
|
| 478 |
+
)
|
| 479 |
+
scaffold_signal = _scaffold_learning_signal(
|
| 480 |
+
before,
|
| 481 |
+
after,
|
| 482 |
+
action,
|
| 483 |
+
action_ok=action_ok,
|
| 484 |
+
step_index=step_index,
|
| 485 |
+
max_steps=max_steps,
|
| 486 |
+
)
|
| 487 |
+
quality_signal = _quality_separation_signal(
|
| 488 |
+
c=c,
|
| 489 |
+
r=r,
|
| 490 |
+
t=t,
|
| 491 |
+
action=action,
|
| 492 |
+
action_ok=action_ok,
|
| 493 |
+
)
|
| 494 |
+
shaping_total_inner = (
|
| 495 |
+
synergy + tradeoff_penalty + potential_progress + scaffold_signal + quality_signal
|
| 496 |
+
)
|
| 497 |
+
weighted_inner += _budgeted_shaping_total(weighted_base_only, shaping_total_inner)
|
| 498 |
bd = aggregate_scores(
|
| 499 |
c,
|
| 500 |
r,
|
|
|
|
| 502 |
conflict_raw=c_raw,
|
| 503 |
critical_queue_bonus=crit_b,
|
| 504 |
weighted_inner=weighted_inner,
|
| 505 |
+
weighted_base_only=weighted_base_only,
|
| 506 |
+
shaping_synergy=synergy,
|
| 507 |
+
shaping_tradeoff=tradeoff_penalty,
|
| 508 |
+
shaping_potential=potential_progress,
|
| 509 |
+
shaping_scaffold=scaffold_signal,
|
| 510 |
+
shaping_quality=quality_signal,
|
| 511 |
action_ok=action_ok,
|
| 512 |
episode_done=episode_done,
|
| 513 |
world_after=after,
|
tests/test_phase4.py
CHANGED
|
@@ -28,6 +28,12 @@ def test_reward_weights_and_aggregator_helpers():
|
|
| 28 |
conflict_raw=c,
|
| 29 |
critical_queue_bonus=0.0,
|
| 30 |
weighted_inner=weighted_inner,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
action_ok=True,
|
| 32 |
episode_done=False,
|
| 33 |
world_after=w,
|
|
@@ -92,6 +98,25 @@ def test_scripted_episode_reward_direction_and_log(tmp_path, monkeypatch):
|
|
| 92 |
assert "reward" in row and "episode_id" in row
|
| 93 |
assert row.get("action_type") == "reschedule_meeting"
|
| 94 |
assert "conflict_raw" in row and "step_ok" in row
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
|
| 97 |
def test_schema_drift_events_mutate_world():
|
|
|
|
| 28 |
conflict_raw=c,
|
| 29 |
critical_queue_bonus=0.0,
|
| 30 |
weighted_inner=weighted_inner,
|
| 31 |
+
weighted_base_only=weighted_inner,
|
| 32 |
+
shaping_synergy=0.0,
|
| 33 |
+
shaping_tradeoff=0.0,
|
| 34 |
+
shaping_potential=0.0,
|
| 35 |
+
shaping_scaffold=0.0,
|
| 36 |
+
shaping_quality=0.0,
|
| 37 |
action_ok=True,
|
| 38 |
episode_done=False,
|
| 39 |
world_after=w,
|
|
|
|
| 98 |
assert "reward" in row and "episode_id" in row
|
| 99 |
assert row.get("action_type") == "reschedule_meeting"
|
| 100 |
assert "conflict_raw" in row and "step_ok" in row
|
| 101 |
+
assert "shaping_total" in row and "shaping_to_base_ratio" in row
|
| 102 |
+
assert "shaping_scaffold" in row
|
| 103 |
+
assert row.get("reward_mode") == "full"
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def test_reward_mode_base_turns_off_shaping_terms():
|
| 107 |
+
env = GhostexecEnvironment(SCENARIO, reward_mode="base")
|
| 108 |
+
env.reset()
|
| 109 |
+
obs = env.step(
|
| 110 |
+
GhostexecAction(
|
| 111 |
+
action_type="reschedule_meeting",
|
| 112 |
+
meeting_id="m02",
|
| 113 |
+
new_time="2026-04-21T18:00:00",
|
| 114 |
+
)
|
| 115 |
+
)
|
| 116 |
+
bd = (obs.metadata or {}).get("reward_breakdown") or {}
|
| 117 |
+
assert float(bd.get("shaping_synergy") or 0.0) == pytest.approx(0.0)
|
| 118 |
+
assert float(bd.get("shaping_tradeoff") or 0.0) == pytest.approx(0.0)
|
| 119 |
+
assert float(bd.get("shaping_potential") or 0.0) == pytest.approx(0.0)
|
| 120 |
|
| 121 |
|
| 122 |
def test_schema_drift_events_mutate_world():
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|