vaibhav12332112312 commited on
Commit
a1be3fe
·
1 Parent(s): 6c01076
Files changed (1) hide show
  1. training/train_grpo.ipynb +73 -61
training/train_grpo.ipynb CHANGED
@@ -23,27 +23,41 @@
23
  },
24
  {
25
  "cell_type": "code",
 
26
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
27
  "source": [
28
  "# Cell 1: Install dependencies\n",
29
  "!pip install -q torch torchvision torchaudio\n",
30
- "!pip install -q transformers>=4.40.0 accelerate peft>=0.10.0 trl>=0.8.0 datasets bitsandbytes\n",
31
  "!pip install -q matplotlib pandas\n",
32
  "!pip install -q pydantic httpx\n",
33
  "!pip install -q \"openenv-core[core]>=0.2.2\""
34
- ],
35
- "execution_count": null,
36
- "outputs": []
37
  },
38
  {
39
  "cell_type": "code",
 
40
  "metadata": {},
 
41
  "source": [
42
  "# Cell 2: Clone the repo and set up paths\n",
43
  "import os, sys\n",
44
  "REPO_DIR = \"/content/viral-posts-env\"\n",
 
45
  "if not os.path.exists(REPO_DIR):\n",
46
- " !git clone https://github.com/VaibhavKhandare/viral-posts-env.git {REPO_DIR}\n",
47
  "os.chdir(REPO_DIR)\n",
48
  "sys.path.insert(0, REPO_DIR)\n",
49
  "\n",
@@ -51,13 +65,13 @@
51
  "os.makedirs(PLOTS_DIR, exist_ok=True)\n",
52
  "print(f\"Working dir: {os.getcwd()}\")\n",
53
  "print(f\"Plots dir: {PLOTS_DIR}\")"
54
- ],
55
- "execution_count": null,
56
- "outputs": []
57
  },
58
  {
59
  "cell_type": "code",
 
60
  "metadata": {},
 
61
  "source": [
62
  "# Cell 3: Imports\n",
63
  "import json, random, time, textwrap, copy\n",
@@ -84,9 +98,7 @@
84
  "\n",
85
  "print(f\"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}\")\n",
86
  "print(f\"Tags: {len(TAG_POOL)}, Topics: {len(ALL_TOPICS)}, Horizon: {TASK_HORIZON} days\")"
87
- ],
88
- "execution_count": null,
89
- "outputs": []
90
  },
91
  {
92
  "cell_type": "markdown",
@@ -99,7 +111,9 @@
99
  },
100
  {
101
  "cell_type": "code",
 
102
  "metadata": {},
 
103
  "source": [
104
  "# Cell 4: Define heuristic agents + episode runner\n",
105
  "_rng = random.Random(42)\n",
@@ -176,13 +190,13 @@
176
  " \"rewards\": rewards, \"energies\": energies}\n",
177
  "\n",
178
  "print(\"Agents and episode runner defined.\")"
179
- ],
180
- "execution_count": null,
181
- "outputs": []
182
  },
183
  {
184
  "cell_type": "code",
 
185
  "metadata": {},
 
186
  "source": [
187
  "# Cell 5: Run baselines\n",
188
  "print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
@@ -205,13 +219,13 @@
205
  "for name in BASELINE_AGENTS:\n",
206
  " scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
207
  " print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
208
- ],
209
- "execution_count": null,
210
- "outputs": []
211
  },
212
  {
213
  "cell_type": "code",
 
214
  "metadata": {},
 
215
  "source": [
216
  "# Cell 6: Baseline plots\n",
217
  "fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
@@ -229,9 +243,7 @@
229
  "fig.tight_layout()\n",
230
  "fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
231
  "plt.show()"
232
- ],
233
- "execution_count": null,
234
- "outputs": []
235
  },
236
  {
237
  "cell_type": "markdown",
@@ -244,7 +256,9 @@
244
  },
245
  {
246
  "cell_type": "code",
 
247
  "metadata": {},
 
248
  "source": [
249
  "# Cell 7: Load model\n",
250
  "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
@@ -268,13 +282,13 @@
268
  "model.eval()\n",
269
  "print(f\"Model loaded. Device: {model.device}\")\n",
270
  "print(f\"Memory: {torch.cuda.memory_allocated()/1e9:.1f} GB\")"
271
- ],
272
- "execution_count": null,
273
- "outputs": []
274
  },
