nvan15 commited on
Commit
6bb0065
·
verified ·
1 Parent(s): b03742a

Batch upload part 2

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. examples/subject.ipynb +216 -0
  3. examples/subject_1024.ipynb +216 -0
  4. nl_tasks/README.md +45 -0
  5. nl_tasks/config/commonsense.yaml +44 -0
  6. nl_tasks/config/commonsense_opt.yaml +32 -0
  7. nl_tasks/config/glue.yaml +48 -0
  8. nl_tasks/config/math395.yaml +46 -0
  9. nl_tasks/data/MATH_test.jsonl +0 -0
  10. nl_tasks/data/MetaMathQA-40K/MetaMathQA-40K.json +3 -0
  11. nl_tasks/data/MetaMathQA/MetaMathQA-395K.json +3 -0
  12. nl_tasks/data/gsm8k_test.jsonl +0 -0
  13. nl_tasks/environment.yaml +55 -0
  14. nl_tasks/exps/run_ex01/trainer_state.json +914 -0
  15. nl_tasks/repro.sh +87 -0
  16. nl_tasks/rpeft/__init__.py +43 -0
  17. nl_tasks/rpeft/mapping.py +273 -0
  18. nl_tasks/rpeft/peft_model.py +922 -0
  19. nl_tasks/rpeft/rotation/__init__.py +3 -0
  20. nl_tasks/rpeft/rotation/layer.py +412 -0
  21. nl_tasks/rpeft/rotation/layer_test.py +296 -0
  22. nl_tasks/rpeft/rotation/model.py +392 -0
  23. nl_tasks/rpeft/rotation/rotation_config.py +89 -0
  24. nl_tasks/rpeft/utils/__init__.py +29 -0
  25. nl_tasks/rpeft/utils/adapters_utils.py +19 -0
  26. nl_tasks/rpeft/utils/config.py +220 -0
  27. nl_tasks/rpeft/utils/other.py +160 -0
  28. nl_tasks/rpeft/utils/save_and_load.py +166 -0
  29. nl_tasks/scripts/.nfs80e7f26e00566c630000664a +117 -0
  30. nl_tasks/scripts/.nfs80e7f26e0132942e00006649 +341 -0
  31. nl_tasks/scripts/copy train_cms_reasoning.sh +133 -0
  32. nl_tasks/scripts/down_math_train.sh +14 -0
  33. nl_tasks/scripts/inference.sh +14 -0
  34. nl_tasks/scripts/merge.sh +137 -0
  35. nl_tasks/scripts/merge_100k.sh +100 -0
  36. nl_tasks/scripts/merge_math.sh +31 -0
  37. nl_tasks/scripts/peft_merge.sh +60 -0
  38. nl_tasks/scripts/train_100math.sh +184 -0
  39. nl_tasks/scripts/train_cms_reasoning.sh +260 -0
  40. nl_tasks/scripts/train_initn40k.sh +341 -0
  41. nl_tasks/scripts/train_math.sh +162 -0
  42. nl_tasks/setup.py +28 -0
  43. nl_tasks/src/bb.ipynb +0 -0
  44. nl_tasks/src/cc.ipynb +0 -0
  45. nl_tasks/src/config.py +183 -0
  46. nl_tasks/src/ft_mathQ.py +702 -0
  47. nl_tasks/src/ft_mathR.py +689 -0
  48. nl_tasks/src/merge.py +82 -0
  49. nl_tasks/src/peft_merge.py +82 -0
  50. nl_tasks/src/testLlama.py +702 -0
.gitattributes CHANGED
@@ -51,3 +51,5 @@ assets/ominicontrol_art/DistractedBoyfriend.webp filter=lfs diff=lfs merge=lfs -
51
  assets/ominicontrol_art/PulpFiction.jpg filter=lfs diff=lfs merge=lfs -text
52
  assets/ominicontrol_art/breakingbad.jpg filter=lfs diff=lfs merge=lfs -text
53
  assets/ominicontrol_art/oiiai.png filter=lfs diff=lfs merge=lfs -text
 
 
 
51
  assets/ominicontrol_art/PulpFiction.jpg filter=lfs diff=lfs merge=lfs -text
52
  assets/ominicontrol_art/breakingbad.jpg filter=lfs diff=lfs merge=lfs -text
53
  assets/ominicontrol_art/oiiai.png filter=lfs diff=lfs merge=lfs -text
54
+ nl_tasks/data/MetaMathQA-40K/MetaMathQA-40K.json filter=lfs diff=lfs merge=lfs -text
55
+ nl_tasks/data/MetaMathQA/MetaMathQA-395K.json filter=lfs diff=lfs merge=lfs -text
examples/subject.ipynb ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "\n",
11
+ "os.chdir(\"..\")"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "import torch\n",
21
+ "from diffusers.pipelines import FluxPipeline\n",
22
+ "from PIL import Image\n",
23
+ "\n",
24
+ "from omini.pipeline.flux_omini import Condition, generate, seed_everything"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "pipe = FluxPipeline.from_pretrained(\n",
34
+ " \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n",
35
+ ")\n",
36
+ "pipe = pipe.to(\"cuda\")\n",
37
+ "pipe.load_lora_weights(\n",
38
+ " \"Yuanshi/OminiControl\",\n",
39
+ " weight_name=f\"omini/subject_512.safetensors\",\n",
40
+ " adapter_name=\"subject\",\n",
41
+ ")"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n",
51
+ "\n",
52
+ "# For this model, the position_delta is (0, 32).\n",
53
+ "# For more details of position_delta, please refer to:\n",
54
+ "# https://github.com/Yuanshi9815/OminiControl/issues/89#issuecomment-2827080344\n",
55
+ "condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
56
+ "\n",
57
+ "prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n",
58
+ "\n",
59
+ "\n",
60
+ "seed_everything(0)\n",
61
+ "\n",
62
+ "result_img = generate(\n",
63
+ " pipe,\n",
64
+ " prompt=prompt,\n",
65
+ " conditions=[condition],\n",
66
+ " num_inference_steps=8,\n",
67
+ " height=512,\n",
68
+ " width=512,\n",
69
+ ").images[0]\n",
70
+ "\n",
71
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
72
+ "concat_image.paste(image, (0, 0))\n",
73
+ "concat_image.paste(result_img, (512, 0))\n",
74
+ "concat_image"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": null,
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n",
84
+ "\n",
85
+ "condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
86
+ "\n",
87
+ "prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n",
88
+ "\n",
89
+ "\n",
90
+ "seed_everything()\n",
91
+ "\n",
92
+ "result_img = generate(\n",
93
+ " pipe,\n",
94
+ " prompt=prompt,\n",
95
+ " conditions=[condition],\n",
96
+ " num_inference_steps=8,\n",
97
+ " height=512,\n",
98
+ " width=512,\n",
99
+ ").images[0]\n",
100
+ "\n",
101
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
102
+ "concat_image.paste(condition.condition, (0, 0))\n",
103
+ "concat_image.paste(result_img, (512, 0))\n",
104
+ "concat_image"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n",
114
+ "\n",
115
+ "condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
116
+ "\n",
117
+ "prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n",
118
+ "\n",
119
+ "seed_everything()\n",
120
+ "\n",
121
+ "result_img = generate(\n",
122
+ " pipe,\n",
123
+ " prompt=prompt,\n",
124
+ " conditions=[condition],\n",
125
+ " num_inference_steps=8,\n",
126
+ " height=512,\n",
127
+ " width=512,\n",
128
+ ").images[0]\n",
129
+ "\n",
130
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
131
+ "concat_image.paste(condition.condition, (0, 0))\n",
132
+ "concat_image.paste(result_img, (512, 0))\n",
133
+ "concat_image"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n",
143
+ "\n",
144
+ "condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
145
+ "\n",
146
+ "prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n",
147
+ "\n",
148
+ "seed_everything()\n",
149
+ "\n",
150
+ "result_img = generate(\n",
151
+ " pipe,\n",
152
+ " prompt=prompt,\n",
153
+ " conditions=[condition],\n",
154
+ " num_inference_steps=8,\n",
155
+ " height=512,\n",
156
+ " width=512,\n",
157
+ ").images[0]\n",
158
+ "\n",
159
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
160
+ "concat_image.paste(condition.condition, (0, 0))\n",
161
+ "concat_image.paste(result_img, (512, 0))\n",
162
+ "concat_image"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "code",
167
+ "execution_count": null,
168
+ "metadata": {},
169
+ "outputs": [],
170
+ "source": [
171
+ "image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n",
172
+ "\n",
173
+ "condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
174
+ "\n",
175
+ "prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n",
176
+ "\n",
177
+ "seed_everything()\n",
178
+ "\n",
179
+ "result_img = generate(\n",
180
+ " pipe,\n",
181
+ " prompt=prompt,\n",
182
+ " conditions=[condition],\n",
183
+ " num_inference_steps=8,\n",
184
+ " height=512,\n",
185
+ " width=512,\n",
186
+ ").images[0]\n",
187
+ "\n",
188
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
189
+ "concat_image.paste(condition.condition, (0, 0))\n",
190
+ "concat_image.paste(result_img, (512, 0))\n",
191
+ "concat_image"
192
+ ]
193
+ }
194
+ ],
195
+ "metadata": {
196
+ "kernelspec": {
197
+ "display_name": "base",
198
+ "language": "python",
199
+ "name": "python3"
200
+ },
201
+ "language_info": {
202
+ "codemirror_mode": {
203
+ "name": "ipython",
204
+ "version": 3
205
+ },
206
+ "file_extension": ".py",
207
+ "mimetype": "text/x-python",
208
+ "name": "python",
209
+ "nbconvert_exporter": "python",
210
+ "pygments_lexer": "ipython3",
211
+ "version": "3.12.3"
212
+ }
213
+ },
214
+ "nbformat": 4,
215
+ "nbformat_minor": 2
216
+ }
examples/subject_1024.ipynb ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "\n",
11
+ "os.chdir(\"..\")"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "import torch\n",
21
+ "from diffusers.pipelines import FluxPipeline\n",
22
+ "from PIL import Image\n",
23
+ "\n",
24
+ "from omini.pipeline.flux_omini import Condition, generate, seed_everything"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "pipe = FluxPipeline.from_pretrained(\n",
34
+ " \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n",
35
+ ")\n",
36
+ "pipe = pipe.to(\"cuda\")\n",
37
+ "pipe.load_lora_weights(\n",
38
+ " \"Yuanshi/OminiControl\",\n",
39
+ " weight_name=f\"omini/subject_1024_beta.safetensors\",\n",
40
+ " adapter_name=\"subject\",\n",
41
+ ")"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n",
51
+ "\n",
52
+ "# For this model, the position_delta is (0, -32).\n",
53
+ "# For more details of position_delta, please refer to:\n",
54
+ "# https://github.com/Yuanshi9815/OminiControl/issues/89#issuecomment-2827080344\n",
55
+ "condition = Condition(image, \"subject\", position_delta=(0, -32))\n",
56
+ "\n",
57
+ "prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n",
58
+ "\n",
59
+ "\n",
60
+ "seed_everything(0)\n",
61
+ "\n",
62
+ "result_img = generate(\n",
63
+ " pipe,\n",
64
+ " prompt=prompt,\n",
65
+ " conditions=[condition],\n",
66
+ " num_inference_steps=8,\n",
67
+ " height=1024,\n",
68
+ " width=1024,\n",
69
+ ").images[0]\n",
70
+ "\n",
71
+ "concat_image = Image.new(\"RGB\", (1024 + 512, 1024))\n",
72
+ "concat_image.paste(image, (0, 0))\n",
73
+ "concat_image.paste(result_img, (512, 0))\n",
74
+ "concat_image"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": null,
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n",
84
+ "\n",
85
+ "condition = Condition(image, \"subject\", position_delta=(0, -32))\n",
86
+ "\n",
87
+ "prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n",
88
+ "\n",
89
+ "\n",
90
+ "seed_everything(0)\n",
91
+ "\n",
92
+ "result_img = generate(\n",
93
+ " pipe,\n",
94
+ " prompt=prompt,\n",
95
+ " conditions=[condition],\n",
96
+ " num_inference_steps=8,\n",
97
+ " height=1024,\n",
98
+ " width=1024,\n",
99
+ ").images[0]\n",
100
+ "\n",
101
+ "concat_image = Image.new(\"RGB\", (1024 + 512, 1024))\n",
102
+ "concat_image.paste(image, (0, 0))\n",
103
+ "concat_image.paste(result_img, (512, 0))\n",
104
+ "concat_image"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n",
114
+ "\n",
115
+ "condition = Condition(image, \"subject\", position_delta=(0, -32))\n",
116
+ "\n",
117
+ "prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n",
118
+ "\n",
119
+ "seed_everything()\n",
120
+ "\n",
121
+ "result_img = generate(\n",
122
+ " pipe,\n",
123
+ " prompt=prompt,\n",
124
+ " conditions=[condition],\n",
125
+ " num_inference_steps=8,\n",
126
+ " height=1024,\n",
127
+ " width=1024,\n",
128
+ ").images[0]\n",
129
+ "\n",
130
+ "concat_image = Image.new(\"RGB\", (1024 + 512, 1024))\n",
131
+ "concat_image.paste(image, (0, 0))\n",
132
+ "concat_image.paste(result_img, (512, 0))\n",
133
+ "concat_image"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n",
143
+ "\n",
144
+ "condition = Condition(image, \"subject\", position_delta=(0, -32))\n",
145
+ "\n",
146
+ "prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n",
147
+ "\n",
148
+ "seed_everything(0)\n",
149
+ "\n",
150
+ "result_img = generate(\n",
151
+ " pipe,\n",
152
+ " prompt=prompt,\n",
153
+ " conditions=[condition],\n",
154
+ " num_inference_steps=8,\n",
155
+ " height=1024,\n",
156
+ " width=1024,\n",
157
+ ").images[0]\n",
158
+ "\n",
159
+ "concat_image = Image.new(\"RGB\", (1024 + 512, 1024))\n",
160
+ "concat_image.paste(image, (0, 0))\n",
161
+ "concat_image.paste(result_img, (512, 0))\n",
162
+ "concat_image"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "code",
167
+ "execution_count": null,
168
+ "metadata": {},
169
+ "outputs": [],
170
+ "source": [
171
+ "image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n",
172
+ "\n",
173
+ "condition = Condition(image, \"subject\", position_delta=(0, -32))\n",
174
+ "\n",
175
+ "prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n",
176
+ "\n",
177
+ "seed_everything()\n",
178
+ "\n",
179
+ "result_img = generate(\n",
180
+ " pipe,\n",
181
+ " prompt=prompt,\n",
182
+ " conditions=[condition],\n",
183
+ " num_inference_steps=8,\n",
184
+ " height=1024,\n",
185
+ " width=1024,\n",
186
+ ").images[0]\n",
187
+ "\n",
188
+ "concat_image = Image.new(\"RGB\", (1024 + 512, 1024))\n",
189
+ "concat_image.paste(image, (0, 0))\n",
190
+ "concat_image.paste(result_img, (512, 0))\n",
191
+ "concat_image"
192
+ ]
193
+ }
194
+ ],
195
+ "metadata": {
196
+ "kernelspec": {
197
+ "display_name": "base",
198
+ "language": "python",
199
+ "name": "python3"
200
+ },
201
+ "language_info": {
202
+ "codemirror_mode": {
203
+ "name": "ipython",
204
+ "version": 3
205
+ },
206
+ "file_extension": ".py",
207
+ "mimetype": "text/x-python",
208
+ "name": "python",
209
+ "nbconvert_exporter": "python",
210
+ "pygments_lexer": "ipython3",
211
+ "version": "3.12.3"
212
+ }
213
+ },
214
+ "nbformat": 4,
215
+ "nbformat_minor": 2
216
+ }
nl_tasks/README.md ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Dynamo
2
+ export TORCH_COMPILE_DISABLE=1
3
+ unset TORCH_COMPILE_DISABLE
4
+ echo $TORCH_COMPILE_DISABLE
5
+
6
+
7
+ Untracked files:
8
+ (use "git add <file>..." to include in what will be committed)
9
+ nl_tasks/README.md
10
+ nl_tasks/config/commonsense_opt.yaml
11
+ nl_tasks/config/glue.yaml
12
+ nl_tasks/config/math395.yaml
13
+ nl_tasks/data/MetaMathQA/
14
+ nl_tasks/data/gsm8k_infer.py
15
+ nl_tasks/exp100/
16
+ nl_tasks/exp395/
17
+ nl_tasks/exp_init/
18
+ nl_tasks/expsBOFT/
19
+ nl_tasks/expsOFT/
20
+ nl_tasks/repro.sh
21
+ nl_tasks/run_all/
22
+ nl_tasks/run_exps/
23
+ nl_tasks/scripts/copy train_cms_reasoning.sh
24
+ nl_tasks/scripts/inference.sh
25
+ nl_tasks/scripts/merge_100k.sh
26
+ nl_tasks/scripts/merge_math.sh
27
+ nl_tasks/scripts/peft_merge.sh
28
+ nl_tasks/scripts/train_100math.sh
29
+ nl_tasks/scripts/train_initn40k.sh
30
+ nl_tasks/scripts/train_math.sh
31
+ nl_tasks/src/ft_mathQ.py
32
+ nl_tasks/src/peft_merge.py
33
+ nl_tasks/src/testLlama.py
34
+ nl_tasks/testLlama.sh
35
+ nl_tasks/training_metrics_bs8.json
36
+ nlu/1Mgrid/
37
+ nlu/_scripts_/
38
+ nlu/glue22_exp/
39
+ nlu/glue_exp00/
40
+ nlu/glue_test/
41
+ nlu/seeds/
42
+ nlu/src/test.py
43
+ nlu/test.sh
44
+ nlu/training_metrics_bs8.json
45
+
nl_tasks/config/commonsense.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ model:
3
+ model_name: meta-llama/Llama-2-7b-hf #facebook/opt-125m #meta-llama/Llama-2-7b-hf #"openai-community/gpt2" #EleutherAI/pythia-160m #Qwen/Qwen2.5-0.5B
4
+ # model_name: facebook/opt-125m
5
+ # adapter_path: "./run_all/exnr15/ft2"
6
+ # adapter_path: './exp_init/run_ex01/ft2'
7
+ data_collator_mode: 'dynamic'
8
+
9
+ rotation_adapter_config:
10
+ r: 4
11
+ num_rotations: 4
12
+ # target_modules: ["q_proj", "v_proj", "v_proj", "o_proj", "gate_proj","up_proj","down_proj"]
13
+ target_modules: ["q_proj", "v_proj",]
14
+
15
+ data:
16
+ dataset_name: 'math'
17
+ split_ratio: 0.025
18
+ # path: "./data/gsm8k_test.jsonl"
19
+ path: ./data/MetaMathQA-40K/MetaMathQA-40K.json
20
+ # path: ./data/MetaMathQA/MetaMathQA-395K.json
21
+ dataset_split: train
22
+ # dataset_field: [question, answer]
23
+ dataset_field: [query, response]
24
+
25
+ trainer_args:
26
+ learning_rate: 2e-4
27
+ # eval_strategy: steps
28
+ per_device_train_batch_size: 32
29
+ per_device_eval_batch_size: 64
30
+ # accumulate_grad_batches: 1
31
+ # save_steps: 1000
32
+ gradient_checkpointing: False # (Turn off for faster training)
33
+ output_dir: "./run_exps"
34
+ # save_path: "runs"
35
+
36
+ report_to: wandb
37
+ logging_steps: 25
38
+ # eval_steps: 100
39
+ #dataloader_num_workers: 4
40
+
41
+ num_train_epochs: 2.0
42
+ # max_steps: -1
43
+
44
+ # device: 'cuda'
nl_tasks/config/commonsense_opt.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ model:
3
+ model_name: facebook/opt-125m #"openai-community/gpt2" #EleutherAI/pythia-160m #Qwen/Qwen2.5-0.5B
4
+ # adapter_path: "./nl_tasks/run_exps/ft2"
5
+ data_collator_mode: 'dynamic'
6
+
7
+ rotation_adapter_config:
8
+ r: 4
9
+ num_rotations: 2
10
+ target_modules: ["q_proj", "v_proj"]
11
+
12
+ data:
13
+ dataset_name: 'math'
14
+ # path: "./nl_tasks/data/MetaMathQA-40K" #MetaMathQA-40K.json"
15
+ path: "./data/gsm8k_test.jsonl"
16
+ dataset_split: train[:200]
17
+ dataset_field: [question, answer]
18
+
19
+ trainer_args:
20
+ learning_rate: 2e-4
21
+ # accumulate_grad_batches: 1
22
+ # dataloader_workers: 5
23
+ # save_interval: 1000
24
+ # sample_interval: 100
25
+ # max_steps: -1
26
+ gradient_checkpointing: False # (Turn off for faster training)
27
+ output_dir: "./run_exps"
28
+ # save_path: "runs"
29
+
30
+ max_steps: 40
31
+
32
+ # device: 'cuda'
nl_tasks/config/glue.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ model:
3
+ model_name: microsoft/deberta-v3-base #facebook/opt-125m #meta-llama/Llama-2-7b-hf #"openai-community/gpt2" #EleutherAI/pythia-160m #Qwen/Qwen2.5-0.5B
4
+ # model_name: facebook/opt-125m
5
+ # adapter_path: "./run_all/exnr15/ft2"
6
+ # adapter_path: './run_all/run_exps9/ft2'
7
+ # adapter_path: "./exp395/run_ex07/ft2"
8
+ data_collator_mode: 'dynamic'
9
+
10
+ rotation_adapter_config:
11
+ r: 5
12
+ num_rotations: 1
13
+ # target_modules: ["q_proj", "v_proj", "v_proj", "o_proj", "gate_proj","up_proj","down_proj"]
14
+ target_modules: ["query_proj", "value_proj", "key_proj", 'attention.output.dense', 'intermediate.dense', 'output.dense']
15
+ task_type: "SEQ_CLS"
16
+
17
+ data:
18
+ dataset_name: 'math'
19
+ split_ratio: 0.00258
20
+ # path: "./data/gsm8k_test.jsonl"
21
+ # path: ./data/MetaMathQA-40K/MetaMathQA-40K.json
22
+ path: ./data/MetaMathQA/MetaMathQA-395K.json
23
+ dataset_split: train[:100000]
24
+ # dataset_field: [question, answer]
25
+ dataset_field: [query, response]
26
+
27
+ trainer_args:
28
+ learning_rate: 2e-4
29
+ warmup_ratio: 0.01
30
+ # eval_strategy: steps
31
+ per_device_train_batch_size: 32
32
+ per_device_eval_batch_size: 64
33
+ # accumulate_grad_batches: 1
34
+ # save_steps: 1000
35
+ gradient_checkpointing: False # (Turn off for faster training)
36
+ output_dir: "./exps/run_exps"
37
+ # save_path: "runs"
38
+
39
+ # report_to: wandb
40
+ logging_steps: 200
41
+ # eval_steps: 1000
42
+ #dataloader_num_workers: 4
43
+
44
+ num_train_epochs: 2.0
45
+ # max_steps: 21
46
+ # torch_compile: False
47
+
48
+ # device: 'cuda'
nl_tasks/config/math395.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ model:
3
+ model_name: meta-llama/Llama-2-7b-hf #facebook/opt-125m #meta-llama/Llama-2-7b-hf #"openai-community/gpt2" #EleutherAI/pythia-160m #Qwen/Qwen2.5-0.5B
4
+ # model_name: facebook/opt-125m
5
+ # adapter_path: "./run_all/exnr15/ft2"
6
+ # adapter_path: './run_all/run_exps9/ft2'
7
+ # adapter_path: "./exp395/run_ex07/ft2"
8
+ data_collator_mode: 'dynamic'
9
+
10
+ rotation_adapter_config:
11
+ r: 16
12
+ num_rotations: 1
13
+ # target_modules: ["q_proj", "v_proj", "v_proj", "o_proj", "gate_proj","up_proj","down_proj"]
14
+ target_modules: ["q_proj", "v_proj",]
15
+
16
+ data:
17
+ dataset_name: 'math'
18
+ split_ratio: 0.00258
19
+ # path: "./data/gsm8k_test.jsonl"
20
+ # path: ./data/MetaMathQA-40K/MetaMathQA-40K.json
21
+ path: ./data/MetaMathQA/MetaMathQA-395K.json
22
+ dataset_split: train[:100000]
23
+ # dataset_field: [question, answer]
24
+ dataset_field: [query, response]
25
+
26
+ trainer_args:
27
+ learning_rate: 2e-4
28
+ warmup_ratio: 0.01
29
+ # eval_strategy: steps
30
+ per_device_train_batch_size: 32
31
+ per_device_eval_batch_size: 64
32
+ # accumulate_grad_batches: 1
33
+ # save_steps: 1000
34
+ gradient_checkpointing: False # (Turn off for faster training)
35
+ output_dir: "./exps/run_exps"
36
+ # save_path: "runs"
37
+
38
+ report_to: wandb
39
+ logging_steps: 200
40
+ # eval_steps: 1000
41
+ #dataloader_num_workers: 4
42
+
43
+ num_train_epochs: 2.0
44
+ # max_steps: 21
45
+
46
+ # device: 'cuda'
nl_tasks/data/MATH_test.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
nl_tasks/data/MetaMathQA-40K/MetaMathQA-40K.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c884f10e8aa1229a6e73a6bba2c9134ee0c7b7de92a02a7b8c9459085a59e117
3
+ size 31076207
nl_tasks/data/MetaMathQA/MetaMathQA-395K.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb39a5d8c05c042ece92eae37dfd5ea414a5979df2bf3ad3b86411bef8205725
3
+ size 395626321
nl_tasks/data/gsm8k_test.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
nl_tasks/environment.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # environment.yml
2
+ name: allm # The name of the environment
3
+
4
+ channels: # The conda channels to search for packages
5
+ # - pytorch
6
+ - conda-forge
7
+ # - dnachun
8
+ # - anaconda
9
+ channel_priority: strict
10
+
11
+ dependencies:
12
+ # Packages to install with conda
13
+ - python=3.11.3
14
+ #- pytorch-cuda=12.4
15
+ #- pytorch
16
+ # - numpy
17
+ - transformers>=4.55
18
+ - einops
19
+ - jaxtyping
20
+
21
+ - tensorboard
22
+ - omegaconf
23
+ - accelerate
24
+ - peft
25
+
26
+ - wandb
27
+
28
+ - scipy
29
+ - pandas
30
+ - matplotlib
31
+ - scikit-image
32
+ - scikit-learn
33
+ - joblib
34
+ - pillow
35
+ - datasets
36
+ ## NO - huggingface_hub
37
+ - tqdm
38
+ - nltk
39
+ - future
40
+
41
+
42
+ - defusedxml
43
+ - ipdb
44
+ - torchinfo
45
+
46
+
47
+
48
+ - timm
49
+ - graphviz #anaconda::graphviz
50
+ - dnachun::torchviz
51
+ - pip:
52
+ # - draccus
53
+ - fraction
54
+ - vllm
55
+
nl_tasks/exps/run_ex01/trainer_state.json ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": null,
3
+ "best_metric": null,
4
+ "best_model_checkpoint": null,
5
+ "epoch": 2.0,
6
+ "eval_steps": 100,
7
+ "global_step": 2438,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "epoch": 0.020508613617719443,
14
+ "grad_norm": 0.06690643727779388,
15
+ "learning_rate": 4.918032786885246e-06,
16
+ "loss": 0.751,
17
+ "step": 25
18
+ },
19
+ {
20
+ "epoch": 0.04101722723543889,
21
+ "grad_norm": 0.23132523894309998,
22
+ "learning_rate": 1.0040983606557377e-05,
23
+ "loss": 0.7344,
24
+ "step": 50
25
+ },
26
+ {
27
+ "epoch": 0.06152584085315833,
28
+ "grad_norm": 0.2116735428571701,
29
+ "learning_rate": 1.5163934426229509e-05,
30
+ "loss": 0.6404,
31
+ "step": 75
32
+ },
33
+ {
34
+ "epoch": 0.08203445447087777,
35
+ "grad_norm": 0.1672675907611847,
36
+ "learning_rate": 2.028688524590164e-05,
37
+ "loss": 0.486,
38
+ "step": 100
39
+ },
40
+ {
41
+ "epoch": 0.08203445447087777,
42
+ "eval_loss": 0.4427674114704132,
43
+ "eval_runtime": 19.6288,
44
+ "eval_samples_per_second": 50.945,
45
+ "eval_steps_per_second": 0.815,
46
+ "step": 100
47
+ },
48
+ {
49
+ "epoch": 0.10254306808859721,
50
+ "grad_norm": 0.16888108849525452,
51
+ "learning_rate": 2.540983606557377e-05,
52
+ "loss": 0.4407,
53
+ "step": 125
54
+ },
55
+ {
56
+ "epoch": 0.12305168170631665,
57
+ "grad_norm": 0.17033565044403076,
58
+ "learning_rate": 3.05327868852459e-05,
59
+ "loss": 0.4031,
60
+ "step": 150
61
+ },
62
+ {
63
+ "epoch": 0.1435602953240361,
64
+ "grad_norm": 0.194916769862175,
65
+ "learning_rate": 3.5655737704918037e-05,
66
+ "loss": 0.3787,
67
+ "step": 175
68
+ },
69
+ {
70
+ "epoch": 0.16406890894175555,
71
+ "grad_norm": 0.29443657398223877,
72
+ "learning_rate": 4.077868852459016e-05,
73
+ "loss": 0.3769,
74
+ "step": 200
75
+ },
76
+ {
77
+ "epoch": 0.16406890894175555,
78
+ "eval_loss": 0.3537510335445404,
79
+ "eval_runtime": 19.4681,
80
+ "eval_samples_per_second": 51.366,
81
+ "eval_steps_per_second": 0.822,
82
+ "step": 200
83
+ },
84
+ {
85
+ "epoch": 0.184577522559475,
86
+ "grad_norm": 0.2323056161403656,
87
+ "learning_rate": 4.59016393442623e-05,
88
+ "loss": 0.3658,
89
+ "step": 225
90
+ },
91
+ {
92
+ "epoch": 0.20508613617719443,
93
+ "grad_norm": 0.239767923951149,
94
+ "learning_rate": 4.999935927058032e-05,
95
+ "loss": 0.3402,
96
+ "step": 250
97
+ },
98
+ {
99
+ "epoch": 0.22559474979491387,
100
+ "grad_norm": 0.21633633971214294,
101
+ "learning_rate": 4.997693718919013e-05,
102
+ "loss": 0.3342,
103
+ "step": 275
104
+ },
105
+ {
106
+ "epoch": 0.2461033634126333,
107
+ "grad_norm": 0.23770427703857422,
108
+ "learning_rate": 4.992251147198466e-05,
109
+ "loss": 0.3366,
110
+ "step": 300
111
+ },
112
+ {
113
+ "epoch": 0.2461033634126333,
114
+ "eval_loss": 0.3253124952316284,
115
+ "eval_runtime": 19.4632,
116
+ "eval_samples_per_second": 51.379,
117
+ "eval_steps_per_second": 0.822,
118
+ "step": 300
119
+ },
120
+ {
121
+ "epoch": 0.2666119770303528,
122
+ "grad_norm": 0.2381717562675476,
123
+ "learning_rate": 4.98361518561306e-05,
124
+ "loss": 0.3379,
125
+ "step": 325
126
+ },
127
+ {
128
+ "epoch": 0.2871205906480722,
129
+ "grad_norm": 0.2646787166595459,
130
+ "learning_rate": 4.971796899657632e-05,
131
+ "loss": 0.329,
132
+ "step": 350
133
+ },
134
+ {
135
+ "epoch": 0.30762920426579166,
136
+ "grad_norm": 0.3449915647506714,
137
+ "learning_rate": 4.9568114324266624e-05,
138
+ "loss": 0.3406,
139
+ "step": 375
140
+ },
141
+ {
142
+ "epoch": 0.3281378178835111,
143
+ "grad_norm": 0.3168472647666931,
144
+ "learning_rate": 4.938677985211011e-05,
145
+ "loss": 0.3227,
146
+ "step": 400
147
+ },
148
+ {
149
+ "epoch": 0.3281378178835111,
150
+ "eval_loss": 0.3124091625213623,
151
+ "eval_runtime": 19.4632,
152
+ "eval_samples_per_second": 51.379,
153
+ "eval_steps_per_second": 0.822,
154
+ "step": 400
155
+ },
156
+ {
157
+ "epoch": 0.34864643150123054,
158
+ "grad_norm": 0.24941708147525787,
159
+ "learning_rate": 4.9174197928947795e-05,
160
+ "loss": 0.3301,
161
+ "step": 425
162
+ },
163
+ {
164
+ "epoch": 0.36915504511895,
165
+ "grad_norm": 0.25401854515075684,
166
+ "learning_rate": 4.8930640941838104e-05,
167
+ "loss": 0.3267,
168
+ "step": 450
169
+ },
170
+ {
171
+ "epoch": 0.3896636587366694,
172
+ "grad_norm": 0.2830125689506531,
173
+ "learning_rate": 4.86564209670399e-05,
174
+ "loss": 0.3159,
175
+ "step": 475
176
+ },
177
+ {
178
+ "epoch": 0.41017227235438886,
179
+ "grad_norm": 0.26643845438957214,
180
+ "learning_rate": 4.835188937014059e-05,
181
+ "loss": 0.3006,
182
+ "step": 500
183
+ },
184
+ {
185
+ "epoch": 0.41017227235438886,
186
+ "eval_loss": 0.30365315079689026,
187
+ "eval_runtime": 19.5051,
188
+ "eval_samples_per_second": 51.269,
189
+ "eval_steps_per_second": 0.82,
190
+ "step": 500
191
+ },
192
+ {
193
+ "epoch": 0.4306808859721083,
194
+ "grad_norm": 0.2782154977321625,
195
+ "learning_rate": 4.801743635584168e-05,
196
+ "loss": 0.3015,
197
+ "step": 525
198
+ },
199
+ {
200
+ "epoch": 0.45118949958982774,
201
+ "grad_norm": 0.23164071142673492,
202
+ "learning_rate": 4.7653490467978906e-05,
203
+ "loss": 0.3095,
204
+ "step": 550
205
+ },
206
+ {
207
+ "epoch": 0.4716981132075472,
208
+ "grad_norm": 0.282122403383255,
209
+ "learning_rate": 4.726051804041709e-05,
210
+ "loss": 0.3049,
211
+ "step": 575
212
+ },
213
+ {
214
+ "epoch": 0.4922067268252666,
215
+ "grad_norm": 0.2644311487674713,
216
+ "learning_rate": 4.683902259952387e-05,
217
+ "loss": 0.3213,
218
+ "step": 600
219
+ },
220
+ {
221
+ "epoch": 0.4922067268252666,
222
+ "eval_loss": 0.29857054352760315,
223
+ "eval_runtime": 19.4613,
224
+ "eval_samples_per_second": 51.384,
225
+ "eval_steps_per_second": 0.822,
226
+ "step": 600
227
+ },
228
+ {
229
+ "epoch": 0.5127153404429861,
230
+ "grad_norm": 0.2529069185256958,
231
+ "learning_rate": 4.638954421898746e-05,
232
+ "loss": 0.3001,
233
+ "step": 625
234
+ },
235
+ {
236
+ "epoch": 0.5332239540607056,
237
+ "grad_norm": 0.2736290395259857,
238
+ "learning_rate": 4.5912658827805425e-05,
239
+ "loss": 0.2916,
240
+ "step": 650
241
+ },
242
+ {
243
+ "epoch": 0.5537325676784249,
244
+ "grad_norm": 0.2555182874202728,
245
+ "learning_rate": 4.5408977472331005e-05,
246
+ "loss": 0.3052,
247
+ "step": 675
248
+ },
249
+ {
250
+ "epoch": 0.5742411812961444,
251
+ "grad_norm": 0.30381742119789124,
252
+ "learning_rate": 4.48791455333227e-05,
253
+ "loss": 0.3036,
254
+ "step": 700
255
+ },
256
+ {
257
+ "epoch": 0.5742411812961444,
258
+ "eval_loss": 0.292624831199646,
259
+ "eval_runtime": 19.4487,
260
+ "eval_samples_per_second": 51.417,
261
+ "eval_steps_per_second": 0.823,
262
+ "step": 700
263
+ },
264
+ {
265
+ "epoch": 0.5947497949138638,
266
+ "grad_norm": 0.26833590865135193,
267
+ "learning_rate": 4.432384189900008e-05,
268
+ "loss": 0.3023,
269
+ "step": 725
270
+ },
271
+ {
272
+ "epoch": 0.6152584085315833,
273
+ "grad_norm": 0.263784259557724,
274
+ "learning_rate": 4.3743778095165764e-05,
275
+ "loss": 0.3016,
276
+ "step": 750
277
+ },
278
+ {
279
+ "epoch": 0.6357670221493027,
280
+ "grad_norm": 0.30153971910476685,
281
+ "learning_rate": 4.313969737350775e-05,
282
+ "loss": 0.2984,
283
+ "step": 775
284
+ },
285
+ {
286
+ "epoch": 0.6562756357670222,
287
+ "grad_norm": 0.27196648716926575,
288
+ "learning_rate": 4.251237375925071e-05,
289
+ "loss": 0.3034,
290
+ "step": 800
291
+ },
292
+ {
293
+ "epoch": 0.6562756357670222,
294
+ "eval_loss": 0.28956037759780884,
295
+ "eval_runtime": 19.4528,
296
+ "eval_samples_per_second": 51.407,
297
+ "eval_steps_per_second": 0.823,
298
+ "step": 800
299
+ },
300
+ {
301
+ "epoch": 0.6767842493847416,
302
+ "grad_norm": 0.25102221965789795,
303
+ "learning_rate": 4.186261105937612e-05,
304
+ "loss": 0.2961,
305
+ "step": 825
306
+ },
307
+ {
308
+ "epoch": 0.6972928630024611,
309
+ "grad_norm": 0.30812397599220276,
310
+ "learning_rate": 4.1191241832682364e-05,
311
+ "loss": 0.2995,
312
+ "step": 850
313
+ },
314
+ {
315
+ "epoch": 0.7178014766201805,
316
+ "grad_norm": 0.2563766837120056,
317
+ "learning_rate": 4.049912632300421e-05,
318
+ "loss": 0.2878,
319
+ "step": 875
320
+ },
321
+ {
322
+ "epoch": 0.7383100902379,
323
+ "grad_norm": 0.3218742311000824,
324
+ "learning_rate": 3.978715135695881e-05,
325
+ "loss": 0.296,
326
+ "step": 900
327
+ },
328
+ {
329
+ "epoch": 0.7383100902379,
330
+ "eval_loss": 0.2854253053665161,
331
+ "eval_runtime": 19.4687,
332
+ "eval_samples_per_second": 51.365,
333
+ "eval_steps_per_second": 0.822,
334
+ "step": 900
335
+ },
336
+ {
337
+ "epoch": 0.7588187038556193,
338
+ "grad_norm": 0.27102115750312805,
339
+ "learning_rate": 3.905622920763031e-05,
340
+ "loss": 0.2944,
341
+ "step": 925
342
+ },
343
+ {
344
+ "epoch": 0.7793273174733388,
345
+ "grad_norm": 0.276564359664917,
346
+ "learning_rate": 3.83072964256494e-05,
347
+ "loss": 0.2847,
348
+ "step": 950
349
+ },
350
+ {
351
+ "epoch": 0.7998359310910582,
352
+ "grad_norm": 0.3700260519981384,
353
+ "learning_rate": 3.7541312639165145e-05,
354
+ "loss": 0.2877,
355
+ "step": 975
356
+ },
357
+ {
358
+ "epoch": 0.8203445447087777,
359
+ "grad_norm": 0.2763277292251587,
360
+ "learning_rate": 3.675925932424715e-05,
361
+ "loss": 0.2819,
362
+ "step": 1000
363
+ },
364
+ {
365
+ "epoch": 0.8203445447087777,
366
+ "eval_loss": 0.2817630469799042,
367
+ "eval_runtime": 19.4695,
368
+ "eval_samples_per_second": 51.362,
369
+ "eval_steps_per_second": 0.822,
370
+ "step": 1000
371
+ },
372
+ {
373
+ "epoch": 0.8408531583264971,
374
+ "grad_norm": 0.3116626739501953,
375
+ "learning_rate": 3.596213854729328e-05,
376
+ "loss": 0.2855,
377
+ "step": 1025
378
+ },
379
+ {
380
+ "epoch": 0.8613617719442166,
381
+ "grad_norm": 0.27529171109199524,
382
+ "learning_rate": 3.515097168105444e-05,
383
+ "loss": 0.2847,
384
+ "step": 1050
385
+ },
386
+ {
387
+ "epoch": 0.881870385561936,
388
+ "grad_norm": 0.3135339021682739,
389
+ "learning_rate": 3.4326798095921656e-05,
390
+ "loss": 0.2875,
391
+ "step": 1075
392
+ },
393
+ {
394
+ "epoch": 0.9023789991796555,
395
+ "grad_norm": 0.33244284987449646,
396
+ "learning_rate": 3.349067382815217e-05,
397
+ "loss": 0.2885,
398
+ "step": 1100
399
+ },
400
+ {
401
+ "epoch": 0.9023789991796555,
402
+ "eval_loss": 0.27884459495544434,
403
+ "eval_runtime": 19.4401,
404
+ "eval_samples_per_second": 51.44,
405
+ "eval_steps_per_second": 0.823,
406
+ "step": 1100
407
+ },
408
+ {
409
+ "epoch": 0.9228876127973749,
410
+ "grad_norm": 0.2761545479297638,
411
+ "learning_rate": 3.264367022674124e-05,
412
+ "loss": 0.2857,
413
+ "step": 1125
414
+ },
415
+ {
416
+ "epoch": 0.9433962264150944,
417
+ "grad_norm": 0.24953879415988922,
418
+ "learning_rate": 3.1786872580673214e-05,
419
+ "loss": 0.2832,
420
+ "step": 1150
421
+ },
422
+ {
423
+ "epoch": 0.9639048400328137,
424
+ "grad_norm": 0.32240793108940125,
425
+ "learning_rate": 3.09213787283109e-05,
426
+ "loss": 0.2909,
427
+ "step": 1175
428
+ },
429
+ {
430
+ "epoch": 0.9844134536505332,
431
+ "grad_norm": 0.2800053358078003,
432
+ "learning_rate": 3.004829765070516e-05,
433
+ "loss": 0.297,
434
+ "step": 1200
435
+ },
436
+ {
437
+ "epoch": 0.9844134536505332,
438
+ "eval_loss": 0.2768399119377136,
439
+ "eval_runtime": 19.4646,
440
+ "eval_samples_per_second": 51.375,
441
+ "eval_steps_per_second": 0.822,
442
+ "step": 1200
443
+ },
444
+ {
445
+ "epoch": 1.0049220672682526,
446
+ "grad_norm": 0.27479347586631775,
447
+ "learning_rate": 2.916874805062701e-05,
448
+ "loss": 0.2817,
449
+ "step": 1225
450
+ },
451
+ {
452
+ "epoch": 1.0254306808859721,
453
+ "grad_norm": 0.28057900071144104,
454
+ "learning_rate": 2.828385691914301e-05,
455
+ "loss": 0.2821,
456
+ "step": 1250
457
+ },
458
+ {
459
+ "epoch": 1.0459392945036916,
460
+ "grad_norm": 0.2822401225566864,
461
+ "learning_rate": 2.7394758091570664e-05,
462
+ "loss": 0.2831,
463
+ "step": 1275
464
+ },
465
+ {
466
+ "epoch": 1.066447908121411,
467
+ "grad_norm": 0.2816685438156128,
468
+ "learning_rate": 2.6502590794664073e-05,
469
+ "loss": 0.285,
470
+ "step": 1300
471
+ },
472
+ {
473
+ "epoch": 1.066447908121411,
474
+ "eval_loss": 0.2752681076526642,
475
+ "eval_runtime": 19.4336,
476
+ "eval_samples_per_second": 51.457,
477
+ "eval_steps_per_second": 0.823,
478
+ "step": 1300
479
+ },
480
+ {
481
+ "epoch": 1.0869565217391304,
482
+ "grad_norm": 0.2750077545642853,
483
+ "learning_rate": 2.560849818689141e-05,
484
+ "loss": 0.2829,
485
+ "step": 1325
486
+ },
487
+ {
488
+ "epoch": 1.1074651353568499,
489
+ "grad_norm": 0.29034796357154846,
490
+ "learning_rate": 2.471362589367452e-05,
491
+ "loss": 0.2727,
492
+ "step": 1350
493
+ },
494
+ {
495
+ "epoch": 1.1279737489745694,
496
+ "grad_norm": 0.33246326446533203,
497
+ "learning_rate": 2.3819120539467663e-05,
498
+ "loss": 0.2806,
499
+ "step": 1375
500
+ },
501
+ {
502
+ "epoch": 1.1484823625922886,
503
+ "grad_norm": 0.26562532782554626,
504
+ "learning_rate": 2.2926128278556052e-05,
505
+ "loss": 0.2666,
506
+ "step": 1400
507
+ },
508
+ {
509
+ "epoch": 1.1484823625922886,
510
+ "eval_loss": 0.27384424209594727,
511
+ "eval_runtime": 19.4305,
512
+ "eval_samples_per_second": 51.466,
513
+ "eval_steps_per_second": 0.823,
514
+ "step": 1400
515
+ },
516
+ {
517
+ "epoch": 1.1689909762100081,
518
+ "grad_norm": 0.3040076494216919,
519
+ "learning_rate": 2.2035793326456883e-05,
520
+ "loss": 0.2799,
521
+ "step": 1425
522
+ },
523
+ {
524
+ "epoch": 1.1894995898277276,
525
+ "grad_norm": 0.38336554169654846,
526
+ "learning_rate": 2.1149256493804576e-05,
527
+ "loss": 0.2858,
528
+ "step": 1450
529
+ },
530
+ {
531
+ "epoch": 1.2100082034454471,
532
+ "grad_norm": 0.28841930627822876,
533
+ "learning_rate": 2.0267653724598747e-05,
534
+ "loss": 0.2777,
535
+ "step": 1475
536
+ },
537
+ {
538
+ "epoch": 1.2305168170631666,
539
+ "grad_norm": 0.3314826786518097,
540
+ "learning_rate": 1.9392114640687985e-05,
541
+ "loss": 0.2884,
542
+ "step": 1500
543
+ },
544
+ {
545
+ "epoch": 1.2305168170631666,
546
+ "eval_loss": 0.27183249592781067,
547
+ "eval_runtime": 19.4549,
548
+ "eval_samples_per_second": 51.401,
549
+ "eval_steps_per_second": 0.822,
550
+ "step": 1500
551
+ },
552
+ {
553
+ "epoch": 1.251025430680886,
554
+ "grad_norm": 0.35494905710220337,
555
+ "learning_rate": 1.8523761094354304e-05,
556
+ "loss": 0.2833,
557
+ "step": 1525
558
+ },
559
+ {
560
+ "epoch": 1.2715340442986054,
561
+ "grad_norm": 0.30966347455978394,
562
+ "learning_rate": 1.7663705730853012e-05,
563
+ "loss": 0.276,
564
+ "step": 1550
565
+ },
566
+ {
567
+ "epoch": 1.2920426579163249,
568
+ "grad_norm": 0.2722432613372803,
569
+ "learning_rate": 1.6813050562749778e-05,
570
+ "loss": 0.2775,
571
+ "step": 1575
572
+ },
573
+ {
574
+ "epoch": 1.3125512715340442,
575
+ "grad_norm": 0.2977288067340851,
576
+ "learning_rate": 1.5972885557881666e-05,
577
+ "loss": 0.269,
578
+ "step": 1600
579
+ },
580
+ {
581
+ "epoch": 1.3125512715340442,
582
+ "eval_loss": 0.27014729380607605,
583
+ "eval_runtime": 19.4752,
584
+ "eval_samples_per_second": 51.347,
585
+ "eval_steps_per_second": 0.822,
586
+ "step": 1600
587
+ },
588
+ {
589
+ "epoch": 1.3330598851517639,
590
+ "grad_norm": 0.2877044081687927,
591
+ "learning_rate": 1.5144287242751378e-05,
592
+ "loss": 0.2727,
593
+ "step": 1625
594
+ },
595
+ {
596
+ "epoch": 1.3535684987694832,
597
+ "grad_norm": 0.3222461938858032,
598
+ "learning_rate": 1.4328317323144284e-05,
599
+ "loss": 0.2742,
600
+ "step": 1650
601
+ },
602
+ {
603
+ "epoch": 1.3740771123872026,
604
+ "grad_norm": 0.3473854959011078,
605
+ "learning_rate": 1.3526021323735626e-05,
606
+ "loss": 0.2724,
607
+ "step": 1675
608
+ },
609
+ {
610
+ "epoch": 1.3945857260049221,
611
+ "grad_norm": 0.3066307306289673,
612
+ "learning_rate": 1.2738427248431028e-05,
613
+ "loss": 0.2696,
614
+ "step": 1700
615
+ },
616
+ {
617
+ "epoch": 1.3945857260049221,
618
+ "eval_loss": 0.26889586448669434,
619
+ "eval_runtime": 19.4809,
620
+ "eval_samples_per_second": 51.332,
621
+ "eval_steps_per_second": 0.821,
622
+ "step": 1700
623
+ },
624
+ {
625
+ "epoch": 1.4150943396226414,
626
+ "grad_norm": 0.2990802228450775,
627
+ "learning_rate": 1.1966544263156865e-05,
628
+ "loss": 0.269,
629
+ "step": 1725
630
+ },
631
+ {
632
+ "epoch": 1.435602953240361,
633
+ "grad_norm": 0.29301875829696655,
634
+ "learning_rate": 1.1211361402788226e-05,
635
+ "loss": 0.2681,
636
+ "step": 1750
637
+ },
638
+ {
639
+ "epoch": 1.4561115668580804,
640
+ "grad_norm": 0.32831257581710815,
641
+ "learning_rate": 1.047384630387131e-05,
642
+ "loss": 0.2771,
643
+ "step": 1775
644
+ },
645
+ {
646
+ "epoch": 1.4766201804758,
647
+ "grad_norm": 0.3254742920398712,
648
+ "learning_rate": 9.75494396476423e-06,
649
+ "loss": 0.2644,
650
+ "step": 1800
651
+ },
652
+ {
653
+ "epoch": 1.4766201804758,
654
+ "eval_loss": 0.2682200074195862,
655
+ "eval_runtime": 19.4711,
656
+ "eval_samples_per_second": 51.358,
657
+ "eval_steps_per_second": 0.822,
658
+ "step": 1800
659
+ },
660
+ {
661
+ "epoch": 1.4971287940935194,
662
+ "grad_norm": 0.29671165347099304,
663
+ "learning_rate": 9.05557553478459e-06,
664
+ "loss": 0.2658,
665
+ "step": 1825
666
+ },
667
+ {
668
+ "epoch": 1.5176374077112387,
669
+ "grad_norm": 0.29359179735183716,
670
+ "learning_rate": 8.376637133915558e-06,
671
+ "loss": 0.2676,
672
+ "step": 1850
673
+ },
674
+ {
675
+ "epoch": 1.5381460213289582,
676
+ "grad_norm": 0.3013196587562561,
677
+ "learning_rate": 7.718998704582739e-06,
678
+ "loss": 0.2708,
679
+ "step": 1875
680
+ },
681
+ {
682
+ "epoch": 1.5586546349466777,
683
+ "grad_norm": 0.33183860778808594,
684
+ "learning_rate": 7.0835028969730185e-06,
685
+ "loss": 0.2727,
686
+ "step": 1900
687
+ },
688
+ {
689
+ "epoch": 1.5586546349466777,
690
+ "eval_loss": 0.26749491691589355,
691
+ "eval_runtime": 19.4495,
692
+ "eval_samples_per_second": 51.415,
693
+ "eval_steps_per_second": 0.823,
694
+ "step": 1900
695
+ },
696
+ {
697
+ "epoch": 1.579163248564397,
698
+ "grad_norm": 0.3297445774078369,
699
+ "learning_rate": 6.470963989323764e-06,
700
+ "loss": 0.2792,
701
+ "step": 1925
702
+ },
703
+ {
704
+ "epoch": 1.5996718621821167,
705
+ "grad_norm": 0.33076536655426025,
706
+ "learning_rate": 5.8821668445656924e-06,
707
+ "loss": 0.2749,
708
+ "step": 1950
709
+ },
710
+ {
711
+ "epoch": 1.620180475799836,
712
+ "grad_norm": 0.290751188993454,
713
+ "learning_rate": 5.317865904656497e-06,
714
+ "loss": 0.2653,
715
+ "step": 1975
716
+ },
717
+ {
718
+ "epoch": 1.6406890894175554,
719
+ "grad_norm": 0.30254530906677246,
720
+ "learning_rate": 4.778784223893601e-06,
721
+ "loss": 0.2767,
722
+ "step": 2000
723
+ },
724
+ {
725
+ "epoch": 1.6406890894175554,
726
+ "eval_loss": 0.26705402135849,
727
+ "eval_runtime": 19.451,
728
+ "eval_samples_per_second": 51.411,
729
+ "eval_steps_per_second": 0.823,
730
+ "step": 2000
731
+ },
732
+ {
733
+ "epoch": 1.661197703035275,
734
+ "grad_norm": 0.30978214740753174,
735
+ "learning_rate": 4.265612542444827e-06,
736
+ "loss": 0.2661,
737
+ "step": 2025
738
+ },
739
+ {
740
+ "epoch": 1.6817063166529942,
741
+ "grad_norm": 0.3454751968383789,
742
+ "learning_rate": 3.7790084012840453e-06,
743
+ "loss": 0.2717,
744
+ "step": 2050
745
+ },
746
+ {
747
+ "epoch": 1.7022149302707137,
748
+ "grad_norm": 0.3721317946910858,
749
+ "learning_rate": 3.319595299665873e-06,
750
+ "loss": 0.2767,
751
+ "step": 2075
752
+ },
753
+ {
754
+ "epoch": 1.7227235438884332,
755
+ "grad_norm": 0.33564668893814087,
756
+ "learning_rate": 2.8879618962189326e-06,
757
+ "loss": 0.2614,
758
+ "step": 2100
759
+ },
760
+ {
761
+ "epoch": 1.7227235438884332,
762
+ "eval_loss": 0.26674506068229675,
763
+ "eval_runtime": 19.4651,
764
+ "eval_samples_per_second": 51.374,
765
+ "eval_steps_per_second": 0.822,
766
+ "step": 2100
767
+ },
768
+ {
769
+ "epoch": 1.7432321575061525,
770
+ "grad_norm": 0.31175196170806885,
771
+ "learning_rate": 2.484661254681381e-06,
772
+ "loss": 0.2689,
773
+ "step": 2125
774
+ },
775
+ {
776
+ "epoch": 1.7637407711238722,
777
+ "grad_norm": 0.27521854639053345,
778
+ "learning_rate": 2.110210135245147e-06,
779
+ "loss": 0.266,
780
+ "step": 2150
781
+ },
782
+ {
783
+ "epoch": 1.7842493847415914,
784
+ "grad_norm": 0.301544189453125,
785
+ "learning_rate": 1.765088332416917e-06,
786
+ "loss": 0.2746,
787
+ "step": 2175
788
+ },
789
+ {
790
+ "epoch": 1.804757998359311,
791
+ "grad_norm": 0.29790034890174866,
792
+ "learning_rate": 1.4497380602442378e-06,
793
+ "loss": 0.2671,
794
+ "step": 2200
795
+ },
796
+ {
797
+ "epoch": 1.804757998359311,
798
+ "eval_loss": 0.2663249373435974,
799
+ "eval_runtime": 19.4746,
800
+ "eval_samples_per_second": 51.349,
801
+ "eval_steps_per_second": 0.822,
802
+ "step": 2200
803
+ },
804
+ {
805
+ "epoch": 1.8252666119770304,
806
+ "grad_norm": 0.3353317677974701,
807
+ "learning_rate": 1.1645633856944977e-06,
808
+ "loss": 0.2693,
809
+ "step": 2225
810
+ },
811
+ {
812
+ "epoch": 1.8457752255947497,
813
+ "grad_norm": 0.2950851023197174,
814
+ "learning_rate": 9.099297109128407e-07,
815
+ "loss": 0.2704,
816
+ "step": 2250
817
+ },
818
+ {
819
+ "epoch": 1.8662838392124692,
820
+ "grad_norm": 0.3311655521392822,
821
+ "learning_rate": 6.861633050223526e-07,
822
+ "loss": 0.2757,
823
+ "step": 2275
824
+ },
825
+ {
826
+ "epoch": 1.8867924528301887,
827
+ "grad_norm": 0.3037591576576233,
828
+ "learning_rate": 4.935508860664601e-07,
829
+ "loss": 0.2684,
830
+ "step": 2300
831
+ },
832
+ {
833
+ "epoch": 1.8867924528301887,
834
+ "eval_loss": 0.26620855927467346,
835
+ "eval_runtime": 19.477,
836
+ "eval_samples_per_second": 51.343,
837
+ "eval_steps_per_second": 0.821,
838
+ "step": 2300
839
+ },
840
+ {
841
+ "epoch": 1.907301066447908,
842
+ "grad_norm": 0.3440220057964325,
843
+ "learning_rate": 3.323392536292436e-07,
844
+ "loss": 0.2773,
845
+ "step": 2325
846
+ },
847
+ {
848
+ "epoch": 1.9278096800656277,
849
+ "grad_norm": 0.33441346883773804,
850
+ "learning_rate": 2.0273497260433204e-07,
851
+ "loss": 0.2582,
852
+ "step": 2350
853
+ },
854
+ {
855
+ "epoch": 1.948318293683347,
856
+ "grad_norm": 0.28934305906295776,
857
+ "learning_rate": 1.0490410851763943e-07,
858
+ "loss": 0.2653,
859
+ "step": 2375
860
+ },
861
+ {
862
+ "epoch": 1.9688269073010665,
863
+ "grad_norm": 0.304415225982666,
864
+ "learning_rate": 3.8972014743038356e-08,
865
+ "loss": 0.268,
866
+ "step": 2400
867
+ },
868
+ {
869
+ "epoch": 1.9688269073010665,
870
+ "eval_loss": 0.26615387201309204,
871
+ "eval_runtime": 19.4583,
872
+ "eval_samples_per_second": 51.392,
873
+ "eval_steps_per_second": 0.822,
874
+ "step": 2400
875
+ },
876
+ {
877
+ "epoch": 1.989335520918786,
878
+ "grad_norm": 0.2703385055065155,
879
+ "learning_rate": 5.023171883647426e-09,
880
+ "loss": 0.263,
881
+ "step": 2425
882
+ },
883
+ {
884
+ "epoch": 2.0,
885
+ "step": 2438,
886
+ "total_flos": 1.58523627405312e+18,
887
+ "train_loss": 0.3074496070120157,
888
+ "train_runtime": 3055.5986,
889
+ "train_samples_per_second": 25.527,
890
+ "train_steps_per_second": 0.798
891
+ }
892
+ ],
893
+ "logging_steps": 25,
894
+ "max_steps": 2438,
895
+ "num_input_tokens_seen": 0,
896
+ "num_train_epochs": 2,
897
+ "save_steps": 500,
898
+ "stateful_callbacks": {
899
+ "TrainerControl": {
900
+ "args": {
901
+ "should_epoch_stop": false,
902
+ "should_evaluate": false,
903
+ "should_log": false,
904
+ "should_save": true,
905
+ "should_training_stop": true
906
+ },
907
+ "attributes": {}
908
+ }
909
+ },
910
+ "total_flos": 1.58523627405312e+18,
911
+ "train_batch_size": 32,
912
+ "trial_name": null,
913
+ "trial_params": null
914
+ }
nl_tasks/repro.sh ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export OMINI_CONFIG=./config/commonsense.yaml
2
+
3
+ #echo $OMINI_CONFIG
4
+ export TOKENIZERS_PARALLELISM=true
5
+
6
+ # CUDA Include (/cuda.h)
7
+ CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
8
+
9
+ # 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
10
+ export CPATH=$CPATH:$CUDA_INCLUDE_PATH
11
+ export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
12
+ # echo "CPATH is set to: $CPATH"
13
+ # echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
14
+
15
+ export WANDB_PROJECT="Llama2_7B_FT_Math40k_2"
16
+
17
+ export OMP_NUM_THREADS=1
18
+ export MKL_NUM_THREADS=1
19
+ export OPENBLAS_NUM_THREADS=1
20
+ export NUMEXPR_NUM_THREADS=1
21
+
22
+ date +"%F %T"
23
+ TEXT=("oft" "boft" "loco" "hra")
24
+
25
+ # --run_text "$text" --dynamo_backend no
26
+ export ACCELERATE_DYNAMO_BACKEND="no"
27
+ # --trainer_args.max_steps=81 \
28
+
29
+ accelerate launch --dynamo_backend no --main_process_port 41353 -m src.testLlama \
30
+ --config_path $OMINI_CONFIG --trainer_args.output_dir "./expsBOFT/seed44/" --trainer_args.learning_rate=8e-4 \
31
+ --run_text "boft" --trainer_args.per_device_train_batch_size 32 \
32
+ --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
33
+ --trainer_args.gradient_accumulation_steps 2 \
34
+ --trainer_args.num_train_epochs 2.0 --data.dataset_split train \
35
+ --trainer_args.eval_strategy '"no"' \
36
+ --trainer_args.load_best_model_at_end False \
37
+ --trainer_args.save_strategy '"no"' \
38
+ --trainer_args.logging_step 50 \
39
+ --trainer_args.report_to none --trainer_args.warmup_steps 100 \
40
+ --seed 44
41
+ date +"%F %T"
42
+
43
+ accelerate launch --dynamo_backend no --main_process_port 41353 -m src.testLlama \
44
+ --config_path $OMINI_CONFIG --trainer_args.output_dir "./expsBOFT/seed43/" --trainer_args.learning_rate=8e-4 \
45
+ --run_text "boft" --trainer_args.per_device_train_batch_size 32 \
46
+ --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
47
+ --trainer_args.gradient_accumulation_steps 2 \
48
+ --trainer_args.num_train_epochs 2.0 --data.dataset_split train \
49
+ --trainer_args.eval_strategy '"no"' \
50
+ --trainer_args.load_best_model_at_end False \
51
+ --trainer_args.save_strategy '"no"' \
52
+ --trainer_args.logging_step 50 \
53
+ --trainer_args.report_to none --trainer_args.warmup_steps 100 \
54
+ --seed 43
55
+ date +"%F %T"
56
+
57
+
58
+ accelerate launch --main_process_port 41353 -m src.testLlama \
59
+ --config_path $OMINI_CONFIG --trainer_args.output_dir "./expsOFT/seed43/" --trainer_args.learning_rate=8e-4 \
60
+ --run_text "oft" --trainer_args.per_device_train_batch_size 64 \
61
+ --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
62
+ --trainer_args.gradient_accumulation_steps 1 \
63
+ --trainer_args.num_train_epochs 2.0 --data.dataset_split train \
64
+ --trainer_args.eval_strategy '"no"' \
65
+ --trainer_args.load_best_model_at_end False \
66
+ --trainer_args.save_strategy '"no"' \
67
+ --trainer_args.logging_step 50 \
68
+ --trainer_args.report_to none --trainer_args.warmup_steps 100 \
69
+ --seed 43
70
+
71
+ date +"%F %T"
72
+
73
+ accelerate launch --main_process_port 41353 -m src.testLlama \
74
+ --config_path $OMINI_CONFIG --trainer_args.output_dir "./expsOFT/seed44/" --trainer_args.learning_rate=8e-4 \
75
+ --run_text "oft" --trainer_args.per_device_train_batch_size 64 \
76
+ --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
77
+ --trainer_args.gradient_accumulation_steps 1 \
78
+ --trainer_args.num_train_epochs 2.0 --data.dataset_split train \
79
+ --trainer_args.eval_strategy '"no"' \
80
+ --trainer_args.load_best_model_at_end False \
81
+ --trainer_args.save_strategy '"no"' \
82
+ --trainer_args.logging_step 50 \
83
+ --trainer_args.report_to none --trainer_args.warmup_steps 100 \
84
+ --seed 44
85
+
86
+ date +"%F %T"
87
+
nl_tasks/rpeft/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ # There's no way to ignore "F401 '...' imported but unused" warnings in this
3
+ # module, but to preserve other warnings. So, don't check this module at all.
4
+
5
+ # coding=utf-8
6
+ # Copyright 2023-present the HuggingFace Inc. team.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ __version__ = "0.0.1"
21
+
22
+ from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING,\
23
+ get_peft_config, get_peft_model #, PEFT_TYPE_TO_TUNER_MAPPING
24
+
25
+
26
+ from .rotation import (
27
+ RotationConfig,
28
+ RotationTuner,
29
+ )
30
+ from .utils import (
31
+ TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
32
+ PeftConfig,
33
+ PeftType,
34
+ PromptLearningConfig,
35
+ TaskType,
36
+ bloom_model_postprocess_past_key_value,
37
+ get_peft_model_state_dict,
38
+ prepare_model_for_int8_training,
39
+ set_peft_model_state_dict,
40
+ shift_tokens_right,
41
+ )
42
+
43
+ from .peft_model import PeftModel
nl_tasks/rpeft/mapping.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Original License:
3
+ # Copyright 2023-present the HuggingFace Inc. team.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from .peft_model import (
18
+ PeftModel,
19
+ PeftModelForCausalLM,
20
+ PeftModelForSeq2SeqLM,
21
+ PeftModelForSequenceClassification,
22
+ PeftModelForTokenClassification,
23
+ )
24
+ from .rotation import RotationConfig, RotationTuner
25
+ from .utils import PromptLearningConfig
26
+
27
+ from transformers import PreTrainedModel
28
+
29
+ MODEL_TYPE_TO_PEFT_MODEL_MAPPING = {
30
+ "SEQ_CLS": PeftModelForSequenceClassification,
31
+ "SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM,
32
+ "CAUSAL_LM": PeftModelForCausalLM,
33
+ "TOKEN_CLS": PeftModelForTokenClassification,
34
+ }
35
+
36
+ PEFT_TYPE_TO_CONFIG_MAPPING: dict = {
37
+ "ROTATION": RotationConfig,
38
+ }
39
+
40
+ PEFT_TYPE_TO_TUNER_MAPPING: dict = {
41
+ "ROTATION": RotationTuner
42
+ }
43
+
44
+ TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = {
45
+ "t5": ["q", "v"],
46
+ "mt5": ["q", "v"],
47
+ "bart": ["q_proj", "v_proj"],
48
+ "gpt2": ["c_attn"],
49
+ "bloom": ["query_key_value"],
50
+ "blip-2": ["q", "v", "q_proj", "v_proj"],
51
+ "opt": ["q_proj", "v_proj"],
52
+ "gptj": ["q_proj", "v_proj"],
53
+ "gpt_neox": ["query_key_value"],
54
+ "gpt_neo": ["q_proj", "v_proj"],
55
+ "bert": ["query", "value"],
56
+ "roberta": ["query", "value"],
57
+ "xlm-roberta": ["query", "value"],
58
+ "electra": ["query", "value"],
59
+ "deberta-v2": ["query_proj", "value_proj"],
60
+ "deberta": ["in_proj"],
61
+ "layoutlm": ["query", "value"],
62
+ "llama": ["q_proj", "v_proj"],
63
+ "chatglm": ["query_key_value"],
64
+ "gpt_bigcode": ["c_attn"],
65
+ "mpt": ["Wqkv"],
66
+ "RefinedWebModel": ["query_key_value"],
67
+ "RefinedWeb": ["query_key_value"],
68
+ "falcon": ["query_key_value"],
69
+ "btlm": ["c_proj", "c_attn"],
70
+ "codegen": ["qkv_proj"],
71
+ "mistral": ["q_proj", "v_proj"],
72
+ "mixtral": ["q_proj", "v_proj"],
73
+ "stablelm": ["q_proj", "v_proj"],
74
+ "phi": ["q_proj", "v_proj", "fc1", "fc2"],
75
+ "gemma": ["q_proj", "v_proj"],
76
+ }
77
+
78
+
79
+
80
+ def get_peft_config(config_dict):
81
+ """
82
+ Returns a Peft config object from a dictionary.
83
+
84
+ Args:
85
+ config_dict (`Dict[str, Any]`): Dictionary containing the configuration parameters.
86
+ """
87
+
88
+ return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict)
89
+
90
+
91
+ def _prepare_prompt_learning_config(peft_config, model_config):
92
+ if peft_config.num_layers is None:
93
+ if "num_hidden_layers" in model_config:
94
+ num_layers = model_config["num_hidden_layers"]
95
+ elif "num_layers" in model_config:
96
+ num_layers = model_config["num_layers"]
97
+ elif "n_layer" in model_config:
98
+ num_layers = model_config["n_layer"]
99
+ else:
100
+ raise ValueError("Please specify `num_layers` in `peft_config`")
101
+ peft_config.num_layers = num_layers
102
+
103
+ if peft_config.token_dim is None:
104
+ if "hidden_size" in model_config:
105
+ token_dim = model_config["hidden_size"]
106
+ elif "n_embd" in model_config:
107
+ token_dim = model_config["n_embd"]
108
+ elif "d_model" in model_config:
109
+ token_dim = model_config["d_model"]
110
+ else:
111
+ raise ValueError("Please specify `token_dim` in `peft_config`")
112
+ peft_config.token_dim = token_dim
113
+
114
+ if peft_config.num_attention_heads is None:
115
+ if "num_attention_heads" in model_config:
116
+ num_attention_heads = model_config["num_attention_heads"]
117
+ elif "n_head" in model_config:
118
+ num_attention_heads = model_config["n_head"]
119
+ elif "num_heads" in model_config:
120
+ num_attention_heads = model_config["num_heads"]
121
+ elif "encoder_attention_heads" in model_config:
122
+ num_attention_heads = model_config["encoder_attention_heads"]
123
+ else:
124
+ raise ValueError("Please specify `num_attention_heads` in `peft_config`")
125
+ peft_config.num_attention_heads = num_attention_heads
126
+
127
+ if getattr(peft_config, "encoder_hidden_size", None) is None:
128
+ setattr(peft_config, "encoder_hidden_size", token_dim)
129
+
130
+ return peft_config
131
+
132
+
133
+ def _prepare_lora_config(peft_config, model_config):
134
+ if peft_config.target_modules is None:
135
+ if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING:
136
+ raise ValueError("Please specify `target_modules` in `peft_config`")
137
+ peft_config.target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]]
138
+ if len(peft_config.target_modules) == 1:
139
+ peft_config.fan_in_fan_out = True
140
+ peft_config.enable_lora = [True, False, True]
141
+ if peft_config.inference_mode:
142
+ peft_config.merge_weights = True
143
+ return peft_config
144
+
145
+
146
+
147
+ def get_peft_model(model, peft_config,
148
+ adapter_name: str = "default"):
149
+ """
150
+ Returns a Peft model object from a model and a config.
151
+
152
+ Args:
153
+ model ([`transformers.PreTrainedModel`]): Model to be wrapped.
154
+ peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model.
155
+ """
156
+ model_config = model.config.to_dict()
157
+ peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)
158
+ if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():
159
+ if peft_config.peft_type == "LORA" or "QUANTA":
160
+ peft_config = _prepare_lora_config(peft_config, model_config)
161
+ return PeftModel(model, peft_config)
162
+ if not isinstance(peft_config, PromptLearningConfig):
163
+ if peft_config.peft_type == "LORA" or "QUANTA":
164
+ peft_config = _prepare_lora_config(peft_config, model_config)
165
+ else:
166
+ peft_config = _prepare_prompt_learning_config(peft_config, model_config)
167
+ # assert False
168
+ return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](
169
+ model,
170
+ peft_config,
171
+ adapter_name=adapter_name,
172
+ )
173
+
174
+
175
+ # def get_peft_model(
176
+ # model: PreTrainedModel,
177
+ # peft_config,
178
+ # adapter_name: str = "default",
179
+ # mixed: bool = False,
180
+ # autocast_adapter_dtype: bool = True,
181
+ # revision: Optional[str] = None,
182
+ # low_cpu_mem_usage: bool = False,
183
+ # ) -> PeftModel | PeftMixedModel:
184
+ # """
185
+ # Returns a Peft model object from a model and a config, where the model will be modified in-place.
186
+
187
+ # Args:
188
+ # model ([`transformers.PreTrainedModel`]):
189
+ # Model to be wrapped.
190
+ # peft_config ([`PeftConfig`]):
191
+ # Configuration object containing the parameters of the Peft model.
192
+ # adapter_name (`str`, `optional`, defaults to `"default"`):
193
+ # The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
194
+ # mixed (`bool`, `optional`, defaults to `False`):
195
+ # Whether to allow mixing different (compatible) adapter types.
196
+ # autocast_adapter_dtype (`bool`, *optional*):
197
+ # Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights
198
+ # using float16 or bfloat16 to float32, as this is typically required for stable training, and only affect
199
+ # select PEFT tuners.
200
+ # revision (`str`, `optional`, defaults to `main`):
201
+ # The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for
202
+ # the base model
203
+ # low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
204
+ # Create empty adapter weights on meta device. Useful to speed up the loading process. Leave this setting as
205
+ # False if you intend on training the model, unless the adapter weights will be replaced by different weights
206
+ # before training starts.
207
+ # """
208
+ # model_config = BaseTuner.get_model_config(model)
209
+ # old_name = peft_config.base_model_name_or_path
210
+ # new_name = model.__dict__.get("name_or_path", None)
211
+ # peft_config.base_model_name_or_path = new_name
212
+
213
+ # # Especially in notebook environments there could be a case that a user wants to experiment with different
214
+ # # configuration values. However, it is likely that there won't be any changes for new configs on an already
215
+ # # initialized PEFT model. The best we can do is warn the user about it.
216
+ # if any(isinstance(module, BaseTunerLayer) for module in model.modules()):
217
+ # warnings.warn(
218
+ # "You are trying to modify a model with PEFT for a second time. If you want to reload the model with a "
219
+ # "different config, make sure to call `.unload()` before."
220
+ # )
221
+
222
+ # if (old_name is not None) and (old_name != new_name):
223
+ # warnings.warn(
224
+ # f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. "
225
+ # "Please ensure that the correct base model is loaded when loading this checkpoint."
226
+ # )
227
+
228
+ # if revision is not None:
229
+ # if peft_config.revision is not None and peft_config.revision != revision:
230
+ # warnings.warn(
231
+ # f"peft config has already set base model revision to {peft_config.revision}, overwriting with revision {revision}"
232
+ # )
233
+ # peft_config.revision = revision
234
+
235
+ # if (
236
+ # (isinstance(peft_config, PEFT_TYPE_TO_CONFIG_MAPPING["LORA"]))
237
+ # and (peft_config.init_lora_weights == "eva")
238
+ # and not low_cpu_mem_usage
239
+ # ):
240
+ # warnings.warn(
241
+ # "lora with eva initialization used with low_cpu_mem_usage=False. "
242
+ # "Setting low_cpu_mem_usage=True can improve the maximum batch size possible for eva initialization."
243
+ # )
244
+
245
+ # prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(peft_config.peft_type)
246
+ # if prefix and adapter_name in prefix:
247
+ # warnings.warn(
248
+ # f"Adapter name '{adapter_name}' should not be contained in the prefix '{prefix}'. "
249
+ # "This may lead to reinitialization of the adapter weights during loading."
250
+ # )
251
+
252
+ # if mixed:
253
+ # # note: PeftMixedModel does not support autocast_adapter_dtype, so don't pass it
254
+ # return PeftMixedModel(model, peft_config, adapter_name=adapter_name)
255
+
256
+ # # We explicitly exclude prompt learning here since prompt learning is specific to the task and needs special
257
+ # # handling in the PEFT model's forward method.
258
+ # if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning:
259
+ # return PeftModel(
260
+ # model,
261
+ # peft_config,
262
+ # adapter_name=adapter_name,
263
+ # autocast_adapter_dtype=autocast_adapter_dtype,
264
+ # low_cpu_mem_usage=low_cpu_mem_usage,
265
+ # )
266
+
267
+ # return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](
268
+ # model,
269
+ # peft_config,
270
+ # adapter_name=adapter_name,
271
+ # autocast_adapter_dtype=autocast_adapter_dtype,
272
+ # low_cpu_mem_usage=low_cpu_mem_usage,
273
+ # )
nl_tasks/rpeft/peft_model.py ADDED
@@ -0,0 +1,922 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Original License:
3
+ # Copyright 2023-present the HuggingFace Inc. team.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import inspect
18
+ import os
19
+ import warnings
20
+ from contextlib import contextmanager
21
+
22
+ import torch
23
+ from accelerate import dispatch_model, infer_auto_device_map
24
+ from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules
25
+ from accelerate.utils import get_balanced_memory
26
+ from huggingface_hub import hf_hub_download
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+ from transformers import PreTrainedModel
29
+ from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput
30
+ from transformers.utils import PushToHubMixin
31
+
32
+ import packaging.version
33
+ import transformers
34
+ from typing import Any, Literal, Optional, Union
35
+
36
+ from .rotation import RotationTuner
37
+ from .utils import (
38
+ TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
39
+ WEIGHTS_NAME,
40
+ PeftConfig,
41
+ PeftType,
42
+ PromptLearningConfig,
43
+ TaskType,
44
+ _set_trainable,
45
+ get_peft_model_state_dict,
46
+ set_peft_model_state_dict,
47
+ shift_tokens_right,
48
+ )
49
+
50
+ class PeftModel(PushToHubMixin, torch.nn.Module):
51
+ """
52
+
53
+ """
54
+
55
+ def __init__(self, model, peft_config: PeftConfig, adapter_name: str = "default"):
56
+ super().__init__()
57
+ self.peft_config = peft_config
58
+ self.base_model = model
59
+ self.config = self.base_model.config
60
+ self.modules_to_save = None
61
+ self.active_adapter = adapter_name
62
+ ##### Diff do nothing with active_adapter
63
+ if isinstance(self.peft_config, PromptLearningConfig):
64
+ self._setup_prompt_encoder()
65
+ else:
66
+ if self.peft_config.peft_type == PeftType.ROTATION:
67
+ self.base_model = RotationTuner(model, {adapter_name: peft_config}, adapter_name)
68
+ if getattr(self.peft_config, "modules_to_save", None) is not None:
69
+ self.modules_to_save = self.peft_config.modules_to_save
70
+ _set_trainable(self)
71
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72
+ self.base_model_torch_dtype = getattr(model, "dtype", None)
73
+
74
+ def save_pretrained(self, save_directory, **kwargs):
75
+ r"""
76
+ Args:
77
+ This function saves the adapter model and the adapter configuration files to a directory, so that it can be
78
+ re-loaded using the `LoraModel.from_pretrained` class method, and also used by the `LoraModel.push_to_hub`
79
+ method.
80
+ save_directory (`str`):
81
+ Directory where the adapter model and configuration files will be saved (will be created if it does not
82
+ exist).
83
+ **kwargs:
84
+ Additional keyword arguments passed along to the `push_to_hub` method.
85
+ """
86
+ if os.path.isfile(save_directory):
87
+ raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
88
+ os.makedirs(save_directory, exist_ok=True)
89
+
90
+ # save only the trainable weights
91
+ output_state_dict = get_peft_model_state_dict(self, kwargs.get("state_dict", None))
92
+ torch.save(output_state_dict, os.path.join(save_directory, WEIGHTS_NAME))
93
+
94
+ # save the config and change the inference mode to `True`
95
+ if self.peft_config.base_model_name_or_path is None:
96
+ self.peft_config.base_model_name_or_path = (
97
+ self.base_model.__dict__.get("name_or_path", None)
98
+ if isinstance(self.peft_config, PromptLearningConfig)
99
+ else self.base_model.model.__dict__.get("name_or_path", None)
100
+ )
101
+ inference_mode = self.peft_config.inference_mode
102
+ self.peft_config.inference_mode = True
103
+ self.peft_config.save_pretrained(save_directory)
104
+ self.peft_config.inference_mode = inference_mode
105
+
106
+ @classmethod
107
+ def from_pretrained(cls, model, model_id, is_trainable = False, **kwargs):
108
+ r"""
109
+ Args:
110
+ Instantiate a `LoraModel` from a pretrained Lora configuration and weights.
111
+ model (`transformers.PreTrainedModel`):
112
+ The model to be adapted. The model should be initialized with the `from_pretrained` method. from
113
+ `transformers` library.
114
+ model_id (`str`):
115
+ The name of the Lora configuration to use. Can be either:
116
+ - A string, the `model id` of a Lora configuration hosted inside a model repo on
117
+ huggingface Hub
118
+ - A path to a directory containing a Lora configuration file saved using the
119
+ `save_pretrained` method, e.g., ``./my_lora_config_directory/``.
120
+ """
121
+ from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING
122
+ # load the config
123
+ config = PEFT_TYPE_TO_CONFIG_MAPPING[PeftConfig.from_pretrained(model_id).peft_type].from_pretrained(model_id)
124
+ config.inference_mode = not is_trainable
125
+
126
+ if getattr(model, "hf_device_map", None) is not None:
127
+ remove_hook_from_submodules(model)
128
+
129
+ if config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():
130
+ model = cls(model, config)
131
+ else:
132
+ model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config)
133
+
134
+ # load weights if any
135
+ if os.path.exists(os.path.join(model_id, WEIGHTS_NAME)):
136
+ filename = os.path.join(model_id, WEIGHTS_NAME)
137
+ else:
138
+ try:
139
+ filename = hf_hub_download(model_id, WEIGHTS_NAME)
140
+ except: # noqa
141
+ raise ValueError(
142
+ f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. "
143
+ f"Please check that the file {WEIGHTS_NAME} is present at {model_id}."
144
+ )
145
+
146
+ adapters_weights = torch.load(
147
+ filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
148
+ )
149
+ # load the weights into the model
150
+ model = set_peft_model_state_dict(model, adapters_weights)
151
+ if getattr(model, "hf_device_map", None) is not None:
152
+ device_map = kwargs.get("device_map", "auto")
153
+ max_memory = kwargs.get("max_memory", None)
154
+ no_split_module_classes = model._no_split_modules
155
+ if device_map != "sequential":
156
+ max_memory = get_balanced_memory(
157
+ model,
158
+ max_memory=max_memory,
159
+ no_split_module_classes=no_split_module_classes,
160
+ low_zero=(device_map == "balanced_low_0"),
161
+ )
162
+ if isinstance(device_map, str):
163
+ device_map = infer_auto_device_map(
164
+ model, max_memory=max_memory, no_split_module_classes=no_split_module_classes
165
+ )
166
+ model = dispatch_model(model, device_map=device_map)
167
+ hook = AlignDevicesHook(io_same_device=True)
168
+ if model.peft_config.peft_type == PeftType.LORA or model.peft_config.peft_type == PeftType.BOTTLENECK \
169
+ or model.peft_config.peft_type == "ROTATION":
170
+ add_hook_to_module(model.base_model.model, hook)
171
+ else:
172
+ remove_hook_from_submodules(model.prompt_encoder)
173
+ add_hook_to_module(model.base_model, hook)
174
+ # if model.peft_config.is_prompt_learning:
175
+ # remove_hook_from_submodules(model.prompt_encoder)
176
+ # add_hook_to_module(model.base_model, hook)
177
+ return model
178
+
179
+ def _setup_prompt_encoder(self):
180
+ transformer_backbone = None
181
+ for name, module in self.base_model.named_children():
182
+ for param in module.parameters():
183
+ param.requires_grad = False
184
+ if isinstance(module, PreTrainedModel):
185
+ # Make sure to freeze Tranformers model
186
+ if transformer_backbone is None:
187
+ transformer_backbone = module
188
+ self.transformer_backbone_name = name
189
+
190
+ if self.peft_config.num_transformer_submodules is None:
191
+ self.peft_config.num_transformer_submodules = (
192
+ 2 if self.peft_config.task_type == TaskType.SEQ_2_SEQ_LM else 1
193
+ )
194
+
195
+ for named_param, value in list(transformer_backbone.named_parameters()):
196
+ if value.shape[0] == self.base_model.config.vocab_size:
197
+ self.word_embeddings = transformer_backbone.get_submodule(named_param.replace(".weight", ""))
198
+ break
199
+
200
+ if self.peft_config.peft_type == PeftType.PROMPT_TUNING:
201
+ prompt_encoder = PromptEmbedding(self.peft_config, self.word_embeddings)
202
+ elif self.peft_config.peft_type == PeftType.P_TUNING:
203
+ prompt_encoder = PromptEncoder(self.peft_config)
204
+ elif self.peft_config.peft_type == PeftType.PREFIX_TUNING:
205
+ prompt_encoder = PrefixEncoder(self.peft_config)
206
+ else:
207
+ raise ValueError("Not supported")
208
+ self.prompt_encoder = prompt_encoder
209
+ self.prompt_tokens = torch.arange(
210
+ self.peft_config.num_virtual_tokens * self.peft_config.num_transformer_submodules
211
+ ).long()
212
+
213
+ def get_prompt_embedding_to_save(self):
214
+ """
215
+ Returns the prompt embedding to save when saving the model. Only applicable when `peft_config.peft_type !=
216
+ PeftType.LORA`.
217
+ """
218
+ prompt_tokens = self.prompt_tokens.unsqueeze(0).expand(1, -1).to(self.device)
219
+ if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
220
+ prompt_tokens = prompt_tokens[:, : self.peft_config.num_virtual_tokens]
221
+ prompt_embeddings = self.prompt_encoder(prompt_tokens)
222
+ return prompt_embeddings[0].detach().cpu()
223
+
224
+ def get_prompt(self, batch_size):
225
+ """
226
+ Returns the virtual prompts to use for Peft. Only applicable when `peft_config.peft_type != PeftType.LORA`.
227
+ """
228
+ prompt_tokens = self.prompt_tokens.unsqueeze(0).expand(batch_size, -1).to(self.device)
229
+ if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
230
+ prompt_tokens = prompt_tokens[:, : self.peft_config.num_virtual_tokens]
231
+ if self.peft_config.inference_mode:
232
+ past_key_values = self.prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
233
+ else:
234
+ past_key_values = self.prompt_encoder(prompt_tokens)
235
+ past_key_values = past_key_values.view(
236
+ batch_size,
237
+ self.peft_config.num_virtual_tokens,
238
+ self.peft_config.num_layers * 2,
239
+ self.peft_config.num_attention_heads,
240
+ self.peft_config.token_dim // self.peft_config.num_attention_heads,
241
+ )
242
+ if self.peft_config.num_transformer_submodules == 2:
243
+ past_key_values = torch.cat([past_key_values, past_key_values], dim=2)
244
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(
245
+ self.peft_config.num_transformer_submodules * 2
246
+ )
247
+ if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None:
248
+ post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type]
249
+ past_key_values = post_process_fn(past_key_values)
250
+ return past_key_values
251
+ else:
252
+ if self.peft_config.inference_mode:
253
+ prompts = self.prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
254
+ else:
255
+ prompts = self.prompt_encoder(prompt_tokens)
256
+ return prompts
257
+
258
+ def print_trainable_parameters(self):
259
+ """
260
+ Prints the number of trainable parameters in the model.
261
+ """
262
+ trainable_params = 0
263
+ all_param = 0
264
+ for _, param in self.named_parameters():
265
+ num_params = param.numel()
266
+ # if using DS Zero 3 and the weights are initialized empty
267
+ if num_params == 0 and hasattr(param, "ds_numel"):
268
+ num_params = param.ds_numel
269
+
270
+ all_param += num_params
271
+ if param.requires_grad:
272
+ trainable_params += num_params
273
+ print(
274
+ f"trainable params: {trainable_params:,} || all params: {all_param:,} || trainable: {100 * trainable_params / all_param}%"
275
+ )
276
+
277
+ def __getattr__(self, name: str):
278
+ """Forward missing attributes to the wrapped module."""
279
+ try:
280
+ return super().__getattr__(name) # defer to nn.Module's logic
281
+ except AttributeError:
282
+ return getattr(self.base_model, name)
283
+
284
+ def forward(self, *args, **kwargs):
285
+ """
286
+ Forward pass of the model.
287
+ """
288
+ return self.get_base_model()(*args, **kwargs)
289
+
290
+ @contextmanager
291
+ def disable_adapter(self):
292
+ """
293
+ Disables the adapter module.
294
+ """
295
+ if isinstance(self.peft_config, PromptLearningConfig):
296
+ old_forward = self.forward
297
+ self.forward = self.base_model.forward
298
+ else:
299
+ self.base_model.disable_adapter_layers()
300
+ yield
301
+ if isinstance(self.peft_config, PromptLearningConfig):
302
+ self.forward = old_forward
303
+ else:
304
+ self.base_model.enable_adapter_layers()
305
+
306
+ def get_base_model(self):
307
+ """
308
+ Returns the base model.
309
+ """
310
+ return self.base_model if isinstance(self.peft_config, PromptLearningConfig) else self.base_model.model
311
+
312
+
313
+
314
+ class PeftModelForSequenceClassification(PeftModel):
315
+ """
316
+ """
317
+
318
+ def __init__(self, model, peft_config: PeftConfig, adapter_name: str = "default"):
319
+ super().__init__(model, peft_config, adapter_name)
320
+ self.modules_to_save = ["classifier", "score", "pooler"]
321
+
322
+ # for name, _ in self.base_model.named_children():
323
+ # if any(module_name in name for module_name in self.modules_to_save):
324
+ # self.cls_layer_name = name
325
+ # break
326
+ user_modules = getattr(peft_config, "modules_to_save", None) or []
327
+ default_modules = ["classifier", "score"]
328
+ self.modules_to_save = list(set(user_modules + default_modules))
329
+
330
+ #from .rotation import RotationTuner # Import để check type
331
+ if isinstance(self.base_model, RotationTuner):
332
+ real_model = self.base_model.model
333
+ else:
334
+ real_model = self.base_model
335
+
336
+ # 3. Tìm tên layer thực tế
337
+ for name, _ in real_model.named_children():
338
+ if any(module_name in name for module_name in self.modules_to_save):
339
+ self.cls_layer_name = name
340
+
341
+ # # to make sure classifier layer is trainable
342
+ _set_trainable(self)
343
+
344
+ def forward(
345
+ self,
346
+ input_ids=None,
347
+ attention_mask=None,
348
+ inputs_embeds=None,
349
+ labels=None,
350
+ output_attentions=None,
351
+ output_hidden_states=None,
352
+ return_dict=None,
353
+ **kwargs,
354
+ ):
355
+ if "num_items_in_batch" in kwargs:
356
+ kwargs.pop("num_items_in_batch")
357
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
358
+
359
+ if not isinstance(self.peft_config, PromptLearningConfig):
360
+ return self.base_model(
361
+ input_ids=input_ids,
362
+ attention_mask=attention_mask,
363
+ inputs_embeds=inputs_embeds,
364
+ labels=labels,
365
+ output_attentions=output_attentions,
366
+ output_hidden_states=output_hidden_states,
367
+ return_dict=return_dict,
368
+ **kwargs,
369
+ )
370
+
371
+ batch_size = input_ids.shape[0]
372
+ if attention_mask is not None:
373
+ # concat prompt attention mask
374
+ prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
375
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
376
+ if kwargs.get("position_ids", None) is not None:
377
+ warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
378
+ kwargs["position_ids"] = None
379
+ kwargs.update(
380
+ {
381
+ "attention_mask": attention_mask,
382
+ "labels": labels,
383
+ "output_attentions": output_attentions,
384
+ "output_hidden_states": output_hidden_states,
385
+ "return_dict": return_dict,
386
+ }
387
+ )
388
+
389
+ if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
390
+ return self._prefix_tuning_forward(input_ids=input_ids, **kwargs)
391
+ else:
392
+ if kwargs.get("token_type_ids", None) is not None:
393
+ kwargs["token_type_ids"] = torch.cat(
394
+ (
395
+ torch.zeros(batch_size, self.peft_config.num_virtual_tokens).to(self.device),
396
+ kwargs["token_type_ids"],
397
+ ),
398
+ dim=1,
399
+ ).long()
400
+ if inputs_embeds is None:
401
+ inputs_embeds = self.word_embeddings(input_ids)
402
+ prompts = self.get_prompt(batch_size=batch_size)
403
+ prompts = prompts.to(inputs_embeds.dtype)
404
+ inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
405
+ return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
406
+
407
+ def _prefix_tuning_forward(
408
+ self,
409
+ input_ids=None,
410
+ attention_mask=None,
411
+ inputs_embeds=None,
412
+ labels=None,
413
+ output_attentions=None,
414
+ output_hidden_states=None,
415
+ return_dict=None,
416
+ **kwargs,
417
+ ):
418
+ batch_size = input_ids.shape[0]
419
+ past_key_values = self.get_prompt(batch_size)
420
+ fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys())
421
+ kwargs.update(
422
+ {
423
+ "input_ids": input_ids,
424
+ "attention_mask": attention_mask,
425
+ "inputs_embeds": inputs_embeds,
426
+ "output_attentions": output_attentions,
427
+ "output_hidden_states": output_hidden_states,
428
+ "return_dict": return_dict,
429
+ "past_key_values": past_key_values,
430
+ }
431
+ )
432
+ if "past_key_values" in fwd_params:
433
+ return self.base_model(labels=labels, **kwargs)
434
+ else:
435
+ transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name)
436
+ fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys())
437
+ if "past_key_values" not in fwd_params:
438
+ raise ValueError("Model does not support past key values which are required for prefix tuning.")
439
+ outputs = transformer_backbone_name(**kwargs)
440
+ pooled_output = outputs[1] if len(outputs) > 1 else outputs[0]
441
+ if "dropout" in [name for name, _ in list(self.base_model.named_children())]:
442
+ pooled_output = self.base_model.dropout(pooled_output)
443
+ logits = self.base_model.get_submodule(self.cls_layer_name)(pooled_output)
444
+
445
+ loss = None
446
+ if labels is not None:
447
+ if self.config.problem_type is None:
448
+ if self.base_model.num_labels == 1:
449
+ self.config.problem_type = "regression"
450
+ elif self.base_model.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
451
+ self.config.problem_type = "single_label_classification"
452
+ else:
453
+ self.config.problem_type = "multi_label_classification"
454
+
455
+ if self.config.problem_type == "regression":
456
+ loss_fct = MSELoss()
457
+ if self.base_model.num_labels == 1:
458
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
459
+ else:
460
+ loss = loss_fct(logits, labels)
461
+ elif self.config.problem_type == "single_label_classification":
462
+ loss_fct = CrossEntropyLoss()
463
+ loss = loss_fct(logits.view(-1, self.base_model.num_labels), labels.view(-1))
464
+ elif self.config.problem_type == "multi_label_classification":
465
+ loss_fct = BCEWithLogitsLoss()
466
+ loss = loss_fct(logits, labels)
467
+ if not return_dict:
468
+ output = (logits,) + outputs[2:]
469
+ return ((loss,) + output) if loss is not None else output
470
+
471
+ return SequenceClassifierOutput(
472
+ loss=loss,
473
+ logits=logits,
474
+ hidden_states=outputs.hidden_states,
475
+ attentions=outputs.attentions,
476
+ )
477
+
478
+
479
+ class PeftModelForCausalLM(PeftModel):
480
+ """
481
+ Peft model for Causal LM
482
+
483
+ Args:
484
+ model ([`PreTrainedModel`]): Base transformer model
485
+ peft_config ([`PeftConfig`]): Peft config.
486
+
487
+
488
+ Example::
489
+
490
+ >>> from transformers import AutoModelForCausalLM >>> from peft_local_tensor import PeftModelForCausalLM, get_peft_config
491
+ >>> config = {
492
+ 'peft_type': 'PREFIX_TUNING', 'task_type': 'CAUSAL_LM', 'inference_mode': False, 'num_virtual_tokens':
493
+ 20, 'token_dim': 1280, 'num_transformer_submodules': 1, 'num_attention_heads': 20, 'num_layers': 36,
494
+ 'encoder_hidden_size': 1280, 'prefix_projection': False, 'postprocess_past_key_value_function': None
495
+ }
496
+ >>> peft_config = get_peft_config(config) >>> model = AutoModelForCausalLM.from_pretrained("gpt2-large") >>>
497
+ peft_model = PeftModelForCausalLM(model, peft_config) >>> peft_model.print_trainable_parameters() trainable
498
+ params: 1843200 || all params: 775873280 || trainable%: 0.23756456724479544
499
+ """
500
+
501
+ def __init__(self, model, peft_config: PeftConfig, adapter_name: str = "default"):
502
+ self.prompt_encoder = None #### don't know why
503
+ self.modules_to_save = None
504
+ super().__init__(model, peft_config, adapter_name)
505
+ self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
506
+
507
+ def forward(
508
+ self,
509
+ input_ids=None,
510
+ attention_mask=None,
511
+ inputs_embeds=None,
512
+ labels=None,
513
+ output_attentions=None,
514
+ output_hidden_states=None,
515
+ return_dict=None,
516
+ **kwargs,
517
+ ):
518
+ if not isinstance(self.peft_config, PromptLearningConfig):
519
+ return self.base_model(
520
+ input_ids=input_ids,
521
+ attention_mask=attention_mask,
522
+ inputs_embeds=inputs_embeds,
523
+ labels=labels,
524
+ output_attentions=output_attentions,
525
+ output_hidden_states=output_hidden_states,
526
+ return_dict=return_dict,
527
+ **kwargs,
528
+ )
529
+
530
+ batch_size = input_ids.shape[0]
531
+ if attention_mask is not None:
532
+ # concat prompt attention mask
533
+ prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
534
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
535
+
536
+ if kwargs.get("position_ids", None) is not None:
537
+ warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
538
+ kwargs["position_ids"] = None
539
+ if kwargs.get("token_type_ids", None) is not None:
540
+ warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
541
+ kwargs["token_type_ids"] = None
542
+ kwargs.update(
543
+ {
544
+ "attention_mask": attention_mask,
545
+ "labels": labels,
546
+ "output_attentions": output_attentions,
547
+ "output_hidden_states": output_hidden_states,
548
+ "return_dict": return_dict,
549
+ }
550
+ )
551
+
552
+ if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
553
+ past_key_values = self.get_prompt(batch_size)
554
+ return self.base_model(input_ids=input_ids, past_key_values=past_key_values, **kwargs)
555
+ else:
556
+ if inputs_embeds is None:
557
+ inputs_embeds = self.word_embeddings(input_ids)
558
+ # concat prompt labels
559
+ if labels is not None:
560
+ prefix_labels = torch.full((batch_size, self.peft_config.num_virtual_tokens), -100).to(self.device)
561
+ kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
562
+ prompts = self.get_prompt(batch_size=batch_size)
563
+ prompts = prompts.to(inputs_embeds.dtype)
564
+ inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
565
+ return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
566
+
567
+ def generate(self, **kwargs):
568
+ self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
569
+ try:
570
+ if not isinstance(self.peft_config, PromptLearningConfig):
571
+ outputs = self.base_model.generate(**kwargs)
572
+ else:
573
+ if "input_ids" not in kwargs:
574
+ raise ValueError("input_ids must be provided for Peft model generation")
575
+ if kwargs.get("attention_mask", None) is not None:
576
+ # concat prompt attention mask
577
+ prefix_attention_mask = torch.ones(
578
+ kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens
579
+ ).to(kwargs["input_ids"].device)
580
+ kwargs["attention_mask"] = torch.cat((prefix_attention_mask, kwargs["attention_mask"]), dim=1)
581
+
582
+ if kwargs.get("position_ids", None) is not None:
583
+ warnings.warn(
584
+ "Position ids are not supported for parameter efficient tuning. Ignoring position ids."
585
+ )
586
+ kwargs["position_ids"] = None
587
+ if kwargs.get("token_type_ids", None) is not None:
588
+ warnings.warn(
589
+ "Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
590
+ )
591
+ kwargs["token_type_ids"] = None
592
+
593
+ outputs = self.base_model.generate(**kwargs)
594
+ except:
595
+ self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
596
+ raise
597
+ else:
598
+ self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
599
+ return outputs
600
+
601
+ def prepare_inputs_for_generation(self, *args, **kwargs):
602
+ model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
603
+ if isinstance(self.peft_config, PromptLearningConfig):
604
+ if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
605
+ prefix_attention_mask = torch.ones(
606
+ model_kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens
607
+ ).to(model_kwargs["input_ids"].device)
608
+ model_kwargs["attention_mask"] = torch.cat(
609
+ (prefix_attention_mask, model_kwargs["attention_mask"]), dim=1
610
+ )
611
+
612
+ if model_kwargs["past_key_values"] is None and self.peft_config.peft_type == PeftType.PREFIX_TUNING:
613
+ past_key_values = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0])
614
+ if self.base_model_torch_dtype is not None:
615
+ # handle the case for Bloom where it outputs tuple of tuples
616
+ if isinstance(past_key_values[0], tuple):
617
+ past_key_values = tuple(
618
+ tuple(
619
+ past_key_value.to(self.base_model_torch_dtype)
620
+ for past_key_value in past_key_value_tuple
621
+ )
622
+ for past_key_value_tuple in past_key_values
623
+ )
624
+ else:
625
+ past_key_values = tuple(
626
+ past_key_value.to(self.base_model_torch_dtype) for past_key_value in past_key_values
627
+ )
628
+
629
+ model_kwargs["past_key_values"] = past_key_values
630
+ else:
631
+ if model_kwargs["past_key_values"] is None:
632
+ inputs_embeds = self.word_embeddings(model_kwargs["input_ids"])
633
+ prompts = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0])
634
+ prompts = prompts.to(inputs_embeds.dtype)
635
+ model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1)
636
+ model_kwargs["input_ids"] = None
637
+
638
+ return model_kwargs
639
+
640
+
641
+ class PeftModelForSeq2SeqLM(PeftModel):
642
+ """
643
+
644
+ """
645
+
646
+ def __init__(self, model, peft_config: PeftConfig):
647
+ super().__init__(model, peft_config)
648
+ self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
649
+ self.base_model_prepare_encoder_decoder_kwargs_for_generation = (
650
+ self.base_model._prepare_encoder_decoder_kwargs_for_generation
651
+ )
652
+
653
+ def forward(
654
+ self,
655
+ input_ids=None,
656
+ attention_mask=None,
657
+ inputs_embeds=None,
658
+ decoder_input_ids=None,
659
+ decoder_attention_mask=None,
660
+ decoder_inputs_embeds=None,
661
+ labels=None,
662
+ output_attentions=None,
663
+ output_hidden_states=None,
664
+ return_dict=None,
665
+ **kwargs,
666
+ ):
667
+ if not isinstance(self.peft_config, PromptLearningConfig):
668
+ return self.base_model(
669
+ input_ids=input_ids,
670
+ attention_mask=attention_mask,
671
+ inputs_embeds=inputs_embeds,
672
+ decoder_input_ids=decoder_input_ids,
673
+ decoder_attention_mask=decoder_attention_mask,
674
+ decoder_inputs_embeds=decoder_inputs_embeds,
675
+ labels=labels,
676
+ output_attentions=output_attentions,
677
+ output_hidden_states=output_hidden_states,
678
+ return_dict=return_dict,
679
+ **kwargs,
680
+ )
681
+
682
+ batch_size = input_ids.shape[0]
683
+ if decoder_attention_mask is not None:
684
+ # concat prompt attention mask
685
+ prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
686
+ decoder_attention_mask = torch.cat((prefix_attention_mask, decoder_attention_mask), dim=1)
687
+
688
+ if kwargs.get("position_ids", None) is not None:
689
+ warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
690
+ kwargs["position_ids"] = None
691
+ if kwargs.get("token_type_ids", None) is not None:
692
+ warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
693
+ kwargs["token_type_ids"] = None
694
+ kwargs.update(
695
+ {
696
+ "attention_mask": attention_mask,
697
+ "decoder_attention_mask": decoder_attention_mask,
698
+ "labels": labels,
699
+ "output_attentions": output_attentions,
700
+ "output_hidden_states": output_hidden_states,
701
+ "return_dict": return_dict,
702
+ }
703
+ )
704
+
705
+ if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
706
+ past_key_values = self.get_prompt(batch_size)
707
+ return self.base_model(
708
+ input_ids=input_ids, decoder_input_ids=decoder_input_ids, past_key_values=past_key_values, **kwargs
709
+ )
710
+ else:
711
+ if inputs_embeds is None:
712
+ inputs_embeds = self.word_embeddings(input_ids)
713
+ if decoder_inputs_embeds is None and decoder_input_ids is None:
714
+ decoder_input_ids = shift_tokens_right(
715
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
716
+ )
717
+ decoder_inputs_embeds = self.word_embeddings(decoder_input_ids)
718
+
719
+ if attention_mask is not None:
720
+ # concat prompt attention mask
721
+ prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
722
+ kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)
723
+ # concat prompt labels
724
+ if labels is not None:
725
+ if self.peft_config.num_transformer_submodules == 1:
726
+ kwargs["labels"] = labels
727
+ elif self.peft_config.num_transformer_submodules == 2:
728
+ prefix_labels = torch.full((batch_size, self.peft_config.num_virtual_tokens), -100).to(self.device)
729
+ kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
730
+ prompts = self.get_prompt(batch_size=batch_size)
731
+ prompts = prompts.to(inputs_embeds.dtype)
732
+ inputs_embeds = torch.cat((prompts[:, : self.peft_config.num_virtual_tokens], inputs_embeds), dim=1)
733
+ if self.peft_config.num_transformer_submodules == 1:
734
+ return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
735
+ elif self.peft_config.num_transformer_submodules == 2:
736
+ decoder_inputs_embeds = torch.cat(
737
+ (prompts[:, self.peft_config.num_virtual_tokens :], decoder_inputs_embeds), dim=1
738
+ )
739
+ return self.base_model(
740
+ inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, **kwargs
741
+ )
742
+
743
+ def generate(self, **kwargs):
744
+ self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
745
+ self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
746
+ self._prepare_encoder_decoder_kwargs_for_generation
747
+ )
748
+ try:
749
+ if not isinstance(self.peft_config, PromptLearningConfig):
750
+ outputs = self.base_model.generate(**kwargs)
751
+ else:
752
+ if "input_ids" not in kwargs:
753
+ raise ValueError("input_ids must be provided for Peft model generation")
754
+ if kwargs.get("position_ids", None) is not None:
755
+ warnings.warn(
756
+ "Position ids are not supported for parameter efficient tuning. Ignoring position ids."
757
+ )
758
+ kwargs["position_ids"] = None
759
+ if kwargs.get("token_type_ids", None) is not None:
760
+ warnings.warn(
761
+ "Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
762
+ )
763
+ kwargs["token_type_ids"] = None
764
+
765
+ if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
766
+ outputs = self.base_model.generate(**kwargs)
767
+ else:
768
+ raise NotImplementedError
769
+ except:
770
+ self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
771
+ self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
772
+ self.base_model_prepare_encoder_decoder_kwargs_for_generation
773
+ )
774
+ raise
775
+ else:
776
+ self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
777
+ self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
778
+ self.base_model_prepare_encoder_decoder_kwargs_for_generation
779
+ )
780
+ return outputs
781
+
782
+ def prepare_inputs_for_generation(self, *args, **kwargs):
783
+ model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
784
+ if model_kwargs["past_key_values"] is None and self.peft_config.peft_type == PeftType.PREFIX_TUNING:
785
+ batch_size = model_kwargs["decoder_input_ids"].shape[0]
786
+ past_key_values = self.get_prompt(batch_size)
787
+ model_kwargs["past_key_values"] = past_key_values
788
+ return model_kwargs
789
+
790
+
791
+ class PeftModelForTokenClassification(PeftModel):
792
+ """
793
+
794
+ """
795
+
796
+ def __init__(self, model, peft_config: PeftConfig):
797
+ super().__init__(model, peft_config)
798
+ self.modules_to_save = ["classifier", "score"]
799
+
800
+ for name, _ in self.base_model.named_children():
801
+ if any(module_name in name for module_name in self.modules_to_save):
802
+ self.cls_layer_name = name
803
+ break
804
+
805
+ # to make sure classifier layer is trainable
806
+ _set_trainable(self)
807
+
808
+ def forward(
809
+ self,
810
+ input_ids=None,
811
+ attention_mask=None,
812
+ inputs_embeds=None,
813
+ labels=None,
814
+ output_attentions=None,
815
+ output_hidden_states=None,
816
+ return_dict=None,
817
+ **kwargs,
818
+ ):
819
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
820
+
821
+ if not isinstance(self.peft_config, PromptLearningConfig):
822
+ return self.base_model(
823
+ input_ids=input_ids,
824
+ attention_mask=attention_mask,
825
+ inputs_embeds=inputs_embeds,
826
+ labels=labels,
827
+ output_attentions=output_attentions,
828
+ output_hidden_states=output_hidden_states,
829
+ return_dict=return_dict,
830
+ **kwargs,
831
+ )
832
+
833
+ batch_size = input_ids.shape[0]
834
+ if attention_mask is not None:
835
+ # concat prompt attention mask
836
+ prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
837
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
838
+ if kwargs.get("position_ids", None) is not None:
839
+ warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
840
+ kwargs["position_ids"] = None
841
+ kwargs.update(
842
+ {
843
+ "attention_mask": attention_mask,
844
+ "labels": labels,
845
+ "output_attentions": output_attentions,
846
+ "output_hidden_states": output_hidden_states,
847
+ "return_dict": return_dict,
848
+ }
849
+ )
850
+
851
+ if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
852
+ return self._prefix_tuning_forward(input_ids=input_ids, **kwargs)
853
+ else:
854
+ if kwargs.get("token_type_ids", None) is not None:
855
+ kwargs["token_type_ids"] = torch.cat(
856
+ (
857
+ torch.zeros(batch_size, self.peft_config.num_virtual_tokens).to(self.device),
858
+ kwargs["token_type_ids"],
859
+ ),
860
+ dim=1,
861
+ ).long()
862
+ if inputs_embeds is None:
863
+ inputs_embeds = self.word_embeddings(input_ids)
864
+ prompts = self.get_prompt(batch_size=batch_size)
865
+ prompts = prompts.to(inputs_embeds.dtype)
866
+ inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
867
+ return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
868
+
869
+ def _prefix_tuning_forward(
870
+ self,
871
+ input_ids=None,
872
+ attention_mask=None,
873
+ inputs_embeds=None,
874
+ labels=None,
875
+ output_attentions=None,
876
+ output_hidden_states=None,
877
+ return_dict=None,
878
+ **kwargs,
879
+ ):
880
+ batch_size = input_ids.shape[0]
881
+ past_key_values = self.get_prompt(batch_size)
882
+ fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys())
883
+ kwargs.update(
884
+ {
885
+ "input_ids": input_ids,
886
+ "attention_mask": attention_mask,
887
+ "inputs_embeds": inputs_embeds,
888
+ "output_attentions": output_attentions,
889
+ "output_hidden_states": output_hidden_states,
890
+ "return_dict": return_dict,
891
+ "past_key_values": past_key_values,
892
+ }
893
+ )
894
+ if "past_key_values" in fwd_params:
895
+ return self.base_model(labels=labels, **kwargs)
896
+ else:
897
+ transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name)
898
+ fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys())
899
+ if "past_key_values" not in fwd_params:
900
+ raise ValueError("Model does not support past key values which are required for prefix tuning.")
901
+ outputs = transformer_backbone_name(**kwargs)
902
+ sequence_output = outputs[0]
903
+ if "dropout" in [name for name, _ in list(self.base_model.named_children())]:
904
+ sequence_output = self.base_model.dropout(sequence_output)
905
+ logits = self.base_model.get_submodule(self.cls_layer_name)(sequence_output)
906
+
907
+ loss = None
908
+ loss = None
909
+ if labels is not None:
910
+ loss_fct = CrossEntropyLoss()
911
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
912
+
913
+ if not return_dict:
914
+ output = (logits,) + outputs[2:]
915
+ return ((loss,) + output) if loss is not None else output
916
+
917
+ return TokenClassifierOutput(
918
+ loss=loss,
919
+ logits=logits,
920
+ hidden_states=outputs.hidden_states,
921
+ attentions=outputs.attentions,
922
+ )
nl_tasks/rpeft/rotation/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .rotation_config import RotationConfig
2
+ from .layer import RotationLayer
3
+ from .model import RotationTuner
nl_tasks/rpeft/rotation/layer.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional, Set
4
+
5
+ from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
6
+ import torch.nn.functional as F
7
+
8
+ def inverse_2x2(matrices):
9
+
10
+ # Extract matrix elements
11
+ # matrices[..., 0, 0] corresponds to 'a' in [[a, b], [c, d]]
12
+ a = matrices[..., 0, 0]
13
+ b = matrices[..., 0, 1]
14
+ c = matrices[..., 1, 0]
15
+ d = matrices[..., 1, 1]
16
+
17
+ # Compute determinant
18
+ det = a * d - b * c
19
+
20
+ # Compute inverse using the formula:
21
+ # inv = (1/det) * [[d, -b], [-c, a]]
22
+ inv_det = 1.0 / det
23
+
24
+ # Create output tensor
25
+ inv_matrices = torch.empty_like(matrices)
26
+ inv_matrices[..., 0, 0] = d * inv_det
27
+ inv_matrices[..., 0, 1] = -b * inv_det
28
+ inv_matrices[..., 1, 0] = -c * inv_det
29
+ inv_matrices[..., 1, 1] = a * inv_det
30
+
31
+ return inv_matrices
32
+
33
+ class Rotation(nn.Module):
34
+ """
35
+ Rotation layer based on Cayley transformation for parameter-efficient fine-tuning.
36
+
37
+ This layer implements orthogonal fine-tuning through Cayley transformation:
38
+ h(x) = (I - A)^{-1} (I + A) x
39
+
40
+ where A = XY^T with X = [U; -V] and Y = [V; U]
41
+ """
42
+
43
+ def __init__(self, r, dim, T=1.0, num_rotations=4, drop_out=0.1):
44
+ super().__init__()
45
+ self.r = r
46
+ self.T = T
47
+ self.U = nn.Parameter(torch.randn(num_rotations, r, dim) * 0.002, requires_grad=True)
48
+ self.V = nn.Parameter(torch.randn(num_rotations, r, dim) * 0.0, requires_grad=True)
49
+ # self.U = nn.Parameter(torch.empty(num_rotations, r, dim), requires_grad=True)
50
+ # self.V = nn.Parameter(torch.empty(num_rotations, r, dim), requires_grad=True)
51
+ self.num_rotations = num_rotations
52
+ self.dropout = nn.Dropout(drop_out) if drop_out > 0 else nn.Identity()
53
+ self.dim = dim
54
+
55
+ # elf._post_init()
56
+ # @property
57
+ # def U(self):
58
+ # # Calculate U = [a, b] whenever self.U is accessed
59
+ # # This function acts as the 'getter' for self.U
60
+ # return torch.cat([self.a, self.b], dim=-1)
61
+
62
+ # @property
63
+ # def V(self):
64
+ # # Calculate V = [b, a] whenever self.V is accessed
65
+ # # This function acts as the 'getter' for self.V
66
+ # return torch.cat([self.b, self.a], dim=-1)
67
+
68
+ # def _post_init(self):
69
+ # import torch.nn.init as init
70
+ # import math
71
+ # # init.kaiming_uniform_(self.U, a=math.sqrt(1), mode='fan_out')
72
+ # init.normal_(self.U, 0, 1e-2)
73
+ # with torch.no_grad():
74
+ # self.V.data.copy_(self.U.data)
75
+
76
+ def forward(self, x):
77
+ """
78
+ Apply Cayley transformation to input x.
79
+
80
+ A = XY^T where X = [U; -V], Y = [V; U]
81
+ Cayley transformation: h(x) = (I - A)^{-1} (I + A) x
82
+
83
+ Uses Woodbury identity for efficient computation:
84
+ (I - XY^T)^{-1} = I + X (I - Y^T X)^{-1} Y^T
85
+
86
+ Args:
87
+ x: Input tensor of shape (..., dim)
88
+
89
+ Returns:
90
+ Transformed tensor of shape (..., dim)
91
+ """
92
+ x_dtype = x.dtype
93
+
94
+ x = self.dropout(x) # NLU tasks do not use dropout
95
+ X = torch.cat([self.U, -self.V], dim=1) # Shape: (num_rotations, 2r, dim)
96
+ Y = torch.cat([self.V, self.U], dim=1) * self.T # Shape: (num_rotations, 2r, dim)
97
+
98
+ Y_T_X = torch.matmul(Y, X.transpose(1, 2)) # Shape: (num_rotations, 2r, 2r)
99
+ # I_2r = torch.eye(2 * self.r, device=x.device, dtype=x.dtype).repeat(self.num_rotations, 1, 1)
100
+ I_2r = torch.eye(2 * self.r, device=x.device, dtype=x.dtype).unsqueeze(0)
101
+ I_minus_YX = I_2r - Y_T_X
102
+
103
+ if self.r == 1:
104
+ I_minus_YX_inv = inverse_2x2(I_minus_YX)
105
+ else:
106
+ # make it float32
107
+ I_minus_YX = I_minus_YX.to(torch.float32)
108
+ I_minus_YX_inv = torch.linalg.inv(I_minus_YX) # Shape: (num_rotations, 2r, 2r)
109
+ I_minus_YX_inv = I_minus_YX_inv.to(x_dtype)
110
+
111
+ # Yx = torch.einsum("...d,nrd->...nr", x, Y) # Shape: (batch*seq_len, num_rotations, 2r)
112
+
113
+ input_shape = x.shape
114
+ x_flat = x.reshape(-1, self.dim) # Shape: (B, dim)
115
+ Y_reshape = Y.reshape(-1, self.dim) # Shape: (nr, d)
116
+ Yx2_flat = F.linear(x_flat, Y_reshape)
117
+ Yx2 = Yx2_flat.view(-1, self.num_rotations, 2*self.r)
118
+ # is_close = torch.allclose(Yx.view(-1, self.num_rotations, 2*self.r), Yx2, atol=1e-5, rtol=1e-4)
119
+ # if is_close:
120
+ # pass
121
+ # # print("✅ SUCCESS 11: The optimized code produces identical results!")
122
+ # else:
123
+ # print("❌ FAILURE: The results diverge.")
124
+
125
+ ###
126
+ # n of (r,r) @ n of (r,1) -> n of r
127
+ # I_minus_YX_inv_Yx = torch.einsum("nrr,...nr->...nr", I_minus_YX_inv, Yx)
128
+ # I_minus_YX_inv_Yx = torch.einsum("...qr,...r->...q", I_minus_YX_inv, Yx)
129
+
130
+ Yx2_expanded = Yx2.unsqueeze(-1)
131
+ I_minus_YX_inv_ex = I_minus_YX_inv.unsqueeze(0)
132
+ I_minus_YX_inv_Yx2 = I_minus_YX_inv_ex @ Yx2_expanded
133
+ I_minus_YX_inv_Yx2 = I_minus_YX_inv_Yx2.squeeze(-1)
134
+ # is_close = torch.allclose(I_minus_YX_inv_Yx.view(-1, self.num_rotations, 2*self.r), I_minus_YX_inv_Yx2, atol=1e-4, rtol=1e-3)
135
+ # if is_close:
136
+ # pass
137
+ # #print("✅ SUCCESS 22: The optimized code produces identical results!")
138
+ # else:
139
+ # print("❌ FAILURE: The results diverge.")
140
+ # exit()
141
+
142
+ # n of (r,) @ n of (r,d)
143
+ # second_term = torch.einsum("...nr,nrd->...nd", I_minus_YX_inv_Yx, X) # Shape: (batch*seq_len, num_rotations, dim)
144
+ # second_term = torch.einsum("...r, ...rd->...d", I_minus_YX_inv_Yx, X)
145
+
146
+
147
+ # I_minus_YX_inv_Yx_ex = I_minus_YX_inv_Yx2.unsqueeze(-2)
148
+ # X_ex = X.unsqueeze(0)
149
+ # second_term2 = I_minus_YX_inv_Yx_ex @ X_ex
150
+ # second_term2 = second_term2.squeeze(-2)
151
+ # is_close = torch.allclose(second_term, second_term2, atol=1e-5, rtol=1e-4)
152
+ # if is_close:
153
+ # pass
154
+ # # print("✅ SUCCESS 33: The optimized code produces identical results!")
155
+ # else:
156
+ # print("❌ FAILURE: The results diverge.")
157
+
158
+ # second_term = second_term.sum(dim=-2) # Sum over rotations
159
+
160
+ coeffs_flat = I_minus_YX_inv_Yx2.reshape(-1, self.num_rotations * 2 * self.r) # (batch*len, 2n r)
161
+ X_flat = X.reshape(-1, self.dim) #(N*2r, dim)
162
+ second_term3 = torch.matmul(coeffs_flat, X_flat)
163
+ # is_close = torch.allclose(second_term.view(-1, self.dim), second_term3, atol=1e-5, rtol=1e-4)
164
+ # if is_close:
165
+ # pass
166
+ # # print("✅ SUCCESS 44: The optimized code produces identical results!")
167
+ # else:
168
+ # print("❌ FAILURE: The results diverge.")
169
+
170
+ # output = x + 2 * second_term # Shape: (batch*seq_len, dim)
171
+ output = x_flat + 2 * second_term3 # (batch*seq_len, dim)
172
+
173
+ # return output.to(x_dtype)
174
+ return output.view(*input_shape)
175
+
176
+ def get_delta_weight(self):
177
+ """
178
+ Compute the delta weight matrix induced by the rotation layer.
179
+
180
+ Returns:
181
+ Delta weight matrix of shape (dim, dim)
182
+ """
183
+
184
+ X = torch.cat([self.U, -self.V], dim=1) # Shape: (num_rotations, 2r, dim)
185
+ Y = torch.cat([self.V, self.U], dim=1) * self.T # Shape: (num_rotations, 2r, dim)
186
+
187
+ Y_T_X = torch.matmul(Y, X.transpose(1, 2)) # Shape: (num_rotations, 2r, 2r)
188
+ I_2r = torch.eye(2 * self.r, device=X.device, dtype=X.dtype).repeat(self.num_rotations, 1, 1)
189
+ I_minus_YX = I_2r - Y_T_X
190
+
191
+ if self.r == 1:
192
+ I_minus_YX_inv = inverse_2x2(I_minus_YX)
193
+ I_minus_YX_inv_Y = torch.einsum("nRr,nrd->nRd", I_minus_YX_inv, Y) # Shape: (num_rotations, 2r, dim)
194
+ # I_minus_YX_inv_Y = torch.einsum("...rr,...rd->...rd", I_minus_YX_inv, Y) ## reproduce
195
+ else:
196
+ I_minus_YX_inv_Y = torch.linalg.solve(I_minus_YX.to(torch.float32), Y.to(torch.float32)) # Shape: (num_rotations, 2r, dim)
197
+
198
+ # I_minus_YX_inv = torch.linalg.inv(I_minus_YX)
199
+ # I_minus_YX_inv_Y = torch.einsum("...rr,...rd->...rd", I_minus_YX_inv, Y) ## reproduce
200
+
201
+ I_minus_YX_inv_Y = I_minus_YX_inv_Y.to(X.dtype)
202
+
203
+ # I_minus_YX_float = I_minus_YX.float()
204
+ # I_minus_YX_inv = torch.linalg.inv(I_minus_YX_float) # Shape: (num_rotations, 2r, 2r)
205
+ # I_minus_YX_inv = I_minus_YX_inv.to(X.dtype)
206
+
207
+
208
+ # I_minus_YX_inv_Y = torch.einsum("nRr,nrd->nRd", I_minus_YX_inv, Y) # Shape: (num_rotations, 2r, dim)
209
+
210
+
211
+ second_term = torch.einsum("nrd,nrD->ndD", X, I_minus_YX_inv_Y) # Shape: (num_rotations, dim, dim)
212
+ second_term = second_term.sum(dim=0)
213
+ total_delta_weight = 2 * second_term
214
+
215
+ return total_delta_weight
216
+
217
+
218
+ class RotationLayer(BaseTunerLayer):
219
+ """
220
+ Adapter-like wrapper that attaches Rotation modules to a base linear layer.
221
+ """
222
+
223
+ adapter_layer_names: tuple[str, ...] = ("rotation",)
224
+ other_param_names: tuple[str, ...] = ("r", "T", "num_rotations", "scaling")
225
+
226
+ def __init__(self, base_layer: nn.Module, **kwargs):
227
+ # Let BaseTunerLayer do its init (it usually subclasses nn.Module)
228
+ super().__init__()
229
+ # store base layer and adapter containers
230
+ self.base_layer = base_layer
231
+ self.rotation = nn.ModuleDict() # mapping adapter_name -> Rotation module
232
+ self.scaling={} # default scaling per adapter
233
+ self._adapter_config = {} # store r, T, num_rotations per adapter
234
+
235
+ # flags (exposed in a simple way)
236
+ self._disable_adapters = False
237
+ self.merged_adapters: list[str] = []
238
+ self._cast_input_dtype_enabled = True
239
+ self.kwargs = kwargs
240
+
241
+ if isinstance(base_layer, nn.Linear):
242
+ self.in_features = base_layer.in_features
243
+ self.out_features = base_layer.out_features
244
+ else:
245
+ raise NotImplementedError("RotationLayer only supports nn.Linear base layers for now.")
246
+
247
+ @property
248
+ def _available_adapters(self) -> set[str]:
249
+ return set(self.rotation.keys())
250
+
251
+ @property
252
+ def disable_adapters(self) -> bool:
253
+ return self._disable_adapters
254
+
255
+ @property
256
+ def merged(self) -> bool:
257
+ return bool(self.merged_adapters)
258
+
259
+ @property
260
+ def active_adapters(self) -> list[str]:
261
+ # If some external mechanism sets active adapters, prefer it; else use all added adapters.
262
+ return getattr(self, "_active_adapters", list(self.rotation.keys()))
263
+
264
+ def get_base_layer(self) -> nn.Module:
265
+ return self.base_layer
266
+
267
+ def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
268
+ if not self._cast_input_dtype_enabled:
269
+ return x
270
+ return x.to(dtype)
271
+
272
+ def update_layer(
273
+ self,
274
+ adapter_name: str,
275
+ r: int,
276
+ T: float,
277
+ num_rotations: int,
278
+ drop_out: float,
279
+ **kwargs,
280
+ ):
281
+ """
282
+ Add / update a rotation adapter for this layer.
283
+ """
284
+
285
+ if r <= 0:
286
+ raise ValueError(f"r must be positive, got {r}")
287
+ if num_rotations <= 0:
288
+ raise ValueError(f"num_rotations must be positive, got {num_rotations}")
289
+
290
+ rot = Rotation(r=r, dim=self.in_features, T=T, num_rotations=num_rotations, drop_out=drop_out)
291
+ self.rotation[adapter_name] = rot
292
+ self.scaling[adapter_name] = 1.0
293
+ self._adapter_config[adapter_name] = {"r": r, "T": T, "num_rotations": num_rotations}
294
+
295
+ # (optional) helper to set currently active adapters externally
296
+ def set_active_adapters(self, adapters: Optional[list[str]]):
297
+ if adapters is None:
298
+ if hasattr(self, "_active_adapters"):
299
+ delattr(self, "_active_adapters")
300
+ else:
301
+ self._active_adapters = adapters
302
+
303
+
304
+ class Linear(nn.Module, RotationLayer):
305
+ """
306
+ A linear layer with an integrated rotation layer for parameter-efficient fine-tuning.
307
+ """
308
+
309
+ def __init__(self,
310
+ base_layer: nn.Linear,
311
+ adapter_name: str,
312
+ r: int,
313
+ T: float,
314
+ num_rotations: int,
315
+ drop_out: float,
316
+ **kwargs):
317
+
318
+ super().__init__()
319
+ RotationLayer.__init__(self, base_layer=base_layer, **kwargs)
320
+
321
+ self._active_adapter = adapter_name
322
+
323
+ self.update_layer(
324
+ adapter_name=adapter_name,
325
+ r=r,
326
+ T=T,
327
+ num_rotations=num_rotations,
328
+ drop_out=drop_out,
329
+ **kwargs,
330
+ )
331
+
332
+ def merge(self, safe_merge: bool = False, adapter_names: Optional[str] = None):
333
+ """
334
+ Merge the adapter effect into the base layer weights:
335
+ W_merged = W @ R
336
+ where R = I + delta (delta returned by get_delta_weight()).
337
+ """
338
+ adapter_names = check_adapters_to_merge(self, adapter_names)
339
+
340
+ if not adapter_names:
341
+ return
342
+
343
+ base_layer = self.get_base_layer()
344
+ orig_dtype = base_layer.weight.dtype
345
+ # base_layer.weight shape: (out_features, in_features)
346
+ W = base_layer.weight.data # (out, in)
347
+
348
+ for active_adapter in adapter_names:
349
+
350
+ if active_adapter not in self._available_adapters:
351
+ continue
352
+
353
+ delta_R = self.rotation[active_adapter].get_delta_weight() # (in, in)
354
+ R = torch.eye(delta_R.size(0), device=delta_R.device, dtype=delta_R.dtype) + delta_R # (in, in)
355
+ # merged W = W @ R
356
+ merged_W = W.to(R.dtype) @ R
357
+
358
+ if safe_merge and not torch.isfinite(merged_W).all():
359
+ raise ValueError("Merging resulted in non-finite weights. Aborting merge.")
360
+
361
+ base_layer.weight.data = merged_W.contiguous().to(orig_dtype)
362
+
363
+ # mark merged (so unmerge can restore by inverse)
364
+ self.merged_adapters.append(active_adapter)
365
+
366
+
367
+ def unmerge(self):
368
+ """
369
+ Reverse merges in LIFO order (pop merged adapters and invert R).
370
+ """
371
+ base_layer = self.get_base_layer()
372
+ orig_dtype = base_layer.weight.dtype
373
+
374
+ while self.merged_adapters:
375
+ active_adapter = self.merged_adapters.pop()
376
+ if active_adapter not in self._available_adapters:
377
+ continue
378
+ delta_R = self.rotation[active_adapter].get_delta_weight() # (in, in)
379
+ R = torch.eye(delta_R.size(0), device=delta_R.device, dtype=delta_R.dtype) + delta_R
380
+ R_inv = torch.linalg.inv(R)
381
+ merged_W = base_layer.weight.data.to(R.dtype)
382
+ unmerged_W = merged_W @ R_inv
383
+ base_layer.weight.data = unmerged_W.contiguous().to(orig_dtype)
384
+
385
+
386
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
387
+ x_dtype = x.dtype
388
+ base_layer = self.get_base_layer()
389
+
390
+ if self.disable_adapters:
391
+ # if merged, unmerge to ensure base_layer produces original behavior
392
+ if self.merged:
393
+ self.unmerge()
394
+ return base_layer(x, *args, **kwargs).to(x_dtype)
395
+
396
+ if self.merged:
397
+ # if merged into base layer, just forward
398
+ return base_layer(x, *args, **kwargs).to(x_dtype)
399
+
400
+ # otherwise apply active adapters (transform inputs) then call base layer
401
+ for active_adapter in self.active_adapters:
402
+ if active_adapter not in self.rotation:
403
+ continue
404
+ rotation = self.rotation[active_adapter]
405
+ x = self._cast_input_dtype(x, rotation.U.dtype)
406
+ x = rotation(x)
407
+
408
+ return base_layer(x, *args, **kwargs).to(x_dtype)
409
+
410
+ def __repr__(self):
411
+ return f"rotation.{super().__repr__()}"
412
+
nl_tasks/rpeft/rotation/layer_test.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from omini.rotation.layer import Linear, Rotation
4
+
5
+ def test_rotation_merge():
6
+ """
7
+ Test that merging rotation adapter produces the same output as the unmerged version.
8
+ """
9
+ print("="*60)
10
+ print("Testing Rotation Layer Merge")
11
+ print("="*60)
12
+
13
+ # Set random seed for reproducibility
14
+ torch.manual_seed(42)
15
+
16
+ # Configuration
17
+ in_features = 512
18
+ out_features = 1024
19
+ r = 4
20
+ num_rotations = 4
21
+ T = 1.0
22
+ batch_size = 8
23
+ seq_len = 16
24
+
25
+ # Create base linear layer
26
+ base_layer = nn.Linear(in_features, out_features, bias=True)
27
+
28
+ # Create rotation layer
29
+ rotation_layer = Linear(
30
+ base_layer=base_layer,
31
+ adapter_name="default",
32
+ r=r,
33
+ T=T,
34
+ num_rotations=num_rotations
35
+ )
36
+
37
+ # Create random input
38
+ x = torch.randn(batch_size, seq_len, in_features)
39
+
40
+ # Test 1: Forward pass before merge
41
+ print("\n" + "-"*60)
42
+ print("Test 1: Computing output BEFORE merge")
43
+ print("-"*60)
44
+ rotation_layer.eval()
45
+ with torch.no_grad():
46
+ output_before = rotation_layer(x)
47
+
48
+ print(f"Output shape: {output_before.shape}")
49
+ print(f"Output mean: {output_before.mean().item():.6f}")
50
+ print(f"Output std: {output_before.std().item():.6f}")
51
+ print(f"Output min: {output_before.min().item():.6f}")
52
+ print(f"Output max: {output_before.max().item():.6f}")
53
+
54
+ # Save original weight for verification
55
+ original_weight = base_layer.weight.data.clone()
56
+
57
+ # Test 2: Merge adapter
58
+ print("\n" + "-"*60)
59
+ print("Test 2: Merging adapter")
60
+ print("-"*60)
61
+ rotation_layer.merge(safe_merge=True, adapter_names=["default"])
62
+ print(f"✓ Adapter merged successfully")
63
+ print(f"✓ Merged adapters: {rotation_layer.merged_adapters}")
64
+
65
+ # Check that weights have changed
66
+ weight_diff = (base_layer.weight.data - original_weight).abs().max().item()
67
+ print(f"Max weight change: {weight_diff:.6e}")
68
+
69
+ # Test 3: Forward pass after merge
70
+ print("\n" + "-"*60)
71
+ print("Test 3: Computing output AFTER merge")
72
+ print("-"*60)
73
+ with torch.no_grad():
74
+ output_after = rotation_layer(x)
75
+
76
+ print(f"Output shape: {output_after.shape}")
77
+ print(f"Output mean: {output_after.mean().item():.6f}")
78
+ print(f"Output std: {output_after.std().item():.6f}")
79
+ print(f"Output min: {output_after.min().item():.6f}")
80
+ print(f"Output max: {output_after.max().item():.6f}")
81
+
82
+ # Test 4: Compare outputs
83
+ print("\n" + "-"*60)
84
+ print("Test 4: Comparing outputs")
85
+ print("-"*60)
86
+
87
+ # Compute differences
88
+ abs_diff = (output_after - output_before).abs()
89
+ rel_diff = abs_diff / (output_before.abs() + 1e-8)
90
+
91
+ max_abs_diff = abs_diff.max().item()
92
+ mean_abs_diff = abs_diff.mean().item()
93
+ max_rel_diff = rel_diff.max().item()
94
+ mean_rel_diff = rel_diff.mean().item()
95
+
96
+ print(f"Max absolute difference: {max_abs_diff:.6e}")
97
+ print(f"Mean absolute difference: {mean_abs_diff:.6e}")
98
+ print(f"Max relative difference: {max_rel_diff:.6e}")
99
+ print(f"Mean relative difference: {mean_rel_diff:.6e}")
100
+
101
+ # Check if outputs are close
102
+ atol = 1e-4 # Absolute tolerance
103
+ rtol = 1e-3 # Relative tolerance
104
+
105
+ are_close = torch.allclose(output_before, output_after, atol=atol, rtol=rtol)
106
+
107
+ if are_close:
108
+ print(f"\n✅ PASS: Outputs are identical (within atol={atol}, rtol={rtol})")
109
+ else:
110
+ print(f"\n❌ FAIL: Outputs differ significantly")
111
+ print(f" Expected: atol < {atol}, rtol < {rtol}")
112
+ print(f" Got: max_abs_diff = {max_abs_diff:.6e}, max_rel_diff = {max_rel_diff:.6e}")
113
+
114
+ # Test 5: Unmerge and verify
115
+ print("\n" + "-"*60)
116
+ print("Test 5: Testing unmerge")
117
+ print("-"*60)
118
+ rotation_layer.unmerge()
119
+ print(f"✓ Adapter unmerged")
120
+ print(f"✓ Merged adapters: {rotation_layer.merged_adapters}")
121
+
122
+ with torch.no_grad():
123
+ output_unmerged = rotation_layer(x)
124
+
125
+ unmerge_diff = (output_unmerged - output_before).abs().max().item()
126
+ print(f"Max difference after unmerge: {unmerge_diff:.6e}")
127
+
128
+ unmerge_close = torch.allclose(output_before, output_unmerged, atol=atol, rtol=rtol)
129
+ if unmerge_close:
130
+ print(f"✅ PASS: Unmerge restored original behavior")
131
+ else:
132
+ print(f"❌ FAIL: Unmerge did not restore original behavior")
133
+
134
+ # Test 6: Verify weight restoration
135
+ weight_restored_diff = (base_layer.weight.data - original_weight).abs().max().item()
136
+ print(f"Max weight difference after unmerge: {weight_restored_diff:.6e}")
137
+
138
+ weight_restored = torch.allclose(base_layer.weight.data, original_weight, atol=1e-5)
139
+ if weight_restored:
140
+ print(f"✅ PASS: Original weights restored")
141
+ else:
142
+ print(f"❌ FAIL: Original weights not fully restored")
143
+
144
+ print("\n" + "="*60)
145
+ print("Test Summary")
146
+ print("="*60)
147
+ return are_close and unmerge_close and weight_restored
148
+
149
+
150
+ def test_multiple_merges():
151
+ """
152
+ Test merging and unmerging multiple times.
153
+ """
154
+ print("\n" + "="*60)
155
+ print("Testing Multiple Merge/Unmerge Cycles")
156
+ print("="*60)
157
+
158
+ torch.manual_seed(42)
159
+
160
+ in_features = 256
161
+ out_features = 512
162
+ r = 4
163
+ num_rotations = 4
164
+
165
+ base_layer = nn.Linear(in_features, out_features, bias=True)
166
+ rotation_layer = Linear(
167
+ base_layer=base_layer,
168
+ adapter_name="default",
169
+ r=r,
170
+ T=1.0,
171
+ num_rotations=num_rotations
172
+ )
173
+
174
+ x = torch.randn(4, 8, in_features)
175
+ rotation_layer.eval()
176
+
177
+ # Get original output
178
+ with torch.no_grad():
179
+ original_output = rotation_layer(x)
180
+
181
+ # Test multiple cycles
182
+ all_passed = True
183
+ for cycle in range(3):
184
+ print(f"\nCycle {cycle + 1}:")
185
+
186
+ # Merge
187
+ rotation_layer.merge(safe_merge=True)
188
+ with torch.no_grad():
189
+ merged_output = rotation_layer(x)
190
+
191
+ merge_close = torch.allclose(original_output, merged_output, atol=1e-4, rtol=1e-3)
192
+ print(f" Merge: {'✅ PASS' if merge_close else '❌ FAIL'}")
193
+
194
+ # Unmerge
195
+ rotation_layer.unmerge()
196
+ with torch.no_grad():
197
+ unmerged_output = rotation_layer(x)
198
+
199
+ unmerge_close = torch.allclose(original_output, unmerged_output, atol=1e-4, rtol=1e-3)
200
+ print(f" Unmerge: {'✅ PASS' if unmerge_close else '❌ FAIL'}")
201
+
202
+ all_passed = all_passed and merge_close and unmerge_close
203
+
204
+ return all_passed
205
+
206
+
207
+ def test_with_different_dtypes():
208
+ """
209
+ Test merging with different data types.
210
+ """
211
+ print("\n" + "="*60)
212
+ print("Testing Different Data Types")
213
+ print("="*60)
214
+
215
+ torch.manual_seed(42)
216
+
217
+ dtypes = [torch.float32, torch.float16, torch.bfloat16]
218
+ all_passed = True
219
+
220
+ for dtype in dtypes:
221
+ print(f"\nTesting with dtype: {dtype}")
222
+
223
+ in_features = 256
224
+ out_features = 512
225
+ r = 4
226
+ num_rotations = 4
227
+
228
+ base_layer = nn.Linear(in_features, out_features, bias=True)
229
+ base_layer = base_layer.to(dtype)
230
+
231
+ rotation_layer = Linear(
232
+ base_layer=base_layer,
233
+ adapter_name="default",
234
+ r=r,
235
+ T=1.0,
236
+ num_rotations=num_rotations
237
+ )
238
+ rotation_layer = rotation_layer.to(dtype)
239
+
240
+ x = torch.randn(4, 8, in_features, dtype=dtype)
241
+ rotation_layer.eval()
242
+
243
+ with torch.no_grad():
244
+ output_before = rotation_layer(x)
245
+ rotation_layer.merge(safe_merge=True)
246
+ output_after = rotation_layer(x)
247
+
248
+ # Adjust tolerances based on dtype
249
+ if dtype == torch.float32:
250
+ atol, rtol = 1e-5, 1e-4
251
+ elif dtype == torch.float16:
252
+ atol, rtol = 1e-2, 1e-2
253
+ else: # bfloat16
254
+ atol, rtol = 1e-2, 1e-2
255
+
256
+ are_close = torch.allclose(output_before, output_after, atol=atol, rtol=rtol)
257
+
258
+ if are_close:
259
+ print(f" ✅ PASS")
260
+ else:
261
+ max_diff = (output_after - output_before).abs().max().item()
262
+ print(f" ❌ FAIL (max diff: {max_diff:.6e})")
263
+
264
+ all_passed = all_passed and are_close
265
+
266
+ return all_passed
267
+
268
+
269
+ if __name__ == "__main__":
270
+ print("\n" + "="*60)
271
+ print("ROTATION LAYER MERGE TEST SUITE")
272
+ print("="*60)
273
+
274
+ results = {}
275
+
276
+ # Run all tests
277
+ results["basic_merge"] = test_rotation_merge()
278
+ results["multiple_cycles"] = test_multiple_merges()
279
+ results["different_dtypes"] = test_with_different_dtypes()
280
+
281
+ # Print summary
282
+ print("\n" + "="*60)
283
+ print("FINAL SUMMARY")
284
+ print("="*60)
285
+
286
+ for test_name, passed in results.items():
287
+ status = "✅ PASS" if passed else "❌ FAIL"
288
+ print(f"{test_name}: {status}")
289
+
290
+ all_passed = all(results.values())
291
+ print("\n" + "="*60)
292
+ if all_passed:
293
+ print("🎉 ALL TESTS PASSED!")
294
+ else:
295
+ print("⚠️ SOME TESTS FAILED")
296
+ print("="*60)
nl_tasks/rpeft/rotation/model.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from enum import Enum
6
+ from dataclasses import asdict
7
+ from tqdm import tqdm
8
+
9
+
10
+ from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists, onload_layer
11
+
12
+ from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, ModulesToSaveWrapper, _get_submodules
13
+
14
+ from .layer import RotationLayer, Linear
15
+
16
+ TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
17
+
18
+ class RotationTuner(BaseTuner):
19
+
20
+ prefix: str = "rotation"
21
+ tuner_layer_class = RotationLayer
22
+ target_module_mapping = TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING
23
+
24
+
25
+ @staticmethod
26
+ def _check_target_module_exists(rotation_config, key: str) -> bool:
27
+ return check_target_module_exists(rotation_config, key)
28
+
29
+ def _create_and_replace(
30
+ self,
31
+ rotation_config,
32
+ adapter_name: str,
33
+ target: nn.Module,
34
+ target_name: str,
35
+ parent: nn.Module,
36
+ current_key: str,
37
+ **optional_kwargs,
38
+ ) -> None:
39
+ """
40
+ Create and replace a target module with a rotation-augmented version.
41
+
42
+ This method is called when an existing module is already a RotationLayer
43
+ and needs to have a new adapter added to it.
44
+
45
+ Args:
46
+ rotation_config: Configuration for the rotation adapter
47
+ adapter_name: Name of the adapter to add
48
+ target: The target module to augment
49
+ target_name: Name of the target module
50
+ parent: Parent module containing the target
51
+ current_key: Full key path to the current module
52
+ **optional_kwargs: Additional optional arguments
53
+
54
+ Raises:
55
+ ValueError: If current_key is not provided
56
+ """
57
+
58
+ if current_key is None:
59
+ raise ValueError("current_key must be provided to create Rotation layer")
60
+
61
+ # Check if target is already a RotationLayer
62
+ if isinstance(target, RotationLayer):
63
+ target.update_layer(
64
+ adapter_name=adapter_name,
65
+ r=rotation_config.r,
66
+ T=rotation_config.T,
67
+ num_rotations=rotation_config.num_rotations,
68
+ )
69
+ else:
70
+ # Create new rotation layer
71
+ new_module = self._create_new_module(
72
+ rotation_config=rotation_config,
73
+ adapter_name=adapter_name,
74
+ target=target,
75
+ **optional_kwargs,
76
+ )
77
+ if new_module is not None:
78
+ self._replace_module(parent, target_name, new_module, target)
79
+
80
+ def _replace_module(self, parent, child_name, new_module, child):
81
+
82
+ setattr(parent, child_name, new_module)
83
+
84
+ # child layer wraps the original module, unpack it
85
+ if hasattr(child, "base_layer"):
86
+ child = child.base_layer
87
+
88
+ meta = torch.device("meta")
89
+ # dispatch to correct device
90
+ for name, module in new_module.named_modules():
91
+ if (self.prefix in name) or ("ranknum" in name):
92
+ if hasattr(child, "qweight"):
93
+ weight = child.qweight
94
+ elif hasattr(child, "W_q"):
95
+ weight = child.W_q
96
+ elif hasattr(child, "weight"):
97
+ weight = child.weight
98
+ elif getattr(child, "in_proj_weight", None) is not None: # MHA
99
+ weight = child.in_proj_weight
100
+ else:
101
+ weight = next(child.parameters())
102
+ if not any(p.device == meta for p in module.parameters()):
103
+ module.to(weight.device)
104
+
105
+ def _mark_only_adapters_as_trainable(self, model):
106
+
107
+ # First, freeze all parameters
108
+ for n, p in model.named_parameters():
109
+ # print(f'{n}, np {p.requires_grad}')
110
+ if self.prefix not in n:
111
+ p.requires_grad = False
112
+ else:
113
+ p.requires_grad = True
114
+
115
+ # Handle bias parameters based on config
116
+ for active_adapter in self.active_adapters:
117
+ bias_config = self.peft_config[active_adapter].bias
118
+
119
+ if bias_config == "none":
120
+ continue
121
+ elif bias_config == "all":
122
+ # Enable all bias parameters
123
+ for n, p in model.named_parameters():
124
+ if "bias" in n:
125
+ p.requires_grad = True
126
+ elif bias_config == "rotation_only":
127
+ # Enable only bias in rotation layers
128
+ for name, m in model.named_modules():
129
+ if isinstance(m, RotationLayer):
130
+ if hasattr(m, "bias") and m.bias is not None:
131
+ m.bias.requires_grad = True
132
+ else:
133
+ raise NotImplementedError(
134
+ f"Requested bias configuration '{bias_config}' is not implemented. "
135
+ f"Supported values: 'none', 'all', 'rotation_only'"
136
+ )
137
+
138
+ @staticmethod
139
+ def _create_new_module(
140
+ rotation_config,
141
+ adapter_name: str,
142
+ target: nn.Module,
143
+ **kwargs,
144
+ ) -> Optional[nn.Module]:
145
+ """
146
+ Create a new rotation-augmented module.
147
+
148
+ Args:
149
+ rotation_config: Configuration for the rotation adapter
150
+ adapter_name: Name of the adapter
151
+ target: Base module to augment
152
+ **kwargs: Additional arguments
153
+
154
+ Returns:
155
+ New RotationLayer module wrapping the target, or None if unsupported
156
+ """
157
+ if isinstance(target, nn.Linear):
158
+ return Linear(
159
+ base_layer=target,
160
+ adapter_name=adapter_name,
161
+ r=rotation_config.r,
162
+ T=rotation_config.T,
163
+ num_rotations=rotation_config.num_rotations,
164
+ drop_out=rotation_config.drop_out,
165
+ **kwargs,
166
+ )
167
+ else:
168
+ # Unsupported layer type
169
+ print(
170
+ f"Rotation layer does not support {type(target).__name__} yet. "
171
+ f"Skipping this module."
172
+ )
173
+ return None
174
+
175
+
176
+ def __getattr__(self, name: str):
177
+ """Forward missing attributes to the wrapped module."""
178
+ try:
179
+ return super().__getattr__(name) # defer to nn.Module's logic
180
+ except AttributeError:
181
+ if name == "model": # see #1892: prevent infinite recursion if class is not initialized
182
+ raise
183
+ return getattr(self.model, name)
184
+
185
+ def get_peft_config_as_dict(self, inference: bool = False):
186
+ config_dict = {}
187
+ for key, value in self.peft_config.items():
188
+ config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()}
189
+ if inference:
190
+ config["inference_mode"] = True
191
+ config_dict[key] = config
192
+ return config
193
+
194
+
195
+ def _set_adapter_layers(self, enabled=True):
196
+ for module in self.model.modules():
197
+ if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
198
+ module.enable_adapters(enabled)
199
+
200
+ def enable_adapter_layers(self) -> None:
201
+ """Enable all adapters.
202
+
203
+ Call this if you have previously disabled all adapters and want to re-enable them.
204
+ """
205
+ self._set_adapter_layers(enabled=True)
206
+
207
+ def disable_adapter_layers(self):
208
+ for active_adapter in self.active_adapters:
209
+ val = self.peft_config[active_adapter].bias
210
+ if val != "none":
211
+ msg = (
212
+ f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same "
213
+ "output as the base model would without adaption."
214
+ )
215
+ print(msg)
216
+ self._set_adapter_layers(enabled=False)
217
+
218
+ def set_adapter(self, adapter_name, inference_mode):
219
+ """Set the active adapter(s).
220
+
221
+ Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
222
+ not desired, use the following code.
223
+
224
+ ```py
225
+ >>> for name, param in model_peft.named_parameters():
226
+ ... if ...: # some check on name (ex. if 'lora' in name)
227
+ ... param.requires_grad = False
228
+ ```
229
+
230
+ Args:
231
+ adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated.
232
+ """
233
+ for module in self.model.modules():
234
+ if isinstance(module, RotationLayer):
235
+ if module.merged:
236
+ print("Adapter cannot be set when the model is merged. Unmerging the model first.")
237
+ module.unmerge()
238
+ module.set_adapter(adapter_name, inference_mode)
239
+ self.active_adapter = adapter_name
240
+
241
+ def merge_adapter(self, adapter_names: Optional[list[str]] = None) -> None:
242
+ """
243
+ Merge adapter weights into the base model weights.
244
+
245
+ This can speed up inference by eliminating the need for runtime
246
+ rotation computations.
247
+
248
+ Args:
249
+ adapter_names: List of adapter names to merge. If None, merges all
250
+ active adapters.
251
+ """
252
+ for module in self.model.modules():
253
+ if isinstance(module, RotationLayer):
254
+ module.merge(safe_merge=False, adapter_names=adapter_names)
255
+
256
+
257
+ def unmerge_adapter(self) -> None:
258
+ """
259
+ Unmerge adapter weights from the base model weights.
260
+
261
+ This reverses the merge operation, restoring dynamic adapter behavior.
262
+ """
263
+ for module in self.model.modules():
264
+ if isinstance(module, RotationLayer):
265
+ module.unmerge()
266
+
267
+ @staticmethod
268
+ def _prepare_adapter_config(peft_config, model_config):
269
+
270
+ if peft_config.target_modules is None:
271
+ if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING:
272
+ raise ValueError("Please specify `target_modules` in `peft_config`")
273
+ peft_config.target_modules = set(
274
+ TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING[model_config["model_type"]]
275
+ )
276
+
277
+ return peft_config
278
+
279
+
280
+ def _check_new_adapter_config(self, config) -> None:
281
+ """
282
+ Check the validity of a new adapter configuration.
283
+
284
+ Args:
285
+ config: Configuration to validate
286
+
287
+ Raises:
288
+ ValueError: If configuration is invalid
289
+ """
290
+ # Validate rank
291
+ if config.r <= 0:
292
+ raise ValueError(f"r must be positive, got {config.r}")
293
+
294
+ # Validate num_rotations
295
+ if config.num_rotations <= 0:
296
+ raise ValueError(
297
+ f"num_rotations must be positive, got {config.num_rotations}"
298
+ )
299
+
300
+
301
+ # Validate bias configuration
302
+ valid_bias_configs = ["none", "all", "rotation_only"]
303
+ if hasattr(config, "bias") and config.bias not in valid_bias_configs:
304
+ raise ValueError(
305
+ f"Invalid bias configuration '{config.bias}'. "
306
+ f"Must be one of {valid_bias_configs}"
307
+ )
308
+
309
+
310
+ def _unload_and_optionally_merge(
311
+ self,
312
+ merge=True,
313
+ progressbar: bool = False,
314
+ safe_merge: bool = False,
315
+ adapter_names: Optional[list[str]] = None,
316
+ ):
317
+ if merge:
318
+ self._check_merge_allowed()
319
+
320
+ key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
321
+ desc = "Unloading " + ("and merging " if merge else "") + "model"
322
+ for key in tqdm(key_list, disable=not progressbar, desc=desc):
323
+ try:
324
+ parent, target, target_name = _get_submodules(self.model, key)
325
+ except AttributeError:
326
+ continue
327
+ with onload_layer(target):
328
+ if hasattr(target, "unload_and_optionally_merge_module"):
329
+ # if layers have special unloading method, like MultiheadAttention, use that
330
+ unloaded_module = target.unload_and_optionally_merge_module(
331
+ merge=merge, safe_merge=safe_merge, adapter_names=adapter_names
332
+ )
333
+ self._replace_module(parent, target_name, unloaded_module, target)
334
+ elif hasattr(target, "base_layer"):
335
+ if merge:
336
+ target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
337
+ self._replace_module(parent, target_name, target.get_base_layer(), target)
338
+
339
+ return self.model
340
+
341
+ def delete_adapter(self, adapter_name: str) -> None:
342
+ """
343
+ Deletes an existing adapter.
344
+
345
+ Args:
346
+ adapter_name (str): Name of the adapter to be deleted.
347
+ """
348
+ if adapter_name not in list(self.peft_config.keys()):
349
+ raise ValueError(f"Adapter {adapter_name} does not exist")
350
+ del self.peft_config[adapter_name]
351
+
352
+ key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
353
+ new_adapter = None
354
+ for key in key_list:
355
+ _, target, _ = _get_submodules(self.model, key)
356
+ if isinstance(target, RotationLayer):
357
+ target.delete_adapter(adapter_name)
358
+ if new_adapter is None:
359
+ new_adapter = target.active_adapters[:]
360
+
361
+ self.active_adapter = new_adapter or []
362
+ self._delete_auxiliary_adapter(adapter_name, new_active_adapters=new_adapter)
363
+
364
+ def merge_and_unload(
365
+ self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
366
+ ) -> torch.nn.Module:
367
+ r"""
368
+ This method merges the OFT layers into the base model. This is needed if someone wants to use the base model as
369
+ a standalone model.
370
+
371
+ Args:
372
+ progressbar (`bool`):
373
+ whether to show a progressbar indicating the unload and merge process
374
+ safe_merge (`bool`):
375
+ whether to activate the safe merging check to check if there is any potential Nan in the adapter
376
+ weights
377
+ adapter_names (`List[str]`, *optional*):
378
+ The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
379
+ to `None`.
380
+
381
+ """
382
+ return self._unload_and_optionally_merge(
383
+ progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names
384
+ )
385
+
386
+ def unload(self) -> torch.nn.Module:
387
+ """
388
+ Gets back the base model by removing all the oft modules without merging. This gives back the original base
389
+ model.
390
+ """
391
+ return self._unload_and_optionally_merge(merge=False)
392
+
nl_tasks/rpeft/rotation/rotation_config.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Optional
3
+ # from peft.config import PeftConfig
4
+ from rpeft.utils import PeftConfig
5
+ from rpeft.utils import PeftType
6
+
7
+ @dataclass
8
+ class RotationConfig(PeftConfig):
9
+ """
10
+ Configuration class for Rotation-based Parameter-Efficient Fine-Tuning.
11
+
12
+ This configuration stores all parameters needed to apply the Rotation method
13
+ (based on Cayley transformation) to a model's linear layers.
14
+
15
+ Args:
16
+ r (`int`):
17
+ The rank parameter for the low-rank approximation in rotation matrices.
18
+ T (`float`, *optional*, defaults to 1.0):
19
+ Temperature parameter for the transformation.
20
+ num_rotations (`int`, *optional*, defaults to 4):
21
+ Number of rotation matrices to use in parallel.
22
+ target_modules (`Union[List[str], str]`):
23
+ Module names to apply rotation to (e.g., ["q_proj", "v_proj"]).
24
+ target_modules_to_skip (`Union[List[str], str]`, *optional*):
25
+ Module names to skip when applying rotation.
26
+ modules_to_save (`Union[List[str], str]`, *optional*):
27
+ Modules to save in addition to rotation parameters.
28
+ layers_to_transform (`Union[List[int], int]`, *optional*):
29
+ Layers to transform. If None, all layers matching target_modules are transformed.
30
+ apply_before (`bool`, *optional*, defaults to False):
31
+ If True, apply rotation before the base linear layer. If False, apply after.
32
+ """
33
+
34
+ peft_type: str = field(default="ROTATION", init=False)
35
+ target_modules: Optional[List[str]] = field(
36
+ default=None,
37
+ metadata={
38
+ "help": "List of module names to apply rotation to (e.g., ['q_proj', 'v_proj', 'linear'])"
39
+ },
40
+ )
41
+ target_modules_to_skip: Optional[List[str]] = field(
42
+ default=None,
43
+ metadata={"help": "List of module names to skip when applying rotation"},
44
+ )
45
+ modules_to_save: Optional[List[str]] = field(
46
+ default=None,
47
+ metadata={"help": "List of modules to save in addition to rotation parameters"},
48
+ )
49
+ r: int = field(
50
+ default=8,
51
+ metadata={"help": "Rank parameter for low-rank approximation"},
52
+ )
53
+ T: float = field(
54
+ default=1.0,
55
+ metadata={"help": "Temperature parameter for Cayley transformation"},
56
+ )
57
+ num_rotations: int = field(
58
+ default=4,
59
+ metadata={"help": "Number of rotation matrices to use in parallel"},
60
+ )
61
+
62
+ bias: str = field(
63
+ default="none",
64
+ metadata={
65
+ "help": "Bias training configuration. Options: 'none', 'all', 'rotation_only'"
66
+ }
67
+ )
68
+ layers_to_transform: Optional[List[int]] = field(
69
+ default=None,
70
+ metadata={"help": "Layers to transform. If None, all matching layers are transformed"},
71
+ )
72
+ drop_out: float = field(
73
+ default=0.0,
74
+ metadata={
75
+ 'help': 'intput drop out rate'
76
+ }
77
+ )
78
+
79
+ def __post_init__(self):
80
+ ##### Diff
81
+ self.peft_type = PeftType.ROTATION
82
+ self.target_modules = (
83
+ set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
84
+ )
85
+ self.target_modules_to_skip = (
86
+ set(self.target_modules_to_skip)
87
+ if isinstance(self.target_modules_to_skip, list)
88
+ else self.target_modules_to_skip
89
+ )
nl_tasks/rpeft/utils/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ # There's no way to ignore "F401 '...' imported but unused" warnings in this
3
+ # module, but to preserve other warnings. So, don't check this module at all
4
+
5
+ # coding=utf-8
6
+ # Copyright 2023-present the HuggingFace Inc. team.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ from .adapters_utils import CONFIG_NAME, WEIGHTS_NAME
20
+ from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType
21
+ from .other import (
22
+ TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
23
+ _set_trainable,
24
+ bloom_model_postprocess_past_key_value,
25
+ prepare_model_for_int8_training,
26
+ shift_tokens_right,
27
+ transpose,
28
+ )
29
+ from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict
nl_tasks/rpeft/utils/adapters_utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Original License:
3
+ # Copyright 2023-present the HuggingFace Inc. team.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ WEIGHTS_NAME = "adapter_model.bin"
17
+ CONFIG_NAME = "adapter_config.json"
18
+
19
+ # TODO: add automapping and superclass here?
nl_tasks/rpeft/utils/config.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Original License:
3
+ # Copyright 2023-present the HuggingFace Inc. team.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import enum
17
+ import json
18
+ import os
19
+ from dataclasses import asdict, dataclass, field
20
+ from typing import Optional, Union
21
+
22
+ from huggingface_hub import hf_hub_download
23
+ from transformers.utils import PushToHubMixin, http_user_agent
24
+
25
+ from .adapters_utils import CONFIG_NAME
26
+
27
+
28
+ class PeftType(str, enum.Enum):
29
+ PROMPT_TUNING = "PROMPT_TUNING"
30
+ P_TUNING = "P_TUNING"
31
+ PREFIX_TUNING = "PREFIX_TUNING"
32
+ LORA = "LORA"
33
+ BOTTLENECK = "BOTTLENECK"
34
+ QUANTA = "QUANTA"
35
+
36
+ ROTATION = "ROTATION"
37
+
38
+
39
+
40
+ class TaskType(str, enum.Enum):
41
+ SEQ_CLS = "SEQ_CLS"
42
+ SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM"
43
+ CAUSAL_LM = "CAUSAL_LM"
44
+ TOKEN_CLS = "TOKEN_CLS"
45
+
46
+
47
+ @dataclass
48
+ class PeftConfigMixin(PushToHubMixin):
49
+ r"""
50
+ This is the base configuration class for PEFT adapter models. It contains all the methods that are common to all
51
+ PEFT adapter models. This class inherits from `transformers.utils.PushToHubMixin` which contains the methods to
52
+ push your model to the Hub. The method `save_pretrained` will save the configuration of your adapter model in a
53
+ directory. The method `from_pretrained` will load the configuration of your adapter model from a directory.
54
+
55
+ Args:
56
+ peft_type (Union[[`~peft_local_tensor.utils.config.PeftType`], `str`]): The type of Peft method to use.
57
+ """
58
+ peft_type: Optional[PeftType] = field(default=None, metadata={"help": "The type of PEFT model."})
59
+
60
+ @property
61
+ def __dict__(self):
62
+ return asdict(self)
63
+
64
+ def to_dict(self):
65
+ return self.__dict__
66
+
67
+ @classmethod
68
+ def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional[str] = None, **kwargs):
69
+ r"""
70
+ This method loads the configuration of your adapter model from a directory.
71
+
72
+ Args:
73
+ pretrained_model_name_or_path (`str`):
74
+ The directory or the Hub repository id where the configuration is saved.
75
+ kwargs (additional keyword arguments, *optional*):
76
+ Additional keyword arguments passed along to the child class initialization.
77
+ """
78
+ path = (
79
+ os.path.join(pretrained_model_name_or_path, subfolder)
80
+ if subfolder is not None
81
+ else pretrained_model_name_or_path
82
+ )
83
+
84
+ hf_hub_download_kwargs, class_kwargs, _ = cls._split_kwargs(kwargs)
85
+ if "user_agent" not in hf_hub_download_kwargs:
86
+ hf_hub_download_kwargs["user_agent"] = http_user_agent()
87
+
88
+ if os.path.isfile(os.path.join(path, CONFIG_NAME)):
89
+ config_file = os.path.join(path, CONFIG_NAME)
90
+ else:
91
+ try:
92
+ config_file = hf_hub_download(
93
+ pretrained_model_name_or_path, CONFIG_NAME, subfolder=subfolder, **hf_hub_download_kwargs
94
+ )
95
+ except Exception as exc:
96
+ raise ValueError(f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'") from exc
97
+
98
+ loaded_attributes = cls.from_json_file(config_file)
99
+ kwargs = {**class_kwargs, **loaded_attributes}
100
+ kwargs = cls.check_kwargs(**kwargs)
101
+ return cls.from_peft_type(**kwargs)
102
+
103
+ def save_pretrained(self, save_directory, **kwargs):
104
+ r"""
105
+ This method saves the configuration of your adapter model in a directory.
106
+
107
+ Args:
108
+ save_directory (`str`):
109
+ The directory where the configuration will be saved.
110
+ **kwargs:
111
+ Additional keyword arguments passed along to the `transformers.utils.PushToHubMixin.push_to_hub`
112
+ method.
113
+ """
114
+ if os.path.isfile(save_directory):
115
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
116
+
117
+ os.makedirs(save_directory, exist_ok=True)
118
+ auto_mapping_dict = kwargs.pop("auto_mapping_dict", None)
119
+
120
+ output_dict = self.to_dict()
121
+ # converting set type to list
122
+ for key, value in output_dict.items():
123
+ if isinstance(value, set):
124
+ output_dict[key] = list(value)
125
+
126
+ output_path = os.path.join(save_directory, CONFIG_NAME)
127
+
128
+ # Add auto mapping details for custom models.
129
+ if auto_mapping_dict is not None:
130
+ output_dict["auto_mapping"] = auto_mapping_dict
131
+
132
+ # save it
133
+ with open(output_path, "w") as writer:
134
+ writer.write(json.dumps(output_dict, indent=2, sort_keys=True))
135
+
136
+ @classmethod
137
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
138
+ r"""
139
+ This method loads the configuration of your adapter model from a directory.
140
+
141
+ Args:
142
+ pretrained_model_name_or_path (`str`):
143
+ The directory or the hub-id where the configuration is saved.
144
+ **kwargs:
145
+ Additional keyword arguments passed along to the child class initialization.
146
+ """
147
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)):
148
+ config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
149
+ else:
150
+ try:
151
+ config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME)
152
+ except Exception:
153
+ raise ValueError(f"Can't find config.json at '{pretrained_model_name_or_path}'")
154
+
155
+ loaded_attributes = cls.from_json_file(config_file)
156
+
157
+ config = cls(**kwargs)
158
+
159
+ for key, value in loaded_attributes.items():
160
+ if hasattr(config, key):
161
+ setattr(config, key, value)
162
+
163
+ return config
164
+
165
+ @classmethod
166
+ def from_json_file(cls, path_json_file, **kwargs):
167
+ r"""
168
+ Loads a configuration file from a json file.
169
+
170
+ Args:
171
+ path_json_file (`str`):
172
+ The path to the json file.
173
+ """
174
+ with open(path_json_file, "r") as file:
175
+ json_object = json.load(file)
176
+
177
+ return json_object
178
+
179
+
180
+ @dataclass
181
+ class PeftConfig(PeftConfigMixin):
182
+ """
183
+ This is the base configuration class to store the configuration of a :class:`~peft_local_tensor.PeftModel`.
184
+
185
+ Args:
186
+ peft_type (Union[[`~peft_local_tensor.utils.config.PeftType`], `str`]): The type of Peft method to use.
187
+ task_type (Union[[`~peft_local_tensor.utils.config.TaskType`], `str`]): The type of task to perform.
188
+ inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode.
189
+ """
190
+
191
+ base_model_name_or_path: str = field(default=None, metadata={"help": "The name of the base model to use."})
192
+ revision: Optional[str] = field(default=None, metadata={"help": "The specific base model version to use."})
193
+ peft_type: Union[str, PeftType] = field(default=None, metadata={"help": "Peft type"})
194
+ task_type: Union[str, TaskType] = field(default=None, metadata={"help": "Task type"})
195
+ inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"})
196
+
197
+
198
+ @dataclass
199
+ class PromptLearningConfig(PeftConfig):
200
+ """
201
+ This is the base configuration class to store the configuration of a Union[[`~peft_local_tensor.PrefixTuning`],
202
+ [`~peft_local_tensor.PromptEncoder`], [`~peft_local_tensor.PromptTuning`]].
203
+
204
+ Args:
205
+ num_virtual_tokens (`int`): The number of virtual tokens to use.
206
+ token_dim (`int`): The hidden embedding dimension of the base transformer model.
207
+ num_transformer_submodules (`int`): The number of transformer submodules in the base transformer model.
208
+ num_attention_heads (`int`): The number of attention heads in the base transformer model.
209
+ num_layers (`int`): The number of layers in the base transformer model.
210
+ """
211
+
212
+ num_virtual_tokens: int = field(default=None, metadata={"help": "Number of virtual tokens"})
213
+ token_dim: int = field(
214
+ default=None, metadata={"help": "The hidden embedding dimension of the base transformer model"}
215
+ )
216
+ num_transformer_submodules: Optional[int] = field(
217
+ default=None, metadata={"help": "Number of transformer submodules"}
218
+ )
219
+ num_attention_heads: Optional[int] = field(default=None, metadata={"help": "Number of attention heads"})
220
+ num_layers: Optional[int] = field(default=None, metadata={"help": "Number of transformer layers"})
nl_tasks/rpeft/utils/other.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Original License:
3
+ # Copyright 2023-present the HuggingFace Inc. team.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import torch
18
+
19
+
20
+ # needed for prefix-tuning of bloom model
21
+ def bloom_model_postprocess_past_key_value(past_key_values):
22
+ past_key_values = torch.cat(past_key_values)
23
+ total_layers, batch_size, num_attention_heads, num_virtual_tokens, head_dim = past_key_values.shape
24
+ keys = past_key_values[: total_layers // 2]
25
+ keys = keys.transpose(2, 3).reshape(
26
+ total_layers // 2, batch_size * num_attention_heads, head_dim, num_virtual_tokens
27
+ )
28
+ values = past_key_values[total_layers // 2 :]
29
+ values = values.reshape(total_layers // 2, batch_size * num_attention_heads, num_virtual_tokens, head_dim)
30
+
31
+ return tuple(zip(keys, values))
32
+
33
+
34
+ def prepare_model_for_int8_training(
35
+ model, output_embedding_layer_name="lm_head", use_gradient_checkpointing=True, layer_norm_names=["layer_norm"]
36
+ ):
37
+ r"""
38
+ This method wrapps the entire protocol for preparing a model before running a training. This includes:
39
+ 1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
40
+ head to fp32
41
+
42
+ Args:
43
+ model, (`transformers.PreTrainedModel`):
44
+ The loaded model from `transformers`
45
+ """
46
+ loaded_in_8bit = getattr(model, "is_loaded_in_8bit", False)
47
+
48
+ for name, param in model.named_parameters():
49
+ # freeze base model's layers
50
+ param.requires_grad = False
51
+
52
+ if loaded_in_8bit:
53
+ # cast layer norm in fp32 for stability for 8bit models
54
+ if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
55
+ param.data = param.data.to(torch.float32)
56
+
57
+ if loaded_in_8bit and use_gradient_checkpointing:
58
+ # For backward compatibility
59
+ if hasattr(model, "enable_input_require_grads"):
60
+ model.enable_input_require_grads()
61
+ else:
62
+
63
+ def make_inputs_require_grad(module, input, output):
64
+ output.requires_grad_(True)
65
+
66
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
67
+
68
+ # enable gradient checkpointing for memory efficiency
69
+ model.gradient_checkpointing_enable()
70
+
71
+ if hasattr(model, output_embedding_layer_name):
72
+ output_embedding_layer = getattr(model, output_embedding_layer_name)
73
+ input_dtype = output_embedding_layer.weight.dtype
74
+
75
+ class CastOutputToFloat(torch.nn.Sequential):
76
+ r"""
77
+ Manually cast to the expected dtype of the lm_head as sometimes there is a final layer norm that is casted
78
+ in fp32
79
+
80
+ """
81
+
82
+ def forward(self, x):
83
+ return super().forward(x.to(input_dtype)).to(torch.float32)
84
+
85
+ setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
86
+
87
+ return model
88
+
89
+
90
+ TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = {
91
+ "bloom": bloom_model_postprocess_past_key_value,
92
+ }
93
+
94
+
95
+ # copied from transformers.models.bart.modeling_bart
96
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
97
+ """
98
+ Shift input ids one token to the right.
99
+
100
+ Args:
101
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids
102
+ pad_token_id (`int`): The id of the `padding` token.
103
+ decoder_start_token_id (`int`): The id of the `start` token.
104
+ """
105
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
106
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
107
+ shifted_input_ids[:, 0] = decoder_start_token_id
108
+
109
+ if pad_token_id is None:
110
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
111
+ # replace possible -100 values in labels by `pad_token_id`
112
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
113
+
114
+ return shifted_input_ids
115
+
116
+
117
+ def _set_trainable(model):
118
+ if model.modules_to_save is not None:
119
+ for name, param in model.named_parameters():
120
+ if any(module_name in name for module_name in model.modules_to_save):
121
+ param.requires_grad = True
122
+
123
+
124
+ def fsdp_auto_wrap_policy(model):
125
+ import functools
126
+ import os
127
+
128
+ from accelerate import FullyShardedDataParallelPlugin
129
+ from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
130
+
131
+ from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
132
+
133
+ def lambda_policy_fn(module):
134
+ if (
135
+ len(list(module.named_children())) == 0
136
+ and getattr(module, "weight", None) is not None
137
+ and module.weight.requires_grad
138
+ ):
139
+ return True
140
+ return False
141
+
142
+ lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
143
+ transformer_wrap_policy = functools.partial(
144
+ transformer_auto_wrap_policy,
145
+ transformer_layer_cls=(
146
+ PrefixEncoder,
147
+ PromptEncoder,
148
+ PromptEmbedding,
149
+ FullyShardedDataParallelPlugin.get_module_class_from_name(
150
+ model, os.environ.get("FSDP_TRANSFORMER_CLS_TO_WRAP", "")
151
+ ),
152
+ ),
153
+ )
154
+
155
+ auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
156
+ return auto_wrap_policy
157
+
158
+
159
+ def transpose(weight, fan_in_fan_out):
160
+ return weight.T if fan_in_fan_out else weight
nl_tasks/rpeft/utils/save_and_load.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Original License:
3
+ # Copyright 2023-present the HuggingFace Inc. team.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from .config import PeftType
18
+ import warnings
19
+ import torch
20
+
21
+ def _find_mismatched_keys(
22
+ model: torch.nn.Module, peft_model_state_dict: dict[str, torch.Tensor], ignore_mismatched_sizes: bool = False
23
+ ) -> tuple[dict[str, torch.Tensor], list[tuple[str, tuple[int, ...], tuple[int, ...]]]]:
24
+ if not ignore_mismatched_sizes:
25
+ return peft_model_state_dict, []
26
+
27
+ mismatched = []
28
+ state_dict = model.state_dict()
29
+ for key, tensor in peft_model_state_dict.items():
30
+ if key not in state_dict:
31
+ continue
32
+
33
+ # see https://github.com/huggingface/transformers/blob/09f9f566de83eef1f13ee83b5a1bbeebde5c80c1/src/transformers/modeling_utils.py#L3858-L3864
34
+ if (state_dict[key].shape[-1] == 1) and (state_dict[key].numel() * 2 == tensor.numel()):
35
+ # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size
36
+ # differences. Without matching with module type or parameter type it seems like a practical way to detect
37
+ # valid 4bit weights.
38
+ continue
39
+
40
+ if state_dict[key].shape != tensor.shape:
41
+ mismatched.append((key, tensor.shape, state_dict[key].shape))
42
+
43
+ for key, _, _ in mismatched:
44
+ del peft_model_state_dict[key]
45
+
46
+ return peft_model_state_dict, mismatched
47
+
48
+ def get_peft_model_state_dict(model, state_dict=None):
49
+ """
50
+ Get the state dict of the Peft model.
51
+
52
+ Args:
53
+ model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP,
54
+ the model should be the underlying model/unwrapped model (i.e. model.module).
55
+ state_dict (`dict`, *optional*, defaults to `None`):
56
+ The state dict of the model. If not provided, the state dict of the model
57
+ will be used.
58
+ """
59
+ if state_dict is None:
60
+ state_dict = model.state_dict()
61
+ if model.peft_config.peft_type == PeftType.LORA:
62
+ # to_return = lora_state_dict(model, bias=model.peft_config.bias)
63
+ # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
64
+ # to directly with the state dict which is necessary when using DeepSpeed or FSDP
65
+ bias = model.peft_config.bias
66
+ if bias == "none":
67
+ to_return = {k: state_dict[k] for k in state_dict if "lora_" in k}
68
+ elif bias == "all":
69
+ to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k}
70
+ elif bias == "lora_only":
71
+ to_return = {}
72
+ for k in state_dict:
73
+ if "lora_" in k:
74
+ to_return[k] = state_dict[k]
75
+ bias_name = k.split("lora_")[0] + "bias"
76
+ if bias_name in state_dict:
77
+ to_return[bias_name] = state_dict[bias_name]
78
+ else:
79
+ raise NotImplementedError
80
+ elif model.peft_config.peft_type == PeftType.BOTTLENECK:
81
+ # return the state dict of the model with Bottleneck adapters
82
+ bias = model.peft_config.bias
83
+ if bias == "none":
84
+ to_return = {k: state_dict[k] for k in state_dict if "adapter_" in k}
85
+ elif bias == "all":
86
+ to_return = {k: state_dict[k] for k in state_dict if "adapter_" in k or "bias" in k}
87
+ elif bias == "adapter_only":
88
+ to_return = {}
89
+ for k in state_dict:
90
+ if "adapter_" in k:
91
+ to_return[k] = state_dict[k]
92
+ bias_name = k.split("adapter_")[0] + "bias"
93
+ if bias_name in state_dict:
94
+ to_return[bias_name] = state_dict[bias_name]
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ elif model.peft_config.peft_type == PeftType.ROTATION:
99
+ bias = model.peft_config.bias
100
+ if bias == "none":
101
+ to_return = {k: state_dict[k] for k in state_dict if "rotation" in k}
102
+ elif bias == "all":
103
+ to_return = {k: state_dict[k] for k in state_dict if "rotation" in k or "bias" in k}
104
+ elif bias == "rotation_only":
105
+ to_return = {}
106
+ for k in state_dict:
107
+ if "rotation" in k:
108
+ to_return[k] = state_dict[k]
109
+ bias_name = k.split("rotation")[0] + "bias"
110
+ if bias_name in state_dict:
111
+ to_return[bias_name] = state_dict[bias_name]
112
+ else:
113
+ raise NotImplementedError
114
+
115
+ elif model.peft_config.is_prompt_learning:
116
+ to_return = {}
117
+ if model.peft_config.inference_mode:
118
+ prompt_embeddings = model.prompt_encoder.embedding.weight
119
+ else:
120
+ prompt_embeddings = model.get_prompt_embedding_to_save()
121
+ to_return["prompt_embeddings"] = prompt_embeddings
122
+ else:
123
+ raise NotImplementedError
124
+ if model.modules_to_save is not None:
125
+ for key, value in state_dict.items():
126
+ if any(module_name in key for module_name in model.modules_to_save):
127
+ to_return[key] = value
128
+ return to_return
129
+
130
+
131
+ def set_peft_model_state_dict(model, peft_model_state_dict,
132
+ adapter_name="default",
133
+ ignore_mismatched_sizes: bool = False):
134
+ """
135
+ Set the state dict of the Peft model.
136
+
137
+ Args:
138
+ model ([`PeftModel`]): The Peft model.
139
+ peft_model_state_dict (`dict`): The state dict of the Peft model.
140
+ adapter_name (`str`, *optional*, defaults to `"default"`):
141
+ The name of the adapter whose state dict should be set.
142
+ """
143
+ peft_model_state_dict, mismatched_keys = _find_mismatched_keys(
144
+ model, peft_model_state_dict, ignore_mismatched_sizes=ignore_mismatched_sizes
145
+ )
146
+ if mismatched_keys:
147
+ # see https://github.com/huggingface/transformers/blob/09f9f566de83eef1f13ee83b5a1bbeebde5c80c1/src/transformers/modeling_utils.py#L4039
148
+ mismatched_warning = "\n".join(
149
+ [
150
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
151
+ for key, shape1, shape2 in mismatched_keys
152
+ ]
153
+ )
154
+ msg = (
155
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint "
156
+ f"and are being ignored because you passed `ignore_mismatched_sizes=True`: {mismatched_warning}."
157
+ )
158
+ warnings.warn(msg)
159
+
160
+ model.load_state_dict(peft_model_state_dict, strict=False)
161
+ if model.peft_config.peft_type != PeftType.LORA and model.peft_config.peft_type != PeftType.BOTTLENECK \
162
+ and model.peft_config.peft_type != PeftType.ROTATION:
163
+ model.prompt_encoder.embedding.load_state_dict(
164
+ {"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True
165
+ )
166
+ return model
nl_tasks/scripts/.nfs80e7f26e00566c630000664a ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # export OMINI_CONFIG=./config/commonsense.yaml
2
+ export OMINI_CONFIG=./config/commonsense.yaml
3
+ #echo $OMINI_CONFIG
4
+ export TOKENIZERS_PARALLELISM=true
5
+
6
+ # CUDA Include (/cuda.h)
7
+ CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
8
+
9
+ # 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
10
+ export CPATH=$CPATH:$CUDA_INCLUDE_PATH
11
+ export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
12
+ # echo "CPATH is set to: $CPATH"
13
+ # echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
14
+
15
+ #./exps/run_ex15_3ep
16
+ # ./exp_init/run_ex02/ft2
17
+ # ADAPTER = "--model.merge_adapter_path "./exps/run_ex12/ft2" --model.merge_output_path "./exps/run_ex12/merged"
18
+ # export ADAPTER = "--model.merge_adapter_path ./exp395/run_ex01/ft2 --model.merge_output_path ./exp395/run_ex01/merged"
19
+ # accelerate launch --main_process_port 41353 -m src.merge \
20
+ # --config_path $OMINI_CONFIG \
21
+ # --model.merge_adapter_path "./exps/run_ex19_2ep/ft2" --model.merge_output_path "./exps/run_ex19_2ep/merged"
22
+
23
+ # OUTPUT="./exps/run_ex19_2ep/merged"
24
+
25
+ # date +"%F %T"
26
+ # python inference/MATH_infer.py --model $OUTPUT
27
+ # date +"%F %T"
28
+ # python inference/gsm8k_infer.py --model $OUTPUT
29
+ # date +"%F %T"
30
+
31
+
32
+ # MERGE_DIR="./exps/run_ex24"
33
+ # accelerate launch --main_process_port 41353 -m src.merge \
34
+ # --config_path $OMINI_CONFIG \
35
+ # --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
36
+
37
+ # date +"%F %T"
38
+ # python inference/MATH_infer.py --model $MERGE_DIR/merged/
39
+ # date +"%F %T"
40
+ # python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
41
+ # date +"%F %T"
42
+
43
+
44
+ # MERGE_DIR="./exps/run_ex25"
45
+ # accelerate launch --main_process_port 41353 -m src.merge \
46
+ # --config_path $OMINI_CONFIG \
47
+ # --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
48
+
49
+ # date +"%F %T"
50
+ # python inference/MATH_infer.py --model $MERGE_DIR/merged/
51
+ # date +"%F %T"
52
+ # python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
53
+ # date +"%F %T"
54
+
55
+
56
+ # MERGE_DIR="./exps/run_ex26"
57
+ # accelerate launch --main_process_port 41353 -m src.merge \
58
+ # --config_path $OMINI_CONFIG \
59
+ # --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
60
+
61
+ # date +"%F %T"
62
+ # python inference/MATH_infer.py --model $MERGE_DIR/merged/
63
+ # date +"%F %T"
64
+ # python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
65
+ # date +"%F %T"
66
+
67
+
68
+ # MERGE_DIR="./exps/run_ex27"
69
+ # accelerate launch --main_process_port 41353 -m src.merge \
70
+ # --config_path $OMINI_CONFIG \
71
+ # --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
72
+
73
+ # date +"%F %T"
74
+ # python inference/MATH_infer.py --model $MERGE_DIR/merged/
75
+ # date +"%F %T"
76
+ # python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
77
+ # date +"%F %T"
78
+
79
+
80
+ # MERGE_DIR="./exps/run_ex28"
81
+ # accelerate launch --main_process_port 41353 -m src.merge \
82
+ # --config_path $OMINI_CONFIG \
83
+ # --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
84
+
85
+ # date +"%F %T"
86
+ # python inference/MATH_infer.py --model $MERGE_DIR/merged/
87
+ # date +"%F %T"
88
+ # python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
89
+ # date +"%F %T"
90
+
91
+
92
+ MERGE_DIR="./exps/run_ex33"
93
+ # accelerate launch --main_process_port 41353 -m src.merge \
94
+ # --config_path $OMINI_CONFIG \
95
+ # --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
96
+
97
+ # date +"%F %T"
98
+ # python inference/MATH_infer.py --model $MERGE_DIR/merged/
99
+ # date +"%F %T"
100
+ # python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
101
+ # date +"%F %T"
102
+
103
+
104
+ MERGE_DIR="./exps/run_ex34"
105
+ accelerate launch --main_process_port 41353 -m src.merge \
106
+ --config_path $OMINI_CONFIG \
107
+ --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
108
+
109
+ # date +"%F %T"
110
+ # python inference/MATH_infer.py --model $MERGE_DIR/merged/
111
+ date +"%F %T"
112
+ python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
113
+ date +"%F %T"
114
+
115
+
116
+
117
+
nl_tasks/scripts/.nfs80e7f26e0132942e00006649 ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export OMINI_CONFIG=./config/commonsense.yaml
2
+
3
+ #echo $OMINI_CONFIG
4
+ export TOKENIZERS_PARALLELISM=true
5
+
6
+ # CUDA Include (/cuda.h)
7
+ CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
8
+
9
+ # 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
10
+ export CPATH=$CPATH:$CUDA_INCLUDE_PATH
11
+ export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
12
+ # echo "CPATH is set to: $CPATH"
13
+ # echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
14
+
15
+ export WANDB_PROJECT="Llama2_7B_FT_Math40k_2"
16
+
17
+ export OMP_NUM_THREADS=1
18
+ export MKL_NUM_THREADS=1
19
+ export OPENBLAS_NUM_THREADS=1
20
+ export NUMEXPR_NUM_THREADS=1
21
+
22
+ date +"%F %T"
23
+
24
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
25
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp_init/run_ex01" --trainer_args.learning_rate=1e-3 \
26
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text "init|kaim_out_u=v"
27
+
28
+ # sleep 5
29
+ # echo "1st exp finishes"
30
+ # date +"%F %T"
31
+ # wandb sync wandb/latest-run
32
+
33
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
34
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp_init/run_ex02" --trainer_args.learning_rate=1e-3 \
35
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text "init|kaim_out_u=v(ratio)"
36
+
37
+ # sleep 5
38
+ # echo "2nd exp finishes"
39
+ # date +"%F %T"
40
+ # wandb sync wandb/latest-run
41
+
42
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
43
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex03" --trainer_args.learning_rate=1e-3 \
44
+ # --rotation_adapter_config.num_rotations 2 --rotation_adapter_config.r 8
45
+
46
+ # sleep 5
47
+ # echo "3rd exp finishes"
48
+ # date +"%F %T"
49
+ # wandb sync wandb/latest-run
50
+
51
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
52
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex04" --trainer_args.learning_rate=2e-3 \
53
+ # --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
54
+ # sleep 5
55
+ # echo "4th exp finishes"
56
+ # date +"%F %T"
57
+ # wandb sync wandb/latest-run
58
+
59
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
60
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex05" --trainer_args.learning_rate=2e-3 \
61
+ # --rotation_adapter_config.num_rotations 2 --rotation_adapter_config.r 8
62
+ # sleep 5
63
+ # echo "5th exp finishes"
64
+ # date +"%F %T"
65
+ # wandb sync wandb/latest-run
66
+
67
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
68
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex06" --trainer_args.learning_rate=1e-3 \
69
+ # --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
70
+ # sleep 5
71
+ # echo "6th exp finishes"
72
+ # date +"%F %T"
73
+ # wandb sync wandb/latest-run
74
+
75
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
76
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_exps7" --trainer_args.learning_rate=1e-3 \
77
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
78
+ # sleep 5
79
+ # echo "7th exp finishes"
80
+ # date +"%F %T"
81
+ # wandb sync wandb/latest-run
82
+
83
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
84
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex08" --trainer_args.learning_rate=2e-3 \
85
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
86
+ # sleep 5
87
+ # echo "8th exp finishes"
88
+ # date +"%F %T"
89
+ # wandb sync wandb/latest-run
90
+
91
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
92
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex09" --trainer_args.learning_rate=2e-3 \
93
+ # --rotation_adapter_config.num_rotations 16 --rotation_adapter_config.r 1
94
+
95
+ # sleep 5
96
+ # echo "9th exp finishes"
97
+ # date +"%F %T"
98
+ # wandb sync wandb/latest-run
99
+
100
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
101
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex10" --trainer_args.learning_rate=2e-3 \
102
+ # --rotation_adapter_config.num_rotations 8 --rotation_adapter_config.r 2
103
+
104
+ # sleep 5
105
+ # echo "10 exp finishes"
106
+ # date +"%F %T"
107
+ # wandb sync wandb/latest-run
108
+
109
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
110
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex11" --trainer_args.learning_rate=1e-2 \
111
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
112
+
113
+ # sleep 5
114
+ # echo "11 exp finishes"
115
+ # date +"%F %T"
116
+ # wandb sync wandb/latest-run
117
+
118
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
119
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex12" --trainer_args.learning_rate=1e-2 \
120
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'u=v,def'
121
+
122
+ # sleep 5
123
+ # echo "12 exp finishes"
124
+ # date +"%F %T"
125
+ # wandb sync wandb/latest-run
126
+
127
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
128
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex13" --trainer_args.learning_rate=1e-3 \
129
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'u=vkaim'
130
+
131
+ # sleep 5
132
+ # echo "13 exp finishes"
133
+ # date +"%F %T"
134
+ # wandb sync wandb/latest-run
135
+
136
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
137
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex14" --trainer_args.learning_rate=2e-3 \
138
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'a,b,def'
139
+
140
+ # sleep 5
141
+ # echo "14 exp finishes"
142
+ # date +"%F %T"
143
+ # wandb sync wandb/latest-run
144
+
145
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
146
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex15" --trainer_args.learning_rate=1e-3 \
147
+ # --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
148
+
149
+ # sleep 5
150
+ # echo "15 exp finishes"
151
+ # date +"%F %T"
152
+
153
+
154
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
155
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex17" --trainer_args.learning_rate=1e-3 \
156
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
157
+ # --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
158
+ # --run_text "dropout|fix_token"
159
+ # sleep 5
160
+ # echo "15 exp finishes"
161
+ # date +"%F %T"
162
+ # wandb sync wandb/latest-run
163
+
164
+
165
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
166
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex18" --trainer_args.learning_rate=1e-3 \
167
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
168
+ # --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
169
+ # --run_text "dropout|fix_token"
170
+ # sleep 5
171
+ # echo "158exp finishes"
172
+ # date +"%F %T"
173
+ # wandb sync wandb/latest-run
174
+
175
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
176
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex19" --trainer_args.learning_rate=2e-3 \
177
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
178
+ # --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
179
+ # --run_text "dropout|fix_token"
180
+ # sleep 5
181
+ # echo "19 exp finishes"
182
+ # date +"%F %T"
183
+ # wandb sync wandb/latest-run
184
+
185
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
186
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex20" --trainer_args.learning_rate=8e-4 \
187
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
188
+ # --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
189
+ # --run_text "dropout|fix_token"
190
+ # sleep 5
191
+ # echo "20 exp finishes"
192
+ # date +"%F %T"
193
+ # wandb sync wandb/latest-run
194
+
195
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
196
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex21" --trainer_args.learning_rate=1e-3 \
197
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
198
+ # --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
199
+ # --run_text "dropout|2ep|1e3"
200
+ # sleep 5
201
+ # echo "21 exp finishes"
202
+ # date +"%F %T"
203
+ # wandb sync wandb/latest-run
204
+
205
+
206
+ # back to official 40k
207
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
208
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex22" --trainer_args.learning_rate=1e-3 \
209
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
210
+ # --trainer_args.num_train_epochs 2.0 --data.dataset_split train \
211
+ # --run_text "drop0.1|2ep|1e3|40k"
212
+ # sleep 5
213
+ # echo "21 exp finishes"
214
+ # date +"%F %T"
215
+ # wandb sync wandb/latest-run
216
+
217
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
218
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex23" --trainer_args.learning_rate=1e-3 \
219
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
220
+ # --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
221
+ # --run_text "drop0.1|2ep|1e3|40k"
222
+ # sleep 5
223
+ # echo "21 exp finishes"
224
+ # date +"%F %T"
225
+ # wandb sync wandb/latest-run
226
+
227
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
228
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex24" --trainer_args.learning_rate=1e-2 \
229
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
230
+ # --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
231
+ # --run_text "drop0.1|2ep|1e2|40k"
232
+ # sleep 5
233
+ # echo "21 exp finishes"
234
+ # date +"%F %T"
235
+ # wandb sync wandb/latest-run
236
+
237
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
238
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex25" --trainer_args.learning_rate=2e-3 \
239
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
240
+ # --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
241
+ # --run_text "drop0.1|2ep|2e3|40k"
242
+ # sleep 5
243
+ # echo "21 exp finishes"
244
+ # date +"%F %T"
245
+ # wandb sync wandb/latest-run
246
+
247
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
248
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex26" --trainer_args.learning_rate=5e-3 \
249
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
250
+ # --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
251
+ # --run_text "drop0.1|2ep|5e3|40k"
252
+ # sleep 5
253
+ # echo "21 exp finishes"
254
+ # date +"%F %T"
255
+ # wandb sync wandb/latest-run
256
+
257
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
258
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex27" --trainer_args.learning_rate=8e-3 \
259
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
260
+ # --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
261
+ # --run_text "drop0.1|2ep|8e3|40k"
262
+ # sleep 5
263
+ # echo "21 exp finishes"
264
+ # date +"%F %T"
265
+ # wandb sync wandb/latest-run
266
+
267
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
268
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex28" --trainer_args.learning_rate=2e-2 \
269
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
270
+ # --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
271
+ # --run_text "drop0.1|2ep|2e2|40k"
272
+ # sleep 5
273
+ # echo "21 exp finishes"
274
+ # date +"%F %T"
275
+ # wandb sync wandb/latest-run
276
+
277
+
278
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
279
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex29" --trainer_args.learning_rate=5e-3 \
280
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
281
+ # --trainer_args.num_train_epochs 2.0 --data.dataset_split train \
282
+ # --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
283
+ # sleep 5
284
+ # echo "29 exp finishes"
285
+ # date +"%F %T"
286
+ # wandb sync wandb/latest-run
287
+
288
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
289
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex30" --trainer_args.learning_rate=1e-3 \
290
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
291
+ # --trainer_args.num_train_epochs 2.0 --data.dataset_split train \
292
+ # --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
293
+ # sleep 5
294
+ # echo "29 exp finishes"
295
+ # date +"%F %T"
296
+ # wandb sync wandb/latest-run
297
+
298
+
299
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
300
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex31" --trainer_args.learning_rate=5e-3 \
301
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
302
+ # --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
303
+ # --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
304
+ # sleep 5
305
+ # echo "29 exp finishes"
306
+ # date +"%F %T"
307
+ # wandb sync wandb/latest-run
308
+
309
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
310
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex32" --trainer_args.learning_rate=1e-3 \
311
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
312
+ # --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
313
+ # --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
314
+ # sleep 5
315
+ # echo "29 exp finishes"
316
+ # date +"%F %T"
317
+ # wandb sync wandb/latest-run
318
+
319
+
320
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
321
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex33" --trainer_args.learning_rate=1e-2 \
322
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
323
+ # --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
324
+ # --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
325
+ # sleep 5
326
+ # echo "29 exp finishes"
327
+ # date +"%F %T"
328
+ # wandb sync wandb/latest-run
329
+
330
+
331
+ accelerate launch --main_process_port 41353 -m src.ft_mathQ \
332
+ --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex34" --trainer_args.learning_rate=2e-2 \
333
+ --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
334
+ --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
335
+ --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
336
+ sleep 5
337
+ echo "29 exp finishes"
338
+ date +"%F %T"
339
+ wandb sync wandb/latest-run
340
+
341
+ bash scripts/merge.sh
nl_tasks/scripts/copy train_cms_reasoning.sh ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export OMINI_CONFIG=./config/commonsense.yaml
2
+
3
+ #echo $OMINI_CONFIG
4
+ export TOKENIZERS_PARALLELISM=true
5
+
6
+ # CUDA Include (/cuda.h)
7
+ CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
8
+
9
+ # 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
10
+ export CPATH=$CPATH:$CUDA_INCLUDE_PATH
11
+ export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
12
+ # echo "CPATH is set to: $CPATH"
13
+ # echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
14
+
15
+ export WANDB_PROJECT="Llama2_7B_FT_Math40k"
16
+
17
+ export OMP_NUM_THREADS=1
18
+ export MKL_NUM_THREADS=1
19
+ export OPENBLAS_NUM_THREADS=1
20
+ export NUMEXPR_NUM_THREADS=1
21
+
22
+ date +"%F %T"
23
+
24
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
25
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exps1" --trainer_args.learning_rate=1e-5
26
+
27
+ # sleep 5
28
+ # echo "1st exp finishes"
29
+ # date +"%F %T"
30
+
31
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
32
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exps2" --trainer_args.learning_rate=2e-5
33
+
34
+ # sleep 5
35
+ # echo "2nd exp finishes"
36
+ # date +"%F %T"
37
+
38
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
39
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exps3" --trainer_args.learning_rate=5e-5
40
+
41
+ # sleep 5
42
+ # echo "3rd exp finishes"
43
+ # date +"%F %T"
44
+
45
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
46
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exps4" --trainer_args.learning_rate=1e-4
47
+
48
+ # sleep 5
49
+ # echo "4th exp finishes"
50
+ # date +"%F %T"
51
+
52
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
53
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exps5" --trainer_args.learning_rate=2e-4
54
+
55
+ # sleep 5
56
+ # echo "5th exp finishes"
57
+ # date +"%F %T"
58
+
59
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
60
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exps6" --trainer_args.learning_rate=5e-4
61
+
62
+ # sleep 5
63
+ # echo "6th exp finishes"
64
+ # date +"%F %T"
65
+
66
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
67
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exps7" --trainer_args.learning_rate=8e-4
68
+
69
+ # sleep 5
70
+ # echo "7th exp finishes"
71
+ # date +"%F %T"
72
+
73
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
74
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exps8" --trainer_args.learning_rate=1e-3
75
+
76
+ # sleep 5
77
+ # echo "8th exp finishes"
78
+ # date +"%F %T"
79
+
80
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
81
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exps9" --trainer_args.learning_rate=2e-3
82
+
83
+ # sleep 5
84
+ # echo "9th exp finishes"
85
+ # date +"%F %T"
86
+
87
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
88
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exnr10" --trainer_args.learning_rate=1e-3 \
89
+ # --rotation_adapter_config.num_rotations 8 --rotation_adapter_config.r 2
90
+
91
+ # sleep 5
92
+ # echo "10 exp finishes"
93
+ # date +"%F %T"
94
+
95
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
96
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exnr11" --trainer_args.learning_rate=1e-3 \
97
+ # --rotation_adapter_config.num_rotations 2 --rotation_adapter_config.r 8
98
+
99
+ # sleep 5
100
+ # echo "11 exp finishes"
101
+ # date +"%F %T"
102
+
103
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
104
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exnr12" --trainer_args.learning_rate=1e-3 \
105
+ # --rotation_adapter_config.num_rotations 16 --rotation_adapter_config.r 1
106
+
107
+ # sleep 5
108
+ # echo "12 exp finishes"
109
+ # date +"%F %T"
110
+
111
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
112
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exnr13" --trainer_args.learning_rate=1e-3 \
113
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
114
+
115
+ # sleep 5
116
+ # echo "13 exp finishes"
117
+ # date +"%F %T"
118
+
119
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
120
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_all/exnr14" --trainer_args.learning_rate=2e-3 \
121
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
122
+
123
+ # sleep 5
124
+ # echo "14 exp finishes"
125
+ # date +"%F %T"
126
+
127
+ accelerate launch --main_process_port 41353 -m src.ft_mathR \
128
+ --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_all/exnr15" --trainer_args.learning_rate=1e-3 \
129
+ --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
130
+
131
+ # sleep 5
132
+ echo "15 exp finishes"
133
+ date +"%F %T"
nl_tasks/scripts/down_math_train.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ DATASET_ID="meta-math/MetaMathQA"
4
+ LOCAL_DIR="./data/MetaMathQA"
5
+
6
+ # echo "Starting download for dataset: $DATASET_ID..."
7
+ huggingface-cli download $DATASET_ID \
8
+ --repo-type dataset \
9
+ --local-dir $LOCAL_DIR \
10
+ --local-dir-use-symlinks False \
11
+ --resume-download \
12
+ --include "*.json"
13
+
14
+ # echo "Download completed. Data is located at: $LOCAL_DIR"
nl_tasks/scripts/inference.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OUTPUT="./exps/run_ex12/merged"
2
+ # OUTPUT="./exp395/run_ex02/merged"
3
+ # OUTPUT="./exp_init/run_ex02/merged"
4
+ OUTPUT="./exps/run_ex15_3ep/merged"
5
+
6
+ date +"%F %T"
7
+
8
+ echo 'test math'
9
+
10
+ date +"%F %T"
11
+ python inference/gsm8k_infer.py --model $OUTPUT
12
+ date +"%F %T"
13
+ python inference/MATH_infer.py --model $OUTPUT
14
+ date +"%F %T"
nl_tasks/scripts/merge.sh ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # export OMINI_CONFIG=./config/commonsense.yaml
2
+ export OMINI_CONFIG=./config/commonsense.yaml
3
+ #echo $OMINI_CONFIG
4
+ export TOKENIZERS_PARALLELISM=true
5
+
6
+ # CUDA Include (/cuda.h)
7
+ CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
8
+
9
+ # 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
10
+ export CPATH=$CPATH:$CUDA_INCLUDE_PATH
11
+ export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
12
+ # echo "CPATH is set to: $CPATH"
13
+ # echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
14
+
15
+ #./exps/run_ex15_3ep
16
+ # ./exp_init/run_ex02/ft2
17
+ # ADAPTER = "--model.merge_adapter_path "./exps/run_ex12/ft2" --model.merge_output_path "./exps/run_ex12/merged"
18
+ # export ADAPTER = "--model.merge_adapter_path ./exp395/run_ex01/ft2 --model.merge_output_path ./exp395/run_ex01/merged"
19
+ # accelerate launch --main_process_port 41353 -m src.merge \
20
+ # --config_path $OMINI_CONFIG \
21
+ # --model.merge_adapter_path "./exps/run_ex19_2ep/ft2" --model.merge_output_path "./exps/run_ex19_2ep/merged"
22
+
23
+ # OUTPUT="./exps/run_ex19_2ep/merged"
24
+
25
+ # date +"%F %T"
26
+ # python inference/MATH_infer.py --model $OUTPUT
27
+ # date +"%F %T"
28
+ # python inference/gsm8k_infer.py --model $OUTPUT
29
+ # date +"%F %T"
30
+
31
+
32
+ # MERGE_DIR="./exps/run_ex24"
33
+ # accelerate launch --main_process_port 41353 -m src.merge \
34
+ # --config_path $OMINI_CONFIG \
35
+ # --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
36
+
37
+ # date +"%F %T"
38
+ # python inference/MATH_infer.py --model $MERGE_DIR/merged/
39
+ # date +"%F %T"
40
+ # python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
41
+ # date +"%F %T"
42
+
43
+
44
+ # MERGE_DIR="./exps/run_ex25"
45
+ # accelerate launch --main_process_port 41353 -m src.merge \
46
+ # --config_path $OMINI_CONFIG \
47
+ # --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
48
+
49
+ # date +"%F %T"
50
+ # python inference/MATH_infer.py --model $MERGE_DIR/merged/
51
+ # date +"%F %T"
52
+ # python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
53
+ # date +"%F %T"
54
+
55
+
56
+ # MERGE_DIR="./exps/run_ex26"
57
+ # accelerate launch --main_process_port 41353 -m src.merge \
58
+ # --config_path $OMINI_CONFIG \
59
+ # --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
60
+
61
+ # date +"%F %T"
62
+ # python inference/MATH_infer.py --model $MERGE_DIR/merged/
63
+ # date +"%F %T"
64
+ # python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
65
+ # date +"%F %T"
66
+
67
+
68
+ # MERGE_DIR="./exps/run_ex27"
69
+ # accelerate launch --main_process_port 41353 -m src.merge \
70
+ # --config_path $OMINI_CONFIG \
71
+ # --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
72
+
73
+ # date +"%F %T"
74
+ # python inference/MATH_infer.py --model $MERGE_DIR/merged/
75
+ # date +"%F %T"
76
+ # python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
77
+ # date +"%F %T"
78
+
79
+
80
+ # MERGE_DIR="./exps/run_ex28"
81
+ # accelerate launch --main_process_port 41353 -m src.merge \
82
+ # --config_path $OMINI_CONFIG \
83
+ # --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
84
+
85
+ # date +"%F %T"
86
+ # python inference/MATH_infer.py --model $MERGE_DIR/merged/
87
+ # date +"%F %T"
88
+ # python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
89
+ # date +"%F %T"
90
+
91
+
92
+ # MERGE_DIR="./exps/run_ex33"
93
+ # accelerate launch --main_process_port 41353 -m src.merge \
94
+ # --config_path $OMINI_CONFIG \
95
+ # --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
96
+
97
+ # date +"%F %T"
98
+ # python inference/MATH_infer.py --model $MERGE_DIR/merged/
99
+ # date +"%F %T"
100
+ # python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
101
+ # date +"%F %T"
102
+
103
+ # 140126
104
+ MERGE_DIR="./exprep/run_ex30"
105
+ accelerate launch --main_process_port 41353 -m src.merge \
106
+ --config_path $OMINI_CONFIG \
107
+ --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
108
+
109
+ date +"%F %T"
110
+ python inference/MATH_infer.py --model $MERGE_DIR/merged/
111
+ date +"%F %T"
112
+ python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
113
+ date +"%F %T"
114
+
115
+ MERGE_DIR="./exprep/run_ex31"
116
+ accelerate launch --main_process_port 41353 -m src.merge \
117
+ --config_path $OMINI_CONFIG \
118
+ --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
119
+
120
+ date +"%F %T"
121
+ python inference/MATH_infer.py --model $MERGE_DIR/merged/
122
+ date +"%F %T"
123
+ python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
124
+ date +"%F %T"
125
+
126
+ MERGE_DIR="./exprep/run_ex32"
127
+ accelerate launch --main_process_port 41353 -m src.merge \
128
+ --config_path $OMINI_CONFIG \
129
+ --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
130
+
131
+ date +"%F %T"
132
+ python inference/MATH_infer.py --model $MERGE_DIR/merged/
133
+ date +"%F %T"
134
+ python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
135
+ date +"%F %T"
136
+
137
+
nl_tasks/scripts/merge_100k.sh ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # export OMINI_CONFIG=./config/commonsense.yaml
2
+ export OMINI_CONFIG=./config/math395.yaml
3
+ #echo $OMINI_CONFIG
4
+ export TOKENIZERS_PARALLELISM=true
5
+
6
+ # CUDA Include (/cuda.h)
7
+ CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
8
+
9
+ # 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
10
+ export CPATH=$CPATH:$CUDA_INCLUDE_PATH
11
+ export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
12
+ # echo "CPATH is set to: $CPATH"
13
+ # echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
14
+
15
+ #./exps/run_ex15_3ep
16
+ # ./exp_init/run_ex02/ft2
17
+ # ADAPTER = "--model.merge_adapter_path "./exps/run_ex12/ft2" --model.merge_output_path "./exps/run_ex12/merged"
18
+ # export ADAPTER = "--model.merge_adapter_path ./exp395/run_ex01/ft2 --model.merge_output_path ./exp395/run_ex01/merged"
19
+
20
+ MERGE_DIR="./exp100/run_ex06"
21
+ accelerate launch --main_process_port 41353 -m src.merge \
22
+ --config_path $OMINI_CONFIG \
23
+ --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
24
+
25
+ date +"%F %T"
26
+ python inference/MATH_infer.py --model $MERGE_DIR/merged/
27
+ date +"%F %T"
28
+ python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
29
+ date +"%F %T"
30
+
31
+ MERGE_DIR="./exp100/run_ex07"
32
+ accelerate launch --main_process_port 41353 -m src.merge \
33
+ --config_path $OMINI_CONFIG \
34
+ --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
35
+
36
+ date +"%F %T"
37
+ python inference/MATH_infer.py --model $MERGE_DIR/merged/
38
+ date +"%F %T"
39
+ python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
40
+ date +"%F %T"
41
+
42
+
43
+ MERGE_DIR="./exp100/run_ex08"
44
+ accelerate launch --main_process_port 41353 -m src.merge \
45
+ --config_path $OMINI_CONFIG \
46
+ --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
47
+
48
+ date +"%F %T"
49
+ python inference/MATH_infer.py --model $MERGE_DIR/merged/
50
+ date +"%F %T"
51
+ python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
52
+ date +"%F %T"
53
+
54
+
55
+ MERGE_DIR="./exp100/run_ex09"
56
+ accelerate launch --main_process_port 41353 -m src.merge \
57
+ --config_path $OMINI_CONFIG \
58
+ --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
59
+
60
+ date +"%F %T"
61
+ python inference/MATH_infer.py --model $MERGE_DIR/merged/
62
+ date +"%F %T"
63
+ python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
64
+ date +"%F %T"
65
+
66
+
67
+ MERGE_DIR="./exp100/run_ex10"
68
+ accelerate launch --main_process_port 41353 -m src.merge \
69
+ --config_path $OMINI_CONFIG \
70
+ --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
71
+
72
+ date +"%F %T"
73
+ python inference/MATH_infer.py --model $MERGE_DIR/merged/
74
+ date +"%F %T"
75
+ python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
76
+ date +"%F %T"
77
+
78
+
79
+ MERGE_DIR="./exp100/run_ex11"
80
+ accelerate launch --main_process_port 41353 -m src.merge \
81
+ --config_path $OMINI_CONFIG \
82
+ --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
83
+
84
+ date +"%F %T"
85
+ python inference/MATH_infer.py --model $MERGE_DIR/merged/
86
+ date +"%F %T"
87
+ python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
88
+ date +"%F %T"
89
+
90
+
91
+ MERGE_DIR="./exp100/run_ex12"
92
+ accelerate launch --main_process_port 41353 -m src.merge \
93
+ --config_path $OMINI_CONFIG \
94
+ --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
95
+
96
+ date +"%F %T"
97
+ python inference/MATH_infer.py --model $MERGE_DIR/merged/
98
+ date +"%F %T"
99
+ python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
100
+ date +"%F %T"
nl_tasks/scripts/merge_math.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # export OMINI_CONFIG=./config/commonsense.yaml
2
+ export OMINI_CONFIG=./config/math395.yaml
3
+ #echo $OMINI_CONFIG
4
+ export TOKENIZERS_PARALLELISM=true
5
+
6
+ # CUDA Include (/cuda.h)
7
+ CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
8
+
9
+ # 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
10
+ export CPATH=$CPATH:$CUDA_INCLUDE_PATH
11
+ export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
12
+ # echo "CPATH is set to: $CPATH"
13
+ # echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
14
+
15
+ #./exps/run_ex15_3ep
16
+ # ./exp_init/run_ex02/ft2
17
+ # ADAPTER = "--model.merge_adapter_path "./exps/run_ex12/ft2" --model.merge_output_path "./exps/run_ex12/merged"
18
+ # export ADAPTER = "--model.merge_adapter_path ./exp395/run_ex01/ft2 --model.merge_output_path ./exp395/run_ex01/merged"
19
+
20
+ MERGE_DIR="./exp395/run_ex10"
21
+ # accelerate launch --main_process_port 41353 -m src.merge \
22
+ # --config_path $OMINI_CONFIG \
23
+ # --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
24
+
25
+ # # OUTPUT="./exp395/run_ex09/merged"
26
+
27
+ # date +"%F %T"
28
+ # python inference/MATH_infer.py --model $MERGE_DIR/merged/
29
+ date +"%F %T"
30
+ python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
31
+ date +"%F %T"
nl_tasks/scripts/peft_merge.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # export OMINI_CONFIG=./config/commonsense.yaml
2
+ export OMINI_CONFIG=./config/commonsense.yaml
3
+ #echo $OMINI_CONFIG
4
+ export TOKENIZERS_PARALLELISM=true
5
+
6
+ # CUDA Include (/cuda.h)
7
+ CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
8
+
9
+ # 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
10
+ export CPATH=$CPATH:$CUDA_INCLUDE_PATH
11
+ export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
12
+ # echo "CPATH is set to: $CPATH"
13
+ # echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
14
+
15
+
16
+
17
+ MERGE_DIR="./expsBOFT/seed43"
18
+ accelerate launch --main_process_port 41353 -m src.peft_merge \
19
+ --config_path $OMINI_CONFIG \
20
+ --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
21
+
22
+ date +"%F %T"
23
+ python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
24
+ date +"%F %T"
25
+ python inference/MATH_infer.py --model $MERGE_DIR/merged/
26
+ date +"%F %T"
27
+
28
+ MERGE_DIR="./expsBOFT/seed44"
29
+ accelerate launch --main_process_port 41353 -m src.peft_merge \
30
+ --config_path $OMINI_CONFIG \
31
+ --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
32
+
33
+ date +"%F %T"
34
+ python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
35
+ date +"%F %T"
36
+ python inference/MATH_infer.py --model $MERGE_DIR/merged/
37
+ date +"%F %T"
38
+
39
+
40
+ MERGE_DIR="./expsOFT/seed43"
41
+ accelerate launch --main_process_port 41353 -m src.peft_merge \
42
+ --config_path $OMINI_CONFIG \
43
+ --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
44
+
45
+ date +"%F %T"
46
+ python inference/MATH_infer.py --model $MERGE_DIR/merged/
47
+ date +"%F %T"
48
+ python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
49
+ date +"%F %T"
50
+
51
+ MERGE_DIR="./expsOFT/seed44"
52
+ accelerate launch --main_process_port 41353 -m src.peft_merge \
53
+ --config_path $OMINI_CONFIG \
54
+ --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
55
+ date +"%F %T"
56
+ python inference/MATH_infer.py --model $MERGE_DIR/merged/
57
+ date +"%F %T"
58
+ python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
59
+ date +"%F %T"
60
+
nl_tasks/scripts/train_100math.sh ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export OMINI_CONFIG=./config/math395.yaml
2
+
3
+ #echo $OMINI_CONFIG
4
+ export TOKENIZERS_PARALLELISM=true
5
+
6
+ # CUDA Include (/cuda.h)
7
+ CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
8
+
9
+ # 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
10
+ export CPATH=$CPATH:$CUDA_INCLUDE_PATH
11
+ export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
12
+ # echo "CPATH is set to: $CPATH"
13
+ # echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
14
+
15
+ export WANDB_PROJECT="Llama2_7B_FT_Math_395k"
16
+
17
+ export OMP_NUM_THREADS=1
18
+ export MKL_NUM_THREADS=1
19
+ export OPENBLAS_NUM_THREADS=1
20
+ export NUMEXPR_NUM_THREADS=1
21
+
22
+ date +"%F %T"
23
+
24
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
25
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex01" --trainer_args.learning_rate=5e-3 \
26
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
27
+ # --trainer_args.num_train_epochs 2.0 --data.dataset_split train[:100000] \
28
+ # --run_text 'def|o100k'
29
+
30
+ # sleep 5
31
+ # echo "1st exp finishes"
32
+ # date +"%F %T"
33
+ # wandb sync wandb/latest-run
34
+
35
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
36
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex02" --trainer_args.learning_rate=2e-2 \
37
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
38
+ # --trainer_args.num_train_epochs 2.0 --data.dataset_split train[:100000] \
39
+ # --run_text 'def|o100k'
40
+
41
+ # sleep 5
42
+ # echo "2nd exp finishes"
43
+ # date +"%F %T"
44
+ # wandb sync wandb/latest-run
45
+
46
+ # bash scripts/merge_100k.sh
47
+
48
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
49
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex03" --trainer_args.learning_rate=1e-2 \
50
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
51
+ # --trainer_args.num_train_epochs 2.0 --data.dataset_split train[:100000] \
52
+ # --run_text 'def|o100k'
53
+
54
+ # sleep 5
55
+ # echo "3rd exp finishes"
56
+ # date +"%F %T"
57
+ # wandb sync wandb/latest-run
58
+
59
+
60
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
61
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex04" --trainer_args.learning_rate=5e-2 \
62
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
63
+ # --trainer_args.num_train_epochs 2.0 --data.dataset_split train[:100000] \
64
+ # --run_text 'def|o100k'
65
+
66
+ # sleep 5
67
+ # echo "4th exp finishes"
68
+ # date +"%F %T"
69
+ # wandb sync wandb/latest-run
70
+
71
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
72
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex05" --trainer_args.learning_rate=1e-2 \
73
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
74
+ # --trainer_args.num_train_epochs 3.0 --data.dataset_split train[:100000] \
75
+ # --run_text 'def|o100k'
76
+ # sleep 5
77
+ # echo "5th exp finishes"
78
+ # date +"%F %T"
79
+ # wandb sync wandb/latest-run
80
+ # bash scripts/merge_100math.sh
81
+
82
+ accelerate launch --main_process_port 41353 -m src.ft_mathQ \
83
+ --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex06" --trainer_args.learning_rate=1e-2 \
84
+ --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
85
+ --trainer_args.num_train_epochs 2.0 --data.dataset_split train[:100000] \
86
+ --run_text 'def|o100k|b48' --trainer_args.per_device_train_batch_size 48
87
+ sleep 5
88
+ echo "6th exp finishes"
89
+ date +"%F %T"
90
+ wandb sync wandb/latest-run
91
+
92
+ accelerate launch --main_process_port 41353 -m src.ft_mathQ \
93
+ --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex07" --trainer_args.learning_rate=1e-2 \
94
+ --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
95
+ --trainer_args.num_train_epochs 3.0 --data.dataset_split train[:100000] \
96
+ --run_text 'def|o100k|b48' --trainer_args.per_device_train_batch_size 48
97
+ sleep 5
98
+ echo "6th exp finishes"
99
+ date +"%F %T"
100
+ wandb sync wandb/latest-run
101
+
102
+ accelerate launch --main_process_port 41353 -m src.ft_mathQ \
103
+ --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex08" --trainer_args.learning_rate=2e-2 \
104
+ --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
105
+ --trainer_args.num_train_epochs 2.0 --data.dataset_split train[:100000] \
106
+ --run_text 'def|o100k|b48' --trainer_args.per_device_train_batch_size 48
107
+ sleep 5
108
+ echo "8th exp finishes"
109
+ date +"%F %T"
110
+ wandb sync wandb/latest-run
111
+
112
+ accelerate launch --main_process_port 41353 -m src.ft_mathQ \
113
+ --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex09" --trainer_args.learning_rate=2e-2 \
114
+ --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
115
+ --trainer_args.num_train_epochs 3.0 --data.dataset_split train[:100000] \
116
+ --run_text 'def|o100k|b48' --trainer_args.per_device_train_batch_size 48
117
+ sleep 5
118
+ echo "9th exp finishes"
119
+ date +"%F %T"
120
+ wandb sync wandb/latest-run
121
+
122
+ accelerate launch --main_process_port 41353 -m src.ft_mathQ \
123
+ --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex10" --trainer_args.learning_rate=3e-2 \
124
+ --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
125
+ --trainer_args.num_train_epochs 2.0 --data.dataset_split train[:100000] \
126
+ --run_text 'def|o100k|b48' --trainer_args.per_device_train_batch_size 48
127
+ sleep 5
128
+ echo "10th exp finishes"
129
+ date +"%F %T"
130
+ wandb sync wandb/latest-run
131
+
132
+
133
+
134
+ accelerate launch --main_process_port 41353 -m src.ft_mathQ \
135
+ --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex11" --trainer_args.learning_rate=8e-3 \
136
+ --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
137
+ --trainer_args.num_train_epochs 2.0 --data.dataset_split train[:100000] \
138
+ --run_text 'def|o100k|b48' --trainer_args.per_device_train_batch_size 48
139
+
140
+ sleep 5
141
+ echo "11 exp finishes"
142
+ date +"%F %T"
143
+ wandb sync wandb/latest-run
144
+
145
+
146
+ accelerate launch --main_process_port 41353 -m src.ft_mathQ \
147
+ --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex12" --trainer_args.learning_rate=8e-3 \
148
+ --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
149
+ --trainer_args.num_train_epochs 3.0 --data.dataset_split train[:100000] \
150
+ --run_text 'def|o100k|b48' --trainer_args.per_device_train_batch_size 48
151
+
152
+ sleep 5
153
+ echo "12 exp finishes"
154
+ date +"%F %T"
155
+ wandb sync wandb/latest-run
156
+
157
+
158
+ bash ./scripts/merge_100k.sh
159
+
160
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
161
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex13" --trainer_args.learning_rate=1e-3 \
162
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'u=vkaim'
163
+
164
+ # sleep 5
165
+ # echo "13 exp finishes"
166
+ # date +"%F %T"
167
+ # wandb sync wandb/latest-run
168
+
169
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
170
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex14" --trainer_args.learning_rate=2e-3 \
171
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'a,b,def'
172
+
173
+ # sleep 5
174
+ # echo "14 exp finishes"
175
+ # date +"%F %T"
176
+ # wandb sync wandb/latest-run
177
+
178
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
179
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex15" --trainer_args.learning_rate=1e-3 \
180
+ # --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
181
+
182
+ # sleep 5
183
+ # echo "15 exp finishes"
184
+ # date +"%F %T"
nl_tasks/scripts/train_cms_reasoning.sh ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export OMINI_CONFIG=./config/commonsense.yaml
2
+
3
+ #echo $OMINI_CONFIG
4
+ export TOKENIZERS_PARALLELISM=true
5
+
6
+ # CUDA Include (/cuda.h)
7
+ CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
8
+
9
+ # 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
10
+ export CPATH=$CPATH:$CUDA_INCLUDE_PATH
11
+ export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
12
+ # echo "CPATH is set to: $CPATH"
13
+ # echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
14
+
15
+ export WANDB_PROJECT="Llama2_7B_FT_Math40k_2"
16
+
17
+ export OMP_NUM_THREADS=1
18
+ export MKL_NUM_THREADS=1
19
+ export OPENBLAS_NUM_THREADS=1
20
+ export NUMEXPR_NUM_THREADS=1
21
+
22
+ date +"%F %T"
23
+
24
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
25
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex01" --trainer_args.learning_rate=5e-5 \
26
+ # --rotation_adapter_config.num_rotations 2 --rotation_adapter_config.r 8
27
+
28
+ # sleep 5
29
+ # echo "1st exp finishes"
30
+ # date +"%F %T"
31
+
32
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
33
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex02" --trainer_args.learning_rate=5e-4 \
34
+ # --rotation_adapter_config.num_rotations 2 --rotation_adapter_config.r 8
35
+
36
+ # sleep 5
37
+ # echo "2nd exp finishes"
38
+ # date +"%F %T"
39
+
40
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
41
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex03" --trainer_args.learning_rate=1e-3 \
42
+ # --rotation_adapter_config.num_rotations 2 --rotation_adapter_config.r 8
43
+
44
+ # sleep 5
45
+ # echo "3rd exp finishes"
46
+ # date +"%F %T"
47
+ # wandb sync wandb/latest-run
48
+
49
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
50
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex04" --trainer_args.learning_rate=2e-3 \
51
+ # --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
52
+ # sleep 5
53
+ # echo "4th exp finishes"
54
+ # date +"%F %T"
55
+ # wandb sync wandb/latest-run
56
+
57
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
58
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex05" --trainer_args.learning_rate=2e-3 \
59
+ # --rotation_adapter_config.num_rotations 2 --rotation_adapter_config.r 8
60
+ # sleep 5
61
+ # echo "5th exp finishes"
62
+ # date +"%F %T"
63
+ # wandb sync wandb/latest-run
64
+
65
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
66
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex06" --trainer_args.learning_rate=1e-3 \
67
+ # --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
68
+ # sleep 5
69
+ # echo "6th exp finishes"
70
+ # date +"%F %T"
71
+ # wandb sync wandb/latest-run
72
+
73
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
74
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_exps7" --trainer_args.learning_rate=1e-3 \
75
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
76
+ # sleep 5
77
+ # echo "7th exp finishes"
78
+ # date +"%F %T"
79
+ # wandb sync wandb/latest-run
80
+
81
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
82
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex08" --trainer_args.learning_rate=2e-3 \
83
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
84
+ # sleep 5
85
+ # echo "8th exp finishes"
86
+ # date +"%F %T"
87
+ # wandb sync wandb/latest-run
88
+
89
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
90
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex09" --trainer_args.learning_rate=2e-3 \
91
+ # --rotation_adapter_config.num_rotations 16 --rotation_adapter_config.r 1
92
+
93
+ # sleep 5
94
+ # echo "9th exp finishes"
95
+ # date +"%F %T"
96
+ # wandb sync wandb/latest-run
97
+
98
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
99
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex10" --trainer_args.learning_rate=2e-3 \
100
+ # --rotation_adapter_config.num_rotations 8 --rotation_adapter_config.r 2
101
+
102
+ # sleep 5
103
+ # echo "10 exp finishes"
104
+ # date +"%F %T"
105
+ # wandb sync wandb/latest-run
106
+
107
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
108
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex11" --trainer_args.learning_rate=1e-2 \
109
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
110
+
111
+ # sleep 5
112
+ # echo "11 exp finishes"
113
+ # date +"%F %T"
114
+ # wandb sync wandb/latest-run
115
+
116
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
117
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex12" --trainer_args.learning_rate=1e-2 \
118
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'u=v,def'
119
+
120
+ # sleep 5
121
+ # echo "12 exp finishes"
122
+ # date +"%F %T"
123
+ # wandb sync wandb/latest-run
124
+
125
+
126
+ ### continue with 40k
127
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
128
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex13_3ep" --trainer_args.learning_rate=1e-3 \
129
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def' \
130
+ # --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 200
131
+
132
+ # sleep 5
133
+ # echo "13 exp finishes"
134
+ # date +"%F %T"
135
+ # wandb sync wandb/latest-run
136
+
137
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
138
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex14_3ep" --trainer_args.learning_rate=2e-4 \
139
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def' \
140
+ # --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 200
141
+
142
+ # sleep 5
143
+ # echo "14 exp finishes"
144
+ # date +"%F %T"
145
+ # wandb sync wandb/latest-run
146
+
147
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
148
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex15_3ep" --trainer_args.learning_rate=5e-4 \
149
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def' \
150
+ # --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 200
151
+
152
+ # sleep 5
153
+ # echo "15 exp finishes"
154
+ # date +"%F %T"
155
+ # wandb sync wandb/latest-run
156
+
157
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
158
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex16_3ep" --trainer_args.learning_rate=1e-3 \
159
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.05' \
160
+ # --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 200
161
+
162
+ # sleep 5
163
+ # echo "15 exp finishes"
164
+ # date +"%F %T"
165
+ # wandb sync wandb/latest-run
166
+
167
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
168
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex17_3ep" --trainer_args.learning_rate=2e-3 \
169
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.05' \
170
+ # --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 200
171
+
172
+ # sleep 5
173
+ # echo "15 exp finishes"
174
+ # date +"%F %T"
175
+ # wandb sync wandb/latest-run
176
+
177
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
178
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex18_2ep" --trainer_args.learning_rate=1e-3 \
179
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
180
+ # --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 200
181
+
182
+ # sleep 5
183
+ # echo "15 exp finishes"
184
+ # date +"%F %T"
185
+ # wandb sync wandb/latest-run
186
+
187
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
188
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex19_2ep" --trainer_args.learning_rate=5e-3 \
189
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
190
+ # --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 200
191
+
192
+ # sleep 5
193
+ # echo "19 exp finishes"
194
+ # date +"%F %T"
195
+ # wandb sync wandb/latest-run
196
+
197
+
198
+
199
+ # 140126
200
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
201
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex20_2ep" --trainer_args.learning_rate=1e-3 \
202
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
203
+ # --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 11
204
+
205
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
206
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex21_2ep" --trainer_args.learning_rate=1e-3 \
207
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
208
+ # --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 10
209
+
210
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
211
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex24_3ep" --trainer_args.learning_rate=1e-3 \
212
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
213
+ # --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 100 --seed 10
214
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
215
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex25_3ep" --trainer_args.learning_rate=1e-3 \
216
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
217
+ # --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 100 --seed 12
218
+
219
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
220
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex22_2ep" --trainer_args.learning_rate=1e-3 \
221
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
222
+ # --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 12
223
+
224
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
225
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex23_3ep" --trainer_args.learning_rate=1e-3 \
226
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
227
+ # --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 100 --seed 11
228
+
229
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
230
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex26_2ep" --trainer_args.learning_rate=8e-4 \
231
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
232
+ # --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 11
233
+
234
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
235
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex27_2ep" --trainer_args.learning_rate=8e-4 \
236
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
237
+ # --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 10
238
+
239
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
240
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex28_2ep" --trainer_args.learning_rate=2e-3 \
241
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
242
+ # --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 11
243
+
244
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
245
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex29_2ep" --trainer_args.learning_rate=2e-3 \
246
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
247
+ # --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 10
248
+
249
+ accelerate launch --main_process_port 41353 -m src.ft_mathQ \
250
+ --config_path $OMINI_CONFIG --trainer_args.output_dir "./exprep/run_ex30" --trainer_args.learning_rate=8e-4 \
251
+ --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
252
+ --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 20
253
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
254
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exprep/run_ex31" --trainer_args.learning_rate=8e-4 \
255
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
256
+ # --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 21
257
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
258
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exprep/run_ex32" --trainer_args.learning_rate=8e-4 \
259
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
260
+ # --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 22
nl_tasks/scripts/train_initn40k.sh ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export OMINI_CONFIG=./config/commonsense.yaml
2
+
3
+ #echo $OMINI_CONFIG
4
+ export TOKENIZERS_PARALLELISM=true
5
+
6
+ # CUDA Include (/cuda.h)
7
+ CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
8
+
9
+ # 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
10
+ export CPATH=$CPATH:$CUDA_INCLUDE_PATH
11
+ export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
12
+ # echo "CPATH is set to: $CPATH"
13
+ # echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
14
+
15
+ export WANDB_PROJECT="Llama2_7B_FT_Math40k_2"
16
+
17
+ export OMP_NUM_THREADS=1
18
+ export MKL_NUM_THREADS=1
19
+ export OPENBLAS_NUM_THREADS=1
20
+ export NUMEXPR_NUM_THREADS=1
21
+
22
+ date +"%F %T"
23
+
24
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
25
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp_init/run_ex01" --trainer_args.learning_rate=1e-3 \
26
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text "init|kaim_out_u=v"
27
+
28
+ # sleep 5
29
+ # echo "1st exp finishes"
30
+ # date +"%F %T"
31
+ # wandb sync wandb/latest-run
32
+
33
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
34
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp_init/run_ex02" --trainer_args.learning_rate=1e-3 \
35
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text "init|kaim_out_u=v(ratio)"
36
+
37
+ # sleep 5
38
+ # echo "2nd exp finishes"
39
+ # date +"%F %T"
40
+ # wandb sync wandb/latest-run
41
+
42
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
43
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex03" --trainer_args.learning_rate=1e-3 \
44
+ # --rotation_adapter_config.num_rotations 2 --rotation_adapter_config.r 8
45
+
46
+ # sleep 5
47
+ # echo "3rd exp finishes"
48
+ # date +"%F %T"
49
+ # wandb sync wandb/latest-run
50
+
51
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
52
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex04" --trainer_args.learning_rate=2e-3 \
53
+ # --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
54
+ # sleep 5
55
+ # echo "4th exp finishes"
56
+ # date +"%F %T"
57
+ # wandb sync wandb/latest-run
58
+
59
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
60
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex05" --trainer_args.learning_rate=2e-3 \
61
+ # --rotation_adapter_config.num_rotations 2 --rotation_adapter_config.r 8
62
+ # sleep 5
63
+ # echo "5th exp finishes"
64
+ # date +"%F %T"
65
+ # wandb sync wandb/latest-run
66
+
67
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
68
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex06" --trainer_args.learning_rate=1e-3 \
69
+ # --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
70
+ # sleep 5
71
+ # echo "6th exp finishes"
72
+ # date +"%F %T"
73
+ # wandb sync wandb/latest-run
74
+
75
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
76
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_exps7" --trainer_args.learning_rate=1e-3 \
77
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
78
+ # sleep 5
79
+ # echo "7th exp finishes"
80
+ # date +"%F %T"
81
+ # wandb sync wandb/latest-run
82
+
83
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
84
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex08" --trainer_args.learning_rate=2e-3 \
85
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
86
+ # sleep 5
87
+ # echo "8th exp finishes"
88
+ # date +"%F %T"
89
+ # wandb sync wandb/latest-run
90
+
91
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
92
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex09" --trainer_args.learning_rate=2e-3 \
93
+ # --rotation_adapter_config.num_rotations 16 --rotation_adapter_config.r 1
94
+
95
+ # sleep 5
96
+ # echo "9th exp finishes"
97
+ # date +"%F %T"
98
+ # wandb sync wandb/latest-run
99
+
100
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
101
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex10" --trainer_args.learning_rate=2e-3 \
102
+ # --rotation_adapter_config.num_rotations 8 --rotation_adapter_config.r 2
103
+
104
+ # sleep 5
105
+ # echo "10 exp finishes"
106
+ # date +"%F %T"
107
+ # wandb sync wandb/latest-run
108
+
109
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
110
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex11" --trainer_args.learning_rate=1e-2 \
111
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
112
+
113
+ # sleep 5
114
+ # echo "11 exp finishes"
115
+ # date +"%F %T"
116
+ # wandb sync wandb/latest-run
117
+
118
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
119
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex12" --trainer_args.learning_rate=1e-2 \
120
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'u=v,def'
121
+
122
+ # sleep 5
123
+ # echo "12 exp finishes"
124
+ # date +"%F %T"
125
+ # wandb sync wandb/latest-run
126
+
127
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
128
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex13" --trainer_args.learning_rate=1e-3 \
129
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'u=vkaim'
130
+
131
+ # sleep 5
132
+ # echo "13 exp finishes"
133
+ # date +"%F %T"
134
+ # wandb sync wandb/latest-run
135
+
136
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
137
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex14" --trainer_args.learning_rate=2e-3 \
138
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'a,b,def'
139
+
140
+ # sleep 5
141
+ # echo "14 exp finishes"
142
+ # date +"%F %T"
143
+ # wandb sync wandb/latest-run
144
+
145
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
146
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex15" --trainer_args.learning_rate=1e-3 \
147
+ # --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
148
+
149
+ # sleep 5
150
+ # echo "15 exp finishes"
151
+ # date +"%F %T"
152
+
153
+
154
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
155
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex17" --trainer_args.learning_rate=1e-3 \
156
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
157
+ # --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
158
+ # --run_text "dropout|fix_token"
159
+ # sleep 5
160
+ # echo "15 exp finishes"
161
+ # date +"%F %T"
162
+ # wandb sync wandb/latest-run
163
+
164
+
165
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
166
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex18" --trainer_args.learning_rate=1e-3 \
167
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
168
+ # --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
169
+ # --run_text "dropout|fix_token"
170
+ # sleep 5
171
+ # echo "158exp finishes"
172
+ # date +"%F %T"
173
+ # wandb sync wandb/latest-run
174
+
175
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
176
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex19" --trainer_args.learning_rate=2e-3 \
177
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
178
+ # --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
179
+ # --run_text "dropout|fix_token"
180
+ # sleep 5
181
+ # echo "19 exp finishes"
182
+ # date +"%F %T"
183
+ # wandb sync wandb/latest-run
184
+
185
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
186
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex20" --trainer_args.learning_rate=8e-4 \
187
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
188
+ # --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
189
+ # --run_text "dropout|fix_token"
190
+ # sleep 5
191
+ # echo "20 exp finishes"
192
+ # date +"%F %T"
193
+ # wandb sync wandb/latest-run
194
+
195
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
196
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex21" --trainer_args.learning_rate=1e-3 \
197
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
198
+ # --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
199
+ # --run_text "dropout|2ep|1e3"
200
+ # sleep 5
201
+ # echo "21 exp finishes"
202
+ # date +"%F %T"
203
+ # wandb sync wandb/latest-run
204
+
205
+
206
+ # back to official 40k
207
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
208
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex22" --trainer_args.learning_rate=1e-3 \
209
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
210
+ # --trainer_args.num_train_epochs 2.0 --data.dataset_split train \
211
+ # --run_text "drop0.1|2ep|1e3|40k"
212
+ # sleep 5
213
+ # echo "21 exp finishes"
214
+ # date +"%F %T"
215
+ # wandb sync wandb/latest-run
216
+
217
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
218
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex23" --trainer_args.learning_rate=1e-3 \
219
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
220
+ # --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
221
+ # --run_text "drop0.1|2ep|1e3|40k"
222
+ # sleep 5
223
+ # echo "21 exp finishes"
224
+ # date +"%F %T"
225
+ # wandb sync wandb/latest-run
226
+
227
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
228
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex24" --trainer_args.learning_rate=1e-2 \
229
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
230
+ # --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
231
+ # --run_text "drop0.1|2ep|1e2|40k"
232
+ # sleep 5
233
+ # echo "21 exp finishes"
234
+ # date +"%F %T"
235
+ # wandb sync wandb/latest-run
236
+
237
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
238
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex25" --trainer_args.learning_rate=2e-3 \
239
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
240
+ # --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
241
+ # --run_text "drop0.1|2ep|2e3|40k"
242
+ # sleep 5
243
+ # echo "21 exp finishes"
244
+ # date +"%F %T"
245
+ # wandb sync wandb/latest-run
246
+
247
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
248
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex26" --trainer_args.learning_rate=5e-3 \
249
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
250
+ # --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
251
+ # --run_text "drop0.1|2ep|5e3|40k"
252
+ # sleep 5
253
+ # echo "21 exp finishes"
254
+ # date +"%F %T"
255
+ # wandb sync wandb/latest-run
256
+
257
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
258
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex27" --trainer_args.learning_rate=8e-3 \
259
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
260
+ # --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
261
+ # --run_text "drop0.1|2ep|8e3|40k"
262
+ # sleep 5
263
+ # echo "21 exp finishes"
264
+ # date +"%F %T"
265
+ # wandb sync wandb/latest-run
266
+
267
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
268
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex28" --trainer_args.learning_rate=2e-2 \
269
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
270
+ # --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
271
+ # --run_text "drop0.1|2ep|2e2|40k"
272
+ # sleep 5
273
+ # echo "21 exp finishes"
274
+ # date +"%F %T"
275
+ # wandb sync wandb/latest-run
276
+
277
+
278
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
279
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex29" --trainer_args.learning_rate=5e-3 \
280
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
281
+ # --trainer_args.num_train_epochs 2.0 --data.dataset_split train \
282
+ # --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
283
+ # sleep 5
284
+ # echo "29 exp finishes"
285
+ # date +"%F %T"
286
+ # wandb sync wandb/latest-run
287
+
288
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
289
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex30" --trainer_args.learning_rate=1e-3 \
290
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
291
+ # --trainer_args.num_train_epochs 2.0 --data.dataset_split train \
292
+ # --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
293
+ # sleep 5
294
+ # echo "29 exp finishes"
295
+ # date +"%F %T"
296
+ # wandb sync wandb/latest-run
297
+
298
+
299
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
300
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex31" --trainer_args.learning_rate=5e-3 \
301
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
302
+ # --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
303
+ # --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
304
+ # sleep 5
305
+ # echo "29 exp finishes"
306
+ # date +"%F %T"
307
+ # wandb sync wandb/latest-run
308
+
309
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
310
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex32" --trainer_args.learning_rate=1e-3 \
311
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
312
+ # --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
313
+ # --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
314
+ # sleep 5
315
+ # echo "29 exp finishes"
316
+ # date +"%F %T"
317
+ # wandb sync wandb/latest-run
318
+
319
+
320
+ # accelerate launch --main_process_port 41353 -m src.ft_mathQ \
321
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex33" --trainer_args.learning_rate=1e-2 \
322
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
323
+ # --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
324
+ # --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
325
+ # sleep 5
326
+ # echo "29 exp finishes"
327
+ # date +"%F %T"
328
+ # wandb sync wandb/latest-run
329
+
330
+
331
+ accelerate launch --main_process_port 41353 -m src.ft_mathQ \
332
+ --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex34" --trainer_args.learning_rate=2e-2 \
333
+ --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
334
+ --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
335
+ --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
336
+ sleep 5
337
+ echo "29 exp finishes"
338
+ date +"%F %T"
339
+ wandb sync wandb/latest-run
340
+
341
+ bash scripts/merge.sh
nl_tasks/scripts/train_math.sh ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export OMINI_CONFIG=./config/math395.yaml
2
+
3
+ #echo $OMINI_CONFIG
4
+ export TOKENIZERS_PARALLELISM=true
5
+
6
+ # CUDA Include (/cuda.h)
7
+ CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
8
+
9
+ # 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
10
+ export CPATH=$CPATH:$CUDA_INCLUDE_PATH
11
+ export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
12
+ # echo "CPATH is set to: $CPATH"
13
+ # echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
14
+
15
+ export WANDB_PROJECT="Llama2_7B_FT_Math_395k"
16
+
17
+ export OMP_NUM_THREADS=1
18
+ export MKL_NUM_THREADS=1
19
+ export OPENBLAS_NUM_THREADS=1
20
+ export NUMEXPR_NUM_THREADS=1
21
+
22
+ date +"%F %T"
23
+
24
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
25
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex01" --trainer_args.learning_rate=1e-3 \
26
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'def'
27
+
28
+ # sleep 5
29
+ # echo "1st exp finishes"
30
+ # date +"%F %T"
31
+ # wandb sync wandb/latest-run
32
+
33
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
34
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex02" --trainer_args.learning_rate=5e-3 \
35
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
36
+
37
+ # sleep 5
38
+ # echo "2nd exp finishes"
39
+ # date +"%F %T"
40
+ # wandb sync wandb/latest-run
41
+
42
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
43
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex03" --trainer_args.learning_rate=2e-4 \
44
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
45
+
46
+ # sleep 5
47
+ # echo "3rd exp finishes"
48
+ # date +"%F %T"
49
+ # wandb sync wandb/latest-run
50
+
51
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
52
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex04" --trainer_args.learning_rate=1e-3 \
53
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
54
+
55
+ # sleep 5
56
+ # echo "4rd exp finishes"
57
+ # date +"%F %T"
58
+ # wandb sync wandb/latest-run
59
+
60
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
61
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex05" --trainer_args.learning_rate=1e-3 \
62
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
63
+ # --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:101011] --data.split_ratio 0.01
64
+ # sleep 5
65
+ # echo "5th exp finishes"
66
+ # date +"%F %T"
67
+ # wandb sync wandb/latest-run
68
+
69
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
70
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex06" --trainer_args.learning_rate=2e-3 \
71
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
72
+ # --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:101011] --data.split_ratio 0.01
73
+ # sleep 5
74
+ # echo "6th exp finishes"
75
+ # date +"%F %T"
76
+ # wandb sync wandb/latest-run
77
+
78
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
79
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex0s7" --trainer_args.learning_rate=5e-3 \
80
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
81
+ # --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:101011] --data.split_ratio 0.01
82
+ # sleep 5
83
+ # echo "7th exp finishes"
84
+ # date +"%F %T"
85
+ # wandb sync wandb/latest-run
86
+
87
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
88
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex08" --trainer_args.learning_rate=1e-4 \
89
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
90
+ # --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:101011] --data.split_ratio 0.01 \
91
+ # --trainer_args.per_device_train_batch_size 32 --run_text 'u2e2,def'
92
+ # sleep 5
93
+ # echo "8th exp finishes"
94
+ # date +"%F %T"
95
+ # wandb sync wandb/latest-run
96
+
97
+ accelerate launch --main_process_port 41353 -m src.ft_mathR \
98
+ --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex09" --trainer_args.learning_rate=2e-3 \
99
+ --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
100
+ --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:101011] --data.split_ratio 0.01 \
101
+ --trainer_args.per_device_train_batch_size 32 --run_text 'init=def|fix_token'
102
+
103
+ sleep 5
104
+ echo "9th exp finishes"
105
+ date +"%F %T"
106
+ wandb sync wandb/latest-run
107
+
108
+ accelerate launch --main_process_port 41353 -m src.ft_mathR \
109
+ --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex10" --trainer_args.learning_rate=2e-3 \
110
+ --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
111
+ --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 500 --data.dataset_split train[:101011] --data.split_ratio 0.01 \
112
+ --trainer_args.per_device_train_batch_size 32 --run_text "init=def|fix_token"
113
+
114
+ sleep 5
115
+ echo "10 exp finishes"
116
+ date +"%F %T"
117
+ wandb sync wandb/latest-run
118
+
119
+
120
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
121
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex11" --trainer_args.learning_rate=1e-2 \
122
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
123
+
124
+ # sleep 5
125
+ # echo "11 exp finishes"
126
+ # date +"%F %T"
127
+ # wandb sync wandb/latest-run
128
+
129
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
130
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex12" --trainer_args.learning_rate=1e-2 \
131
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'u=v,def'
132
+
133
+ # sleep 5
134
+ # echo "12 exp finishes"
135
+ # date +"%F %T"
136
+ # wandb sync wandb/latest-run
137
+
138
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
139
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex13" --trainer_args.learning_rate=1e-3 \
140
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'u=vkaim'
141
+
142
+ # sleep 5
143
+ # echo "13 exp finishes"
144
+ # date +"%F %T"
145
+ # wandb sync wandb/latest-run
146
+
147
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
148
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex14" --trainer_args.learning_rate=2e-3 \
149
+ # --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'a,b,def'
150
+
151
+ # sleep 5
152
+ # echo "14 exp finishes"
153
+ # date +"%F %T"
154
+ # wandb sync wandb/latest-run
155
+
156
+ # accelerate launch --main_process_port 41353 -m src.ft_mathR \
157
+ # --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex15" --trainer_args.learning_rate=1e-3 \
158
+ # --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
159
+
160
+ # sleep 5
161
+ # echo "15 exp finishes"
162
+ # date +"%F %T"
nl_tasks/setup.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright [2024] [Zhuo Chen]
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import setuptools
17
+
18
+ setuptools.setup(
19
+ name="rpeft",
20
+ version="0.0.2",
21
+ author="SDML",
22
+ packages=setuptools.find_packages(),
23
+ install_requires=[
24
+ 'transformers>=4.0.0',
25
+ 'torch>=2.0.0'
26
+ ],
27
+ python_requires='>=3.9',
28
+ )
nl_tasks/src/bb.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
nl_tasks/src/cc.ipynb ADDED
File without changes
nl_tasks/src/config.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field, fields, asdict
2
+ from typing import Optional, List, Literal, Dict, Any, Union
3
+ from transformers import TrainingArguments, Trainer
4
+ from omegaconf import OmegaConf
5
+ import sys
6
+
7
+
8
+ @dataclass
9
+ class ModelConfig:
10
+ model_name: str = ""
11
+ dropout: float = 0.0
12
+ model_max_seq_length: int = field(default=512)
13
+ data_collator_mode: str=field(default='fixed', metadata={"help": "fixed or dynamic padding in DataCollator"})
14
+ lambda_reg: float = field(default=1e-4, metadata={"help": "The control strength of regularity"})
15
+ adapter_path: Optional[str] = field(default=None)
16
+
17
+ merge_adapter_path: Optional[str] = field(default=None)
18
+ merge_output_path: Optional[str] = field(default=None)
19
+
20
+ @dataclass
21
+ class RotationConfig:
22
+ r: int = field(default=4)
23
+ num_rotations: int = field(default=4)
24
+ task_type: str = "CAUSAL_LM"
25
+ target_modules: List[str] = field(default_factory=lambda: ["q_proj",])
26
+
27
+ @dataclass
28
+ class DataConfig:
29
+ dataset_name: str = 'math'
30
+ split_ratio: float = field(default=0.01)
31
+ path: str = "./nl_tasks/data/MetaMathQA-40K/MetaMathQA-40K.json"
32
+ dataset_split: str = field(default="train[:1000]", metadata={"help": "(`['train', 'test', 'eval']`):"})
33
+ adapter_names: List[Optional[str]] = field(default_factory=lambda: ["default"]) ###
34
+ dataset_field: List[str] = field(default_factory=list, metadata={"help": "Fields of dataset input and output."})
35
+
36
+
37
+ @dataclass
38
+ class TrainingOverride:
39
+ optim: str=field(default="adamw_torch") ##
40
+ eval_strategy: str=field(default='no')
41
+ per_device_train_batch_size: int=field(default=8) ##
42
+ per_device_eval_batch_size: int=field(default=8) ##
43
+
44
+ learning_rate: float = field(default=1e-05)
45
+ lr_scheduler_type: str = field(default='cosine')
46
+ # warmup_ratio: float = field(default=0.1)
47
+ warmup_steps: int = field(default=0)
48
+
49
+ gradient_checkpointing: bool = field(default=False)
50
+ gradient_accumulation_steps: int=field(default=1)
51
+ output_dir: str = field(default="runs")
52
+ save_steps: float = field(default=0)
53
+ save_strategy: str =field(default='no')
54
+ # save_total_limit: int=field(default=1) No need any more
55
+ bf16: bool=field(default=False)
56
+ bf16_full_eval: bool=field(default=False)
57
+ save_safetensors: bool=field(default=False)
58
+
59
+ report_to: Union[None, str, list[str]]=field(default="none")
60
+ logging_steps: int=field(default=25) # we use int only
61
+ # logging_first_step: bool=field(default=False)
62
+ eval_steps: Union[None,int]=field(default=None) # we use int only f
63
+
64
+ dataloader_num_workers: int = field(default=1)
65
+ dataloader_pin_memory: bool = field(default=True) ###
66
+ dataloader_persistent_workers: bool=field(default=True) ###
67
+ dataloader_prefetch_factor: int = field(default=1) ###
68
+
69
+ num_train_epochs: float = field(default=1.0)
70
+ max_steps: int=field(default=-1)
71
+ load_best_model_at_end: bool = field(default=True)
72
+
73
+ @dataclass
74
+ class GlueConfig:
75
+ task_name: str = field(default='mnli')
76
+ pad_to_max_length: bool = field(default=True)
77
+
78
+
79
+ @dataclass
80
+ class MainConfig:
81
+ model: ModelConfig = field(default_factory=ModelConfig)
82
+ rotation_adapter_config: RotationConfig = field(default_factory=RotationConfig)
83
+ data: DataConfig = field(default_factory=DataConfig)
84
+ trainer_args: TrainingOverride = field(default_factory=TrainingOverride)
85
+
86
+ glue: GlueConfig = field(default_factory=GlueConfig)
87
+ project_name: str = "llm_rotation"
88
+ seed: int = 42
89
+ run_text: str=field(default='def')
90
+ # device: str = field(default='cpu')
91
+
92
+ @dataclass
93
+ class HFTrainingArguments(TrainingArguments):
94
+ extension: Optional[Dict[str, Any]] = field(
95
+ default=None,
96
+ metadata={"help": "Serialized MainConfig excluding training args"}
97
+ )
98
+
99
+ def convert_to_trainer_args(main_cfg: MainConfig) -> HFTrainingArguments:
100
+ """
101
+ Maps MainConfig to MyTrainingArguments.
102
+ Logic:
103
+ 1. Extract 'training' fields -> Pass to TrainingArguments constructor.
104
+ 2. Pack 'model', 'data', etc. -> Put into 'extension'.
105
+ """
106
+ KEY = "trainer_args"
107
+ # 1. Convert OmegaConf/Dataclass to pure Python dict
108
+ # resolve=True ensures variables like ${model.name} are interpolated
109
+ full_dict = asdict(main_cfg)
110
+
111
+ # 2. Extract Training Arguments
112
+ # These will be unpack **kwargs to initialize the parent TrainingArguments
113
+ train_args_dict = full_dict.pop(KEY)
114
+
115
+ # 3. The rest (model, data, seed) goes into extension
116
+ extension_payload = full_dict
117
+
118
+ # 4. Initialize MyTrainingArguments
119
+ # Note: We must ensure train_args_dict keys match TrainingArguments fields.
120
+ try:
121
+ args = HFTrainingArguments(**train_args_dict)
122
+ except TypeError as e:
123
+ print(f"Error: Your 'training' config contains keys unknown to HF TrainingArguments: {e}")
124
+ sys.exit(1)
125
+
126
+ # 5. Attach the extension
127
+ args.extension = extension_payload
128
+
129
+ return args
130
+
131
+
132
+
133
+
134
+ @dataclass
135
+ class Training:
136
+ model_name_or_path: Optional[str] = field(default="huggyllama/llama-7b")
137
+ adapter_name_or_path: Optional[str] = field(default=None)
138
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
139
+ dataset_split: str = field(
140
+ default="train[:100000]", metadata={"help": "(`['train', 'test', 'eval']`):"}
141
+ )
142
+ dataset_field: List[str] = field(
143
+ default=None, metadata={"help": "Fields of dataset input and output."}
144
+ )
145
+ optim: str = field(default="adamw_torch")
146
+ model_max_length: int = field(default=512, metadata={
147
+ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, )
148
+ hrft_r: int = field(default=8, metadata={
149
+ "help": "The rank of the adapter. When passing `None` and `adapter_name_or_path` is also `None`, full fine-tuning is used."})
150
+ init_a: float = field(default=1e-4, metadata={"help": "The initial weights"})
151
+ eps: float = field(default=1e-4, metadata={"help": "The control strength of COFT. The freedom of rotation."})
152
+ lamda: float = field(default=1e-4, metadata={"help": "The control strength of regularity"})
153
+ add_orth: str = field(default='none', metadata={"help": ""})
154
+ init_weights: Literal[True, "pissa"] = field(
155
+ default=True,
156
+ metadata={
157
+ "help": (
158
+ "Passing True (default) results in the LoRA initialization."
159
+ "Passing `pissa` results in PiSSA initialization."
160
+ ),
161
+ },
162
+ )
163
+ extension: Optional[Dict[str, Any]] = field(
164
+ default=None,
165
+ metadata={"help": "Serialized MainConfig excluding training args"}
166
+ )
167
+
168
+ # target_modules: str = (
169
+ # "(.*x_embedder"
170
+ # "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear"
171
+ # "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k"
172
+ # "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q"
173
+ # "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v"
174
+ # "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0"
175
+ # "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2"
176
+ # "|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear"
177
+ # "|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp"
178
+ # "|.*single_transformer_blocks\\.[0-9]+\\.proj_out"
179
+ # "|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k"
180
+ # "|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q"
181
+ # "|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v"
182
+ # "|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
183
+ # )
nl_tasks/src/ft_mathQ.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ import sys
3
+ #print('sys.path: ___ ', sys.path)
4
+ #print(f"Current Python Executable: {sys.executable}")
5
+
6
+ ### dynamo warning
7
+ import warnings
8
+
9
+ # Ignore FutureWarning: prims_common.check, Online Softmax
10
+ warnings.filterwarnings("ignore", category=FutureWarning, module='torch._inductor.lowering')
11
+ warnings.filterwarnings("ignore", message=".*Online softmax is disabled on the fly.*", category=UserWarning)
12
+
13
+ warnings.filterwarnings("ignore", message=".*Our suggested max number of worker in current system is 1.*", category=UserWarning)
14
+ warnings.filterwarnings("ignore", message=".*will be initialized from a multivariate normal distribution.*")
15
+ warnings.filterwarnings("ignore", message=".*that differ from the model config and generation config.*", category=UserWarning)
16
+ warnings.filterwarnings("ignore", message=".*torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch..*", category=UserWarning)
17
+
18
+ import torch
19
+ torch.backends.cuda.matmul.fp32_precision = 'tf32'
20
+ # import wandb
21
+ import os
22
+ torch.set_num_threads(1)
23
+ os.environ["OMP_NUM_THREADS"]="1"
24
+ os.environ["MKL_NUM_THREADS"]="1"
25
+ import torch
26
+ print(f"PyTorch version: {torch.__version__}")
27
+ print(f"CUDA available: {torch.cuda.is_available()}")
28
+ print(f"PyTorch built with CUDA version: {torch.version.cuda}")
29
+
30
+ import yaml
31
+ #from peft import LoraConfig, get_peft_model_state_dict
32
+ from torch.utils.data import DataLoader
33
+ import time
34
+ from datetime import datetime
35
+ import math
36
+
37
+ from typing import List, Tuple
38
+
39
+ # import prodigyopt
40
+
41
+
42
+ ###
43
+ import copy
44
+ from dataclasses import field, dataclass, asdict
45
+ from typing import Sequence, Literal, Dict
46
+
47
+ import transformers
48
+ from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
49
+ from transformers import Trainer
50
+ from transformers.modeling_utils import *
51
+ from transformers.trainer import _is_peft_model
52
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
53
+ from transformers.data.data_collator import DataCollator
54
+
55
+ from transformers.training_args import TrainingArguments
56
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
57
+ from transformers.trainer_callback import TrainerCallback
58
+ from transformers.trainer_utils import EvalPrediction
59
+ from torch.utils.data import Dataset, IterableDataset
60
+ from datasets import load_dataset
61
+ ##
62
+ #from ..pipeline.flux_omini import transformer_forward, encode_images
63
+ # from ...omini.rotation import RotationTuner, RotationConfig
64
+
65
+
66
+ from rpeft.rotation import RotationTuner, RotationConfig
67
+ from rpeft import get_peft_model, PeftModel
68
+ from .config import MainConfig, convert_to_trainer_args
69
+ import pyrallis
70
+ from omegaconf import OmegaConf
71
+ import torch.optim as optim
72
+ import wandb
73
+ from torch.nn.utils.rnn import pad_sequence
74
+
75
+ IGNORE_INDEX = -100
76
+ PROMPT = (
77
+ "Below is an instruction that describes a task. "
78
+ "Write a response that appropriately completes the request.\n\n"
79
+ "### Instruction:\n{instruction}\n\n### Response:"
80
+ )
81
+
82
+ def get_rank():
83
+ try:
84
+ rank = int(os.environ.get("LOCAL_RANK"))
85
+ except:
86
+ rank = 0
87
+ return rank
88
+
89
+
90
+ def get_config():
91
+ config_path = os.environ.get("OMINI_CONFIG")
92
+ assert config_path is not None, "Please set the OMINI_CONFIG environment variable"
93
+ with open(config_path, "r") as f:
94
+ config = yaml.safe_load(f)
95
+ return config
96
+
97
+
98
+ def init_wandb(wandb_config, run_name):
99
+ import wandb
100
+
101
+ try:
102
+ assert os.environ.get("WANDB_API_KEY") is not None
103
+ wandb.init(
104
+ project=wandb_config["project"],
105
+ name=run_name,
106
+ config={},
107
+ )
108
+ except Exception as e:
109
+ print("Failed to initialize WanDB:", e)
110
+
111
+
112
+
113
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
114
+ """Collects the state dict and dump to disk."""
115
+ state_dict = trainer.model.state_dict()
116
+ if trainer.args.should_save:
117
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
118
+ del state_dict
119
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
120
+
121
+
122
+ def smart_tokenizer_and_embedding_resize(
123
+ special_tokens_dict: Dict,
124
+ tokenizer: transformers.PreTrainedTokenizer,
125
+ model: transformers.PreTrainedModel,
126
+ ):
127
+ """Resize tokenizer and embedding.
128
+
129
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
130
+ """
131
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
132
+ model.resize_token_embeddings(len(tokenizer))
133
+
134
+ if num_new_tokens > 0:
135
+ input_embeddings = model.get_input_embeddings().weight.data
136
+ output_embeddings = model.get_output_embeddings().weight.data
137
+
138
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
139
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
140
+
141
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
142
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
143
+
144
+
145
+ def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
146
+ """Tokenize a list of strings."""
147
+ tokenized_list = [
148
+ tokenizer(
149
+ text,
150
+ return_tensors="pt",
151
+ padding="longest",
152
+ max_length=tokenizer.model_max_length,
153
+ truncation=True,
154
+ )
155
+ for text in strings
156
+ ]
157
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
158
+ input_ids_lens = labels_lens = [
159
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
160
+ ]
161
+ return dict(
162
+ input_ids=input_ids,
163
+ labels=labels,
164
+ input_ids_lens=input_ids_lens,
165
+ labels_lens=labels_lens,
166
+ )
167
+
168
+ def preprocess(
169
+ sources: Sequence[str],
170
+ targets: Sequence[str],
171
+ tokenizer: transformers.PreTrainedTokenizer,
172
+ ) -> Dict:
173
+ """Preprocess the data by tokenizing."""
174
+ examples = [s + t for s, t in zip(sources, targets)]
175
+ examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
176
+ input_ids = examples_tokenized["input_ids"]
177
+ labels = copy.deepcopy(input_ids)
178
+ for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
179
+ label[:source_len] = IGNORE_INDEX
180
+ return dict(input_ids=input_ids, labels=labels)
181
+
182
+
183
+ # @dataclass
184
+ # class DataCollatorForSupervisedDataset():
185
+ # """Collate examples for supervised fine-tuning."""
186
+
187
+ # tokenizer: transformers.PreTrainedTokenizer
188
+ # max_length: int = field(default=512)
189
+ # mode: str = field(default="fixed") # dynamic -> dynamo
190
+
191
+ # def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
192
+ # if self.mode == 'dynamic':
193
+ # input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
194
+ # input_ids = [torch.tensor(x) for x in input_ids]
195
+ # input_ids = torch.nn.utils.rnn.pad_sequence(
196
+ # input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
197
+ # )
198
+ # labels = [torch.tensor(x) for x in labels]
199
+ # labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
200
+ # return dict(
201
+ # input_ids=input_ids,
202
+ # labels=labels,
203
+ # attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
204
+ # )
205
+ # elif self.mode == 'fixed':
206
+ # input_ids = [torch.tensor(x["input_ids"][:self.max_length]) for x in instances]
207
+ # input_ids = torch.stack([
208
+ # torch.nn.functional.pad(x, (0, self.max_length - x.size(0)), value=self.tokenizer.pad_token_id)
209
+ # for x in input_ids
210
+ # ])
211
+
212
+ # # Labels
213
+ # labels = [torch.tensor(x["labels"][:self.max_length]) for x in instances]
214
+ # labels = torch.stack([
215
+ # torch.nn.functional.pad(x, (0, self.max_length - x.size(0)), value=IGNORE_INDEX)
216
+ # for x in labels
217
+ # ])
218
+
219
+ # return dict(
220
+ # input_ids=input_ids,
221
+ # labels=labels,
222
+ # attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
223
+ # )
224
+ # else:
225
+ # raise NotImplementedError
226
+
227
+ # @dataclass
228
+ # class DataCollatorForSupervisedDataset(object):
229
+ # tokenizer: transformers.PreTrainedTokenizer
230
+ # max_length: int = field(default=512)
231
+ # mode: str = field(default="fixed") # "dynamic" or "fixed"
232
+
233
+ # def _pad_to_length(self, tensors: Sequence[torch.Tensor], pad_value: int, target_len: int):
234
+ # """Pad a list of 1D tensors to target_len (int) and stack -> (B, target_len)."""
235
+ # batch_size = len(tensors)
236
+ # out = torch.full((batch_size, target_len), pad_value, dtype=tensors[0].dtype)
237
+ # for i, t in enumerate(tensors):
238
+ # L = min(t.size(0), target_len)
239
+ # out[i, :L] = t[:L]
240
+ # return out
241
+
242
+ # def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
243
+ # # Collect raw sequences (lists or tensors)
244
+ # input_seqs = [torch.tensor(x["input_ids"], dtype=torch.long) for x in instances]
245
+ # label_seqs = [torch.tensor(x["labels"], dtype=torch.long) for x in instances]
246
+
247
+ # if self.mode == "dynamic":
248
+ # # pad to the max length present in this batch (<= self.max_length)
249
+ # batch_max_len = min(max([s.size(0) for s in input_seqs]), self.max_length)
250
+ # input_ids = self._pad_to_length(input_seqs, pad_value=self.tokenizer.pad_token_id, target_len=batch_max_len)
251
+ # labels = self._pad_to_length(label_seqs, pad_value=IGNORE_INDEX, target_len=batch_max_len)
252
+ # elif self.mode == "fixed":
253
+ # # always pad/truncate to self.max_length
254
+ # input_ids = self._pad_to_length(input_seqs, pad_value=self.tokenizer.pad_token_id, target_len=self.max_length)
255
+ # labels = self._pad_to_length(label_seqs, pad_value=IGNORE_INDEX, target_len=self.max_length)
256
+ # else:
257
+ # raise NotImplementedError(f"Unknown mode: {self.mode}")
258
+
259
+ # attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
260
+
261
+ # return {
262
+ # "input_ids": input_ids,
263
+ # "labels": labels,
264
+ # "attention_mask": attention_mask
265
+ # }
266
+
267
+ @dataclass
268
+ class DataCollatorForSupervisedDataset():
269
+ tokenizer: transformers.PreTrainedTokenizer
270
+ max_length: int = field(default=512)
271
+ mode: str = field(default="fixed") # "dynamic" or "fixed"
272
+
273
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
274
+ # Extract inputs and labels
275
+ # Assuming instances is a list of dicts like {'input_ids': [...], 'labels': [...]}
276
+ input_ids_list = [torch.tensor(x["input_ids"], dtype=torch.long) for x in instances]
277
+ labels_list = [torch.tensor(x["labels"], dtype=torch.long) for x in instances]
278
+
279
+ # 1. Determine padding logic
280
+ if self.mode == "dynamic":
281
+ # Dynamic padding: pad to the longest sequence in the batch
282
+ # But cap it at self.max_length to prevent OOM
283
+ batch_max_len = max([len(x) for x in input_ids_list])
284
+ target_len = min(batch_max_len, self.max_length)
285
+ else:
286
+ # Fixed padding: always pad to max_length
287
+ target_len = self.max_length
288
+
289
+ # 2. Helper to pad and truncate
290
+ def pad_and_truncate(tensors, padding_value):
291
+ # First, pad everything using PyTorch's optimized utility (batch_first=True)
292
+ padded = pad_sequence(tensors, batch_first=True, padding_value=padding_value)
293
+
294
+ # Handle truncation/extending to exact target_len
295
+ curr_len = padded.shape[1]
296
+ if curr_len > target_len:
297
+ # Truncate if too long (rare if filtered beforehand)
298
+ return padded[:, :target_len]
299
+ elif curr_len < target_len:
300
+ # Pad more if shorter than target_len (happens in fixed mode)
301
+ diff = target_len - curr_len
302
+ padding = torch.full((padded.shape[0], diff), padding_value, dtype=padded.dtype)
303
+ return torch.cat([padded, padding], dim=1)
304
+ else:
305
+ return padded
306
+
307
+ # 3. Apply padding
308
+ # Critical: tokenizer.pad_token_id must NOT be None here
309
+ if self.tokenizer.pad_token_id is None:
310
+ raise ValueError("Tokenizer.pad_token_id is None. Please set it to eos_token_id or unk_token_id.")
311
+
312
+ input_ids = pad_and_truncate(input_ids_list, self.tokenizer.pad_token_id)
313
+ labels = pad_and_truncate(labels_list, IGNORE_INDEX)
314
+
315
+ # 4. Create Attention Mask explicitly
316
+ # .ne() creates Bools, .long() casts to 0s and 1s for compatibility
317
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
318
+
319
+ return {
320
+ "input_ids": input_ids,
321
+ "labels": labels,
322
+ "attention_mask": attention_mask
323
+ }
324
+
325
+ def train_tokenize_function(examples, tokenizer, query, response):
326
+ sources = [PROMPT.format_map(dict(instruction=instruction)) for instruction in examples[query]]
327
+ targets = [f"{output}{tokenizer.eos_token}" for output in examples[response]]
328
+ data_dict = preprocess(sources, targets, tokenizer)
329
+ return data_dict
330
+
331
+
332
+
333
+ ### Trainer
334
+ def default_worker_init_fn(worker_id):
335
+ # mỗi worker chỉ 1 thread cho BLAS
336
+ try:
337
+ import numpy as _np
338
+ except Exception:
339
+ _np = None
340
+ torch.set_num_threads(1)
341
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
342
+ os.environ.setdefault("MKL_NUM_THREADS", "1")
343
+ os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
344
+ # Optional: bind CPU affinity per worker to avoid contention (NUMA-aware)
345
+ try:
346
+ cpu_count = os.cpu_count() or 1
347
+ # chia đều CPU cho workers
348
+ num_workers = getattr(torch.utils.data, "_num_workers", None)
349
+ # fallback: if not available, compute from environment variable or pass externally
350
+ # We'll do a simple round-robin assignment using worker_id
351
+ # assign a small mask of cores to this worker (e.g., chunk size 4)
352
+ chunk = max(1, cpu_count // max(1, min(64, cpu_count)))
353
+ start = (worker_id * chunk) % cpu_count
354
+ end = start + chunk
355
+ mask = set(range(start, min(end, cpu_count)))
356
+ try:
357
+ os.sched_setaffinity(0, mask)
358
+ except Exception:
359
+ pass
360
+ except Exception:
361
+ pass
362
+
363
+ def set_seed(seed: int):
364
+ # random.seed(seed)
365
+ # np.random.seed(seed)
366
+ torch.manual_seed(seed)
367
+ torch.cuda.manual_seed_all(seed)
368
+ transformers.set_seed(seed)
369
+
370
+
371
+ @pyrallis.wrap()
372
+ def main(mainCfg: MainConfig):
373
+ #mainCfg = get_config()
374
+ #print(mainCfg)
375
+ print('='*120)
376
+ # print(OmegaConf.to_yaml(mainCfg))
377
+ # print('-'*40)
378
+ #
379
+ # print((training_args))
380
+ set_seed(mainCfg.seed)
381
+ training_args = convert_to_trainer_args(mainCfg)
382
+
383
+ # wandb
384
+ ENTITY = "nvan-13-korea-university"
385
+ PROJECT = os.environ.get("WANDB_PROJECT")
386
+ api = wandb.Api()
387
+ try:
388
+ runs_list = api.runs(f"{ENTITY}/{PROJECT}")
389
+ next_run_num = len(runs_list) + 1
390
+ except Exception as e:
391
+ next_run_num = 1
392
+
393
+ training_args.run_name = f'[{next_run_num}]lr={mainCfg.trainer_args.learning_rate:.1e},b={mainCfg.trainer_args.per_device_train_batch_size},'\
394
+ f'n={mainCfg.rotation_adapter_config.num_rotations},r={mainCfg.rotation_adapter_config.r},'\
395
+ f'init={mainCfg.run_text}'
396
+ # training_args.project = f'Rotation-Llama2-{mainCfg.data.dataset_name}'
397
+
398
+ # print('-'*40)
399
+ # print(training_args.to_json_string())
400
+ # exit()
401
+
402
+ model = AutoModelForCausalLM.from_pretrained(mainCfg.model.model_name,
403
+ device_map="auto", low_cpu_mem_usage=True,
404
+ dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
405
+ attn_implementation="sdpa",
406
+ )
407
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
408
+ print("DEVICE", DEVICE)
409
+ # for name, param in model.named_parameters():
410
+ # if 'q_proj' in name and 'layers.5' in name:
411
+ # print(f"Name: {name} | {param.shape} ")
412
+ # print(f"Name (pretrained): {name} | {param.shape} | {param.data[0:5,0:5]}")
413
+ # print('model', model)
414
+ # exit()
415
+
416
+ total_params_now = sum(p.numel() for p in model.parameters())
417
+ print(f'#params of the pretrained model, {total_params_now:,}')
418
+ # print(model)
419
+ if mainCfg.model.adapter_path is not None:
420
+ print('___ Loading from: ', mainCfg.model.adapter_path)
421
+ model = PeftModel.from_pretrained(model, mainCfg.model.adapter_path, is_trainable = True)
422
+ elif mainCfg.rotation_adapter_config.r is not None:
423
+ rotation_adapter_config = asdict(mainCfg.rotation_adapter_config)
424
+ # rotation_adapter_config[peft_type]
425
+
426
+ for adapter_name in mainCfg.data.adapter_names:
427
+ rotation_config = RotationConfig(**rotation_adapter_config)
428
+ model = get_peft_model(model, rotation_config, adapter_name=adapter_name)
429
+ # model.set_adapter(adapter_name)
430
+
431
+ # import peft
432
+ # from peft import OFTConfig
433
+ # oft_config = OFTConfig(
434
+ # # r=16,
435
+ # oft_block_size=4*mainCfg.rotation_adapter_config.r,
436
+ # use_cayley_neumann=True,
437
+ # target_modules=["q_proj", "v_proj",],
438
+ # module_dropout=0.05, # mainCfg.rotation_adapter_config.drop_out,
439
+ # # task_type="CAUSAL_LM",
440
+ # bias="none",
441
+ # )
442
+
443
+ # for adapter_name in mainCfg.data.adapter_names:
444
+ # model = peft.get_peft_model(model, oft_config, adapter_name=adapter_name)
445
+ else:
446
+ print("Full Parameter Fine-Tuning")
447
+ model = model.to(DEVICE)
448
+
449
+ # print('model', model)
450
+ model.print_trainable_parameters()
451
+ exit()
452
+ # print("Program starts")
453
+ # time.sleep(300)
454
+ # exit()
455
+
456
+ # for name, param in model.named_parameters():
457
+ # if 'q_proj' in name and 'rotation' in name and 'layers.5' in name:
458
+ # print(f"Name: {name} | {param.shape} ")
459
+ # print(f"Name (pretrained): {name} | {param.shape} ")
460
+ # X = param.data
461
+ # print('model', type(model), X.shape)
462
+ # visualize_value_distribution(X)
463
+ # exit()
464
+
465
+ rotation_layers = filter(
466
+ lambda p: p.requires_grad, model.parameters()
467
+ )
468
+
469
+ tokenizer = AutoTokenizer.from_pretrained(
470
+ mainCfg.model.model_name,
471
+ model_max_length=mainCfg.model.model_max_seq_length,
472
+ padding_side="right",
473
+ use_fast=True,
474
+ )
475
+
476
+ if tokenizer.pad_token is None:
477
+ if tokenizer.unk_token_id is not None:
478
+ tokenizer.pad_token_id = tokenizer.unk_token_id
479
+ tokenizer.pad_token = tokenizer.unk_token
480
+ print("Set PAD token to UNK token.")
481
+ elif tokenizer.eos_token_id is not None:
482
+ tokenizer.pad_token_id = tokenizer.eos_token_id
483
+ tokenizer.pad_token = tokenizer.eos_token
484
+ print("Set PAD token to EOS token.")
485
+
486
+ if model is not None:
487
+ model.config.pad_token_id = tokenizer.pad_token_id
488
+ if model.config.pad_token_id != tokenizer.pad_token_id:
489
+ raise ValueError("Failed to sync pad_token_id between tokenizer and model config")
490
+
491
+ # local MetaMathQA-40K
492
+ raw_datasets = load_dataset("json", data_files=mainCfg.data.path, split=mainCfg.data.dataset_split)
493
+ #raw_train_datasets = load_dataset("MetaMathQA-40K", split=mainCfg.data.dataset_split)
494
+ # print('raw', type(raw_train_datasets), len(raw_train_datasets))
495
+
496
+ # split a single set
497
+ # split_ratio = mainCfg.data.split_ratio
498
+ # split_data = raw_datasets.train_test_split(test_size=split_ratio, seed=42)
499
+ # raw_train_datasets = split_data['train']
500
+ # raw_valid_datasets = split_data['test']
501
+
502
+ train_dataset = raw_datasets.map(
503
+ train_tokenize_function,
504
+ batched=True,
505
+ batch_size=30000,
506
+ num_proc=32,
507
+ remove_columns=raw_datasets.column_names,
508
+ load_from_cache_file=True,
509
+ desc="Running tokenizer on train dataset",
510
+ fn_kwargs={"tokenizer": tokenizer, "query": mainCfg.data.dataset_field[0],
511
+ "response": mainCfg.data.dataset_field[1]}
512
+ )
513
+
514
+ # valid_dataset = raw_valid_datasets.map(
515
+ # train_tokenize_function,
516
+ # batched=True,
517
+ # batch_size=30000,
518
+ # num_proc=32,
519
+ # remove_columns=raw_train_datasets.column_names,
520
+ # load_from_cache_file=True,
521
+ # desc="Running tokenizer on train dataset",
522
+ # fn_kwargs={"tokenizer": tokenizer, "query": mainCfg.data.dataset_field[0],
523
+ # "response": mainCfg.data.dataset_field[1]}
524
+ # )
525
+ print('- dataset size: ', len(train_dataset))
526
+
527
+
528
+ # print('dataset', type(train_dataset))
529
+ # print('process', len(train_dataset))
530
+ # print(f"Sample features: {train_dataset.column_names}, {train_dataset.num_rows}")
531
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=mainCfg.model.model_max_seq_length,
532
+ #mode=mainCfg.model.data_collator_mode,
533
+ )
534
+ data_module = dict(train_dataset=train_dataset, data_collator=data_collator)
535
+
536
+ optimizer = optim.AdamW(
537
+ rotation_layers,
538
+ lr=mainCfg.trainer_args.learning_rate, #
539
+ eps=1e-8
540
+ )
541
+ # print('model x', model)
542
+ start_time = datetime.now()
543
+ print('start time: ', start_time.strftime("%Y-%m-%d %H:%M:%S"))
544
+ trainer = MyTrainer(model=model, processing_class=tokenizer,
545
+ lamda=mainCfg.model.lambda_reg,
546
+ optimizers=(optimizer, None),
547
+ args=training_args, **data_module)
548
+ model.config.use_cache = False
549
+
550
+
551
+ # now = time.time()
552
+ # for i in range(20):
553
+ # next(iter(trainer.get_train_dataloader()))
554
+ # print('time', time.time()-now)
555
+ # now = time.time()
556
+
557
+
558
+ # dl = trainer.get_train_dataloader()
559
+ # t0 = time.time()
560
+ # for i, batch in enumerate(dl):
561
+ # if i==20: break
562
+ # print("time / 20 batches =", time.time() - t0)
563
+ # exit()
564
+
565
+
566
+ # model2 = model.merge_and_unload()
567
+ # results2 = trainer2.evaluate()
568
+ # print('results2: ', results2)
569
+ # exit()
570
+
571
+ trainer.train()
572
+
573
+ end_time = datetime.now()
574
+ print('end time: ', end_time.strftime("%Y-%m-%d %H:%M:%S"), '| duration: ', end_time - start_time)
575
+ # Save Model (Includes Adapter weights & Config)
576
+ # trainer.save_model(os.path.join(training_args.output_dir, 'ft'))
577
+ # Save Tokenizer
578
+ tokenizer.save_pretrained(os.path.join(training_args.output_dir, 'ft'))
579
+ # Save Training State (Metrics & Logs)
580
+ trainer.save_state()
581
+
582
+ # save peft_config. Or model.base_model.peft_config['default']
583
+ model.peft_config.save_pretrained(os.path.join(training_args.output_dir, 'ft'))
584
+
585
+ # the easiest way
586
+ model.save_pretrained(os.path.join(training_args.output_dir, 'ft2'))
587
+ return
588
+
589
+
590
+
591
+ class MyTrainer(Trainer):
592
+
593
+ def __init__(
594
+ self,
595
+ model: Union[PreTrainedModel, nn.Module] = None,
596
+ args: TrainingArguments = None,
597
+ data_collator: Optional[DataCollator] = None,
598
+ train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
599
+ eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None,
600
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
601
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
602
+ compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
603
+ callbacks: Optional[List[TrainerCallback]] = None,
604
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
605
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
606
+ #run_name: Optional[str] = None,
607
+ #report_to: Optional[Union[str, list[str]]] = None,
608
+ # project
609
+ lamda: float = 1e-4
610
+ ):
611
+ super().__init__(model=model, args=args, data_collator=data_collator,
612
+ train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class,
613
+ model_init=model_init, compute_metrics=compute_metrics, callbacks=callbacks,
614
+ optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics,
615
+ #run_name=run_name, report_to=report_to
616
+ )
617
+ self.lamda = lamda
618
+
619
+ # def compute_loss(self, model, inputs, return_outputs=False,
620
+ # num_items_in_batch: Optional[torch.Tensor] = None,):
621
+ # """
622
+ # How the loss is computed by Trainer. By default, all models return the loss in the first element.
623
+
624
+ # Subclass and override for custom behavior.
625
+ # """
626
+ # if self.label_smoother is not None and "labels" in inputs:
627
+ # labels = inputs.pop("labels")
628
+ # else:
629
+ # labels = None
630
+ # if self.model_accepts_loss_kwargs:
631
+ # kwargs = {}
632
+ # if num_items_in_batch is not None:
633
+ # kwargs["num_items_in_batch"] = num_items_in_batch
634
+ # inputs = {**inputs, **kwargs}
635
+ # outputs = model(**inputs)
636
+ # # Save past state if it exists
637
+ # # TODO: this needs to be fixed and made cleaner later.
638
+ # if self.args.past_index >= 0:
639
+ # self._past = outputs[self.args.past_index]
640
+
641
+ # if labels is not None:
642
+ # unwrapped_model = unwrap_model(model)
643
+ # if _is_peft_model(unwrapped_model):
644
+ # model_name = unwrapped_model.base_model.model._get_name()
645
+ # else:
646
+ # model_name = unwrapped_model._get_name()
647
+ # if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
648
+ # loss = self.label_smoother(outputs, labels, shift_labels=True)
649
+ # else:
650
+ # loss = self.label_smoother(outputs, labels)
651
+ # else:
652
+ # if isinstance(outputs, dict) and "loss" not in outputs:
653
+ # raise ValueError(
654
+ # "The model did not return a loss from the inputs, only the following keys: "
655
+ # f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
656
+ # )
657
+ # # We don't use .loss here since the model may return tuples instead of ModelOutput.
658
+ # loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
659
+ # # ------------------------------------------------------------------------------
660
+
661
+ # # for name, param in model.named_parameters():
662
+ # # if 'oft_r' in name:
663
+ # # device = param.device
664
+ # # householder_U_norm = param / param.norm(dim=0)
665
+ # # orth_loss = torch.norm(
666
+ # # torch.eye(householder_U_norm.size(1), device=device) - householder_U_norm.t() @ householder_U_norm)
667
+ # # print(self.lamda)
668
+ # # loss = loss + self.lamda * orth_loss.to(loss.device)
669
+
670
+ # # ------------------------------------------------------------------------------
671
+
672
+ # return (loss, outputs) if return_outputs else loss
673
+
674
+ def get_train_dataloader(self):
675
+ # get dataset & sampler from super
676
+ train_dataset = self.train_dataset
677
+ sampler = self._get_train_sampler()
678
+
679
+ # compute effective batch size per step (HF has some routines; we use per_device_train_batch_size)
680
+ batch_size = self.args.train_batch_size if hasattr(self.args, "train_batch_size") else self.args.per_device_train_batch_size
681
+
682
+ # recommended num_workers: start moderate (16), you can tune upward
683
+ num_workers = getattr(self.args, "dataloader_num_workers", 16)
684
+ pin_memory = getattr(self.args, "dataloader_pin_memory", True)
685
+ prefetch_factor = getattr(self.args, "dataloader_prefetch_factor", 2)
686
+ persistent_workers = getattr(self.args, "dataloader_persistent_workers", True)
687
+
688
+ return DataLoader(
689
+ train_dataset,
690
+ batch_size=batch_size,
691
+ sampler=sampler,
692
+ collate_fn=self.data_collator,
693
+ drop_last=self.args.dataloader_drop_last if hasattr(self.args, "dataloader_drop_last") else False,
694
+ num_workers=num_workers,
695
+ pin_memory=pin_memory,
696
+ persistent_workers=persistent_workers,
697
+ prefetch_factor=prefetch_factor,
698
+ worker_init_fn=default_worker_init_fn,
699
+ )
700
+
701
+ if __name__ == "__main__":
702
+ main()
nl_tasks/src/ft_mathR.py ADDED
@@ -0,0 +1,689 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ import sys
3
+ #print('sys.path: ___ ', sys.path)
4
+ #print(f"Current Python Executable: {sys.executable}")
5
+
6
+ ### dynamo warning
7
+ import warnings
8
+
9
+ # Ignore FutureWarning: prims_common.check, Online Softmax
10
+ warnings.filterwarnings("ignore", category=FutureWarning, module='torch._inductor.lowering')
11
+ warnings.filterwarnings("ignore", message=".*Online softmax is disabled on the fly.*", category=UserWarning)
12
+
13
+ warnings.filterwarnings("ignore", message=".*Our suggested max number of worker in current system is 1.*", category=UserWarning)
14
+ warnings.filterwarnings("ignore", message=".*will be initialized from a multivariate normal distribution.*")
15
+ warnings.filterwarnings("ignore", message=".*that differ from the model config and generation config.*", category=UserWarning)
16
+ warnings.filterwarnings("ignore", message=".*torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch..*", category=UserWarning)
17
+
18
+ import torch
19
+ torch.backends.cuda.matmul.fp32_precision = 'tf32'
20
+ # import wandb
21
+ import os
22
+ torch.set_num_threads(1)
23
+ os.environ["OMP_NUM_THREADS"]="1"
24
+ os.environ["MKL_NUM_THREADS"]="1"
25
+ import torch
26
+ print(f"PyTorch version: {torch.__version__}")
27
+ print(f"CUDA available: {torch.cuda.is_available()}")
28
+ print(f"PyTorch built with CUDA version: {torch.version.cuda}")
29
+
30
+ import yaml
31
+ #from peft import LoraConfig, get_peft_model_state_dict
32
+ from torch.utils.data import DataLoader
33
+ import time
34
+ from datetime import datetime
35
+ import math
36
+
37
+ from typing import List, Tuple
38
+
39
+ # import prodigyopt
40
+
41
+
42
+ ###
43
+ import copy
44
+ from dataclasses import field, dataclass, asdict
45
+ from typing import Sequence, Literal, Dict
46
+
47
+ import transformers
48
+ from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
49
+ from transformers import Trainer
50
+ from transformers.modeling_utils import *
51
+ from transformers.trainer import _is_peft_model
52
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
53
+ from transformers.data.data_collator import DataCollator
54
+
55
+ from transformers.training_args import TrainingArguments
56
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
57
+ from transformers.trainer_callback import TrainerCallback
58
+ from transformers.trainer_utils import EvalPrediction
59
+ from torch.utils.data import Dataset, IterableDataset
60
+ from datasets import load_dataset
61
+ ##
62
+ #from ..pipeline.flux_omini import transformer_forward, encode_images
63
+ # from ...omini.rotation import RotationTuner, RotationConfig
64
+
65
+
66
+ from rpeft.rotation import RotationTuner, RotationConfig
67
+ from rpeft import get_peft_model, PeftModel
68
+ from .config import MainConfig, convert_to_trainer_args
69
+ import pyrallis
70
+ from omegaconf import OmegaConf
71
+ import torch.optim as optim
72
+ import wandb
73
+ from torch.nn.utils.rnn import pad_sequence
74
+
75
+ IGNORE_INDEX = -100
76
+ PROMPT = (
77
+ "Below is an instruction that describes a task. "
78
+ "Write a response that appropriately completes the request.\n\n"
79
+ "### Instruction:\n{instruction}\n\n### Response:"
80
+ )
81
+
82
+ def get_rank():
83
+ try:
84
+ rank = int(os.environ.get("LOCAL_RANK"))
85
+ except:
86
+ rank = 0
87
+ return rank
88
+
89
+
90
+ def get_config():
91
+ config_path = os.environ.get("OMINI_CONFIG")
92
+ assert config_path is not None, "Please set the OMINI_CONFIG environment variable"
93
+ with open(config_path, "r") as f:
94
+ config = yaml.safe_load(f)
95
+ return config
96
+
97
+
98
+ def init_wandb(wandb_config, run_name):
99
+ import wandb
100
+
101
+ try:
102
+ assert os.environ.get("WANDB_API_KEY") is not None
103
+ wandb.init(
104
+ project=wandb_config["project"],
105
+ name=run_name,
106
+ config={},
107
+ )
108
+ except Exception as e:
109
+ print("Failed to initialize WanDB:", e)
110
+
111
+
112
+
113
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
114
+ """Collects the state dict and dump to disk."""
115
+ state_dict = trainer.model.state_dict()
116
+ if trainer.args.should_save:
117
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
118
+ del state_dict
119
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
120
+
121
+
122
+ def smart_tokenizer_and_embedding_resize(
123
+ special_tokens_dict: Dict,
124
+ tokenizer: transformers.PreTrainedTokenizer,
125
+ model: transformers.PreTrainedModel,
126
+ ):
127
+ """Resize tokenizer and embedding.
128
+
129
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
130
+ """
131
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
132
+ model.resize_token_embeddings(len(tokenizer))
133
+
134
+ if num_new_tokens > 0:
135
+ input_embeddings = model.get_input_embeddings().weight.data
136
+ output_embeddings = model.get_output_embeddings().weight.data
137
+
138
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
139
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
140
+
141
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
142
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
143
+
144
+
145
+ def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
146
+ """Tokenize a list of strings."""
147
+ tokenized_list = [
148
+ tokenizer(
149
+ text,
150
+ return_tensors="pt",
151
+ padding="longest",
152
+ max_length=tokenizer.model_max_length,
153
+ truncation=True,
154
+ )
155
+ for text in strings
156
+ ]
157
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
158
+ input_ids_lens = labels_lens = [
159
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
160
+ ]
161
+ return dict(
162
+ input_ids=input_ids,
163
+ labels=labels,
164
+ input_ids_lens=input_ids_lens,
165
+ labels_lens=labels_lens,
166
+ )
167
+
168
+ def preprocess(
169
+ sources: Sequence[str],
170
+ targets: Sequence[str],
171
+ tokenizer: transformers.PreTrainedTokenizer,
172
+ ) -> Dict:
173
+ """Preprocess the data by tokenizing."""
174
+ examples = [s + t for s, t in zip(sources, targets)]
175
+ examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
176
+ input_ids = examples_tokenized["input_ids"]
177
+ labels = copy.deepcopy(input_ids)
178
+ for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
179
+ label[:source_len] = IGNORE_INDEX
180
+ return dict(input_ids=input_ids, labels=labels)
181
+
182
+
183
+ # @dataclass
184
+ # class DataCollatorForSupervisedDataset():
185
+ # """Collate examples for supervised fine-tuning."""
186
+
187
+ # tokenizer: transformers.PreTrainedTokenizer
188
+ # max_length: int = field(default=512)
189
+ # mode: str = field(default="fixed") # dynamic -> dynamo
190
+
191
+ # def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
192
+ # if self.mode == 'dynamic':
193
+ # input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
194
+ # input_ids = [torch.tensor(x) for x in input_ids]
195
+ # input_ids = torch.nn.utils.rnn.pad_sequence(
196
+ # input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
197
+ # )
198
+ # labels = [torch.tensor(x) for x in labels]
199
+ # labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
200
+ # return dict(
201
+ # input_ids=input_ids,
202
+ # labels=labels,
203
+ # attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
204
+ # )
205
+ # elif self.mode == 'fixed':
206
+ # input_ids = [torch.tensor(x["input_ids"][:self.max_length]) for x in instances]
207
+ # input_ids = torch.stack([
208
+ # torch.nn.functional.pad(x, (0, self.max_length - x.size(0)), value=self.tokenizer.pad_token_id)
209
+ # for x in input_ids
210
+ # ])
211
+
212
+ # # Labels
213
+ # labels = [torch.tensor(x["labels"][:self.max_length]) for x in instances]
214
+ # labels = torch.stack([
215
+ # torch.nn.functional.pad(x, (0, self.max_length - x.size(0)), value=IGNORE_INDEX)
216
+ # for x in labels
217
+ # ])
218
+
219
+ # return dict(
220
+ # input_ids=input_ids,
221
+ # labels=labels,
222
+ # attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
223
+ # )
224
+ # else:
225
+ # raise NotImplementedError
226
+
227
+ # @dataclass
228
+ # class DataCollatorForSupervisedDataset(object):
229
+ # tokenizer: transformers.PreTrainedTokenizer
230
+ # max_length: int = field(default=512)
231
+ # mode: str = field(default="fixed") # "dynamic" or "fixed"
232
+
233
+ # def _pad_to_length(self, tensors: Sequence[torch.Tensor], pad_value: int, target_len: int):
234
+ # """Pad a list of 1D tensors to target_len (int) and stack -> (B, target_len)."""
235
+ # batch_size = len(tensors)
236
+ # out = torch.full((batch_size, target_len), pad_value, dtype=tensors[0].dtype)
237
+ # for i, t in enumerate(tensors):
238
+ # L = min(t.size(0), target_len)
239
+ # out[i, :L] = t[:L]
240
+ # return out
241
+
242
+ # def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
243
+ # # Collect raw sequences (lists or tensors)
244
+ # input_seqs = [torch.tensor(x["input_ids"], dtype=torch.long) for x in instances]
245
+ # label_seqs = [torch.tensor(x["labels"], dtype=torch.long) for x in instances]
246
+
247
+ # if self.mode == "dynamic":
248
+ # # pad to the max length present in this batch (<= self.max_length)
249
+ # batch_max_len = min(max([s.size(0) for s in input_seqs]), self.max_length)
250
+ # input_ids = self._pad_to_length(input_seqs, pad_value=self.tokenizer.pad_token_id, target_len=batch_max_len)
251
+ # labels = self._pad_to_length(label_seqs, pad_value=IGNORE_INDEX, target_len=batch_max_len)
252
+ # elif self.mode == "fixed":
253
+ # # always pad/truncate to self.max_length
254
+ # input_ids = self._pad_to_length(input_seqs, pad_value=self.tokenizer.pad_token_id, target_len=self.max_length)
255
+ # labels = self._pad_to_length(label_seqs, pad_value=IGNORE_INDEX, target_len=self.max_length)
256
+ # else:
257
+ # raise NotImplementedError(f"Unknown mode: {self.mode}")
258
+
259
+ # attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
260
+
261
+ # return {
262
+ # "input_ids": input_ids,
263
+ # "labels": labels,
264
+ # "attention_mask": attention_mask
265
+ # }
266
+
267
+ @dataclass
268
+ class DataCollatorForSupervisedDataset():
269
+ tokenizer: transformers.PreTrainedTokenizer
270
+ max_length: int = field(default=512)
271
+ mode: str = field(default="fixed") # "dynamic" or "fixed"
272
+
273
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
274
+ # Extract inputs and labels
275
+ # Assuming instances is a list of dicts like {'input_ids': [...], 'labels': [...]}
276
+ input_ids_list = [torch.tensor(x["input_ids"], dtype=torch.long) for x in instances]
277
+ labels_list = [torch.tensor(x["labels"], dtype=torch.long) for x in instances]
278
+
279
+ # 1. Determine padding logic
280
+ if self.mode == "dynamic":
281
+ # Dynamic padding: pad to the longest sequence in the batch
282
+ # But cap it at self.max_length to prevent OOM
283
+ batch_max_len = max([len(x) for x in input_ids_list])
284
+ target_len = min(batch_max_len, self.max_length)
285
+ else:
286
+ # Fixed padding: always pad to max_length
287
+ target_len = self.max_length
288
+
289
+ # 2. Helper to pad and truncate
290
+ def pad_and_truncate(tensors, padding_value):
291
+ # First, pad everything using PyTorch's optimized utility (batch_first=True)
292
+ padded = pad_sequence(tensors, batch_first=True, padding_value=padding_value)
293
+
294
+ # Handle truncation/extending to exact target_len
295
+ curr_len = padded.shape[1]
296
+ if curr_len > target_len:
297
+ # Truncate if too long (rare if filtered beforehand)
298
+ return padded[:, :target_len]
299
+ elif curr_len < target_len:
300
+ # Pad more if shorter than target_len (happens in fixed mode)
301
+ diff = target_len - curr_len
302
+ padding = torch.full((padded.shape[0], diff), padding_value, dtype=padded.dtype)
303
+ return torch.cat([padded, padding], dim=1)
304
+ else:
305
+ return padded
306
+
307
+ # 3. Apply padding
308
+ # Critical: tokenizer.pad_token_id must NOT be None here
309
+ if self.tokenizer.pad_token_id is None:
310
+ raise ValueError("Tokenizer.pad_token_id is None. Please set it to eos_token_id or unk_token_id.")
311
+
312
+ input_ids = pad_and_truncate(input_ids_list, self.tokenizer.pad_token_id)
313
+ labels = pad_and_truncate(labels_list, IGNORE_INDEX)
314
+
315
+ # 4. Create Attention Mask explicitly
316
+ # .ne() creates Bools, .long() casts to 0s and 1s for compatibility
317
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
318
+
319
+ return {
320
+ "input_ids": input_ids,
321
+ "labels": labels,
322
+ "attention_mask": attention_mask
323
+ }
324
+
325
+ def train_tokenize_function(examples, tokenizer, query, response):
326
+ sources = [PROMPT.format_map(dict(instruction=instruction)) for instruction in examples[query]]
327
+ targets = [f"{output}{tokenizer.eos_token}" for output in examples[response]]
328
+ data_dict = preprocess(sources, targets, tokenizer)
329
+ return data_dict
330
+
331
+
332
+
333
+ ### Trainer
334
+ def default_worker_init_fn(worker_id):
335
+ # mỗi worker chỉ 1 thread cho BLAS
336
+ try:
337
+ import numpy as _np
338
+ except Exception:
339
+ _np = None
340
+ torch.set_num_threads(1)
341
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
342
+ os.environ.setdefault("MKL_NUM_THREADS", "1")
343
+ os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
344
+ # Optional: bind CPU affinity per worker to avoid contention (NUMA-aware)
345
+ try:
346
+ cpu_count = os.cpu_count() or 1
347
+ # chia đều CPU cho workers
348
+ num_workers = getattr(torch.utils.data, "_num_workers", None)
349
+ # fallback: if not available, compute from environment variable or pass externally
350
+ # We'll do a simple round-robin assignment using worker_id
351
+ # assign a small mask of cores to this worker (e.g., chunk size 4)
352
+ chunk = max(1, cpu_count // max(1, min(64, cpu_count)))
353
+ start = (worker_id * chunk) % cpu_count
354
+ end = start + chunk
355
+ mask = set(range(start, min(end, cpu_count)))
356
+ try:
357
+ os.sched_setaffinity(0, mask)
358
+ except Exception:
359
+ pass
360
+ except Exception:
361
+ pass
362
+
363
+ def set_seed(seed: int):
364
+ # random.seed(seed)
365
+ # np.random.seed(seed)
366
+ torch.manual_seed(seed)
367
+ torch.cuda.manual_seed_all(seed)
368
+ transformers.set_seed(seed)
369
+
370
+
371
+ @pyrallis.wrap()
372
+ def main(mainCfg: MainConfig):
373
+ #mainCfg = get_config()
374
+ #print(mainCfg)
375
+ print('='*120)
376
+ # print(OmegaConf.to_yaml(mainCfg))
377
+ # print('-'*40)
378
+ #
379
+ # print((training_args))
380
+ set_seed(mainCfg.seed)
381
+ training_args = convert_to_trainer_args(mainCfg)
382
+
383
+ # wandb
384
+ ENTITY = "nvan-13-korea-university"
385
+ PROJECT = os.environ.get("WANDB_PROJECT")
386
+ api = wandb.Api()
387
+ try:
388
+ runs_list = api.runs(f"{ENTITY}/{PROJECT}")
389
+ next_run_num = len(runs_list) + 1
390
+ except Exception as e:
391
+ next_run_num = 1
392
+
393
+ training_args.run_name = f'[{next_run_num}]lr={mainCfg.trainer_args.learning_rate:.1e},b={mainCfg.trainer_args.per_device_train_batch_size},'\
394
+ f'n={mainCfg.rotation_adapter_config.num_rotations},r={mainCfg.rotation_adapter_config.r},'\
395
+ f'init={mainCfg.run_text}'
396
+ # training_args.project = f'Rotation-Llama2-{mainCfg.data.dataset_name}'
397
+
398
+ # print('-'*40)
399
+ # print(training_args.to_json_string())
400
+ # exit()
401
+
402
+ model = AutoModelForCausalLM.from_pretrained(mainCfg.model.model_name,
403
+ device_map="auto", low_cpu_mem_usage=True,
404
+ dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
405
+ attn_implementation="sdpa",
406
+ )
407
+
408
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
409
+ print("DEVICE", DEVICE)
410
+ # for name, param in model.named_parameters():
411
+ # if 'q_proj' in name and 'layers.5' in name:
412
+ # print(f"Name: {name} | {param.shape} ")
413
+ # print(f"Name (pretrained): {name} | {param.shape} | {param.data[0:5,0:5]}")
414
+ # print('model', model)
415
+ # exit()
416
+
417
+ total_params_now = sum(p.numel() for p in model.parameters())
418
+ print(f'#params of the pretrained model, {total_params_now:,}')
419
+ # print(model)
420
+ if mainCfg.model.adapter_path is not None:
421
+ print('___ Loading from: ', mainCfg.model.adapter_path)
422
+ model = PeftModel.from_pretrained(model, mainCfg.model.adapter_path, is_trainable = True)
423
+ elif mainCfg.rotation_adapter_config.r is not None:
424
+ rotation_adapter_config = asdict(mainCfg.rotation_adapter_config)
425
+ # rotation_adapter_config[peft_type]
426
+
427
+ for adapter_name in mainCfg.data.adapter_names:
428
+ rotation_config = RotationConfig(**rotation_adapter_config)
429
+ model = get_peft_model(model, rotation_config, adapter_name=adapter_name)
430
+ # model.set_adapter(adapter_name)
431
+
432
+ else:
433
+ print("Full Parameter Fine-Tuning")
434
+ model = model.to(DEVICE)
435
+
436
+ # print('model', model)
437
+ model.print_trainable_parameters()
438
+ # print("Program starts")
439
+ # time.sleep(300)
440
+ # exit()
441
+
442
+ # for name, param in model.named_parameters():
443
+ # if 'q_proj' in name and 'rotation' in name and 'layers.5' in name:
444
+ # print(f"Name: {name} | {param.shape} ")
445
+ # print(f"Name (pretrained): {name} | {param.shape} ")
446
+ # X = param.data
447
+ # print('model', type(model), X.shape)
448
+ # visualize_value_distribution(X)
449
+ # exit()
450
+
451
+ rotation_layers = filter(
452
+ lambda p: p.requires_grad, model.parameters()
453
+ )
454
+
455
+ tokenizer = AutoTokenizer.from_pretrained(
456
+ mainCfg.model.model_name,
457
+ model_max_length=mainCfg.model.model_max_seq_length,
458
+ padding_side="right",
459
+ use_fast=True,
460
+ )
461
+
462
+ if tokenizer.pad_token is None:
463
+ if tokenizer.unk_token_id is not None:
464
+ tokenizer.pad_token_id = tokenizer.unk_token_id
465
+ tokenizer.pad_token = tokenizer.unk_token
466
+ print("Set PAD token to UNK token.")
467
+ elif tokenizer.eos_token_id is not None:
468
+ tokenizer.pad_token_id = tokenizer.eos_token_id
469
+ tokenizer.pad_token = tokenizer.eos_token
470
+ print("Set PAD token to EOS token.")
471
+
472
+ if model is not None:
473
+ model.config.pad_token_id = tokenizer.pad_token_id
474
+ if model.config.pad_token_id != tokenizer.pad_token_id:
475
+ raise ValueError("Failed to sync pad_token_id between tokenizer and model config")
476
+
477
+ # local MetaMathQA-40K
478
+ raw_datasets = load_dataset("json", data_files=mainCfg.data.path, split=mainCfg.data.dataset_split)
479
+ #raw_train_datasets = load_dataset("MetaMathQA-40K", split=mainCfg.data.dataset_split)
480
+ # print('raw', type(raw_train_datasets), len(raw_train_datasets))
481
+
482
+ # split a single set
483
+ split_ratio = mainCfg.data.split_ratio
484
+ split_data = raw_datasets.train_test_split(test_size=split_ratio, seed=42)
485
+ raw_train_datasets = split_data['train']
486
+ raw_valid_datasets = split_data['test']
487
+
488
+ train_dataset = raw_train_datasets.map(
489
+ train_tokenize_function,
490
+ batched=True,
491
+ batch_size=30000,
492
+ num_proc=32,
493
+ remove_columns=raw_train_datasets.column_names,
494
+ load_from_cache_file=True,
495
+ desc="Running tokenizer on train dataset",
496
+ fn_kwargs={"tokenizer": tokenizer, "query": mainCfg.data.dataset_field[0],
497
+ "response": mainCfg.data.dataset_field[1]}
498
+ )
499
+
500
+ valid_dataset = raw_valid_datasets.map(
501
+ train_tokenize_function,
502
+ batched=True,
503
+ batch_size=30000,
504
+ num_proc=32,
505
+ remove_columns=raw_train_datasets.column_names,
506
+ load_from_cache_file=True,
507
+ desc="Running tokenizer on train dataset",
508
+ fn_kwargs={"tokenizer": tokenizer, "query": mainCfg.data.dataset_field[0],
509
+ "response": mainCfg.data.dataset_field[1]}
510
+ )
511
+ print('- dataset size: ', len(valid_dataset), len(train_dataset))
512
+
513
+
514
+ # print('dataset', type(train_dataset))
515
+ # print('process', len(train_dataset))
516
+ # print(f"Sample features: {train_dataset.column_names}, {train_dataset.num_rows}")
517
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=mainCfg.model.model_max_seq_length,
518
+ #mode=mainCfg.model.data_collator_mode,
519
+ )
520
+ data_module = dict(train_dataset=train_dataset, data_collator=data_collator, eval_dataset=valid_dataset)
521
+
522
+ optimizer = optim.AdamW(
523
+ rotation_layers,
524
+ lr=mainCfg.trainer_args.learning_rate, #
525
+ eps=1e-8
526
+ )
527
+ # print('model x', model)
528
+ start_time = datetime.now()
529
+ print('start time: ', start_time.strftime("%Y-%m-%d %H:%M:%S"))
530
+ trainer = MyTrainer(model=model, processing_class=tokenizer,
531
+ lamda=mainCfg.model.lambda_reg,
532
+ optimizers=(optimizer, None),
533
+ args=training_args, **data_module)
534
+ model.config.use_cache = False
535
+
536
+
537
+ # now = time.time()
538
+ # for i in range(20):
539
+ # next(iter(trainer.get_train_dataloader()))
540
+ # print('time', time.time()-now)
541
+ # now = time.time()
542
+
543
+
544
+ # dl = trainer.get_train_dataloader()
545
+ # t0 = time.time()
546
+ # for i, batch in enumerate(dl):
547
+ # if i==20: break
548
+ # print("time / 20 batches =", time.time() - t0)
549
+ # exit()
550
+
551
+
552
+ # model2 = model.merge_and_unload()
553
+ # results2 = trainer2.evaluate()
554
+ # print('results2: ', results2)
555
+ # exit()
556
+
557
+ start_time = datetime.now()
558
+ trainer.train()
559
+
560
+ end_time = datetime.now()
561
+ print('end time: ', end_time.strftime("%Y-%m-%d %H:%M:%S"), '| duration: ', end_time - start_time)
562
+ # Save Model (Includes Adapter weights & Config)
563
+ # trainer.save_model(os.path.join(training_args.output_dir, 'ft'))
564
+ # Save Tokenizer
565
+ tokenizer.save_pretrained(os.path.join(training_args.output_dir, 'ft'))
566
+ # Save Training State (Metrics & Logs)
567
+ trainer.save_state()
568
+
569
+ # save peft_config. Or model.base_model.peft_config['default']
570
+ model.peft_config.save_pretrained(os.path.join(training_args.output_dir, 'ft'))
571
+
572
+ # the easiest way
573
+ model.save_pretrained(os.path.join(training_args.output_dir, 'ft2'))
574
+ return
575
+
576
+
577
+
578
+ class MyTrainer(Trainer):
579
+
580
+ def __init__(
581
+ self,
582
+ model: Union[PreTrainedModel, nn.Module] = None,
583
+ args: TrainingArguments = None,
584
+ data_collator: Optional[DataCollator] = None,
585
+ train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
586
+ eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None,
587
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
588
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
589
+ compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
590
+ callbacks: Optional[List[TrainerCallback]] = None,
591
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
592
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
593
+ #run_name: Optional[str] = None,
594
+ #report_to: Optional[Union[str, list[str]]] = None,
595
+ # project
596
+ lamda: float = 1e-4
597
+ ):
598
+ super().__init__(model=model, args=args, data_collator=data_collator,
599
+ train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class,
600
+ model_init=model_init, compute_metrics=compute_metrics, callbacks=callbacks,
601
+ optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics,
602
+ #run_name=run_name, report_to=report_to
603
+ )
604
+ self.lamda = lamda
605
+
606
+ # def compute_loss(self, model, inputs, return_outputs=False,
607
+ # num_items_in_batch: Optional[torch.Tensor] = None,):
608
+ # """
609
+ # How the loss is computed by Trainer. By default, all models return the loss in the first element.
610
+
611
+ # Subclass and override for custom behavior.
612
+ # """
613
+ # if self.label_smoother is not None and "labels" in inputs:
614
+ # labels = inputs.pop("labels")
615
+ # else:
616
+ # labels = None
617
+ # if self.model_accepts_loss_kwargs:
618
+ # kwargs = {}
619
+ # if num_items_in_batch is not None:
620
+ # kwargs["num_items_in_batch"] = num_items_in_batch
621
+ # inputs = {**inputs, **kwargs}
622
+ # outputs = model(**inputs)
623
+ # # Save past state if it exists
624
+ # # TODO: this needs to be fixed and made cleaner later.
625
+ # if self.args.past_index >= 0:
626
+ # self._past = outputs[self.args.past_index]
627
+
628
+ # if labels is not None:
629
+ # unwrapped_model = unwrap_model(model)
630
+ # if _is_peft_model(unwrapped_model):
631
+ # model_name = unwrapped_model.base_model.model._get_name()
632
+ # else:
633
+ # model_name = unwrapped_model._get_name()
634
+ # if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
635
+ # loss = self.label_smoother(outputs, labels, shift_labels=True)
636
+ # else:
637
+ # loss = self.label_smoother(outputs, labels)
638
+ # else:
639
+ # if isinstance(outputs, dict) and "loss" not in outputs:
640
+ # raise ValueError(
641
+ # "The model did not return a loss from the inputs, only the following keys: "
642
+ # f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
643
+ # )
644
+ # # We don't use .loss here since the model may return tuples instead of ModelOutput.
645
+ # loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
646
+ # # ------------------------------------------------------------------------------
647
+
648
+ # # for name, param in model.named_parameters():
649
+ # # if 'oft_r' in name:
650
+ # # device = param.device
651
+ # # householder_U_norm = param / param.norm(dim=0)
652
+ # # orth_loss = torch.norm(
653
+ # # torch.eye(householder_U_norm.size(1), device=device) - householder_U_norm.t() @ householder_U_norm)
654
+ # # print(self.lamda)
655
+ # # loss = loss + self.lamda * orth_loss.to(loss.device)
656
+
657
+ # # ------------------------------------------------------------------------------
658
+
659
+ # return (loss, outputs) if return_outputs else loss
660
+
661
+ def get_train_dataloader(self):
662
+ # get dataset & sampler from super
663
+ train_dataset = self.train_dataset
664
+ sampler = self._get_train_sampler()
665
+
666
+ # compute effective batch size per step (HF has some routines; we use per_device_train_batch_size)
667
+ batch_size = self.args.train_batch_size if hasattr(self.args, "train_batch_size") else self.args.per_device_train_batch_size
668
+
669
+ # recommended num_workers: start moderate (16), you can tune upward
670
+ num_workers = getattr(self.args, "dataloader_num_workers", 16)
671
+ pin_memory = getattr(self.args, "dataloader_pin_memory", True)
672
+ prefetch_factor = getattr(self.args, "dataloader_prefetch_factor", 2)
673
+ persistent_workers = getattr(self.args, "dataloader_persistent_workers", True)
674
+
675
+ return DataLoader(
676
+ train_dataset,
677
+ batch_size=batch_size,
678
+ sampler=sampler,
679
+ collate_fn=self.data_collator,
680
+ drop_last=self.args.dataloader_drop_last if hasattr(self.args, "dataloader_drop_last") else False,
681
+ num_workers=num_workers,
682
+ pin_memory=pin_memory,
683
+ persistent_workers=persistent_workers,
684
+ prefetch_factor=prefetch_factor,
685
+ worker_init_fn=default_worker_init_fn,
686
+ )
687
+
688
+ if __name__ == "__main__":
689
+ main()
nl_tasks/src/merge.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # import wandb
3
+ import os
4
+ import yaml
5
+ from peft import LoraConfig, get_peft_model_state_dict
6
+ from torch.utils.data import DataLoader
7
+ import time
8
+
9
+ from typing import List, Tuple
10
+
11
+ # import prodigyopt
12
+
13
+
14
+ ###
15
+ import copy
16
+ from dataclasses import field, dataclass, asdict
17
+ from typing import Sequence, Literal, Dict
18
+
19
+ import transformers
20
+ from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
21
+ from transformers import Trainer
22
+ from transformers.modeling_utils import *
23
+ from transformers.trainer import _is_peft_model
24
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
25
+ from transformers.data.data_collator import DataCollator
26
+
27
+ from transformers.training_args import TrainingArguments
28
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
29
+ from transformers.trainer_callback import TrainerCallback
30
+ from transformers.trainer_utils import EvalPrediction
31
+ from torch.utils.data import Dataset, IterableDataset
32
+ from datasets import load_dataset
33
+ ##
34
+ #from ..pipeline.flux_omini import transformer_forward, encode_images
35
+ # from ...omini.rotation import RotationTuner, RotationConfig
36
+ from rpeft.rotation import RotationTuner, RotationConfig
37
+ from rpeft import get_peft_model, PeftModel
38
+ from .config import MainConfig, convert_to_trainer_args
39
+ import pyrallis
40
+ from omegaconf import OmegaConf
41
+
42
+ import argparse
43
+ IGNORE_INDEX = -100
44
+ DEFAULT_PAD_TOKEN = "[PAD]"
45
+ DEFAULT_EOS_TOKEN = "</s>"
46
+ DEFAULT_BOS_TOKEN = "</s>"
47
+ DEFAULT_UNK_TOKEN = "</s>"
48
+ PROMPT = (
49
+ "Below is an instruction that describes a task. "
50
+ "Write a response that appropriately completes the request.\n\n"
51
+ "### Instruction:\n{instruction}\n\n### Response:"
52
+ )
53
+
54
+ # parser = argparse.ArgumentParser(description='Merge Adapter to Base Model')
55
+ # parser.add_argument('--base_mode', type=str)
56
+ # parser.add_argument('--adapter_path', type=str)
57
+ # parser.add_argument('--output_path', type=str)
58
+ # args = parser.parse_args()
59
+
60
+ @pyrallis.wrap()
61
+ def main(mainCfg: MainConfig):
62
+ print('='*120)
63
+ model_name = mainCfg.model.model_name
64
+ # adapter = mainCfg.trainer_args.output_dir + '/ft2'
65
+ # output_path = mainCfg.trainer_args.output_dir + '/merge/'
66
+ adapter = mainCfg.model.merge_adapter_path
67
+ output_path = mainCfg.model.merge_output_path
68
+
69
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto",)
70
+ tokenizer = AutoTokenizer.from_pretrained(model_name, device_map='auto')
71
+
72
+ # config = PeftConfig.from_pretrained(args.adapter)
73
+ model = PeftModel.from_pretrained(model, adapter)
74
+ model = model.merge_and_unload()
75
+ model.save_pretrained(output_path, safe_serialization=False)
76
+ tokenizer.save_pretrained(output_path)
77
+ # print(model)
78
+ print('merge.py ends', adapter, output_path)
79
+ return
80
+
81
+ if __name__ == "__main__":
82
+ main()
nl_tasks/src/peft_merge.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # import wandb
3
+ import os
4
+ import yaml
5
+ from peft import LoraConfig, get_peft_model_state_dict
6
+ from torch.utils.data import DataLoader
7
+ import time
8
+
9
+ from typing import List, Tuple
10
+
11
+ # import prodigyopt
12
+
13
+
14
+ ###
15
+ import copy
16
+ from dataclasses import field, dataclass, asdict
17
+ from typing import Sequence, Literal, Dict
18
+
19
+ import transformers
20
+ from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
21
+ from transformers import Trainer
22
+ from transformers.modeling_utils import *
23
+ from transformers.trainer import _is_peft_model
24
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
25
+ from transformers.data.data_collator import DataCollator
26
+
27
+ from transformers.training_args import TrainingArguments
28
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
29
+ from transformers.trainer_callback import TrainerCallback
30
+ from transformers.trainer_utils import EvalPrediction
31
+ from torch.utils.data import Dataset, IterableDataset
32
+ from datasets import load_dataset
33
+ ##
34
+ #from ..pipeline.flux_omini import transformer_forward, encode_images
35
+ # from ...omini.rotation import RotationTuner, RotationConfig
36
+ from rpeft.rotation import RotationTuner, RotationConfig
37
+ from peft import get_peft_model, PeftModel
38
+ from .config import MainConfig, convert_to_trainer_args
39
+ import pyrallis
40
+ from omegaconf import OmegaConf
41
+
42
+ import argparse
43
+ IGNORE_INDEX = -100
44
+ DEFAULT_PAD_TOKEN = "[PAD]"
45
+ DEFAULT_EOS_TOKEN = "</s>"
46
+ DEFAULT_BOS_TOKEN = "</s>"
47
+ DEFAULT_UNK_TOKEN = "</s>"
48
+ PROMPT = (
49
+ "Below is an instruction that describes a task. "
50
+ "Write a response that appropriately completes the request.\n\n"
51
+ "### Instruction:\n{instruction}\n\n### Response:"
52
+ )
53
+
54
+ # parser = argparse.ArgumentParser(description='Merge Adapter to Base Model')
55
+ # parser.add_argument('--base_mode', type=str)
56
+ # parser.add_argument('--adapter_path', type=str)
57
+ # parser.add_argument('--output_path', type=str)
58
+ # args = parser.parse_args()
59
+
60
+ @pyrallis.wrap()
61
+ def main(mainCfg: MainConfig):
62
+ print('='*120)
63
+ model_name = mainCfg.model.model_name
64
+ # adapter = mainCfg.trainer_args.output_dir + '/ft2'
65
+ # output_path = mainCfg.trainer_args.output_dir + '/merge/'
66
+ adapter = mainCfg.model.merge_adapter_path
67
+ output_path = mainCfg.model.merge_output_path
68
+
69
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto",)
70
+ tokenizer = AutoTokenizer.from_pretrained(model_name, device_map='auto')
71
+
72
+ # config = PeftConfig.from_pretrained(args.adapter)
73
+ model = PeftModel.from_pretrained(model, adapter)
74
+ model = model.merge_and_unload()
75
+ model.save_pretrained(output_path, safe_serialization=False)
76
+ tokenizer.save_pretrained(output_path)
77
+ # print(model)
78
+ print('peft_merge.py ends', adapter, output_path)
79
+ return
80
+
81
+ if __name__ == "__main__":
82
+ main()
nl_tasks/src/testLlama.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ import sys
3
+ #print('sys.path: ___ ', sys.path)
4
+ #print(f"Current Python Executable: {sys.executable}")
5
+
6
+ ### dynamo warning
7
+ import warnings
8
+
9
+ # Ignore FutureWarning: prims_common.check, Online Softmax
10
+ warnings.filterwarnings("ignore", category=FutureWarning, module='torch._inductor.lowering')
11
+ warnings.filterwarnings("ignore", message=".*Online softmax is disabled on the fly.*", category=UserWarning)
12
+
13
+ warnings.filterwarnings("ignore", message=".*Our suggested max number of worker in current system is 1.*", category=UserWarning)
14
+ warnings.filterwarnings("ignore", message=".*will be initialized from a multivariate normal distribution.*")
15
+ warnings.filterwarnings("ignore", message=".*that differ from the model config and generation config.*", category=UserWarning)
16
+ warnings.filterwarnings("ignore", message=".*torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch..*", category=UserWarning)
17
+
18
+ import torch
19
+ torch.backends.cuda.matmul.fp32_precision = 'tf32'
20
+ # import wandb
21
+ import os
22
+ torch.set_num_threads(1)
23
+ os.environ["OMP_NUM_THREADS"]="1"
24
+ os.environ["MKL_NUM_THREADS"]="1"
25
+ import torch
26
+ print(f"PyTorch version: {torch.__version__}")
27
+ print(f"CUDA available: {torch.cuda.is_available()}")
28
+ print(f"PyTorch built with CUDA version: {torch.version.cuda}")
29
+
30
+ import yaml
31
+ #from peft import LoraConfig, get_peft_model_state_dict
32
+ from torch.utils.data import DataLoader
33
+ import time
34
+ from datetime import datetime
35
+ import math
36
+
37
+ from typing import List, Tuple
38
+
39
+ # import prodigyopt
40
+
41
+
42
+ ###
43
+ import copy
44
+ from dataclasses import field, dataclass, asdict
45
+ from typing import Sequence, Literal, Dict
46
+
47
+ import transformers
48
+ from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
49
+ from transformers import Trainer
50
+ from transformers.modeling_utils import *
51
+ from transformers.trainer import _is_peft_model
52
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
53
+ from transformers.data.data_collator import DataCollator
54
+
55
+ from transformers.training_args import TrainingArguments
56
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
57
+ from transformers.trainer_callback import TrainerCallback
58
+ from transformers.trainer_utils import EvalPrediction
59
+ from torch.utils.data import Dataset, IterableDataset
60
+ from datasets import load_dataset
61
+ ##
62
+ #from ..pipeline.flux_omini import transformer_forward, encode_images
63
+ # from ...omini.rotation import RotationTuner, RotationConfig
64
+
65
+
66
+ from rpeft.rotation import RotationTuner, RotationConfig
67
+ from rpeft import get_peft_model, PeftModel
68
+ from .config import MainConfig, convert_to_trainer_args
69
+ import pyrallis
70
+ from omegaconf import OmegaConf
71
+ import torch.optim as optim
72
+ import wandb
73
+ from torch.nn.utils.rnn import pad_sequence
74
+
75
+ IGNORE_INDEX = -100
76
+ PROMPT = (
77
+ "Below is an instruction that describes a task. "
78
+ "Write a response that appropriately completes the request.\n\n"
79
+ "### Instruction:\n{instruction}\n\n### Response:"
80
+ )
81
+
82
+
83
+ import platform
84
+ from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
85
+ class ExperimentMonitorCallback(TrainerCallback):
86
+ """
87
+ Callback to monitor training performance and log system stats to a JSON file.
88
+ It captures:
89
+ 1. Experiment Metadata (GPU info, Batch size, Learning rate, etc.)
90
+ 2. Runtime Metrics (Avg time/step, Throughput)
91
+ 3. Memory Metrics (Allocated, Reserved, and Peak usage)
92
+ """
93
+
94
+ def __init__(self, log_file_path: str, run_name: str = "experiment", log_interval: int = 100):
95
+ # English comments as requested
96
+ self.log_file_path = log_file_path
97
+ self.run_name = run_name
98
+ self.log_interval = log_interval
99
+
100
+ # Timing variables
101
+ self.start_time = None
102
+ self.last_log_time = None
103
+
104
+ # Data container to be saved
105
+ self.log_data = {
106
+ "metadata": {},
107
+ "metrics": []
108
+ }
109
+
110
+ def _get_gpu_info(self):
111
+ # Helper to get GPU details if available
112
+ if torch.cuda.is_available():
113
+ return {
114
+ "name": torch.cuda.get_device_name(0),
115
+ "count": torch.cuda.device_count(),
116
+ "capability": torch.cuda.get_device_capability(0)
117
+ }
118
+ return "CPU_ONLY"
119
+
120
+ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
121
+ # Initialize timing
122
+ self.start_time = time.perf_counter()
123
+ self.last_log_time = self.start_time
124
+
125
+ # Reset peak memory stats to ensure we capture peaks specific to this run
126
+ if torch.cuda.is_available():
127
+ torch.cuda.reset_peak_memory_stats()
128
+
129
+ # Capture experiment metadata
130
+ self.log_data["metadata"] = {
131
+ "run_name": self.run_name,
132
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
133
+ "python_version": platform.python_version(),
134
+ "pytorch_version": torch.__version__,
135
+ "gpu_info": self._get_gpu_info(),
136
+ "configuration": {
137
+ "batch_size_per_device": args.per_device_train_batch_size,
138
+ "learning_rate": args.learning_rate,
139
+ "max_steps": args.max_steps,
140
+ "num_train_epochs": args.num_train_epochs,
141
+ "fp16": args.fp16,
142
+ "bf16": args.bf16,
143
+ "optim": args.optim,
144
+ }
145
+ }
146
+
147
+ # Create/Overwrite the file with initial metadata
148
+ self._save_log()
149
+ # print(f"[{self.run_name}] Experiment started. Logging to {self.log_file_path}")
150
+
151
+ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
152
+ current_step = state.global_step
153
+
154
+ # Perform logging only at specified intervals
155
+ if current_step > 0 and current_step % self.log_interval == 0:
156
+ current_time = time.perf_counter()
157
+
158
+ # Calculate time elapsed since the last log
159
+ elapsed_since_last = current_time - self.last_log_time
160
+ avg_time_per_step = elapsed_since_last / self.log_interval
161
+
162
+ # Memory Statistics (in GB)
163
+ mem_stats = {}
164
+ if torch.cuda.is_available():
165
+ # Current usage
166
+ mem_stats["allocated_gb"] = torch.cuda.memory_allocated() / 1024**3
167
+ mem_stats["reserved_gb"] = torch.cuda.memory_reserved() / 1024**3
168
+ # Peak usage since start (Long-term peak)
169
+ mem_stats["peak_allocated_gb"] = torch.cuda.max_memory_allocated() / 1024**3
170
+
171
+ # Construct metric entry
172
+ metric_entry = {
173
+ "step": current_step,
174
+ "epoch": state.epoch,
175
+ "timestamp": datetime.now().isoformat(),
176
+ "performance": {
177
+ "avg_time_per_step_s": round(avg_time_per_step, 4),
178
+ "steps_per_second": round(1.0 / avg_time_per_step, 2)
179
+ },
180
+ "memory": mem_stats
181
+ }
182
+
183
+ # Append to internal list and save to file
184
+ self.log_data["metrics"].append(metric_entry)
185
+ self._save_log()
186
+
187
+ # Update last log time
188
+ self.last_log_time = current_time
189
+
190
+ # Optional: Print a brief summary to console
191
+ print(f" -> Step {current_step}: {avg_time_per_step*1000:.1f}s/step |"\
192
+ f"Peak Mem: {mem_stats.get('peak_allocated_gb', 0):.2f} GB |"\
193
+ f"Reserved: {mem_stats.get('reserved_gb', 0):.2f} GB")
194
+
195
+ def _save_log(self):
196
+ # Dump the entire data structure to JSON
197
+ # For very long training runs, appending to a JSONL (lines) file might be more efficient,
198
+ # but standard JSON is easier to read for analysis.
199
+ try:
200
+ with open(self.log_file_path, 'w', encoding='utf-8') as f:
201
+ json.dump(self.log_data, f, indent=4)
202
+ except Exception as e:
203
+ print(f"Error saving experiment log: {e}")
204
+
205
+ def get_rank():
206
+ try:
207
+ rank = int(os.environ.get("LOCAL_RANK"))
208
+ except:
209
+ rank = 0
210
+ return rank
211
+
212
+
213
+ def get_config():
214
+ config_path = os.environ.get("OMINI_CONFIG")
215
+ assert config_path is not None, "Please set the OMINI_CONFIG environment variable"
216
+ with open(config_path, "r") as f:
217
+ config = yaml.safe_load(f)
218
+ return config
219
+
220
+
221
+ def init_wandb(wandb_config, run_name):
222
+ import wandb
223
+
224
+ try:
225
+ assert os.environ.get("WANDB_API_KEY") is not None
226
+ wandb.init(
227
+ project=wandb_config["project"],
228
+ name=run_name,
229
+ config={},
230
+ )
231
+ except Exception as e:
232
+ print("Failed to initialize WanDB:", e)
233
+
234
+
235
+
236
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
237
+ """Collects the state dict and dump to disk."""
238
+ state_dict = trainer.model.state_dict()
239
+ if trainer.args.should_save:
240
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
241
+ del state_dict
242
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
243
+
244
+
245
+ def smart_tokenizer_and_embedding_resize(
246
+ special_tokens_dict: Dict,
247
+ tokenizer: transformers.PreTrainedTokenizer,
248
+ model: transformers.PreTrainedModel,
249
+ ):
250
+ """Resize tokenizer and embedding.
251
+
252
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
253
+ """
254
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
255
+ model.resize_token_embeddings(len(tokenizer))
256
+
257
+ if num_new_tokens > 0:
258
+ input_embeddings = model.get_input_embeddings().weight.data
259
+ output_embeddings = model.get_output_embeddings().weight.data
260
+
261
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
262
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
263
+
264
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
265
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
266
+
267
+
268
+ def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
269
+ """Tokenize a list of strings."""
270
+ tokenized_list = [
271
+ tokenizer(
272
+ text,
273
+ return_tensors="pt",
274
+ padding="longest",
275
+ max_length=tokenizer.model_max_length,
276
+ truncation=True,
277
+ )
278
+ for text in strings
279
+ ]
280
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
281
+ input_ids_lens = labels_lens = [
282
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
283
+ ]
284
+ return dict(
285
+ input_ids=input_ids,
286
+ labels=labels,
287
+ input_ids_lens=input_ids_lens,
288
+ labels_lens=labels_lens,
289
+ )
290
+
291
+ def preprocess(
292
+ sources: Sequence[str],
293
+ targets: Sequence[str],
294
+ tokenizer: transformers.PreTrainedTokenizer,
295
+ ) -> Dict:
296
+ """Preprocess the data by tokenizing."""
297
+ examples = [s + t for s, t in zip(sources, targets)]
298
+ examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
299
+ input_ids = examples_tokenized["input_ids"]
300
+ labels = copy.deepcopy(input_ids)
301
+ for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
302
+ label[:source_len] = IGNORE_INDEX
303
+ return dict(input_ids=input_ids, labels=labels)
304
+
305
+
306
+
307
+ @dataclass
308
+ class DataCollatorForSupervisedDataset():
309
+ tokenizer: transformers.PreTrainedTokenizer
310
+ max_length: int = field(default=512)
311
+ mode: str = field(default="fixed") # "dynamic" or "fixed"
312
+
313
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
314
+ # Extract inputs and labels
315
+ # Assuming instances is a list of dicts like {'input_ids': [...], 'labels': [...]}
316
+ input_ids_list = [torch.tensor(x["input_ids"], dtype=torch.long) for x in instances]
317
+ labels_list = [torch.tensor(x["labels"], dtype=torch.long) for x in instances]
318
+
319
+ # 1. Determine padding logic
320
+ if self.mode == "dynamic":
321
+ # Dynamic padding: pad to the longest sequence in the batch
322
+ # But cap it at self.max_length to prevent OOM
323
+ batch_max_len = max([len(x) for x in input_ids_list])
324
+ target_len = min(batch_max_len, self.max_length)
325
+ else:
326
+ # Fixed padding: always pad to max_length
327
+ target_len = self.max_length
328
+
329
+ # 2. Helper to pad and truncate
330
+ def pad_and_truncate(tensors, padding_value):
331
+ # First, pad everything using PyTorch's optimized utility (batch_first=True)
332
+ padded = pad_sequence(tensors, batch_first=True, padding_value=padding_value)
333
+
334
+ # Handle truncation/extending to exact target_len
335
+ curr_len = padded.shape[1]
336
+ if curr_len > target_len:
337
+ # Truncate if too long (rare if filtered beforehand)
338
+ return padded[:, :target_len]
339
+ elif curr_len < target_len:
340
+ # Pad more if shorter than target_len (happens in fixed mode)
341
+ diff = target_len - curr_len
342
+ padding = torch.full((padded.shape[0], diff), padding_value, dtype=padded.dtype)
343
+ return torch.cat([padded, padding], dim=1)
344
+ else:
345
+ return padded
346
+
347
+ # 3. Apply padding
348
+ # Critical: tokenizer.pad_token_id must NOT be None here
349
+ if self.tokenizer.pad_token_id is None:
350
+ raise ValueError("Tokenizer.pad_token_id is None. Please set it to eos_token_id or unk_token_id.")
351
+
352
+ input_ids = pad_and_truncate(input_ids_list, self.tokenizer.pad_token_id)
353
+ labels = pad_and_truncate(labels_list, IGNORE_INDEX)
354
+
355
+ # 4. Create Attention Mask explicitly
356
+ # .ne() creates Bools, .long() casts to 0s and 1s for compatibility
357
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
358
+
359
+ return {
360
+ "input_ids": input_ids,
361
+ "labels": labels,
362
+ "attention_mask": attention_mask
363
+ }
364
+
365
+ def train_tokenize_function(examples, tokenizer, query, response):
366
+ sources = [PROMPT.format_map(dict(instruction=instruction)) for instruction in examples[query]]
367
+ targets = [f"{output}{tokenizer.eos_token}" for output in examples[response]]
368
+ data_dict = preprocess(sources, targets, tokenizer)
369
+ return data_dict
370
+
371
+
372
+
373
+ ### Trainer
374
+ def default_worker_init_fn(worker_id):
375
+ # mỗi worker chỉ 1 thread cho BLAS
376
+ try:
377
+ import numpy as _np
378
+ except Exception:
379
+ _np = None
380
+ torch.set_num_threads(1)
381
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
382
+ os.environ.setdefault("MKL_NUM_THREADS", "1")
383
+ os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
384
+ # Optional: bind CPU affinity per worker to avoid contention (NUMA-aware)
385
+ try:
386
+ cpu_count = os.cpu_count() or 1
387
+ # chia đều CPU cho workers
388
+ num_workers = getattr(torch.utils.data, "_num_workers", None)
389
+ # fallback: if not available, compute from environment variable or pass externally
390
+ # We'll do a simple round-robin assignment using worker_id
391
+ # assign a small mask of cores to this worker (e.g., chunk size 4)
392
+ chunk = max(1, cpu_count // max(1, min(64, cpu_count)))
393
+ start = (worker_id * chunk) % cpu_count
394
+ end = start + chunk
395
+ mask = set(range(start, min(end, cpu_count)))
396
+ try:
397
+ os.sched_setaffinity(0, mask)
398
+ except Exception:
399
+ pass
400
+ except Exception:
401
+ pass
402
+
403
+ def set_seed(seed: int):
404
+ # random.seed(seed)
405
+ # np.random.seed(seed)
406
+ torch.manual_seed(seed)
407
+ torch.cuda.manual_seed_all(seed)
408
+ transformers.set_seed(seed)
409
+
410
+
411
+ @pyrallis.wrap()
412
+ def main(mainCfg: MainConfig):
413
+ #mainCfg = get_config()
414
+ #print(mainCfg)
415
+ print('='*120)
416
+ # print(OmegaConf.to_yaml(mainCfg))
417
+ # print('-'*40)
418
+ #
419
+ # print((training_args))
420
+ set_seed(mainCfg.seed)
421
+ training_args = convert_to_trainer_args(mainCfg)
422
+
423
+ # wandb
424
+ ENTITY = "nvan-13-korea-university"
425
+ PROJECT = os.environ.get("WANDB_PROJECT")
426
+ api = wandb.Api()
427
+ try:
428
+ runs_list = api.runs(f"{ENTITY}/{PROJECT}")
429
+ next_run_num = len(runs_list) + 1
430
+ except Exception as e:
431
+ next_run_num = 1
432
+
433
+ training_args.run_name = f'[{next_run_num}]lr={mainCfg.trainer_args.learning_rate:.1e},b={mainCfg.trainer_args.per_device_train_batch_size},'\
434
+ f'n={mainCfg.rotation_adapter_config.num_rotations},r={mainCfg.rotation_adapter_config.r},'\
435
+ f'init={mainCfg.run_text}'
436
+ # training_args.project = f'Rotation-Llama2-{mainCfg.data.dataset_name}'
437
+
438
+ # print('-'*40)
439
+ # print(training_args.to_json_string())
440
+ # exit()
441
+
442
+ model = AutoModelForCausalLM.from_pretrained(mainCfg.model.model_name,
443
+ device_map="auto", low_cpu_mem_usage=True,
444
+ dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
445
+ # attn_implementation="sdpa",
446
+ )
447
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
448
+ print("DEVICE", model.device)
449
+
450
+ # for name, param in model.named_parameters():
451
+ # if 'q_proj' in name and 'layers.5' in name:
452
+ # print(f"Name: {name} | {param.shape} ")
453
+ # print(f"Name (pretrained): {name} | {param.shape} | {param.data[0:5,0:5]}")
454
+ # print('model', model)
455
+ # exit()
456
+
457
+ total_params_now = sum(p.numel() for p in model.parameters())
458
+ print(f'#params of the pretrained model, {total_params_now:,}')
459
+ # print(model)
460
+ if mainCfg.model.adapter_path is not None:
461
+ print('___ Loading from: ', mainCfg.model.adapter_path)
462
+ model = PeftModel.from_pretrained(model, mainCfg.model.adapter_path, is_trainable = True)
463
+ elif mainCfg.rotation_adapter_config.r is not None:
464
+ import peft
465
+ if mainCfg.run_text == 'loco':
466
+ rotation_adapter_config = asdict(mainCfg.rotation_adapter_config)
467
+
468
+ for adapter_name in mainCfg.data.adapter_names:
469
+ rotation_config = RotationConfig(**rotation_adapter_config)
470
+ model = get_peft_model(model, rotation_config, adapter_name=adapter_name)
471
+ print('loaded a LoCo model, batch = ', training_args.per_device_train_batch_size)
472
+ elif mainCfg.run_text == 'boft':
473
+ from peft import BOFTConfig
474
+ boft_config = BOFTConfig(
475
+ boft_block_size=mainCfg.rotation_adapter_config.r,
476
+ boft_n_butterfly_factor=2*mainCfg.rotation_adapter_config.num_rotations,
477
+ target_modules=["q_proj", "v_proj",],
478
+ boft_dropout=0.05, #mainCfg.rotation_adapter_config.drop_out,
479
+ bias="none",
480
+ # task_type="CAUSAL_LM",
481
+ )
482
+
483
+ for adapter_name in mainCfg.data.adapter_names:
484
+ model = peft.get_peft_model(model, boft_config, adapter_name=adapter_name)
485
+ print('loaded a BOFT model, batch = ', training_args.per_device_train_batch_size)
486
+ elif mainCfg.run_text == 'hra':
487
+ from peft import HRAConfig
488
+ hra_config = HRAConfig(
489
+ r=2*mainCfg.rotation_adapter_config.r,
490
+ target_modules=["q_proj", "v_proj",],
491
+ init_weights=True,
492
+ # task_type="CAUSAL_LM",
493
+ )
494
+
495
+ for adapter_name in mainCfg.data.adapter_names:
496
+ model = peft.get_peft_model(model, hra_config, adapter_name=adapter_name)
497
+ print('loaded a HRA model, batch = ', training_args.per_device_train_batch_size)
498
+ elif mainCfg.run_text == 'oft':
499
+ from peft import HRAConfig, OFTConfig
500
+
501
+ oft_config = OFTConfig(
502
+ # r=16,
503
+ oft_block_size=4*mainCfg.rotation_adapter_config.r,
504
+ use_cayley_neumann=True,
505
+ target_modules=["q_proj", "v_proj",],
506
+ module_dropout=0.05, # mainCfg.rotation_adapter_config.drop_out,
507
+ # task_type="CAUSAL_LM",
508
+ bias="none",
509
+ )
510
+
511
+ for adapter_name in mainCfg.data.adapter_names:
512
+ model = peft.get_peft_model(model, oft_config, adapter_name=adapter_name)
513
+ print('loaded a OFT model, batch = ', training_args.per_device_train_batch_size)
514
+ else:
515
+ raise KeyError('wrong model names')
516
+
517
+ else:
518
+ print("Full Parameter Fine-Tuning")
519
+ model = model.to(DEVICE)
520
+
521
+ # print('model', model)
522
+ model.print_trainable_parameters()
523
+
524
+ # print("Program starts")
525
+ # time.sleep(300)
526
+ # exit()
527
+
528
+ # for name, param in model.named_parameters():
529
+ # if 'q_proj' in name and 'rotation' in name and 'layers.5' in name:
530
+ # print(f"Name: {name} | {param.shape} ")
531
+ # print(f"Name (pretrained): {name} | {param.shape} ")
532
+ # X = param.data
533
+ # print('model', type(model), X.shape)
534
+ # visualize_value_distribution(X)
535
+ # exit()
536
+
537
+ rotation_layers = filter(
538
+ lambda p: p.requires_grad, model.parameters()
539
+ )
540
+
541
+ tokenizer = AutoTokenizer.from_pretrained(
542
+ mainCfg.model.model_name,
543
+ model_max_length=mainCfg.model.model_max_seq_length,
544
+ padding_side="right",
545
+ use_fast=True,
546
+ )
547
+
548
+ if tokenizer.pad_token is None:
549
+ if tokenizer.unk_token_id is not None:
550
+ tokenizer.pad_token_id = tokenizer.unk_token_id
551
+ tokenizer.pad_token = tokenizer.unk_token
552
+ print("Set PAD token to UNK token.")
553
+ elif tokenizer.eos_token_id is not None:
554
+ tokenizer.pad_token_id = tokenizer.eos_token_id
555
+ tokenizer.pad_token = tokenizer.eos_token
556
+ print("Set PAD token to EOS token.")
557
+
558
+ if model is not None:
559
+ model.config.pad_token_id = tokenizer.pad_token_id
560
+ if model.config.pad_token_id != tokenizer.pad_token_id:
561
+ raise ValueError("Failed to sync pad_token_id between tokenizer and model config")
562
+
563
+ # local MetaMathQA-40K
564
+ raw_datasets = load_dataset("json", data_files=mainCfg.data.path, split=mainCfg.data.dataset_split)
565
+
566
+ train_dataset = raw_datasets.map(
567
+ train_tokenize_function,
568
+ batched=True,
569
+ batch_size=30000,
570
+ num_proc=32,
571
+ remove_columns=raw_datasets.column_names,
572
+ load_from_cache_file=True,
573
+ desc="Running tokenizer on train dataset",
574
+ fn_kwargs={"tokenizer": tokenizer, "query": mainCfg.data.dataset_field[0],
575
+ "response": mainCfg.data.dataset_field[1]}
576
+ )
577
+
578
+ # valid_dataset = raw_valid_datasets.map(
579
+ # train_tokenize_function,
580
+ # batched=True,
581
+ # batch_size=30000,
582
+ # num_proc=32,
583
+ # remove_columns=raw_train_datasets.column_names,
584
+ # load_from_cache_file=True,
585
+ # desc="Running tokenizer on train dataset",
586
+ # fn_kwargs={"tokenizer": tokenizer, "query": mainCfg.data.dataset_field[0],
587
+ # "response": mainCfg.data.dataset_field[1]}
588
+ # )
589
+ print('- dataset size: ', len(train_dataset))
590
+
591
+
592
+ # print('dataset', type(train_dataset))
593
+ # print('process', len(train_dataset))
594
+ # print(f"Sample features: {train_dataset.column_names}, {train_dataset.num_rows}")
595
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=mainCfg.model.model_max_seq_length,
596
+ #mode=mainCfg.model.data_collator_mode,
597
+ )
598
+ data_module = dict(train_dataset=train_dataset, data_collator=data_collator)
599
+
600
+ optimizer = optim.AdamW(
601
+ rotation_layers,
602
+ lr=mainCfg.trainer_args.learning_rate, #
603
+ eps=1e-8
604
+ )
605
+ # print('model x', model)
606
+ start_time = datetime.now()
607
+ print('start time: ', start_time.strftime("%Y-%m-%d %H:%M:%S"))
608
+
609
+ monitor = ExperimentMonitorCallback(
610
+ log_file_path="./training_metrics_bs8.json",
611
+ run_name="Experiment_BatchSize_8",
612
+ log_interval=10 # Will calculate average over every 100 steps
613
+ )
614
+ training_args.remove_unused_columns = False
615
+ training_args.torch_compile=False
616
+ trainer = MyTrainer(model=model, processing_class=tokenizer,
617
+ lamda=mainCfg.model.lambda_reg,
618
+ optimizers=(optimizer, None),
619
+ args=training_args, **data_module,
620
+ callbacks=[monitor],
621
+ )
622
+ model.config.use_cache = False
623
+
624
+ trainer.train()
625
+
626
+ end_time = datetime.now()
627
+ print('end time: ', end_time.strftime("%Y-%m-%d %H:%M:%S"), '| duration: ', end_time - start_time)
628
+
629
+ # Save Model (Includes Adapter weights & Config)
630
+ # trainer.save_model(os.path.join(training_args.output_dir, 'ft'))
631
+ # Save Tokenizer
632
+ tokenizer.save_pretrained(os.path.join(training_args.output_dir, 'ft'))
633
+ # Save Training State (Metrics & Logs)
634
+ trainer.save_state()
635
+
636
+ # save peft_config. Or model.base_model.peft_config['default']
637
+ # model.peft_config.save_pretrained(os.path.join(training_args.output_dir, 'ft'))
638
+
639
+ # the easiest way
640
+ model.save_pretrained(os.path.join(training_args.output_dir, 'ft2'))
641
+ return
642
+
643
+
644
+
645
+ class MyTrainer(Trainer):
646
+
647
+ def __init__(
648
+ self,
649
+ model: Union[PreTrainedModel, nn.Module] = None,
650
+ args: TrainingArguments = None,
651
+ data_collator: Optional[DataCollator] = None,
652
+ train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
653
+ eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None,
654
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
655
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
656
+ compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
657
+ callbacks: Optional[List[TrainerCallback]] = None,
658
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
659
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
660
+ #run_name: Optional[str] = None,
661
+ #report_to: Optional[Union[str, list[str]]] = None,
662
+ # project
663
+ lamda: float = 1e-4
664
+ ):
665
+ super().__init__(model=model, args=args, data_collator=data_collator,
666
+ train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class,
667
+ model_init=model_init, compute_metrics=compute_metrics, callbacks=callbacks,
668
+ optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics,
669
+ #run_name=run_name, report_to=report_to
670
+ )
671
+ self.lamda = lamda
672
+
673
+
674
+ def get_train_dataloader(self):
675
+ # get dataset & sampler from super
676
+ train_dataset = self.train_dataset
677
+ sampler = self._get_train_sampler()
678
+
679
+ # compute effective batch size per step (HF has some routines; we use per_device_train_batch_size)
680
+ batch_size = self.args.train_batch_size if hasattr(self.args, "train_batch_size") else self.args.per_device_train_batch_size
681
+
682
+ # recommended num_workers: start moderate (16), you can tune upward
683
+ num_workers = getattr(self.args, "dataloader_num_workers", 16)
684
+ pin_memory = getattr(self.args, "dataloader_pin_memory", True)
685
+ prefetch_factor = getattr(self.args, "dataloader_prefetch_factor", 2)
686
+ persistent_workers = getattr(self.args, "dataloader_persistent_workers", True)
687
+
688
+ return DataLoader(
689
+ train_dataset,
690
+ batch_size=batch_size,
691
+ sampler=sampler,
692
+ collate_fn=self.data_collator,
693
+ drop_last=self.args.dataloader_drop_last if hasattr(self.args, "dataloader_drop_last") else False,
694
+ num_workers=num_workers,
695
+ pin_memory=pin_memory,
696
+ persistent_workers=persistent_workers,
697
+ prefetch_factor=prefetch_factor,
698
+ worker_init_fn=default_worker_init_fn,
699
+ )
700
+
701
+ if __name__ == "__main__":
702
+ main()