Spaces:
Paused
Paused
Commit ·
3326716
1
Parent(s): e955a2d
train(grpo): unified hint prompt, no-history chat, positive-advantage filter
Browse files- env: surface audience_active_hours + competitor_recent_post_hours in obs metadata
- prompt: single audience-hours hint, same for train + eval (clean delta = LoRA only)
- runner: drop assistant history (kills 4712-tok bloat); never append synthetic rest into training pairs; carry step idx for return back-up
- decode: greedy at eval, sampled (T=1.0, top_p=0.95) at rollout
- filter: positive group-relative advantage only; QUALITY_FLOOR=0.40 skips bad rounds
- LoRA: r=8 attn-only; lr 5e-6, 1 epoch, max_len 2048 (less drift)
Made-with: Cursor
- server/viraltest_environment.py +13 -0
- training/train_grpo.ipynb +65 -50
server/viraltest_environment.py
CHANGED
|
@@ -1097,6 +1097,19 @@ class ViraltestEnvironment(Environment):
|
|
| 1097 |
if grader_score is not None:
|
| 1098 |
meta["grader_score"] = round(grader_score, 4)
|
| 1099 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1100 |
burnout_risk = min(1.0, self._low_energy_days / 5.0)
|
| 1101 |
|
| 1102 |
return ViraltestObservation(
|
|
|
|
| 1097 |
if grader_score is not None:
|
| 1098 |
meta["grader_score"] = round(grader_score, 4)
|
| 1099 |
|
| 1100 |
+
audience_hours: set = set()
|
| 1101 |
+
for seg in _AUDIENCE_DATA.get("segments", []):
|
| 1102 |
+
audience_hours.update(seg.get("active_hours", []))
|
| 1103 |
+
meta["audience_active_hours"] = sorted(audience_hours)
|
| 1104 |
+
|
| 1105 |
+
comp_hours = [
|
| 1106 |
+
(self._hour - p["hours_ago"]) % 24
|
| 1107 |
+
for comp in self._competitors
|
| 1108 |
+
for p in comp.recent_posts
|
| 1109 |
+
if p["hours_ago"] < 48
|
| 1110 |
+
]
|
| 1111 |
+
meta["competitor_recent_post_hours"] = sorted(comp_hours)
|
| 1112 |
+
|
| 1113 |
burnout_risk = min(1.0, self._low_energy_days / 5.0)
|
| 1114 |
|
| 1115 |
return ViraltestObservation(
|
training/train_grpo.ipynb
CHANGED
|
@@ -400,7 +400,7 @@
|
|
| 400 |
"metadata": {},
|
| 401 |
"source": [
|
| 402 |
"# Cell 8: LLM agent functions\n",
|
| 403 |
-
"
|
| 404 |
"You are an Instagram content strategy agent. Each step is one day.\n",
|
| 405 |
"You manage a creator account over a 15-day cycle.\n",
|
| 406 |
"\n",
|
|
@@ -439,6 +439,12 @@
|
|
| 439 |
"- topic: free-form string\n",
|
| 440 |
"- empty scheduled_actions = full day rest\"\"\")\n",
|
| 441 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
"\n",
|
| 443 |
"def format_obs(obs):\n",
|
| 444 |
" days = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n",
|
|
@@ -449,6 +455,9 @@
|
|
| 449 |
" signals_str = (f\"Signals: watch={signals.watch_time:.3f} \"\n",
|
| 450 |
" f\"sends={signals.sends_per_reach:.3f} \"\n",
|
| 451 |
" f\"saves={signals.saves:.3f}\\n\")\n",
|
|
|
|
|
|
|
|
|
|
| 452 |
" tool_str = \"\"\n",
|
| 453 |
" for tr in getattr(obs, \"tool_results\", []):\n",
|
| 454 |
" if tr.success:\n",
|
|
@@ -459,8 +468,10 @@
|
|
| 459 |
" f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
|
| 460 |
" f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
|
| 461 |
" f\"{signals_str}\"\n",
|
|
|
|
|
|
|
| 462 |
" f\"Tool results:\\n{tool_str}\"\n",
|
| 463 |
-
" f\"Plan
|
| 464 |
"\n",
|
| 465 |
"\n",
|
| 466 |
"def is_well_formed_response(text):\n",
|
|
@@ -527,35 +538,37 @@
|
|
| 527 |
" return torch.device(\"cpu\")\n",
|
| 528 |
"\n",
|
| 529 |
"\n",
|
| 530 |
-
"def _build_chat(
|
| 531 |
-
"
|
| 532 |
-
"
|
| 533 |
-
"
|
| 534 |
-
"
|
| 535 |
"\n",
|
| 536 |
"\n",
|
| 537 |
-
"def _batched_generate(mdl, tok, prompts,
|
| 538 |
" enc = tok(prompts, return_tensors=\"pt\", padding=True, truncation=False).to(_infer_model_device(mdl))\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
" with torch.no_grad():\n",
|
| 540 |
-
" out = mdl.generate(\n",
|
| 541 |
-
" **enc, max_new_tokens=max_new_tokens, temperature=temperature,\n",
|
| 542 |
-
" do_sample=True, top_p=0.9, pad_token_id=tok.pad_token_id,\n",
|
| 543 |
-
" )\n",
|
| 544 |
" resps = tok.batch_decode(out[:, enc[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
|
| 545 |
" return resps, enc[\"input_ids\"].shape[1]\n",
|
| 546 |
"\n",
|
| 547 |
"\n",
|
| 548 |
-
"def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True):\n",
|
| 549 |
" \"\"\"Run N episodes in parallel. tasks_seeds: list of (task, seed). One batched generate per day.\"\"\"\n",
|
|
|
|
| 550 |
" n = len(tasks_seeds)\n",
|
| 551 |
" envs = [ViraltestEnvironment() for _ in range(n)]\n",
|
| 552 |
" obss = [envs[i].reset(task=t, seed=s) for i, (t, s) in enumerate(tasks_seeds)]\n",
|
| 553 |
-
" histories = [[] for _ in range(n)]\n",
|
| 554 |
" rewards = [[] for _ in range(n)]\n",
|
| 555 |
" energies = [[obs.creator_energy] for obs in obss]\n",
|
| 556 |
" pairs = [[] for _ in range(n)]\n",
|
| 557 |
" done_mask = [obs.done for obs in obss]\n",
|
| 558 |
-
"
|
| 559 |
"\n",
|
| 560 |
" for day in range(1, TASK_HORIZON + 1):\n",
|
| 561 |
" active = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy > 0.25]\n",
|
|
@@ -563,33 +576,26 @@
|
|
| 563 |
" if not active and not rest:\n",
|
| 564 |
" break\n",
|
| 565 |
"\n",
|
| 566 |
-
"
|
| 567 |
" if active:\n",
|
| 568 |
" prompts = [format_obs(obss[i]) for i in active]\n",
|
| 569 |
-
" chats = [_build_chat(
|
| 570 |
" texts = [tok.apply_chat_template(c, tokenize=False, add_generation_prompt=True) for c in chats]\n",
|
| 571 |
-
" resps, ptok = _batched_generate(mdl, tok, texts)\n",
|
| 572 |
" if verbose:\n",
|
| 573 |
" print(f\" D{day:2d}: batch={len(active)} rest={len(rest)} prompt_tok={ptok}\")\n",
|
| 574 |
" for j, i in enumerate(active):\n",
|
| 575 |
-
"
|
| 576 |
-
"
|
| 577 |
-
"
|
| 578 |
"\n",
|
| 579 |
" for i in range(n):\n",
|
| 580 |
-
" if done_mask[i] or i not in
|
| 581 |
" continue\n",
|
| 582 |
-
"
|
| 583 |
-
" action = parse_model_output(resp)\n",
|
| 584 |
-
" pairs[i].append({\"prompt\": prompt, \"response\": resp})\n",
|
| 585 |
-
" obss[i] = envs[i].step(action)\n",
|
| 586 |
" r = obss[i].reward or 0.0\n",
|
| 587 |
" rewards[i].append(r)\n",
|
| 588 |
" energies[i].append(obss[i].creator_energy)\n",
|
| 589 |
-
" histories[i].extend([\n",
|
| 590 |
-
" {\"role\": \"user\", \"content\": prompt},\n",
|
| 591 |
-
" {\"role\": \"assistant\", \"content\": resp},\n",
|
| 592 |
-
" ])\n",
|
| 593 |
" if obss[i].done:\n",
|
| 594 |
" done_mask[i] = True\n",
|
| 595 |
"\n",
|
|
@@ -602,8 +608,9 @@
|
|
| 602 |
" for t in reversed(range(len(rewards[i]))):\n",
|
| 603 |
" G = rewards[i][t] + GAMMA * G\n",
|
| 604 |
" rets[t] = G\n",
|
| 605 |
-
" for
|
| 606 |
-
" pr
|
|
|
|
| 607 |
" results.append({\n",
|
| 608 |
" \"task\": task, \"seed\": seed, \"grader_score\": gs,\n",
|
| 609 |
" \"total_reward\": sum(rewards[i]), \"final_energy\": obss[i].creator_energy,\n",
|
|
@@ -641,7 +648,7 @@
|
|
| 641 |
"print(\"=\" * 60)\n",
|
| 642 |
"\n",
|
| 643 |
"t0 = time.time()\n",
|
| 644 |
-
"results = run_llm_episodes_batched(model, tokenizer, [(t, 42) for t in TASKS], verbose=True)\n",
|
| 645 |
"before_results = {r[\"task\"]: r for r in results}\n",
|
| 646 |
"\n",
|
| 647 |
"print(\"\\n\" + \"=\" * 60)\n",
|
|
@@ -675,9 +682,8 @@
|
|
| 675 |
"from peft import LoraConfig, get_peft_model, TaskType\n",
|
| 676 |
"\n",
|
| 677 |
"lora_config = LoraConfig(\n",
|
| 678 |
-
" r=
|
| 679 |
-
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
|
| 680 |
-
" \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
| 681 |
" task_type=TaskType.CAUSAL_LM, bias=\"none\",\n",
|
| 682 |
")\n",
|
| 683 |
"\n",
|
|
@@ -698,7 +704,7 @@
|
|
| 698 |
"\n",
|
| 699 |
"NUM_ROUNDS = 4\n",
|
| 700 |
"EPISODES_PER_ROUND = 6\n",
|
| 701 |
-
"
|
| 702 |
"\n",
|
| 703 |
"training_log = {\n",
|
| 704 |
" \"round\": [], \"avg_episode_reward\": [], \"max_episode_reward\": [],\n",
|
|
@@ -716,7 +722,8 @@
|
|
| 716 |
" peft_model.eval()\n",
|
| 717 |
" tasks_seeds = [(TASKS[ep % len(TASKS)], 42 + (round_idx - 1) * 100 + ep) for ep in range(EPISODES_PER_ROUND)]\n",
|
| 718 |
" t_roll = time.time()\n",
|
| 719 |
-
" results = run_llm_episodes_batched(peft_model, tokenizer, tasks_seeds, verbose=False
|
|
|
|
| 720 |
" print(f\" Rollouts: {len(results)} eps × {TASK_HORIZON} days in {time.time()-t_roll:.1f}s\")\n",
|
| 721 |
"\n",
|
| 722 |
" all_pairs, episode_rewards, episode_graders = [], [], []\n",
|
|
@@ -728,7 +735,7 @@
|
|
| 728 |
" for pr in result[\"pairs\"]:\n",
|
| 729 |
" if not is_well_formed_response(pr[\"response\"]):\n",
|
| 730 |
" continue\n",
|
| 731 |
-
" text = (f\"<|im_start|>system\\n{
|
| 732 |
" f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
|
| 733 |
" f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\")\n",
|
| 734 |
" all_pairs.append({\"text\": text, \"reward\": pr[\"return\"]})\n",
|
|
@@ -738,28 +745,36 @@
|
|
| 738 |
"\n",
|
| 739 |
" avg_r = float(np.mean(episode_rewards))\n",
|
| 740 |
" avg_g = float(np.mean(episode_graders))\n",
|
| 741 |
-
"
|
|
|
|
| 742 |
" if not all_pairs:\n",
|
| 743 |
" print(\" WARNING: 0 well-formed pairs collected; skipping SFT.\")\n",
|
| 744 |
" continue\n",
|
|
|
|
|
|
|
|
|
|
| 745 |
"\n",
|
| 746 |
-
"
|
| 747 |
-
"
|
| 748 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 749 |
"\n",
|
| 750 |
" dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
|
| 751 |
"\n",
|
| 752 |
" # SFT training (real gradient updates)\n",
|
| 753 |
" sft_config = SFTConfig(\n",
|
| 754 |
" output_dir=f\"./checkpoints/round_{round_idx}\",\n",
|
| 755 |
-
" num_train_epochs=
|
| 756 |
-
" per_device_train_batch_size=
|
| 757 |
-
" gradient_accumulation_steps=
|
| 758 |
-
" learning_rate=
|
| 759 |
-
"
|
| 760 |
" logging_steps=1,\n",
|
| 761 |
" save_strategy=\"no\",\n",
|
| 762 |
-
" max_length=
|
| 763 |
" bf16=True,\n",
|
| 764 |
" report_to=\"none\",\n",
|
| 765 |
" )\n",
|
|
@@ -808,7 +823,7 @@
|
|
| 808 |
"\n",
|
| 809 |
"peft_model.eval()\n",
|
| 810 |
"t0 = time.time()\n",
|
| 811 |
-
"results = run_llm_episodes_batched(peft_model, tokenizer, [(t, 42) for t in TASKS], verbose=True)\n",
|
| 812 |
"after_results = {r[\"task\"]: r for r in results}\n",
|
| 813 |
"\n",
|
| 814 |
"print(\"\\n\" + \"=\" * 60)\n",
|
|
|
|
| 400 |
"metadata": {},
|
| 401 |
"source": [
|
| 402 |
"# Cell 8: LLM agent functions\n",
|
| 403 |
+
"_SYSTEM_BASE = textwrap.dedent(\"\"\"\\\n",
|
| 404 |
"You are an Instagram content strategy agent. Each step is one day.\n",
|
| 405 |
"You manage a creator account over a 15-day cycle.\n",
|
| 406 |
"\n",
|
|
|
|
| 439 |
"- topic: free-form string\n",
|
| 440 |
"- empty scheduled_actions = full day rest\"\"\")\n",
|
| 441 |
"\n",
|
| 442 |
+
"SYSTEM_PROMPT = _SYSTEM_BASE + textwrap.dedent(\"\"\"\n",
|
| 443 |
+
"\n",
|
| 444 |
+
"HINT: schedule posts during/just before the audience_active_hours window — that is when your target users are online.\"\"\")\n",
|
| 445 |
+
"SYSTEM_PROMPT_EVAL = SYSTEM_PROMPT\n",
|
| 446 |
+
"SYSTEM_PROMPT_TRAIN = SYSTEM_PROMPT\n",
|
| 447 |
+
"\n",
|
| 448 |
"\n",
|
| 449 |
"def format_obs(obs):\n",
|
| 450 |
" days = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n",
|
|
|
|
| 455 |
" signals_str = (f\"Signals: watch={signals.watch_time:.3f} \"\n",
|
| 456 |
" f\"sends={signals.sends_per_reach:.3f} \"\n",
|
| 457 |
" f\"saves={signals.saves:.3f}\\n\")\n",
|
| 458 |
+
" meta = getattr(obs, \"metadata\", None) or {}\n",
|
| 459 |
+
" aud = meta.get(\"audience_active_hours\") or []\n",
|
| 460 |
+
" comp = meta.get(\"competitor_recent_post_hours\") or []\n",
|
| 461 |
" tool_str = \"\"\n",
|
| 462 |
" for tr in getattr(obs, \"tool_results\", []):\n",
|
| 463 |
" if tr.success:\n",
|
|
|
|
| 468 |
" f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
|
| 469 |
" f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
|
| 470 |
" f\"{signals_str}\"\n",
|
| 471 |
+
" f\"audience_active_hours: {aud}\\n\"\n",
|
| 472 |
+
" f\"competitor_recent_post_hours: {comp}\\n\"\n",
|
| 473 |
" f\"Tool results:\\n{tool_str}\"\n",
|
| 474 |
+
" f\"Plan today's actions (JSON only):\")\n",
|
| 475 |
"\n",
|
| 476 |
"\n",
|
| 477 |
"def is_well_formed_response(text):\n",
|
|
|
|
| 538 |
" return torch.device(\"cpu\")\n",
|
| 539 |
"\n",
|
| 540 |
"\n",
|
| 541 |
+
"def _build_chat(system, prompt):\n",
|
| 542 |
+
" return [\n",
|
| 543 |
+
" {\"role\": \"system\", \"content\": system},\n",
|
| 544 |
+
" {\"role\": \"user\", \"content\": prompt},\n",
|
| 545 |
+
" ]\n",
|
| 546 |
"\n",
|
| 547 |
"\n",
|
| 548 |
+
"def _batched_generate(mdl, tok, prompts, eval=False, max_new_tokens=512):\n",
|
| 549 |
" enc = tok(prompts, return_tensors=\"pt\", padding=True, truncation=False).to(_infer_model_device(mdl))\n",
|
| 550 |
+
" gen_kwargs = dict(max_new_tokens=max_new_tokens, pad_token_id=tok.pad_token_id)\n",
|
| 551 |
+
" if eval:\n",
|
| 552 |
+
" gen_kwargs.update(do_sample=False)\n",
|
| 553 |
+
" else:\n",
|
| 554 |
+
" gen_kwargs.update(do_sample=True, temperature=1.0, top_p=0.95)\n",
|
| 555 |
" with torch.no_grad():\n",
|
| 556 |
+
" out = mdl.generate(**enc, **gen_kwargs)\n",
|
|
|
|
|
|
|
|
|
|
| 557 |
" resps = tok.batch_decode(out[:, enc[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
|
| 558 |
" return resps, enc[\"input_ids\"].shape[1]\n",
|
| 559 |
"\n",
|
| 560 |
"\n",
|
| 561 |
+
"def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True, eval=False, system=None):\n",
|
| 562 |
" \"\"\"Run N episodes in parallel. tasks_seeds: list of (task, seed). One batched generate per day.\"\"\"\n",
|
| 563 |
+
" sys_prompt = system or (SYSTEM_PROMPT_EVAL if eval else SYSTEM_PROMPT_TRAIN)\n",
|
| 564 |
" n = len(tasks_seeds)\n",
|
| 565 |
" envs = [ViraltestEnvironment() for _ in range(n)]\n",
|
| 566 |
" obss = [envs[i].reset(task=t, seed=s) for i, (t, s) in enumerate(tasks_seeds)]\n",
|
|
|
|
| 567 |
" rewards = [[] for _ in range(n)]\n",
|
| 568 |
" energies = [[obs.creator_energy] for obs in obss]\n",
|
| 569 |
" pairs = [[] for _ in range(n)]\n",
|
| 570 |
" done_mask = [obs.done for obs in obss]\n",
|
| 571 |
+
" rest_action = ViraltestAction(scheduled_actions=[])\n",
|
| 572 |
"\n",
|
| 573 |
" for day in range(1, TASK_HORIZON + 1):\n",
|
| 574 |
" active = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy > 0.25]\n",
|
|
|
|
| 576 |
" if not active and not rest:\n",
|
| 577 |
" break\n",
|
| 578 |
"\n",
|
| 579 |
+
" actions_by_idx = {i: rest_action for i in rest}\n",
|
| 580 |
" if active:\n",
|
| 581 |
" prompts = [format_obs(obss[i]) for i in active]\n",
|
| 582 |
+
" chats = [_build_chat(sys_prompt, p) for p in prompts]\n",
|
| 583 |
" texts = [tok.apply_chat_template(c, tokenize=False, add_generation_prompt=True) for c in chats]\n",
|
| 584 |
+
" resps, ptok = _batched_generate(mdl, tok, texts, eval=eval)\n",
|
| 585 |
" if verbose:\n",
|
| 586 |
" print(f\" D{day:2d}: batch={len(active)} rest={len(rest)} prompt_tok={ptok}\")\n",
|
| 587 |
" for j, i in enumerate(active):\n",
|
| 588 |
+
" actions_by_idx[i] = parse_model_output(resps[j])\n",
|
| 589 |
+
" pairs[i].append({\"prompt\": prompts[j], \"response\": resps[j],\n",
|
| 590 |
+
" \"step\": len(rewards[i])})\n",
|
| 591 |
"\n",
|
| 592 |
" for i in range(n):\n",
|
| 593 |
+
" if done_mask[i] or i not in actions_by_idx:\n",
|
| 594 |
" continue\n",
|
| 595 |
+
" obss[i] = envs[i].step(actions_by_idx[i])\n",
|
|
|
|
|
|
|
|
|
|
| 596 |
" r = obss[i].reward or 0.0\n",
|
| 597 |
" rewards[i].append(r)\n",
|
| 598 |
" energies[i].append(obss[i].creator_energy)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 599 |
" if obss[i].done:\n",
|
| 600 |
" done_mask[i] = True\n",
|
| 601 |
"\n",
|
|
|
|
| 608 |
" for t in reversed(range(len(rewards[i]))):\n",
|
| 609 |
" G = rewards[i][t] + GAMMA * G\n",
|
| 610 |
" rets[t] = G\n",
|
| 611 |
+
" for pr in pairs[i]:\n",
|
| 612 |
+
" k = pr.get(\"step\", 0)\n",
|
| 613 |
+
" pr[\"return\"] = rets[k] if 0 <= k < len(rets) else 0.0\n",
|
| 614 |
" results.append({\n",
|
| 615 |
" \"task\": task, \"seed\": seed, \"grader_score\": gs,\n",
|
| 616 |
" \"total_reward\": sum(rewards[i]), \"final_energy\": obss[i].creator_energy,\n",
|
|
|
|
| 648 |
"print(\"=\" * 60)\n",
|
| 649 |
"\n",
|
| 650 |
"t0 = time.time()\n",
|
| 651 |
+
"results = run_llm_episodes_batched(model, tokenizer, [(t, 42) for t in TASKS], verbose=True, eval=True)\n",
|
| 652 |
"before_results = {r[\"task\"]: r for r in results}\n",
|
| 653 |
"\n",
|
| 654 |
"print(\"\\n\" + \"=\" * 60)\n",
|
|
|
|
| 682 |
"from peft import LoraConfig, get_peft_model, TaskType\n",
|
| 683 |
"\n",
|
| 684 |
"lora_config = LoraConfig(\n",
|
| 685 |
+
" r=8, lora_alpha=16, lora_dropout=0.05,\n",
|
| 686 |
+
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n",
|
|
|
|
| 687 |
" task_type=TaskType.CAUSAL_LM, bias=\"none\",\n",
|
| 688 |
")\n",
|
| 689 |
"\n",
|
|
|
|
| 704 |
"\n",
|
| 705 |
"NUM_ROUNDS = 4\n",
|
| 706 |
"EPISODES_PER_ROUND = 6\n",
|
| 707 |
+
"QUALITY_FLOOR = 0.40 # skip SFT for the round if no episode beats this grader score\n",
|
| 708 |
"\n",
|
| 709 |
"training_log = {\n",
|
| 710 |
" \"round\": [], \"avg_episode_reward\": [], \"max_episode_reward\": [],\n",
|
|
|
|
| 722 |
" peft_model.eval()\n",
|
| 723 |
" tasks_seeds = [(TASKS[ep % len(TASKS)], 42 + (round_idx - 1) * 100 + ep) for ep in range(EPISODES_PER_ROUND)]\n",
|
| 724 |
" t_roll = time.time()\n",
|
| 725 |
+
" results = run_llm_episodes_batched(peft_model, tokenizer, tasks_seeds, verbose=False,\n",
|
| 726 |
+
" eval=False, system=SYSTEM_PROMPT_TRAIN)\n",
|
| 727 |
" print(f\" Rollouts: {len(results)} eps × {TASK_HORIZON} days in {time.time()-t_roll:.1f}s\")\n",
|
| 728 |
"\n",
|
| 729 |
" all_pairs, episode_rewards, episode_graders = [], [], []\n",
|
|
|
|
| 735 |
" for pr in result[\"pairs\"]:\n",
|
| 736 |
" if not is_well_formed_response(pr[\"response\"]):\n",
|
| 737 |
" continue\n",
|
| 738 |
+
" text = (f\"<|im_start|>system\\n{SYSTEM_PROMPT_TRAIN}<|im_end|>\\n\"\n",
|
| 739 |
" f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
|
| 740 |
" f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\")\n",
|
| 741 |
" all_pairs.append({\"text\": text, \"reward\": pr[\"return\"]})\n",
|
|
|
|
| 745 |
"\n",
|
| 746 |
" avg_r = float(np.mean(episode_rewards))\n",
|
| 747 |
" avg_g = float(np.mean(episode_graders))\n",
|
| 748 |
+
" max_g = float(max(episode_graders))\n",
|
| 749 |
+
" print(f\" Avg reward={avg_r:.3f} Avg grader={avg_g:.4f} max_grader={max_g:.4f} | pairs={len(all_pairs)}\")\n",
|
| 750 |
" if not all_pairs:\n",
|
| 751 |
" print(\" WARNING: 0 well-formed pairs collected; skipping SFT.\")\n",
|
| 752 |
" continue\n",
|
| 753 |
+
" if max_g < QUALITY_FLOOR:\n",
|
| 754 |
+
" print(f\" SKIP SFT: no episode beat quality_floor={QUALITY_FLOOR:.2f}\")\n",
|
| 755 |
+
" continue\n",
|
| 756 |
"\n",
|
| 757 |
+
" rets = np.array([p[\"reward\"] for p in all_pairs], dtype=float)\n",
|
| 758 |
+
" adv = (rets - rets.mean()) / (rets.std() + 1e-6)\n",
|
| 759 |
+
" filtered = [p for p, a in zip(all_pairs, adv) if a > 0.0]\n",
|
| 760 |
+
" if not filtered:\n",
|
| 761 |
+
" print(\" SKIP SFT: zero positive-advantage samples\")\n",
|
| 762 |
+
" continue\n",
|
| 763 |
+
" print(f\" Kept {len(filtered)}/{len(all_pairs)} positive-advantage samples\")\n",
|
| 764 |
"\n",
|
| 765 |
" dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
|
| 766 |
"\n",
|
| 767 |
" # SFT training (real gradient updates)\n",
|
| 768 |
" sft_config = SFTConfig(\n",
|
| 769 |
" output_dir=f\"./checkpoints/round_{round_idx}\",\n",
|
| 770 |
+
" num_train_epochs=1,\n",
|
| 771 |
+
" per_device_train_batch_size=2,\n",
|
| 772 |
+
" gradient_accumulation_steps=4,\n",
|
| 773 |
+
" learning_rate=5e-6,\n",
|
| 774 |
+
" warmup_steps=5,\n",
|
| 775 |
" logging_steps=1,\n",
|
| 776 |
" save_strategy=\"no\",\n",
|
| 777 |
+
" max_length=2048,\n",
|
| 778 |
" bf16=True,\n",
|
| 779 |
" report_to=\"none\",\n",
|
| 780 |
" )\n",
|
|
|
|
| 823 |
"\n",
|
| 824 |
"peft_model.eval()\n",
|
| 825 |
"t0 = time.time()\n",
|
| 826 |
+
"results = run_llm_episodes_batched(peft_model, tokenizer, [(t, 42) for t in TASKS], verbose=True, eval=True)\n",
|
| 827 |
"after_results = {r[\"task\"]: r for r in results}\n",
|
| 828 |
"\n",
|
| 829 |
"print(\"\\n\" + \"=\" * 60)\n",
|