275
  {
276
  "cell_type": "code",
 
277
  "metadata": {},
 
278
  "source": [
279
  "# Cell 8: LLM agent functions\n",
280
  "SYSTEM_PROMPT = textwrap.dedent(\"\"\"\\\n",
@@ -390,9 +404,7 @@
390
  " \"burned_out\": obs.creator_energy <= 0}\n",
391
  "\n",
392
  "print(\"LLM agent functions defined.\")"
393
- ],
394
- "execution_count": null,
395
- "outputs": []
396
  },
397
  {
398
  "cell_type": "markdown",
@@ -405,7 +417,9 @@
405
  },
406
  {
407
  "cell_type": "code",
 
408
  "metadata": {},
 
409
  "source": [
410
  "# Cell 9: Run untrained model\n",
411
  "print(\"Running UNTRAINED base model on all tasks...\")\n",
@@ -422,9 +436,7 @@
422
  "print(\"BEFORE TRAINING:\")\n",
423
  "for t in TASKS:\n",
424
  " print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
425
- ],
426
- "execution_count": null,
427
- "outputs": []
428
  },
429
  {
430
  "cell_type": "markdown",
@@ -443,7 +455,9 @@
443
  },
444
  {
445
  "cell_type": "code",
 
446
  "metadata": {},
 
447
  "source": [
448
  "# Cell 10: Attach LoRA adapter\n",
449
  "from peft import LoraConfig, get_peft_model, TaskType\n",
@@ -458,13 +472,13 @@
458
  "model.enable_input_require_grads()\n",
459
  "peft_model = get_peft_model(model, lora_config)\n",
460
  "peft_model.print_trainable_parameters()"
461
- ],
462
- "execution_count": null,
463
- "outputs": []
464
  },
465
  {
466
  "cell_type": "code",
 
467
  "metadata": {},
 
468
  "source": [
469
  "# Cell 11: Training loop\n",
470
  "from trl import SFTTrainer, SFTConfig\n",
@@ -529,14 +543,14 @@
529
  " warmup_steps=5,\n",
530
  " logging_steps=5,\n",
531
  " save_strategy=\"no\",\n",
532
- " max_seq_length=1024,\n",
533
  " fp16=True,\n",
534
  " report_to=\"none\",\n",
535
  " )\n",
536
  "\n",
537
  " peft_model.train()\n",
538
  " trainer = SFTTrainer(\n",
539
- " model=peft_model, tokenizer=tokenizer,\n",
540
  " train_dataset=dataset, args=sft_config,\n",
541
  " )\n",
542
  " train_result = trainer.train()\n",
@@ -555,9 +569,7 @@
555
  "elapsed = time.time() - t_start\n",
556
  "print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n",
557
  "print(pd.DataFrame(training_log).to_string(index=False))"
558
- ],
559
- "execution_count": null,
560
- "outputs": []
561
  },
562
  {
563
  "cell_type": "markdown",
@@ -570,7 +582,9 @@
570
  },
571
  {
572
  "cell_type": "code",
 
573
  "metadata": {},
 
574
  "source": [
575
  "# Cell 12: Run trained model\n",
576
  "print(\"Running TRAINED model on all tasks...\")\n",
@@ -588,9 +602,7 @@
588
  "print(\"AFTER TRAINING:\")\n",
589
  "for t in TASKS:\n",
590
  " print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
591
- ],
592
- "execution_count": null,
593
- "outputs": []
594
  },
595
  {
596
  "cell_type": "markdown",
@@ -601,7 +613,9 @@
601
  },
602
  {
603
  "cell_type": "code",
 
604
  "metadata": {},
 
605
  "source": [
606
  "# Cell 13: Training curves\n",
607
  "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
@@ -623,13 +637,13 @@
623
  "fig.tight_layout()\n",
624
  "fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
625
  "plt.show()"
626
- ],
627
- "execution_count": null,
628
- "outputs": []
629
  },
630
  {
631
  "cell_type": "code",
 
632
  "metadata": {},
 
633
  "source": [
634
  "# Cell 14: Before vs After\n",
635
  "task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
@@ -659,13 +673,13 @@
659
  "fig.tight_layout()\n",
660
  "fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
661
  "plt.show()"
662
- ],
663
- "execution_count": null,
664
- "outputs": []
665
  },
666
  {
667
  "cell_type": "code",
 
668
  "metadata": {},
 
669
  "source": [
670
  "# Cell 15: Trajectory comparison\n",
671
  "fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
@@ -689,9 +703,7 @@
689
  "fig.tight_layout()\n",
690
  "fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
691
  "plt.show()"
692
- ],
693
- "execution_count": null,
694
- "outputs": []
695
  },
696
  {
697
  "cell_type": "markdown",
@@ -702,7 +714,9 @@
702
  },
703
  {
704
  "cell_type": "code",
 
705
  "metadata": {},
 
706
  "source": [
707
  "# Cell 16: Final summary\n",
708
  "print(\"=\" * 67)\n",
@@ -739,13 +753,13 @@
739
  "\n",
740
  "print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
741
  "print(\"All results are from real LoRA weight updates on real environment runs.\")"
742
- ],
743
- "execution_count": null,
744
- "outputs": []
745
  },
746
  {
747
  "cell_type": "code",
 
748
  "metadata": {},
 
749
  "source": [
750
  "# Cell 17: Save adapter\n",
751
  "save_path = \"./viraltest_trained_adapter\"\n",
@@ -753,24 +767,22 @@
753
  "tokenizer.save_pretrained(save_path)\n",
754
  "print(f\"LoRA adapter saved to {save_path}\")\n",
755
  "print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
756
- ],
757
- "execution_count": null,
758
- "outputs": []
759
  }
760
  ],
761
  "metadata": {
 
 
762
  "kernelspec": {
763
- "display_name": "Python 3",
764
  "language": "python",
765
  "name": "python3"
766
  },
767
  "language_info": {
768
  "name": "python",
769
- "version": "3.10.0"
770
- },
771
- "accelerator": "GPU",
772
- "gpuClass": "standard"
773
  },
774
  "nbformat": 4,
775
  "nbformat_minor": 4
776
- }
 
23
  },
24
  {
25
  "cell_type": "code",
26
+ "execution_count": null,
27
  "metadata": {},
28
+ "outputs": [
29
+ {
30
+ "ename": "",
31
+ "evalue": "",
32
+ "output_type": "error",
33
+ "traceback": [
34
+ "\u001b[1;31mRunning cells with '.venv (Python 3.13.1)' requires the ipykernel package.\n",
35
+ "\u001b[1;31mInstall 'ipykernel' into the Python environment. \n",
36
+ "\u001b[1;31mCommand: '/Users/vaibhavkhandare/Projects/mernstack/openenv-course/viraltest/.venv/bin/python -m pip install ipykernel -U --force-reinstall'"
37
+ ]
38
+ }
39
+ ],
40
  "source": [
41
  "# Cell 1: Install dependencies\n",
42
  "!pip install -q torch torchvision torchaudio\n",
43
+ "!pip install -q transformers>=4.45.0 accelerate peft>=0.10.0 trl>=0.20.0 datasets bitsandbytes\n",
44
  "!pip install -q matplotlib pandas\n",
45
  "!pip install -q pydantic httpx\n",
46
  "!pip install -q \"openenv-core[core]>=0.2.2\""
47
+ ]
 
 
48
  },
49
  {
50
  "cell_type": "code",
51
+ "execution_count": null,
52
  "metadata": {},
53
+ "outputs": [],
54
  "source": [
55
  "# Cell 2: Clone the repo and set up paths\n",
56
  "import os, sys\n",
57
  "REPO_DIR = \"/content/viral-posts-env\"\n",
58
+ "REPO_BRANCH = \"hack1\"\n",
59
  "if not os.path.exists(REPO_DIR):\n",
60
+ " !git clone --branch {REPO_BRANCH} --depth 1 https://github.com/VaibhavKhandare/viral-posts-env.git {REPO_DIR}\n",
61
  "os.chdir(REPO_DIR)\n",
62
  "sys.path.insert(0, REPO_DIR)\n",
63
  "\n",
 
65
  "os.makedirs(PLOTS_DIR, exist_ok=True)\n",
66
  "print(f\"Working dir: {os.getcwd()}\")\n",
67
  "print(f\"Plots dir: {PLOTS_DIR}\")"
68
+ ]
 
 
69
  },
70
  {
71
  "cell_type": "code",
72
+ "execution_count": null,
73
  "metadata": {},
74
+ "outputs": [],
75
  "source": [
76
  "# Cell 3: Imports\n",
77
  "import json, random, time, textwrap, copy\n",
 
98
  "\n",
99
  "print(f\"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}\")\n",
100
  "print(f\"Tags: {len(TAG_POOL)}, Topics: {len(ALL_TOPICS)}, Horizon: {TASK_HORIZON} days\")"
101
+ ]
 
 
102
  },
103
  {
104
  "cell_type": "markdown",
 
111
  },
112
  {
113
  "cell_type": "code",
114
+ "execution_count": null,
115
  "metadata": {},
116
+ "outputs": [],
117
  "source": [
118
  "# Cell 4: Define heuristic agents + episode runner\n",
119
  "_rng = random.Random(42)\n",
 
190
  " \"rewards\": rewards, \"energies\": energies}\n",
191
  "\n",
192
  "print(\"Agents and episode runner defined.\")"
193
+ ]
 
 
194
  },
195
  {
196
  "cell_type": "code",
197
+ "execution_count": null,
198
  "metadata": {},
199
+ "outputs": [],
200
  "source": [
201
  "# Cell 5: Run baselines\n",
202
  "print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
 
219
  "for name in BASELINE_AGENTS:\n",
220
  " scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
221
  " print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
222
+ ]
 
 
223
  },
224
  {
225
  "cell_type": "code",
226
+ "execution_count": null,
227
  "metadata": {},
228
+ "outputs": [],
229
  "source": [
230
  "# Cell 6: Baseline plots\n",
231
  "fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
 
243
  "fig.tight_layout()\n",
244
  "fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
245
  "plt.show()"
246
+ ]
 
 
247
  },
248
  {
249
  "cell_type": "markdown",
 
256
  },
257
  {
258
  "cell_type": "code",
259
+ "execution_count": null,
260
  "metadata": {},
261
+ "outputs": [],
262
  "source": [
263
  "# Cell 7: Load model\n",
264
  "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
 
282
  "model.eval()\n",
283
  "print(f\"Model loaded. Device: {model.device}\")\n",
284
  "print(f\"Memory: {torch.cuda.memory_allocated()/1e9:.1f} GB\")"
285
+ ]
 
 
286
  },
287
  {
288
  "cell_type": "code",
289
+ "execution_count": null,
290
  "metadata": {},
291
+ "outputs": [],
292
  "source": [
293
  "# Cell 8: LLM agent functions\n",
294
  "SYSTEM_PROMPT = textwrap.dedent(\"\"\"\\\n",
 
404
  " \"burned_out\": obs.creator_energy <= 0}\n",
405
  "\n",
406
  "print(\"LLM agent functions defined.\")"
407
+ ]
 
 
408
  },
409
  {
410
  "cell_type": "markdown",
 
417
  },
418
  {
419
  "cell_type": "code",
420
+ "execution_count": null,
421
  "metadata": {},
422
+ "outputs": [],
423
  "source": [
424
  "# Cell 9: Run untrained model\n",
425
  "print(\"Running UNTRAINED base model on all tasks...\")\n",
 
436
  "print(\"BEFORE TRAINING:\")\n",
437
  "for t in TASKS:\n",
438
  " print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
439
+ ]
 
 
440
  },
441
  {
442
  "cell_type": "markdown",
 
455
  },
456
  {
457
  "cell_type": "code",
458
+ "execution_count": null,
459
  "metadata": {},
460
+ "outputs": [],
461
  "source": [
462
  "# Cell 10: Attach LoRA adapter\n",
463
  "from peft import LoraConfig, get_peft_model, TaskType\n",
 
472
  "model.enable_input_require_grads()\n",
473
  "peft_model = get_peft_model(model, lora_config)\n",
474
  "peft_model.print_trainable_parameters()"
475
+ ]
 
 
476
  },
477
  {
478
  "cell_type": "code",
479
+ "execution_count": null,
480
  "metadata": {},
481
+ "outputs": [],
482
  "source": [
483
  "# Cell 11: Training loop\n",
484
  "from trl import SFTTrainer, SFTConfig\n",
 
543
  " warmup_steps=5,\n",
544
  " logging_steps=5,\n",
545
  " save_strategy=\"no\",\n",
546
+ " max_length=1024,\n",
547
  " fp16=True,\n",
548
  " report_to=\"none\",\n",
549
  " )\n",
550
  "\n",
551
  " peft_model.train()\n",
552
  " trainer = SFTTrainer(\n",
553
+ " model=peft_model, processing_class=tokenizer,\n",
554
  " train_dataset=dataset, args=sft_config,\n",
555
  " )\n",
556
  " train_result = trainer.train()\n",
 
569
  "elapsed = time.time() - t_start\n",
570
  "print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n",
571
  "print(pd.DataFrame(training_log).to_string(index=False))"
572
+ ]
 
 
573
  },
574
  {
575
  "cell_type": "markdown",
 
582
  },
583
  {
584
  "cell_type": "code",
585
+ "execution_count": null,
586
  "metadata": {},
587
+ "outputs": [],
588
  "source": [
589
  "# Cell 12: Run trained model\n",
590
  "print(\"Running TRAINED model on all tasks...\")\n",
 
602
  "print(\"AFTER TRAINING:\")\n",
603
  "for t in TASKS:\n",
604
  " print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
605
+ ]
 
 
606
  },
607
  {
608
  "cell_type": "markdown",
 
613
  },
614
  {
615
  "cell_type": "code",
616
+ "execution_count": null,
617
  "metadata": {},
618
+ "outputs": [],
619
  "source": [
620
  "# Cell 13: Training curves\n",
621
  "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
 
637
  "fig.tight_layout()\n",
638
  "fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
639
  "plt.show()"
640
+ ]
 
 
641
  },
642
  {
643
  "cell_type": "code",
644
+ "execution_count": null,
645
  "metadata": {},
646
+ "outputs": [],
647
  "source": [
648
  "# Cell 14: Before vs After\n",
649
  "task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
 
673
  "fig.tight_layout()\n",
674
  "fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
675
  "plt.show()"
676
+ ]
 
 
677
  },
678
  {
679
  "cell_type": "code",
680
+ "execution_count": null,
681
  "metadata": {},
682
+ "outputs": [],
683
  "source": [
684
  "# Cell 15: Trajectory comparison\n",
685
  "fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
 
703
  "fig.tight_layout()\n",
704
  "fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
705
  "plt.show()"
706
+ ]
 
 
707
  },
708
  {
709
  "cell_type": "markdown",
 
714
  },
715
  {
716
  "cell_type": "code",
717
+ "execution_count": null,
718
  "metadata": {},
719
+ "outputs": [],
720
  "source": [
721
  "# Cell 16: Final summary\n",
722
  "print(\"=\" * 67)\n",
 
753
  "\n",
754
  "print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
755
  "print(\"All results are from real LoRA weight updates on real environment runs.\")"
756
+ ]
 
 
757
  },
758
  {
759
  "cell_type": "code",
760
+ "execution_count": null,
761
  "metadata": {},
762
+ "outputs": [],
763
  "source": [
764
  "# Cell 17: Save adapter\n",
765
  "save_path = \"./viraltest_trained_adapter\"\n",
 
767
  "tokenizer.save_pretrained(save_path)\n",
768
  "print(f\"LoRA adapter saved to {save_path}\")\n",
769
  "print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
770
+ ]
 
 
771
  }
772
  ],
773
  "metadata": {
774
+ "accelerator": "GPU",
775
+ "gpuClass": "standard",
776
  "kernelspec": {
777
+ "display_name": ".venv",
778
  "language": "python",
779
  "name": "python3"
780
  },
781
  "language_info": {
782
  "name": "python",
783
+ "version": "3.13.1"
784
+ }
 
 
785
  },
786
  "nbformat": 4,
787
  "nbformat_minor": 4
788
+ }