Batch upload part 2
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- examples/subject.ipynb +216 -0
- examples/subject_1024.ipynb +216 -0
- nl_tasks/README.md +45 -0
- nl_tasks/config/commonsense.yaml +44 -0
- nl_tasks/config/commonsense_opt.yaml +32 -0
- nl_tasks/config/glue.yaml +48 -0
- nl_tasks/config/math395.yaml +46 -0
- nl_tasks/data/MATH_test.jsonl +0 -0
- nl_tasks/data/MetaMathQA-40K/MetaMathQA-40K.json +3 -0
- nl_tasks/data/MetaMathQA/MetaMathQA-395K.json +3 -0
- nl_tasks/data/gsm8k_test.jsonl +0 -0
- nl_tasks/environment.yaml +55 -0
- nl_tasks/exps/run_ex01/trainer_state.json +914 -0
- nl_tasks/repro.sh +87 -0
- nl_tasks/rpeft/__init__.py +43 -0
- nl_tasks/rpeft/mapping.py +273 -0
- nl_tasks/rpeft/peft_model.py +922 -0
- nl_tasks/rpeft/rotation/__init__.py +3 -0
- nl_tasks/rpeft/rotation/layer.py +412 -0
- nl_tasks/rpeft/rotation/layer_test.py +296 -0
- nl_tasks/rpeft/rotation/model.py +392 -0
- nl_tasks/rpeft/rotation/rotation_config.py +89 -0
- nl_tasks/rpeft/utils/__init__.py +29 -0
- nl_tasks/rpeft/utils/adapters_utils.py +19 -0
- nl_tasks/rpeft/utils/config.py +220 -0
- nl_tasks/rpeft/utils/other.py +160 -0
- nl_tasks/rpeft/utils/save_and_load.py +166 -0
- nl_tasks/scripts/.nfs80e7f26e00566c630000664a +117 -0
- nl_tasks/scripts/.nfs80e7f26e0132942e00006649 +341 -0
- nl_tasks/scripts/copy train_cms_reasoning.sh +133 -0
- nl_tasks/scripts/down_math_train.sh +14 -0
- nl_tasks/scripts/inference.sh +14 -0
- nl_tasks/scripts/merge.sh +137 -0
- nl_tasks/scripts/merge_100k.sh +100 -0
- nl_tasks/scripts/merge_math.sh +31 -0
- nl_tasks/scripts/peft_merge.sh +60 -0
- nl_tasks/scripts/train_100math.sh +184 -0
- nl_tasks/scripts/train_cms_reasoning.sh +260 -0
- nl_tasks/scripts/train_initn40k.sh +341 -0
- nl_tasks/scripts/train_math.sh +162 -0
- nl_tasks/setup.py +28 -0
- nl_tasks/src/bb.ipynb +0 -0
- nl_tasks/src/cc.ipynb +0 -0
- nl_tasks/src/config.py +183 -0
- nl_tasks/src/ft_mathQ.py +702 -0
- nl_tasks/src/ft_mathR.py +689 -0
- nl_tasks/src/merge.py +82 -0
- nl_tasks/src/peft_merge.py +82 -0
- nl_tasks/src/testLlama.py +702 -0
.gitattributes
CHANGED
|
@@ -51,3 +51,5 @@ assets/ominicontrol_art/DistractedBoyfriend.webp filter=lfs diff=lfs merge=lfs -
|
|
| 51 |
assets/ominicontrol_art/PulpFiction.jpg filter=lfs diff=lfs merge=lfs -text
|
| 52 |
assets/ominicontrol_art/breakingbad.jpg filter=lfs diff=lfs merge=lfs -text
|
| 53 |
assets/ominicontrol_art/oiiai.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 51 |
assets/ominicontrol_art/PulpFiction.jpg filter=lfs diff=lfs merge=lfs -text
|
| 52 |
assets/ominicontrol_art/breakingbad.jpg filter=lfs diff=lfs merge=lfs -text
|
| 53 |
assets/ominicontrol_art/oiiai.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
nl_tasks/data/MetaMathQA-40K/MetaMathQA-40K.json filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
nl_tasks/data/MetaMathQA/MetaMathQA-395K.json filter=lfs diff=lfs merge=lfs -text
|
examples/subject.ipynb
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import os\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"os.chdir(\"..\")"
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"cell_type": "code",
|
| 16 |
+
"execution_count": null,
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"outputs": [],
|
| 19 |
+
"source": [
|
| 20 |
+
"import torch\n",
|
| 21 |
+
"from diffusers.pipelines import FluxPipeline\n",
|
| 22 |
+
"from PIL import Image\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"from omini.pipeline.flux_omini import Condition, generate, seed_everything"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "code",
|
| 29 |
+
"execution_count": null,
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"outputs": [],
|
| 32 |
+
"source": [
|
| 33 |
+
"pipe = FluxPipeline.from_pretrained(\n",
|
| 34 |
+
" \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n",
|
| 35 |
+
")\n",
|
| 36 |
+
"pipe = pipe.to(\"cuda\")\n",
|
| 37 |
+
"pipe.load_lora_weights(\n",
|
| 38 |
+
" \"Yuanshi/OminiControl\",\n",
|
| 39 |
+
" weight_name=f\"omini/subject_512.safetensors\",\n",
|
| 40 |
+
" adapter_name=\"subject\",\n",
|
| 41 |
+
")"
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "code",
|
| 46 |
+
"execution_count": null,
|
| 47 |
+
"metadata": {},
|
| 48 |
+
"outputs": [],
|
| 49 |
+
"source": [
|
| 50 |
+
"image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"# For this model, the position_delta is (0, 32).\n",
|
| 53 |
+
"# For more details of position_delta, please refer to:\n",
|
| 54 |
+
"# https://github.com/Yuanshi9815/OminiControl/issues/89#issuecomment-2827080344\n",
|
| 55 |
+
"condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"seed_everything(0)\n",
|
| 61 |
+
"\n",
|
| 62 |
+
"result_img = generate(\n",
|
| 63 |
+
" pipe,\n",
|
| 64 |
+
" prompt=prompt,\n",
|
| 65 |
+
" conditions=[condition],\n",
|
| 66 |
+
" num_inference_steps=8,\n",
|
| 67 |
+
" height=512,\n",
|
| 68 |
+
" width=512,\n",
|
| 69 |
+
").images[0]\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"concat_image = Image.new(\"RGB\", (1024, 512))\n",
|
| 72 |
+
"concat_image.paste(image, (0, 0))\n",
|
| 73 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
| 74 |
+
"concat_image"
|
| 75 |
+
]
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"cell_type": "code",
|
| 79 |
+
"execution_count": null,
|
| 80 |
+
"metadata": {},
|
| 81 |
+
"outputs": [],
|
| 82 |
+
"source": [
|
| 83 |
+
"image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"seed_everything()\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"result_img = generate(\n",
|
| 93 |
+
" pipe,\n",
|
| 94 |
+
" prompt=prompt,\n",
|
| 95 |
+
" conditions=[condition],\n",
|
| 96 |
+
" num_inference_steps=8,\n",
|
| 97 |
+
" height=512,\n",
|
| 98 |
+
" width=512,\n",
|
| 99 |
+
").images[0]\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"concat_image = Image.new(\"RGB\", (1024, 512))\n",
|
| 102 |
+
"concat_image.paste(condition.condition, (0, 0))\n",
|
| 103 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
| 104 |
+
"concat_image"
|
| 105 |
+
]
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"cell_type": "code",
|
| 109 |
+
"execution_count": null,
|
| 110 |
+
"metadata": {},
|
| 111 |
+
"outputs": [],
|
| 112 |
+
"source": [
|
| 113 |
+
"image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
| 114 |
+
"\n",
|
| 115 |
+
"condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"seed_everything()\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"result_img = generate(\n",
|
| 122 |
+
" pipe,\n",
|
| 123 |
+
" prompt=prompt,\n",
|
| 124 |
+
" conditions=[condition],\n",
|
| 125 |
+
" num_inference_steps=8,\n",
|
| 126 |
+
" height=512,\n",
|
| 127 |
+
" width=512,\n",
|
| 128 |
+
").images[0]\n",
|
| 129 |
+
"\n",
|
| 130 |
+
"concat_image = Image.new(\"RGB\", (1024, 512))\n",
|
| 131 |
+
"concat_image.paste(condition.condition, (0, 0))\n",
|
| 132 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
| 133 |
+
"concat_image"
|
| 134 |
+
]
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"cell_type": "code",
|
| 138 |
+
"execution_count": null,
|
| 139 |
+
"metadata": {},
|
| 140 |
+
"outputs": [],
|
| 141 |
+
"source": [
|
| 142 |
+
"image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n",
|
| 147 |
+
"\n",
|
| 148 |
+
"seed_everything()\n",
|
| 149 |
+
"\n",
|
| 150 |
+
"result_img = generate(\n",
|
| 151 |
+
" pipe,\n",
|
| 152 |
+
" prompt=prompt,\n",
|
| 153 |
+
" conditions=[condition],\n",
|
| 154 |
+
" num_inference_steps=8,\n",
|
| 155 |
+
" height=512,\n",
|
| 156 |
+
" width=512,\n",
|
| 157 |
+
").images[0]\n",
|
| 158 |
+
"\n",
|
| 159 |
+
"concat_image = Image.new(\"RGB\", (1024, 512))\n",
|
| 160 |
+
"concat_image.paste(condition.condition, (0, 0))\n",
|
| 161 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
| 162 |
+
"concat_image"
|
| 163 |
+
]
|
| 164 |
+
},
|
| 165 |
+
{
|
| 166 |
+
"cell_type": "code",
|
| 167 |
+
"execution_count": null,
|
| 168 |
+
"metadata": {},
|
| 169 |
+
"outputs": [],
|
| 170 |
+
"source": [
|
| 171 |
+
"image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
|
| 174 |
+
"\n",
|
| 175 |
+
"prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n",
|
| 176 |
+
"\n",
|
| 177 |
+
"seed_everything()\n",
|
| 178 |
+
"\n",
|
| 179 |
+
"result_img = generate(\n",
|
| 180 |
+
" pipe,\n",
|
| 181 |
+
" prompt=prompt,\n",
|
| 182 |
+
" conditions=[condition],\n",
|
| 183 |
+
" num_inference_steps=8,\n",
|
| 184 |
+
" height=512,\n",
|
| 185 |
+
" width=512,\n",
|
| 186 |
+
").images[0]\n",
|
| 187 |
+
"\n",
|
| 188 |
+
"concat_image = Image.new(\"RGB\", (1024, 512))\n",
|
| 189 |
+
"concat_image.paste(condition.condition, (0, 0))\n",
|
| 190 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
| 191 |
+
"concat_image"
|
| 192 |
+
]
|
| 193 |
+
}
|
| 194 |
+
],
|
| 195 |
+
"metadata": {
|
| 196 |
+
"kernelspec": {
|
| 197 |
+
"display_name": "base",
|
| 198 |
+
"language": "python",
|
| 199 |
+
"name": "python3"
|
| 200 |
+
},
|
| 201 |
+
"language_info": {
|
| 202 |
+
"codemirror_mode": {
|
| 203 |
+
"name": "ipython",
|
| 204 |
+
"version": 3
|
| 205 |
+
},
|
| 206 |
+
"file_extension": ".py",
|
| 207 |
+
"mimetype": "text/x-python",
|
| 208 |
+
"name": "python",
|
| 209 |
+
"nbconvert_exporter": "python",
|
| 210 |
+
"pygments_lexer": "ipython3",
|
| 211 |
+
"version": "3.12.3"
|
| 212 |
+
}
|
| 213 |
+
},
|
| 214 |
+
"nbformat": 4,
|
| 215 |
+
"nbformat_minor": 2
|
| 216 |
+
}
|
examples/subject_1024.ipynb
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import os\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"os.chdir(\"..\")"
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"cell_type": "code",
|
| 16 |
+
"execution_count": null,
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"outputs": [],
|
| 19 |
+
"source": [
|
| 20 |
+
"import torch\n",
|
| 21 |
+
"from diffusers.pipelines import FluxPipeline\n",
|
| 22 |
+
"from PIL import Image\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"from omini.pipeline.flux_omini import Condition, generate, seed_everything"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "code",
|
| 29 |
+
"execution_count": null,
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"outputs": [],
|
| 32 |
+
"source": [
|
| 33 |
+
"pipe = FluxPipeline.from_pretrained(\n",
|
| 34 |
+
" \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n",
|
| 35 |
+
")\n",
|
| 36 |
+
"pipe = pipe.to(\"cuda\")\n",
|
| 37 |
+
"pipe.load_lora_weights(\n",
|
| 38 |
+
" \"Yuanshi/OminiControl\",\n",
|
| 39 |
+
" weight_name=f\"omini/subject_1024_beta.safetensors\",\n",
|
| 40 |
+
" adapter_name=\"subject\",\n",
|
| 41 |
+
")"
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "code",
|
| 46 |
+
"execution_count": null,
|
| 47 |
+
"metadata": {},
|
| 48 |
+
"outputs": [],
|
| 49 |
+
"source": [
|
| 50 |
+
"image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"# For this model, the position_delta is (0, -32).\n",
|
| 53 |
+
"# For more details of position_delta, please refer to:\n",
|
| 54 |
+
"# https://github.com/Yuanshi9815/OminiControl/issues/89#issuecomment-2827080344\n",
|
| 55 |
+
"condition = Condition(image, \"subject\", position_delta=(0, -32))\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"seed_everything(0)\n",
|
| 61 |
+
"\n",
|
| 62 |
+
"result_img = generate(\n",
|
| 63 |
+
" pipe,\n",
|
| 64 |
+
" prompt=prompt,\n",
|
| 65 |
+
" conditions=[condition],\n",
|
| 66 |
+
" num_inference_steps=8,\n",
|
| 67 |
+
" height=1024,\n",
|
| 68 |
+
" width=1024,\n",
|
| 69 |
+
").images[0]\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"concat_image = Image.new(\"RGB\", (1024 + 512, 1024))\n",
|
| 72 |
+
"concat_image.paste(image, (0, 0))\n",
|
| 73 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
| 74 |
+
"concat_image"
|
| 75 |
+
]
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"cell_type": "code",
|
| 79 |
+
"execution_count": null,
|
| 80 |
+
"metadata": {},
|
| 81 |
+
"outputs": [],
|
| 82 |
+
"source": [
|
| 83 |
+
"image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"condition = Condition(image, \"subject\", position_delta=(0, -32))\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"seed_everything(0)\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"result_img = generate(\n",
|
| 93 |
+
" pipe,\n",
|
| 94 |
+
" prompt=prompt,\n",
|
| 95 |
+
" conditions=[condition],\n",
|
| 96 |
+
" num_inference_steps=8,\n",
|
| 97 |
+
" height=1024,\n",
|
| 98 |
+
" width=1024,\n",
|
| 99 |
+
").images[0]\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"concat_image = Image.new(\"RGB\", (1024 + 512, 1024))\n",
|
| 102 |
+
"concat_image.paste(image, (0, 0))\n",
|
| 103 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
| 104 |
+
"concat_image"
|
| 105 |
+
]
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"cell_type": "code",
|
| 109 |
+
"execution_count": null,
|
| 110 |
+
"metadata": {},
|
| 111 |
+
"outputs": [],
|
| 112 |
+
"source": [
|
| 113 |
+
"image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
| 114 |
+
"\n",
|
| 115 |
+
"condition = Condition(image, \"subject\", position_delta=(0, -32))\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"seed_everything()\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"result_img = generate(\n",
|
| 122 |
+
" pipe,\n",
|
| 123 |
+
" prompt=prompt,\n",
|
| 124 |
+
" conditions=[condition],\n",
|
| 125 |
+
" num_inference_steps=8,\n",
|
| 126 |
+
" height=1024,\n",
|
| 127 |
+
" width=1024,\n",
|
| 128 |
+
").images[0]\n",
|
| 129 |
+
"\n",
|
| 130 |
+
"concat_image = Image.new(\"RGB\", (1024 + 512, 1024))\n",
|
| 131 |
+
"concat_image.paste(image, (0, 0))\n",
|
| 132 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
| 133 |
+
"concat_image"
|
| 134 |
+
]
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"cell_type": "code",
|
| 138 |
+
"execution_count": null,
|
| 139 |
+
"metadata": {},
|
| 140 |
+
"outputs": [],
|
| 141 |
+
"source": [
|
| 142 |
+
"image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"condition = Condition(image, \"subject\", position_delta=(0, -32))\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n",
|
| 147 |
+
"\n",
|
| 148 |
+
"seed_everything(0)\n",
|
| 149 |
+
"\n",
|
| 150 |
+
"result_img = generate(\n",
|
| 151 |
+
" pipe,\n",
|
| 152 |
+
" prompt=prompt,\n",
|
| 153 |
+
" conditions=[condition],\n",
|
| 154 |
+
" num_inference_steps=8,\n",
|
| 155 |
+
" height=1024,\n",
|
| 156 |
+
" width=1024,\n",
|
| 157 |
+
").images[0]\n",
|
| 158 |
+
"\n",
|
| 159 |
+
"concat_image = Image.new(\"RGB\", (1024 + 512, 1024))\n",
|
| 160 |
+
"concat_image.paste(image, (0, 0))\n",
|
| 161 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
| 162 |
+
"concat_image"
|
| 163 |
+
]
|
| 164 |
+
},
|
| 165 |
+
{
|
| 166 |
+
"cell_type": "code",
|
| 167 |
+
"execution_count": null,
|
| 168 |
+
"metadata": {},
|
| 169 |
+
"outputs": [],
|
| 170 |
+
"source": [
|
| 171 |
+
"image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"condition = Condition(image, \"subject\", position_delta=(0, -32))\n",
|
| 174 |
+
"\n",
|
| 175 |
+
"prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n",
|
| 176 |
+
"\n",
|
| 177 |
+
"seed_everything()\n",
|
| 178 |
+
"\n",
|
| 179 |
+
"result_img = generate(\n",
|
| 180 |
+
" pipe,\n",
|
| 181 |
+
" prompt=prompt,\n",
|
| 182 |
+
" conditions=[condition],\n",
|
| 183 |
+
" num_inference_steps=8,\n",
|
| 184 |
+
" height=1024,\n",
|
| 185 |
+
" width=1024,\n",
|
| 186 |
+
").images[0]\n",
|
| 187 |
+
"\n",
|
| 188 |
+
"concat_image = Image.new(\"RGB\", (1024 + 512, 1024))\n",
|
| 189 |
+
"concat_image.paste(image, (0, 0))\n",
|
| 190 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
| 191 |
+
"concat_image"
|
| 192 |
+
]
|
| 193 |
+
}
|
| 194 |
+
],
|
| 195 |
+
"metadata": {
|
| 196 |
+
"kernelspec": {
|
| 197 |
+
"display_name": "base",
|
| 198 |
+
"language": "python",
|
| 199 |
+
"name": "python3"
|
| 200 |
+
},
|
| 201 |
+
"language_info": {
|
| 202 |
+
"codemirror_mode": {
|
| 203 |
+
"name": "ipython",
|
| 204 |
+
"version": 3
|
| 205 |
+
},
|
| 206 |
+
"file_extension": ".py",
|
| 207 |
+
"mimetype": "text/x-python",
|
| 208 |
+
"name": "python",
|
| 209 |
+
"nbconvert_exporter": "python",
|
| 210 |
+
"pygments_lexer": "ipython3",
|
| 211 |
+
"version": "3.12.3"
|
| 212 |
+
}
|
| 213 |
+
},
|
| 214 |
+
"nbformat": 4,
|
| 215 |
+
"nbformat_minor": 2
|
| 216 |
+
}
|
nl_tasks/README.md
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Dynamo
|
| 2 |
+
export TORCH_COMPILE_DISABLE=1
|
| 3 |
+
unset TORCH_COMPILE_DISABLE
|
| 4 |
+
echo $TORCH_COMPILE_DISABLE
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
Untracked files:
|
| 8 |
+
(use "git add <file>..." to include in what will be committed)
|
| 9 |
+
nl_tasks/README.md
|
| 10 |
+
nl_tasks/config/commonsense_opt.yaml
|
| 11 |
+
nl_tasks/config/glue.yaml
|
| 12 |
+
nl_tasks/config/math395.yaml
|
| 13 |
+
nl_tasks/data/MetaMathQA/
|
| 14 |
+
nl_tasks/data/gsm8k_infer.py
|
| 15 |
+
nl_tasks/exp100/
|
| 16 |
+
nl_tasks/exp395/
|
| 17 |
+
nl_tasks/exp_init/
|
| 18 |
+
nl_tasks/expsBOFT/
|
| 19 |
+
nl_tasks/expsOFT/
|
| 20 |
+
nl_tasks/repro.sh
|
| 21 |
+
nl_tasks/run_all/
|
| 22 |
+
nl_tasks/run_exps/
|
| 23 |
+
nl_tasks/scripts/copy train_cms_reasoning.sh
|
| 24 |
+
nl_tasks/scripts/inference.sh
|
| 25 |
+
nl_tasks/scripts/merge_100k.sh
|
| 26 |
+
nl_tasks/scripts/merge_math.sh
|
| 27 |
+
nl_tasks/scripts/peft_merge.sh
|
| 28 |
+
nl_tasks/scripts/train_100math.sh
|
| 29 |
+
nl_tasks/scripts/train_initn40k.sh
|
| 30 |
+
nl_tasks/scripts/train_math.sh
|
| 31 |
+
nl_tasks/src/ft_mathQ.py
|
| 32 |
+
nl_tasks/src/peft_merge.py
|
| 33 |
+
nl_tasks/src/testLlama.py
|
| 34 |
+
nl_tasks/testLlama.sh
|
| 35 |
+
nl_tasks/training_metrics_bs8.json
|
| 36 |
+
nlu/1Mgrid/
|
| 37 |
+
nlu/_scripts_/
|
| 38 |
+
nlu/glue22_exp/
|
| 39 |
+
nlu/glue_exp00/
|
| 40 |
+
nlu/glue_test/
|
| 41 |
+
nlu/seeds/
|
| 42 |
+
nlu/src/test.py
|
| 43 |
+
nlu/test.sh
|
| 44 |
+
nlu/training_metrics_bs8.json
|
| 45 |
+
|
nl_tasks/config/commonsense.yaml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
model:
|
| 3 |
+
model_name: meta-llama/Llama-2-7b-hf #facebook/opt-125m #meta-llama/Llama-2-7b-hf #"openai-community/gpt2" #EleutherAI/pythia-160m #Qwen/Qwen2.5-0.5B
|
| 4 |
+
# model_name: facebook/opt-125m
|
| 5 |
+
# adapter_path: "./run_all/exnr15/ft2"
|
| 6 |
+
# adapter_path: './exp_init/run_ex01/ft2'
|
| 7 |
+
data_collator_mode: 'dynamic'
|
| 8 |
+
|
| 9 |
+
rotation_adapter_config:
|
| 10 |
+
r: 4
|
| 11 |
+
num_rotations: 4
|
| 12 |
+
# target_modules: ["q_proj", "v_proj", "v_proj", "o_proj", "gate_proj","up_proj","down_proj"]
|
| 13 |
+
target_modules: ["q_proj", "v_proj",]
|
| 14 |
+
|
| 15 |
+
data:
|
| 16 |
+
dataset_name: 'math'
|
| 17 |
+
split_ratio: 0.025
|
| 18 |
+
# path: "./data/gsm8k_test.jsonl"
|
| 19 |
+
path: ./data/MetaMathQA-40K/MetaMathQA-40K.json
|
| 20 |
+
# path: ./data/MetaMathQA/MetaMathQA-395K.json
|
| 21 |
+
dataset_split: train
|
| 22 |
+
# dataset_field: [question, answer]
|
| 23 |
+
dataset_field: [query, response]
|
| 24 |
+
|
| 25 |
+
trainer_args:
|
| 26 |
+
learning_rate: 2e-4
|
| 27 |
+
# eval_strategy: steps
|
| 28 |
+
per_device_train_batch_size: 32
|
| 29 |
+
per_device_eval_batch_size: 64
|
| 30 |
+
# accumulate_grad_batches: 1
|
| 31 |
+
# save_steps: 1000
|
| 32 |
+
gradient_checkpointing: False # (Turn off for faster training)
|
| 33 |
+
output_dir: "./run_exps"
|
| 34 |
+
# save_path: "runs"
|
| 35 |
+
|
| 36 |
+
report_to: wandb
|
| 37 |
+
logging_steps: 25
|
| 38 |
+
# eval_steps: 100
|
| 39 |
+
#dataloader_num_workers: 4
|
| 40 |
+
|
| 41 |
+
num_train_epochs: 2.0
|
| 42 |
+
# max_steps: -1
|
| 43 |
+
|
| 44 |
+
# device: 'cuda'
|
nl_tasks/config/commonsense_opt.yaml
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
model:
|
| 3 |
+
model_name: facebook/opt-125m #"openai-community/gpt2" #EleutherAI/pythia-160m #Qwen/Qwen2.5-0.5B
|
| 4 |
+
# adapter_path: "./nl_tasks/run_exps/ft2"
|
| 5 |
+
data_collator_mode: 'dynamic'
|
| 6 |
+
|
| 7 |
+
rotation_adapter_config:
|
| 8 |
+
r: 4
|
| 9 |
+
num_rotations: 2
|
| 10 |
+
target_modules: ["q_proj", "v_proj"]
|
| 11 |
+
|
| 12 |
+
data:
|
| 13 |
+
dataset_name: 'math'
|
| 14 |
+
# path: "./nl_tasks/data/MetaMathQA-40K" #MetaMathQA-40K.json"
|
| 15 |
+
path: "./data/gsm8k_test.jsonl"
|
| 16 |
+
dataset_split: train[:200]
|
| 17 |
+
dataset_field: [question, answer]
|
| 18 |
+
|
| 19 |
+
trainer_args:
|
| 20 |
+
learning_rate: 2e-4
|
| 21 |
+
# accumulate_grad_batches: 1
|
| 22 |
+
# dataloader_workers: 5
|
| 23 |
+
# save_interval: 1000
|
| 24 |
+
# sample_interval: 100
|
| 25 |
+
# max_steps: -1
|
| 26 |
+
gradient_checkpointing: False # (Turn off for faster training)
|
| 27 |
+
output_dir: "./run_exps"
|
| 28 |
+
# save_path: "runs"
|
| 29 |
+
|
| 30 |
+
max_steps: 40
|
| 31 |
+
|
| 32 |
+
# device: 'cuda'
|
nl_tasks/config/glue.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
model:
|
| 3 |
+
model_name: microsoft/deberta-v3-base #facebook/opt-125m #meta-llama/Llama-2-7b-hf #"openai-community/gpt2" #EleutherAI/pythia-160m #Qwen/Qwen2.5-0.5B
|
| 4 |
+
# model_name: facebook/opt-125m
|
| 5 |
+
# adapter_path: "./run_all/exnr15/ft2"
|
| 6 |
+
# adapter_path: './run_all/run_exps9/ft2'
|
| 7 |
+
# adapter_path: "./exp395/run_ex07/ft2"
|
| 8 |
+
data_collator_mode: 'dynamic'
|
| 9 |
+
|
| 10 |
+
rotation_adapter_config:
|
| 11 |
+
r: 5
|
| 12 |
+
num_rotations: 1
|
| 13 |
+
# target_modules: ["q_proj", "v_proj", "v_proj", "o_proj", "gate_proj","up_proj","down_proj"]
|
| 14 |
+
target_modules: ["query_proj", "value_proj", "key_proj", 'attention.output.dense', 'intermediate.dense', 'output.dense']
|
| 15 |
+
task_type: "SEQ_CLS"
|
| 16 |
+
|
| 17 |
+
data:
|
| 18 |
+
dataset_name: 'math'
|
| 19 |
+
split_ratio: 0.00258
|
| 20 |
+
# path: "./data/gsm8k_test.jsonl"
|
| 21 |
+
# path: ./data/MetaMathQA-40K/MetaMathQA-40K.json
|
| 22 |
+
path: ./data/MetaMathQA/MetaMathQA-395K.json
|
| 23 |
+
dataset_split: train[:100000]
|
| 24 |
+
# dataset_field: [question, answer]
|
| 25 |
+
dataset_field: [query, response]
|
| 26 |
+
|
| 27 |
+
trainer_args:
|
| 28 |
+
learning_rate: 2e-4
|
| 29 |
+
warmup_ratio: 0.01
|
| 30 |
+
# eval_strategy: steps
|
| 31 |
+
per_device_train_batch_size: 32
|
| 32 |
+
per_device_eval_batch_size: 64
|
| 33 |
+
# accumulate_grad_batches: 1
|
| 34 |
+
# save_steps: 1000
|
| 35 |
+
gradient_checkpointing: False # (Turn off for faster training)
|
| 36 |
+
output_dir: "./exps/run_exps"
|
| 37 |
+
# save_path: "runs"
|
| 38 |
+
|
| 39 |
+
# report_to: wandb
|
| 40 |
+
logging_steps: 200
|
| 41 |
+
# eval_steps: 1000
|
| 42 |
+
#dataloader_num_workers: 4
|
| 43 |
+
|
| 44 |
+
num_train_epochs: 2.0
|
| 45 |
+
# max_steps: 21
|
| 46 |
+
# torch_compile: False
|
| 47 |
+
|
| 48 |
+
# device: 'cuda'
|
nl_tasks/config/math395.yaml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
model:
|
| 3 |
+
model_name: meta-llama/Llama-2-7b-hf #facebook/opt-125m #meta-llama/Llama-2-7b-hf #"openai-community/gpt2" #EleutherAI/pythia-160m #Qwen/Qwen2.5-0.5B
|
| 4 |
+
# model_name: facebook/opt-125m
|
| 5 |
+
# adapter_path: "./run_all/exnr15/ft2"
|
| 6 |
+
# adapter_path: './run_all/run_exps9/ft2'
|
| 7 |
+
# adapter_path: "./exp395/run_ex07/ft2"
|
| 8 |
+
data_collator_mode: 'dynamic'
|
| 9 |
+
|
| 10 |
+
rotation_adapter_config:
|
| 11 |
+
r: 16
|
| 12 |
+
num_rotations: 1
|
| 13 |
+
# target_modules: ["q_proj", "v_proj", "v_proj", "o_proj", "gate_proj","up_proj","down_proj"]
|
| 14 |
+
target_modules: ["q_proj", "v_proj",]
|
| 15 |
+
|
| 16 |
+
data:
|
| 17 |
+
dataset_name: 'math'
|
| 18 |
+
split_ratio: 0.00258
|
| 19 |
+
# path: "./data/gsm8k_test.jsonl"
|
| 20 |
+
# path: ./data/MetaMathQA-40K/MetaMathQA-40K.json
|
| 21 |
+
path: ./data/MetaMathQA/MetaMathQA-395K.json
|
| 22 |
+
dataset_split: train[:100000]
|
| 23 |
+
# dataset_field: [question, answer]
|
| 24 |
+
dataset_field: [query, response]
|
| 25 |
+
|
| 26 |
+
trainer_args:
|
| 27 |
+
learning_rate: 2e-4
|
| 28 |
+
warmup_ratio: 0.01
|
| 29 |
+
# eval_strategy: steps
|
| 30 |
+
per_device_train_batch_size: 32
|
| 31 |
+
per_device_eval_batch_size: 64
|
| 32 |
+
# accumulate_grad_batches: 1
|
| 33 |
+
# save_steps: 1000
|
| 34 |
+
gradient_checkpointing: False # (Turn off for faster training)
|
| 35 |
+
output_dir: "./exps/run_exps"
|
| 36 |
+
# save_path: "runs"
|
| 37 |
+
|
| 38 |
+
report_to: wandb
|
| 39 |
+
logging_steps: 200
|
| 40 |
+
# eval_steps: 1000
|
| 41 |
+
#dataloader_num_workers: 4
|
| 42 |
+
|
| 43 |
+
num_train_epochs: 2.0
|
| 44 |
+
# max_steps: 21
|
| 45 |
+
|
| 46 |
+
# device: 'cuda'
|
nl_tasks/data/MATH_test.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
nl_tasks/data/MetaMathQA-40K/MetaMathQA-40K.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c884f10e8aa1229a6e73a6bba2c9134ee0c7b7de92a02a7b8c9459085a59e117
|
| 3 |
+
size 31076207
|
nl_tasks/data/MetaMathQA/MetaMathQA-395K.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fb39a5d8c05c042ece92eae37dfd5ea414a5979df2bf3ad3b86411bef8205725
|
| 3 |
+
size 395626321
|
nl_tasks/data/gsm8k_test.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
nl_tasks/environment.yaml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# environment.yml
|
| 2 |
+
name: allm # The name of the environment
|
| 3 |
+
|
| 4 |
+
channels: # The conda channels to search for packages
|
| 5 |
+
# - pytorch
|
| 6 |
+
- conda-forge
|
| 7 |
+
# - dnachun
|
| 8 |
+
# - anaconda
|
| 9 |
+
channel_priority: strict
|
| 10 |
+
|
| 11 |
+
dependencies:
|
| 12 |
+
# Packages to install with conda
|
| 13 |
+
- python=3.11.3
|
| 14 |
+
#- pytorch-cuda=12.4
|
| 15 |
+
#- pytorch
|
| 16 |
+
# - numpy
|
| 17 |
+
- transformers>=4.55
|
| 18 |
+
- einops
|
| 19 |
+
- jaxtyping
|
| 20 |
+
|
| 21 |
+
- tensorboard
|
| 22 |
+
- omegaconf
|
| 23 |
+
- accelerate
|
| 24 |
+
- peft
|
| 25 |
+
|
| 26 |
+
- wandb
|
| 27 |
+
|
| 28 |
+
- scipy
|
| 29 |
+
- pandas
|
| 30 |
+
- matplotlib
|
| 31 |
+
- scikit-image
|
| 32 |
+
- scikit-learn
|
| 33 |
+
- joblib
|
| 34 |
+
- pillow
|
| 35 |
+
- datasets
|
| 36 |
+
## NO - huggingface_hub
|
| 37 |
+
- tqdm
|
| 38 |
+
- nltk
|
| 39 |
+
- future
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
- defusedxml
|
| 43 |
+
- ipdb
|
| 44 |
+
- torchinfo
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
- timm
|
| 49 |
+
- graphviz #anaconda::graphviz
|
| 50 |
+
- dnachun::torchviz
|
| 51 |
+
- pip:
|
| 52 |
+
# - draccus
|
| 53 |
+
- fraction
|
| 54 |
+
- vllm
|
| 55 |
+
|
nl_tasks/exps/run_ex01/trainer_state.json
ADDED
|
@@ -0,0 +1,914 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_global_step": null,
|
| 3 |
+
"best_metric": null,
|
| 4 |
+
"best_model_checkpoint": null,
|
| 5 |
+
"epoch": 2.0,
|
| 6 |
+
"eval_steps": 100,
|
| 7 |
+
"global_step": 2438,
|
| 8 |
+
"is_hyper_param_search": false,
|
| 9 |
+
"is_local_process_zero": true,
|
| 10 |
+
"is_world_process_zero": true,
|
| 11 |
+
"log_history": [
|
| 12 |
+
{
|
| 13 |
+
"epoch": 0.020508613617719443,
|
| 14 |
+
"grad_norm": 0.06690643727779388,
|
| 15 |
+
"learning_rate": 4.918032786885246e-06,
|
| 16 |
+
"loss": 0.751,
|
| 17 |
+
"step": 25
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"epoch": 0.04101722723543889,
|
| 21 |
+
"grad_norm": 0.23132523894309998,
|
| 22 |
+
"learning_rate": 1.0040983606557377e-05,
|
| 23 |
+
"loss": 0.7344,
|
| 24 |
+
"step": 50
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"epoch": 0.06152584085315833,
|
| 28 |
+
"grad_norm": 0.2116735428571701,
|
| 29 |
+
"learning_rate": 1.5163934426229509e-05,
|
| 30 |
+
"loss": 0.6404,
|
| 31 |
+
"step": 75
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"epoch": 0.08203445447087777,
|
| 35 |
+
"grad_norm": 0.1672675907611847,
|
| 36 |
+
"learning_rate": 2.028688524590164e-05,
|
| 37 |
+
"loss": 0.486,
|
| 38 |
+
"step": 100
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"epoch": 0.08203445447087777,
|
| 42 |
+
"eval_loss": 0.4427674114704132,
|
| 43 |
+
"eval_runtime": 19.6288,
|
| 44 |
+
"eval_samples_per_second": 50.945,
|
| 45 |
+
"eval_steps_per_second": 0.815,
|
| 46 |
+
"step": 100
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"epoch": 0.10254306808859721,
|
| 50 |
+
"grad_norm": 0.16888108849525452,
|
| 51 |
+
"learning_rate": 2.540983606557377e-05,
|
| 52 |
+
"loss": 0.4407,
|
| 53 |
+
"step": 125
|
| 54 |
+
},
|
| 55 |
+
{
|
| 56 |
+
"epoch": 0.12305168170631665,
|
| 57 |
+
"grad_norm": 0.17033565044403076,
|
| 58 |
+
"learning_rate": 3.05327868852459e-05,
|
| 59 |
+
"loss": 0.4031,
|
| 60 |
+
"step": 150
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"epoch": 0.1435602953240361,
|
| 64 |
+
"grad_norm": 0.194916769862175,
|
| 65 |
+
"learning_rate": 3.5655737704918037e-05,
|
| 66 |
+
"loss": 0.3787,
|
| 67 |
+
"step": 175
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"epoch": 0.16406890894175555,
|
| 71 |
+
"grad_norm": 0.29443657398223877,
|
| 72 |
+
"learning_rate": 4.077868852459016e-05,
|
| 73 |
+
"loss": 0.3769,
|
| 74 |
+
"step": 200
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"epoch": 0.16406890894175555,
|
| 78 |
+
"eval_loss": 0.3537510335445404,
|
| 79 |
+
"eval_runtime": 19.4681,
|
| 80 |
+
"eval_samples_per_second": 51.366,
|
| 81 |
+
"eval_steps_per_second": 0.822,
|
| 82 |
+
"step": 200
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"epoch": 0.184577522559475,
|
| 86 |
+
"grad_norm": 0.2323056161403656,
|
| 87 |
+
"learning_rate": 4.59016393442623e-05,
|
| 88 |
+
"loss": 0.3658,
|
| 89 |
+
"step": 225
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"epoch": 0.20508613617719443,
|
| 93 |
+
"grad_norm": 0.239767923951149,
|
| 94 |
+
"learning_rate": 4.999935927058032e-05,
|
| 95 |
+
"loss": 0.3402,
|
| 96 |
+
"step": 250
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"epoch": 0.22559474979491387,
|
| 100 |
+
"grad_norm": 0.21633633971214294,
|
| 101 |
+
"learning_rate": 4.997693718919013e-05,
|
| 102 |
+
"loss": 0.3342,
|
| 103 |
+
"step": 275
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"epoch": 0.2461033634126333,
|
| 107 |
+
"grad_norm": 0.23770427703857422,
|
| 108 |
+
"learning_rate": 4.992251147198466e-05,
|
| 109 |
+
"loss": 0.3366,
|
| 110 |
+
"step": 300
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"epoch": 0.2461033634126333,
|
| 114 |
+
"eval_loss": 0.3253124952316284,
|
| 115 |
+
"eval_runtime": 19.4632,
|
| 116 |
+
"eval_samples_per_second": 51.379,
|
| 117 |
+
"eval_steps_per_second": 0.822,
|
| 118 |
+
"step": 300
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"epoch": 0.2666119770303528,
|
| 122 |
+
"grad_norm": 0.2381717562675476,
|
| 123 |
+
"learning_rate": 4.98361518561306e-05,
|
| 124 |
+
"loss": 0.3379,
|
| 125 |
+
"step": 325
|
| 126 |
+
},
|
| 127 |
+
{
|
| 128 |
+
"epoch": 0.2871205906480722,
|
| 129 |
+
"grad_norm": 0.2646787166595459,
|
| 130 |
+
"learning_rate": 4.971796899657632e-05,
|
| 131 |
+
"loss": 0.329,
|
| 132 |
+
"step": 350
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"epoch": 0.30762920426579166,
|
| 136 |
+
"grad_norm": 0.3449915647506714,
|
| 137 |
+
"learning_rate": 4.9568114324266624e-05,
|
| 138 |
+
"loss": 0.3406,
|
| 139 |
+
"step": 375
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"epoch": 0.3281378178835111,
|
| 143 |
+
"grad_norm": 0.3168472647666931,
|
| 144 |
+
"learning_rate": 4.938677985211011e-05,
|
| 145 |
+
"loss": 0.3227,
|
| 146 |
+
"step": 400
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"epoch": 0.3281378178835111,
|
| 150 |
+
"eval_loss": 0.3124091625213623,
|
| 151 |
+
"eval_runtime": 19.4632,
|
| 152 |
+
"eval_samples_per_second": 51.379,
|
| 153 |
+
"eval_steps_per_second": 0.822,
|
| 154 |
+
"step": 400
|
| 155 |
+
},
|
| 156 |
+
{
|
| 157 |
+
"epoch": 0.34864643150123054,
|
| 158 |
+
"grad_norm": 0.24941708147525787,
|
| 159 |
+
"learning_rate": 4.9174197928947795e-05,
|
| 160 |
+
"loss": 0.3301,
|
| 161 |
+
"step": 425
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"epoch": 0.36915504511895,
|
| 165 |
+
"grad_norm": 0.25401854515075684,
|
| 166 |
+
"learning_rate": 4.8930640941838104e-05,
|
| 167 |
+
"loss": 0.3267,
|
| 168 |
+
"step": 450
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
"epoch": 0.3896636587366694,
|
| 172 |
+
"grad_norm": 0.2830125689506531,
|
| 173 |
+
"learning_rate": 4.86564209670399e-05,
|
| 174 |
+
"loss": 0.3159,
|
| 175 |
+
"step": 475
|
| 176 |
+
},
|
| 177 |
+
{
|
| 178 |
+
"epoch": 0.41017227235438886,
|
| 179 |
+
"grad_norm": 0.26643845438957214,
|
| 180 |
+
"learning_rate": 4.835188937014059e-05,
|
| 181 |
+
"loss": 0.3006,
|
| 182 |
+
"step": 500
|
| 183 |
+
},
|
| 184 |
+
{
|
| 185 |
+
"epoch": 0.41017227235438886,
|
| 186 |
+
"eval_loss": 0.30365315079689026,
|
| 187 |
+
"eval_runtime": 19.5051,
|
| 188 |
+
"eval_samples_per_second": 51.269,
|
| 189 |
+
"eval_steps_per_second": 0.82,
|
| 190 |
+
"step": 500
|
| 191 |
+
},
|
| 192 |
+
{
|
| 193 |
+
"epoch": 0.4306808859721083,
|
| 194 |
+
"grad_norm": 0.2782154977321625,
|
| 195 |
+
"learning_rate": 4.801743635584168e-05,
|
| 196 |
+
"loss": 0.3015,
|
| 197 |
+
"step": 525
|
| 198 |
+
},
|
| 199 |
+
{
|
| 200 |
+
"epoch": 0.45118949958982774,
|
| 201 |
+
"grad_norm": 0.23164071142673492,
|
| 202 |
+
"learning_rate": 4.7653490467978906e-05,
|
| 203 |
+
"loss": 0.3095,
|
| 204 |
+
"step": 550
|
| 205 |
+
},
|
| 206 |
+
{
|
| 207 |
+
"epoch": 0.4716981132075472,
|
| 208 |
+
"grad_norm": 0.282122403383255,
|
| 209 |
+
"learning_rate": 4.726051804041709e-05,
|
| 210 |
+
"loss": 0.3049,
|
| 211 |
+
"step": 575
|
| 212 |
+
},
|
| 213 |
+
{
|
| 214 |
+
"epoch": 0.4922067268252666,
|
| 215 |
+
"grad_norm": 0.2644311487674713,
|
| 216 |
+
"learning_rate": 4.683902259952387e-05,
|
| 217 |
+
"loss": 0.3213,
|
| 218 |
+
"step": 600
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
"epoch": 0.4922067268252666,
|
| 222 |
+
"eval_loss": 0.29857054352760315,
|
| 223 |
+
"eval_runtime": 19.4613,
|
| 224 |
+
"eval_samples_per_second": 51.384,
|
| 225 |
+
"eval_steps_per_second": 0.822,
|
| 226 |
+
"step": 600
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"epoch": 0.5127153404429861,
|
| 230 |
+
"grad_norm": 0.2529069185256958,
|
| 231 |
+
"learning_rate": 4.638954421898746e-05,
|
| 232 |
+
"loss": 0.3001,
|
| 233 |
+
"step": 625
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"epoch": 0.5332239540607056,
|
| 237 |
+
"grad_norm": 0.2736290395259857,
|
| 238 |
+
"learning_rate": 4.5912658827805425e-05,
|
| 239 |
+
"loss": 0.2916,
|
| 240 |
+
"step": 650
|
| 241 |
+
},
|
| 242 |
+
{
|
| 243 |
+
"epoch": 0.5537325676784249,
|
| 244 |
+
"grad_norm": 0.2555182874202728,
|
| 245 |
+
"learning_rate": 4.5408977472331005e-05,
|
| 246 |
+
"loss": 0.3052,
|
| 247 |
+
"step": 675
|
| 248 |
+
},
|
| 249 |
+
{
|
| 250 |
+
"epoch": 0.5742411812961444,
|
| 251 |
+
"grad_norm": 0.30381742119789124,
|
| 252 |
+
"learning_rate": 4.48791455333227e-05,
|
| 253 |
+
"loss": 0.3036,
|
| 254 |
+
"step": 700
|
| 255 |
+
},
|
| 256 |
+
{
|
| 257 |
+
"epoch": 0.5742411812961444,
|
| 258 |
+
"eval_loss": 0.292624831199646,
|
| 259 |
+
"eval_runtime": 19.4487,
|
| 260 |
+
"eval_samples_per_second": 51.417,
|
| 261 |
+
"eval_steps_per_second": 0.823,
|
| 262 |
+
"step": 700
|
| 263 |
+
},
|
| 264 |
+
{
|
| 265 |
+
"epoch": 0.5947497949138638,
|
| 266 |
+
"grad_norm": 0.26833590865135193,
|
| 267 |
+
"learning_rate": 4.432384189900008e-05,
|
| 268 |
+
"loss": 0.3023,
|
| 269 |
+
"step": 725
|
| 270 |
+
},
|
| 271 |
+
{
|
| 272 |
+
"epoch": 0.6152584085315833,
|
| 273 |
+
"grad_norm": 0.263784259557724,
|
| 274 |
+
"learning_rate": 4.3743778095165764e-05,
|
| 275 |
+
"loss": 0.3016,
|
| 276 |
+
"step": 750
|
| 277 |
+
},
|
| 278 |
+
{
|
| 279 |
+
"epoch": 0.6357670221493027,
|
| 280 |
+
"grad_norm": 0.30153971910476685,
|
| 281 |
+
"learning_rate": 4.313969737350775e-05,
|
| 282 |
+
"loss": 0.2984,
|
| 283 |
+
"step": 775
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
"epoch": 0.6562756357670222,
|
| 287 |
+
"grad_norm": 0.27196648716926575,
|
| 288 |
+
"learning_rate": 4.251237375925071e-05,
|
| 289 |
+
"loss": 0.3034,
|
| 290 |
+
"step": 800
|
| 291 |
+
},
|
| 292 |
+
{
|
| 293 |
+
"epoch": 0.6562756357670222,
|
| 294 |
+
"eval_loss": 0.28956037759780884,
|
| 295 |
+
"eval_runtime": 19.4528,
|
| 296 |
+
"eval_samples_per_second": 51.407,
|
| 297 |
+
"eval_steps_per_second": 0.823,
|
| 298 |
+
"step": 800
|
| 299 |
+
},
|
| 300 |
+
{
|
| 301 |
+
"epoch": 0.6767842493847416,
|
| 302 |
+
"grad_norm": 0.25102221965789795,
|
| 303 |
+
"learning_rate": 4.186261105937612e-05,
|
| 304 |
+
"loss": 0.2961,
|
| 305 |
+
"step": 825
|
| 306 |
+
},
|
| 307 |
+
{
|
| 308 |
+
"epoch": 0.6972928630024611,
|
| 309 |
+
"grad_norm": 0.30812397599220276,
|
| 310 |
+
"learning_rate": 4.1191241832682364e-05,
|
| 311 |
+
"loss": 0.2995,
|
| 312 |
+
"step": 850
|
| 313 |
+
},
|
| 314 |
+
{
|
| 315 |
+
"epoch": 0.7178014766201805,
|
| 316 |
+
"grad_norm": 0.2563766837120056,
|
| 317 |
+
"learning_rate": 4.049912632300421e-05,
|
| 318 |
+
"loss": 0.2878,
|
| 319 |
+
"step": 875
|
| 320 |
+
},
|
| 321 |
+
{
|
| 322 |
+
"epoch": 0.7383100902379,
|
| 323 |
+
"grad_norm": 0.3218742311000824,
|
| 324 |
+
"learning_rate": 3.978715135695881e-05,
|
| 325 |
+
"loss": 0.296,
|
| 326 |
+
"step": 900
|
| 327 |
+
},
|
| 328 |
+
{
|
| 329 |
+
"epoch": 0.7383100902379,
|
| 330 |
+
"eval_loss": 0.2854253053665161,
|
| 331 |
+
"eval_runtime": 19.4687,
|
| 332 |
+
"eval_samples_per_second": 51.365,
|
| 333 |
+
"eval_steps_per_second": 0.822,
|
| 334 |
+
"step": 900
|
| 335 |
+
},
|
| 336 |
+
{
|
| 337 |
+
"epoch": 0.7588187038556193,
|
| 338 |
+
"grad_norm": 0.27102115750312805,
|
| 339 |
+
"learning_rate": 3.905622920763031e-05,
|
| 340 |
+
"loss": 0.2944,
|
| 341 |
+
"step": 925
|
| 342 |
+
},
|
| 343 |
+
{
|
| 344 |
+
"epoch": 0.7793273174733388,
|
| 345 |
+
"grad_norm": 0.276564359664917,
|
| 346 |
+
"learning_rate": 3.83072964256494e-05,
|
| 347 |
+
"loss": 0.2847,
|
| 348 |
+
"step": 950
|
| 349 |
+
},
|
| 350 |
+
{
|
| 351 |
+
"epoch": 0.7998359310910582,
|
| 352 |
+
"grad_norm": 0.3700260519981384,
|
| 353 |
+
"learning_rate": 3.7541312639165145e-05,
|
| 354 |
+
"loss": 0.2877,
|
| 355 |
+
"step": 975
|
| 356 |
+
},
|
| 357 |
+
{
|
| 358 |
+
"epoch": 0.8203445447087777,
|
| 359 |
+
"grad_norm": 0.2763277292251587,
|
| 360 |
+
"learning_rate": 3.675925932424715e-05,
|
| 361 |
+
"loss": 0.2819,
|
| 362 |
+
"step": 1000
|
| 363 |
+
},
|
| 364 |
+
{
|
| 365 |
+
"epoch": 0.8203445447087777,
|
| 366 |
+
"eval_loss": 0.2817630469799042,
|
| 367 |
+
"eval_runtime": 19.4695,
|
| 368 |
+
"eval_samples_per_second": 51.362,
|
| 369 |
+
"eval_steps_per_second": 0.822,
|
| 370 |
+
"step": 1000
|
| 371 |
+
},
|
| 372 |
+
{
|
| 373 |
+
"epoch": 0.8408531583264971,
|
| 374 |
+
"grad_norm": 0.3116626739501953,
|
| 375 |
+
"learning_rate": 3.596213854729328e-05,
|
| 376 |
+
"loss": 0.2855,
|
| 377 |
+
"step": 1025
|
| 378 |
+
},
|
| 379 |
+
{
|
| 380 |
+
"epoch": 0.8613617719442166,
|
| 381 |
+
"grad_norm": 0.27529171109199524,
|
| 382 |
+
"learning_rate": 3.515097168105444e-05,
|
| 383 |
+
"loss": 0.2847,
|
| 384 |
+
"step": 1050
|
| 385 |
+
},
|
| 386 |
+
{
|
| 387 |
+
"epoch": 0.881870385561936,
|
| 388 |
+
"grad_norm": 0.3135339021682739,
|
| 389 |
+
"learning_rate": 3.4326798095921656e-05,
|
| 390 |
+
"loss": 0.2875,
|
| 391 |
+
"step": 1075
|
| 392 |
+
},
|
| 393 |
+
{
|
| 394 |
+
"epoch": 0.9023789991796555,
|
| 395 |
+
"grad_norm": 0.33244284987449646,
|
| 396 |
+
"learning_rate": 3.349067382815217e-05,
|
| 397 |
+
"loss": 0.2885,
|
| 398 |
+
"step": 1100
|
| 399 |
+
},
|
| 400 |
+
{
|
| 401 |
+
"epoch": 0.9023789991796555,
|
| 402 |
+
"eval_loss": 0.27884459495544434,
|
| 403 |
+
"eval_runtime": 19.4401,
|
| 404 |
+
"eval_samples_per_second": 51.44,
|
| 405 |
+
"eval_steps_per_second": 0.823,
|
| 406 |
+
"step": 1100
|
| 407 |
+
},
|
| 408 |
+
{
|
| 409 |
+
"epoch": 0.9228876127973749,
|
| 410 |
+
"grad_norm": 0.2761545479297638,
|
| 411 |
+
"learning_rate": 3.264367022674124e-05,
|
| 412 |
+
"loss": 0.2857,
|
| 413 |
+
"step": 1125
|
| 414 |
+
},
|
| 415 |
+
{
|
| 416 |
+
"epoch": 0.9433962264150944,
|
| 417 |
+
"grad_norm": 0.24953879415988922,
|
| 418 |
+
"learning_rate": 3.1786872580673214e-05,
|
| 419 |
+
"loss": 0.2832,
|
| 420 |
+
"step": 1150
|
| 421 |
+
},
|
| 422 |
+
{
|
| 423 |
+
"epoch": 0.9639048400328137,
|
| 424 |
+
"grad_norm": 0.32240793108940125,
|
| 425 |
+
"learning_rate": 3.09213787283109e-05,
|
| 426 |
+
"loss": 0.2909,
|
| 427 |
+
"step": 1175
|
| 428 |
+
},
|
| 429 |
+
{
|
| 430 |
+
"epoch": 0.9844134536505332,
|
| 431 |
+
"grad_norm": 0.2800053358078003,
|
| 432 |
+
"learning_rate": 3.004829765070516e-05,
|
| 433 |
+
"loss": 0.297,
|
| 434 |
+
"step": 1200
|
| 435 |
+
},
|
| 436 |
+
{
|
| 437 |
+
"epoch": 0.9844134536505332,
|
| 438 |
+
"eval_loss": 0.2768399119377136,
|
| 439 |
+
"eval_runtime": 19.4646,
|
| 440 |
+
"eval_samples_per_second": 51.375,
|
| 441 |
+
"eval_steps_per_second": 0.822,
|
| 442 |
+
"step": 1200
|
| 443 |
+
},
|
| 444 |
+
{
|
| 445 |
+
"epoch": 1.0049220672682526,
|
| 446 |
+
"grad_norm": 0.27479347586631775,
|
| 447 |
+
"learning_rate": 2.916874805062701e-05,
|
| 448 |
+
"loss": 0.2817,
|
| 449 |
+
"step": 1225
|
| 450 |
+
},
|
| 451 |
+
{
|
| 452 |
+
"epoch": 1.0254306808859721,
|
| 453 |
+
"grad_norm": 0.28057900071144104,
|
| 454 |
+
"learning_rate": 2.828385691914301e-05,
|
| 455 |
+
"loss": 0.2821,
|
| 456 |
+
"step": 1250
|
| 457 |
+
},
|
| 458 |
+
{
|
| 459 |
+
"epoch": 1.0459392945036916,
|
| 460 |
+
"grad_norm": 0.2822401225566864,
|
| 461 |
+
"learning_rate": 2.7394758091570664e-05,
|
| 462 |
+
"loss": 0.2831,
|
| 463 |
+
"step": 1275
|
| 464 |
+
},
|
| 465 |
+
{
|
| 466 |
+
"epoch": 1.066447908121411,
|
| 467 |
+
"grad_norm": 0.2816685438156128,
|
| 468 |
+
"learning_rate": 2.6502590794664073e-05,
|
| 469 |
+
"loss": 0.285,
|
| 470 |
+
"step": 1300
|
| 471 |
+
},
|
| 472 |
+
{
|
| 473 |
+
"epoch": 1.066447908121411,
|
| 474 |
+
"eval_loss": 0.2752681076526642,
|
| 475 |
+
"eval_runtime": 19.4336,
|
| 476 |
+
"eval_samples_per_second": 51.457,
|
| 477 |
+
"eval_steps_per_second": 0.823,
|
| 478 |
+
"step": 1300
|
| 479 |
+
},
|
| 480 |
+
{
|
| 481 |
+
"epoch": 1.0869565217391304,
|
| 482 |
+
"grad_norm": 0.2750077545642853,
|
| 483 |
+
"learning_rate": 2.560849818689141e-05,
|
| 484 |
+
"loss": 0.2829,
|
| 485 |
+
"step": 1325
|
| 486 |
+
},
|
| 487 |
+
{
|
| 488 |
+
"epoch": 1.1074651353568499,
|
| 489 |
+
"grad_norm": 0.29034796357154846,
|
| 490 |
+
"learning_rate": 2.471362589367452e-05,
|
| 491 |
+
"loss": 0.2727,
|
| 492 |
+
"step": 1350
|
| 493 |
+
},
|
| 494 |
+
{
|
| 495 |
+
"epoch": 1.1279737489745694,
|
| 496 |
+
"grad_norm": 0.33246326446533203,
|
| 497 |
+
"learning_rate": 2.3819120539467663e-05,
|
| 498 |
+
"loss": 0.2806,
|
| 499 |
+
"step": 1375
|
| 500 |
+
},
|
| 501 |
+
{
|
| 502 |
+
"epoch": 1.1484823625922886,
|
| 503 |
+
"grad_norm": 0.26562532782554626,
|
| 504 |
+
"learning_rate": 2.2926128278556052e-05,
|
| 505 |
+
"loss": 0.2666,
|
| 506 |
+
"step": 1400
|
| 507 |
+
},
|
| 508 |
+
{
|
| 509 |
+
"epoch": 1.1484823625922886,
|
| 510 |
+
"eval_loss": 0.27384424209594727,
|
| 511 |
+
"eval_runtime": 19.4305,
|
| 512 |
+
"eval_samples_per_second": 51.466,
|
| 513 |
+
"eval_steps_per_second": 0.823,
|
| 514 |
+
"step": 1400
|
| 515 |
+
},
|
| 516 |
+
{
|
| 517 |
+
"epoch": 1.1689909762100081,
|
| 518 |
+
"grad_norm": 0.3040076494216919,
|
| 519 |
+
"learning_rate": 2.2035793326456883e-05,
|
| 520 |
+
"loss": 0.2799,
|
| 521 |
+
"step": 1425
|
| 522 |
+
},
|
| 523 |
+
{
|
| 524 |
+
"epoch": 1.1894995898277276,
|
| 525 |
+
"grad_norm": 0.38336554169654846,
|
| 526 |
+
"learning_rate": 2.1149256493804576e-05,
|
| 527 |
+
"loss": 0.2858,
|
| 528 |
+
"step": 1450
|
| 529 |
+
},
|
| 530 |
+
{
|
| 531 |
+
"epoch": 1.2100082034454471,
|
| 532 |
+
"grad_norm": 0.28841930627822876,
|
| 533 |
+
"learning_rate": 2.0267653724598747e-05,
|
| 534 |
+
"loss": 0.2777,
|
| 535 |
+
"step": 1475
|
| 536 |
+
},
|
| 537 |
+
{
|
| 538 |
+
"epoch": 1.2305168170631666,
|
| 539 |
+
"grad_norm": 0.3314826786518097,
|
| 540 |
+
"learning_rate": 1.9392114640687985e-05,
|
| 541 |
+
"loss": 0.2884,
|
| 542 |
+
"step": 1500
|
| 543 |
+
},
|
| 544 |
+
{
|
| 545 |
+
"epoch": 1.2305168170631666,
|
| 546 |
+
"eval_loss": 0.27183249592781067,
|
| 547 |
+
"eval_runtime": 19.4549,
|
| 548 |
+
"eval_samples_per_second": 51.401,
|
| 549 |
+
"eval_steps_per_second": 0.822,
|
| 550 |
+
"step": 1500
|
| 551 |
+
},
|
| 552 |
+
{
|
| 553 |
+
"epoch": 1.251025430680886,
|
| 554 |
+
"grad_norm": 0.35494905710220337,
|
| 555 |
+
"learning_rate": 1.8523761094354304e-05,
|
| 556 |
+
"loss": 0.2833,
|
| 557 |
+
"step": 1525
|
| 558 |
+
},
|
| 559 |
+
{
|
| 560 |
+
"epoch": 1.2715340442986054,
|
| 561 |
+
"grad_norm": 0.30966347455978394,
|
| 562 |
+
"learning_rate": 1.7663705730853012e-05,
|
| 563 |
+
"loss": 0.276,
|
| 564 |
+
"step": 1550
|
| 565 |
+
},
|
| 566 |
+
{
|
| 567 |
+
"epoch": 1.2920426579163249,
|
| 568 |
+
"grad_norm": 0.2722432613372803,
|
| 569 |
+
"learning_rate": 1.6813050562749778e-05,
|
| 570 |
+
"loss": 0.2775,
|
| 571 |
+
"step": 1575
|
| 572 |
+
},
|
| 573 |
+
{
|
| 574 |
+
"epoch": 1.3125512715340442,
|
| 575 |
+
"grad_norm": 0.2977288067340851,
|
| 576 |
+
"learning_rate": 1.5972885557881666e-05,
|
| 577 |
+
"loss": 0.269,
|
| 578 |
+
"step": 1600
|
| 579 |
+
},
|
| 580 |
+
{
|
| 581 |
+
"epoch": 1.3125512715340442,
|
| 582 |
+
"eval_loss": 0.27014729380607605,
|
| 583 |
+
"eval_runtime": 19.4752,
|
| 584 |
+
"eval_samples_per_second": 51.347,
|
| 585 |
+
"eval_steps_per_second": 0.822,
|
| 586 |
+
"step": 1600
|
| 587 |
+
},
|
| 588 |
+
{
|
| 589 |
+
"epoch": 1.3330598851517639,
|
| 590 |
+
"grad_norm": 0.2877044081687927,
|
| 591 |
+
"learning_rate": 1.5144287242751378e-05,
|
| 592 |
+
"loss": 0.2727,
|
| 593 |
+
"step": 1625
|
| 594 |
+
},
|
| 595 |
+
{
|
| 596 |
+
"epoch": 1.3535684987694832,
|
| 597 |
+
"grad_norm": 0.3222461938858032,
|
| 598 |
+
"learning_rate": 1.4328317323144284e-05,
|
| 599 |
+
"loss": 0.2742,
|
| 600 |
+
"step": 1650
|
| 601 |
+
},
|
| 602 |
+
{
|
| 603 |
+
"epoch": 1.3740771123872026,
|
| 604 |
+
"grad_norm": 0.3473854959011078,
|
| 605 |
+
"learning_rate": 1.3526021323735626e-05,
|
| 606 |
+
"loss": 0.2724,
|
| 607 |
+
"step": 1675
|
| 608 |
+
},
|
| 609 |
+
{
|
| 610 |
+
"epoch": 1.3945857260049221,
|
| 611 |
+
"grad_norm": 0.3066307306289673,
|
| 612 |
+
"learning_rate": 1.2738427248431028e-05,
|
| 613 |
+
"loss": 0.2696,
|
| 614 |
+
"step": 1700
|
| 615 |
+
},
|
| 616 |
+
{
|
| 617 |
+
"epoch": 1.3945857260049221,
|
| 618 |
+
"eval_loss": 0.26889586448669434,
|
| 619 |
+
"eval_runtime": 19.4809,
|
| 620 |
+
"eval_samples_per_second": 51.332,
|
| 621 |
+
"eval_steps_per_second": 0.821,
|
| 622 |
+
"step": 1700
|
| 623 |
+
},
|
| 624 |
+
{
|
| 625 |
+
"epoch": 1.4150943396226414,
|
| 626 |
+
"grad_norm": 0.2990802228450775,
|
| 627 |
+
"learning_rate": 1.1966544263156865e-05,
|
| 628 |
+
"loss": 0.269,
|
| 629 |
+
"step": 1725
|
| 630 |
+
},
|
| 631 |
+
{
|
| 632 |
+
"epoch": 1.435602953240361,
|
| 633 |
+
"grad_norm": 0.29301875829696655,
|
| 634 |
+
"learning_rate": 1.1211361402788226e-05,
|
| 635 |
+
"loss": 0.2681,
|
| 636 |
+
"step": 1750
|
| 637 |
+
},
|
| 638 |
+
{
|
| 639 |
+
"epoch": 1.4561115668580804,
|
| 640 |
+
"grad_norm": 0.32831257581710815,
|
| 641 |
+
"learning_rate": 1.047384630387131e-05,
|
| 642 |
+
"loss": 0.2771,
|
| 643 |
+
"step": 1775
|
| 644 |
+
},
|
| 645 |
+
{
|
| 646 |
+
"epoch": 1.4766201804758,
|
| 647 |
+
"grad_norm": 0.3254742920398712,
|
| 648 |
+
"learning_rate": 9.75494396476423e-06,
|
| 649 |
+
"loss": 0.2644,
|
| 650 |
+
"step": 1800
|
| 651 |
+
},
|
| 652 |
+
{
|
| 653 |
+
"epoch": 1.4766201804758,
|
| 654 |
+
"eval_loss": 0.2682200074195862,
|
| 655 |
+
"eval_runtime": 19.4711,
|
| 656 |
+
"eval_samples_per_second": 51.358,
|
| 657 |
+
"eval_steps_per_second": 0.822,
|
| 658 |
+
"step": 1800
|
| 659 |
+
},
|
| 660 |
+
{
|
| 661 |
+
"epoch": 1.4971287940935194,
|
| 662 |
+
"grad_norm": 0.29671165347099304,
|
| 663 |
+
"learning_rate": 9.05557553478459e-06,
|
| 664 |
+
"loss": 0.2658,
|
| 665 |
+
"step": 1825
|
| 666 |
+
},
|
| 667 |
+
{
|
| 668 |
+
"epoch": 1.5176374077112387,
|
| 669 |
+
"grad_norm": 0.29359179735183716,
|
| 670 |
+
"learning_rate": 8.376637133915558e-06,
|
| 671 |
+
"loss": 0.2676,
|
| 672 |
+
"step": 1850
|
| 673 |
+
},
|
| 674 |
+
{
|
| 675 |
+
"epoch": 1.5381460213289582,
|
| 676 |
+
"grad_norm": 0.3013196587562561,
|
| 677 |
+
"learning_rate": 7.718998704582739e-06,
|
| 678 |
+
"loss": 0.2708,
|
| 679 |
+
"step": 1875
|
| 680 |
+
},
|
| 681 |
+
{
|
| 682 |
+
"epoch": 1.5586546349466777,
|
| 683 |
+
"grad_norm": 0.33183860778808594,
|
| 684 |
+
"learning_rate": 7.0835028969730185e-06,
|
| 685 |
+
"loss": 0.2727,
|
| 686 |
+
"step": 1900
|
| 687 |
+
},
|
| 688 |
+
{
|
| 689 |
+
"epoch": 1.5586546349466777,
|
| 690 |
+
"eval_loss": 0.26749491691589355,
|
| 691 |
+
"eval_runtime": 19.4495,
|
| 692 |
+
"eval_samples_per_second": 51.415,
|
| 693 |
+
"eval_steps_per_second": 0.823,
|
| 694 |
+
"step": 1900
|
| 695 |
+
},
|
| 696 |
+
{
|
| 697 |
+
"epoch": 1.579163248564397,
|
| 698 |
+
"grad_norm": 0.3297445774078369,
|
| 699 |
+
"learning_rate": 6.470963989323764e-06,
|
| 700 |
+
"loss": 0.2792,
|
| 701 |
+
"step": 1925
|
| 702 |
+
},
|
| 703 |
+
{
|
| 704 |
+
"epoch": 1.5996718621821167,
|
| 705 |
+
"grad_norm": 0.33076536655426025,
|
| 706 |
+
"learning_rate": 5.8821668445656924e-06,
|
| 707 |
+
"loss": 0.2749,
|
| 708 |
+
"step": 1950
|
| 709 |
+
},
|
| 710 |
+
{
|
| 711 |
+
"epoch": 1.620180475799836,
|
| 712 |
+
"grad_norm": 0.290751188993454,
|
| 713 |
+
"learning_rate": 5.317865904656497e-06,
|
| 714 |
+
"loss": 0.2653,
|
| 715 |
+
"step": 1975
|
| 716 |
+
},
|
| 717 |
+
{
|
| 718 |
+
"epoch": 1.6406890894175554,
|
| 719 |
+
"grad_norm": 0.30254530906677246,
|
| 720 |
+
"learning_rate": 4.778784223893601e-06,
|
| 721 |
+
"loss": 0.2767,
|
| 722 |
+
"step": 2000
|
| 723 |
+
},
|
| 724 |
+
{
|
| 725 |
+
"epoch": 1.6406890894175554,
|
| 726 |
+
"eval_loss": 0.26705402135849,
|
| 727 |
+
"eval_runtime": 19.451,
|
| 728 |
+
"eval_samples_per_second": 51.411,
|
| 729 |
+
"eval_steps_per_second": 0.823,
|
| 730 |
+
"step": 2000
|
| 731 |
+
},
|
| 732 |
+
{
|
| 733 |
+
"epoch": 1.661197703035275,
|
| 734 |
+
"grad_norm": 0.30978214740753174,
|
| 735 |
+
"learning_rate": 4.265612542444827e-06,
|
| 736 |
+
"loss": 0.2661,
|
| 737 |
+
"step": 2025
|
| 738 |
+
},
|
| 739 |
+
{
|
| 740 |
+
"epoch": 1.6817063166529942,
|
| 741 |
+
"grad_norm": 0.3454751968383789,
|
| 742 |
+
"learning_rate": 3.7790084012840453e-06,
|
| 743 |
+
"loss": 0.2717,
|
| 744 |
+
"step": 2050
|
| 745 |
+
},
|
| 746 |
+
{
|
| 747 |
+
"epoch": 1.7022149302707137,
|
| 748 |
+
"grad_norm": 0.3721317946910858,
|
| 749 |
+
"learning_rate": 3.319595299665873e-06,
|
| 750 |
+
"loss": 0.2767,
|
| 751 |
+
"step": 2075
|
| 752 |
+
},
|
| 753 |
+
{
|
| 754 |
+
"epoch": 1.7227235438884332,
|
| 755 |
+
"grad_norm": 0.33564668893814087,
|
| 756 |
+
"learning_rate": 2.8879618962189326e-06,
|
| 757 |
+
"loss": 0.2614,
|
| 758 |
+
"step": 2100
|
| 759 |
+
},
|
| 760 |
+
{
|
| 761 |
+
"epoch": 1.7227235438884332,
|
| 762 |
+
"eval_loss": 0.26674506068229675,
|
| 763 |
+
"eval_runtime": 19.4651,
|
| 764 |
+
"eval_samples_per_second": 51.374,
|
| 765 |
+
"eval_steps_per_second": 0.822,
|
| 766 |
+
"step": 2100
|
| 767 |
+
},
|
| 768 |
+
{
|
| 769 |
+
"epoch": 1.7432321575061525,
|
| 770 |
+
"grad_norm": 0.31175196170806885,
|
| 771 |
+
"learning_rate": 2.484661254681381e-06,
|
| 772 |
+
"loss": 0.2689,
|
| 773 |
+
"step": 2125
|
| 774 |
+
},
|
| 775 |
+
{
|
| 776 |
+
"epoch": 1.7637407711238722,
|
| 777 |
+
"grad_norm": 0.27521854639053345,
|
| 778 |
+
"learning_rate": 2.110210135245147e-06,
|
| 779 |
+
"loss": 0.266,
|
| 780 |
+
"step": 2150
|
| 781 |
+
},
|
| 782 |
+
{
|
| 783 |
+
"epoch": 1.7842493847415914,
|
| 784 |
+
"grad_norm": 0.301544189453125,
|
| 785 |
+
"learning_rate": 1.765088332416917e-06,
|
| 786 |
+
"loss": 0.2746,
|
| 787 |
+
"step": 2175
|
| 788 |
+
},
|
| 789 |
+
{
|
| 790 |
+
"epoch": 1.804757998359311,
|
| 791 |
+
"grad_norm": 0.29790034890174866,
|
| 792 |
+
"learning_rate": 1.4497380602442378e-06,
|
| 793 |
+
"loss": 0.2671,
|
| 794 |
+
"step": 2200
|
| 795 |
+
},
|
| 796 |
+
{
|
| 797 |
+
"epoch": 1.804757998359311,
|
| 798 |
+
"eval_loss": 0.2663249373435974,
|
| 799 |
+
"eval_runtime": 19.4746,
|
| 800 |
+
"eval_samples_per_second": 51.349,
|
| 801 |
+
"eval_steps_per_second": 0.822,
|
| 802 |
+
"step": 2200
|
| 803 |
+
},
|
| 804 |
+
{
|
| 805 |
+
"epoch": 1.8252666119770304,
|
| 806 |
+
"grad_norm": 0.3353317677974701,
|
| 807 |
+
"learning_rate": 1.1645633856944977e-06,
|
| 808 |
+
"loss": 0.2693,
|
| 809 |
+
"step": 2225
|
| 810 |
+
},
|
| 811 |
+
{
|
| 812 |
+
"epoch": 1.8457752255947497,
|
| 813 |
+
"grad_norm": 0.2950851023197174,
|
| 814 |
+
"learning_rate": 9.099297109128407e-07,
|
| 815 |
+
"loss": 0.2704,
|
| 816 |
+
"step": 2250
|
| 817 |
+
},
|
| 818 |
+
{
|
| 819 |
+
"epoch": 1.8662838392124692,
|
| 820 |
+
"grad_norm": 0.3311655521392822,
|
| 821 |
+
"learning_rate": 6.861633050223526e-07,
|
| 822 |
+
"loss": 0.2757,
|
| 823 |
+
"step": 2275
|
| 824 |
+
},
|
| 825 |
+
{
|
| 826 |
+
"epoch": 1.8867924528301887,
|
| 827 |
+
"grad_norm": 0.3037591576576233,
|
| 828 |
+
"learning_rate": 4.935508860664601e-07,
|
| 829 |
+
"loss": 0.2684,
|
| 830 |
+
"step": 2300
|
| 831 |
+
},
|
| 832 |
+
{
|
| 833 |
+
"epoch": 1.8867924528301887,
|
| 834 |
+
"eval_loss": 0.26620855927467346,
|
| 835 |
+
"eval_runtime": 19.477,
|
| 836 |
+
"eval_samples_per_second": 51.343,
|
| 837 |
+
"eval_steps_per_second": 0.821,
|
| 838 |
+
"step": 2300
|
| 839 |
+
},
|
| 840 |
+
{
|
| 841 |
+
"epoch": 1.907301066447908,
|
| 842 |
+
"grad_norm": 0.3440220057964325,
|
| 843 |
+
"learning_rate": 3.323392536292436e-07,
|
| 844 |
+
"loss": 0.2773,
|
| 845 |
+
"step": 2325
|
| 846 |
+
},
|
| 847 |
+
{
|
| 848 |
+
"epoch": 1.9278096800656277,
|
| 849 |
+
"grad_norm": 0.33441346883773804,
|
| 850 |
+
"learning_rate": 2.0273497260433204e-07,
|
| 851 |
+
"loss": 0.2582,
|
| 852 |
+
"step": 2350
|
| 853 |
+
},
|
| 854 |
+
{
|
| 855 |
+
"epoch": 1.948318293683347,
|
| 856 |
+
"grad_norm": 0.28934305906295776,
|
| 857 |
+
"learning_rate": 1.0490410851763943e-07,
|
| 858 |
+
"loss": 0.2653,
|
| 859 |
+
"step": 2375
|
| 860 |
+
},
|
| 861 |
+
{
|
| 862 |
+
"epoch": 1.9688269073010665,
|
| 863 |
+
"grad_norm": 0.304415225982666,
|
| 864 |
+
"learning_rate": 3.8972014743038356e-08,
|
| 865 |
+
"loss": 0.268,
|
| 866 |
+
"step": 2400
|
| 867 |
+
},
|
| 868 |
+
{
|
| 869 |
+
"epoch": 1.9688269073010665,
|
| 870 |
+
"eval_loss": 0.26615387201309204,
|
| 871 |
+
"eval_runtime": 19.4583,
|
| 872 |
+
"eval_samples_per_second": 51.392,
|
| 873 |
+
"eval_steps_per_second": 0.822,
|
| 874 |
+
"step": 2400
|
| 875 |
+
},
|
| 876 |
+
{
|
| 877 |
+
"epoch": 1.989335520918786,
|
| 878 |
+
"grad_norm": 0.2703385055065155,
|
| 879 |
+
"learning_rate": 5.023171883647426e-09,
|
| 880 |
+
"loss": 0.263,
|
| 881 |
+
"step": 2425
|
| 882 |
+
},
|
| 883 |
+
{
|
| 884 |
+
"epoch": 2.0,
|
| 885 |
+
"step": 2438,
|
| 886 |
+
"total_flos": 1.58523627405312e+18,
|
| 887 |
+
"train_loss": 0.3074496070120157,
|
| 888 |
+
"train_runtime": 3055.5986,
|
| 889 |
+
"train_samples_per_second": 25.527,
|
| 890 |
+
"train_steps_per_second": 0.798
|
| 891 |
+
}
|
| 892 |
+
],
|
| 893 |
+
"logging_steps": 25,
|
| 894 |
+
"max_steps": 2438,
|
| 895 |
+
"num_input_tokens_seen": 0,
|
| 896 |
+
"num_train_epochs": 2,
|
| 897 |
+
"save_steps": 500,
|
| 898 |
+
"stateful_callbacks": {
|
| 899 |
+
"TrainerControl": {
|
| 900 |
+
"args": {
|
| 901 |
+
"should_epoch_stop": false,
|
| 902 |
+
"should_evaluate": false,
|
| 903 |
+
"should_log": false,
|
| 904 |
+
"should_save": true,
|
| 905 |
+
"should_training_stop": true
|
| 906 |
+
},
|
| 907 |
+
"attributes": {}
|
| 908 |
+
}
|
| 909 |
+
},
|
| 910 |
+
"total_flos": 1.58523627405312e+18,
|
| 911 |
+
"train_batch_size": 32,
|
| 912 |
+
"trial_name": null,
|
| 913 |
+
"trial_params": null
|
| 914 |
+
}
|
nl_tasks/repro.sh
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export OMINI_CONFIG=./config/commonsense.yaml
|
| 2 |
+
|
| 3 |
+
#echo $OMINI_CONFIG
|
| 4 |
+
export TOKENIZERS_PARALLELISM=true
|
| 5 |
+
|
| 6 |
+
# CUDA Include (/cuda.h)
|
| 7 |
+
CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
|
| 8 |
+
|
| 9 |
+
# 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
|
| 10 |
+
export CPATH=$CPATH:$CUDA_INCLUDE_PATH
|
| 11 |
+
export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
|
| 12 |
+
# echo "CPATH is set to: $CPATH"
|
| 13 |
+
# echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
|
| 14 |
+
|
| 15 |
+
export WANDB_PROJECT="Llama2_7B_FT_Math40k_2"
|
| 16 |
+
|
| 17 |
+
export OMP_NUM_THREADS=1
|
| 18 |
+
export MKL_NUM_THREADS=1
|
| 19 |
+
export OPENBLAS_NUM_THREADS=1
|
| 20 |
+
export NUMEXPR_NUM_THREADS=1
|
| 21 |
+
|
| 22 |
+
date +"%F %T"
|
| 23 |
+
TEXT=("oft" "boft" "loco" "hra")
|
| 24 |
+
|
| 25 |
+
# --run_text "$text" --dynamo_backend no
|
| 26 |
+
export ACCELERATE_DYNAMO_BACKEND="no"
|
| 27 |
+
# --trainer_args.max_steps=81 \
|
| 28 |
+
|
| 29 |
+
accelerate launch --dynamo_backend no --main_process_port 41353 -m src.testLlama \
|
| 30 |
+
--config_path $OMINI_CONFIG --trainer_args.output_dir "./expsBOFT/seed44/" --trainer_args.learning_rate=8e-4 \
|
| 31 |
+
--run_text "boft" --trainer_args.per_device_train_batch_size 32 \
|
| 32 |
+
--rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 33 |
+
--trainer_args.gradient_accumulation_steps 2 \
|
| 34 |
+
--trainer_args.num_train_epochs 2.0 --data.dataset_split train \
|
| 35 |
+
--trainer_args.eval_strategy '"no"' \
|
| 36 |
+
--trainer_args.load_best_model_at_end False \
|
| 37 |
+
--trainer_args.save_strategy '"no"' \
|
| 38 |
+
--trainer_args.logging_step 50 \
|
| 39 |
+
--trainer_args.report_to none --trainer_args.warmup_steps 100 \
|
| 40 |
+
--seed 44
|
| 41 |
+
date +"%F %T"
|
| 42 |
+
|
| 43 |
+
accelerate launch --dynamo_backend no --main_process_port 41353 -m src.testLlama \
|
| 44 |
+
--config_path $OMINI_CONFIG --trainer_args.output_dir "./expsBOFT/seed43/" --trainer_args.learning_rate=8e-4 \
|
| 45 |
+
--run_text "boft" --trainer_args.per_device_train_batch_size 32 \
|
| 46 |
+
--rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 47 |
+
--trainer_args.gradient_accumulation_steps 2 \
|
| 48 |
+
--trainer_args.num_train_epochs 2.0 --data.dataset_split train \
|
| 49 |
+
--trainer_args.eval_strategy '"no"' \
|
| 50 |
+
--trainer_args.load_best_model_at_end False \
|
| 51 |
+
--trainer_args.save_strategy '"no"' \
|
| 52 |
+
--trainer_args.logging_step 50 \
|
| 53 |
+
--trainer_args.report_to none --trainer_args.warmup_steps 100 \
|
| 54 |
+
--seed 43
|
| 55 |
+
date +"%F %T"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
accelerate launch --main_process_port 41353 -m src.testLlama \
|
| 59 |
+
--config_path $OMINI_CONFIG --trainer_args.output_dir "./expsOFT/seed43/" --trainer_args.learning_rate=8e-4 \
|
| 60 |
+
--run_text "oft" --trainer_args.per_device_train_batch_size 64 \
|
| 61 |
+
--rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 62 |
+
--trainer_args.gradient_accumulation_steps 1 \
|
| 63 |
+
--trainer_args.num_train_epochs 2.0 --data.dataset_split train \
|
| 64 |
+
--trainer_args.eval_strategy '"no"' \
|
| 65 |
+
--trainer_args.load_best_model_at_end False \
|
| 66 |
+
--trainer_args.save_strategy '"no"' \
|
| 67 |
+
--trainer_args.logging_step 50 \
|
| 68 |
+
--trainer_args.report_to none --trainer_args.warmup_steps 100 \
|
| 69 |
+
--seed 43
|
| 70 |
+
|
| 71 |
+
date +"%F %T"
|
| 72 |
+
|
| 73 |
+
accelerate launch --main_process_port 41353 -m src.testLlama \
|
| 74 |
+
--config_path $OMINI_CONFIG --trainer_args.output_dir "./expsOFT/seed44/" --trainer_args.learning_rate=8e-4 \
|
| 75 |
+
--run_text "oft" --trainer_args.per_device_train_batch_size 64 \
|
| 76 |
+
--rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 77 |
+
--trainer_args.gradient_accumulation_steps 1 \
|
| 78 |
+
--trainer_args.num_train_epochs 2.0 --data.dataset_split train \
|
| 79 |
+
--trainer_args.eval_strategy '"no"' \
|
| 80 |
+
--trainer_args.load_best_model_at_end False \
|
| 81 |
+
--trainer_args.save_strategy '"no"' \
|
| 82 |
+
--trainer_args.logging_step 50 \
|
| 83 |
+
--trainer_args.report_to none --trainer_args.warmup_steps 100 \
|
| 84 |
+
--seed 44
|
| 85 |
+
|
| 86 |
+
date +"%F %T"
|
| 87 |
+
|
nl_tasks/rpeft/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
| 3 |
+
# module, but to preserve other warnings. So, don't check this module at all.
|
| 4 |
+
|
| 5 |
+
# coding=utf-8
|
| 6 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
|
| 20 |
+
__version__ = "0.0.1"
|
| 21 |
+
|
| 22 |
+
from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING,\
|
| 23 |
+
get_peft_config, get_peft_model #, PEFT_TYPE_TO_TUNER_MAPPING
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
from .rotation import (
|
| 27 |
+
RotationConfig,
|
| 28 |
+
RotationTuner,
|
| 29 |
+
)
|
| 30 |
+
from .utils import (
|
| 31 |
+
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
|
| 32 |
+
PeftConfig,
|
| 33 |
+
PeftType,
|
| 34 |
+
PromptLearningConfig,
|
| 35 |
+
TaskType,
|
| 36 |
+
bloom_model_postprocess_past_key_value,
|
| 37 |
+
get_peft_model_state_dict,
|
| 38 |
+
prepare_model_for_int8_training,
|
| 39 |
+
set_peft_model_state_dict,
|
| 40 |
+
shift_tokens_right,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
from .peft_model import PeftModel
|
nl_tasks/rpeft/mapping.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Original License:
|
| 3 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from .peft_model import (
|
| 18 |
+
PeftModel,
|
| 19 |
+
PeftModelForCausalLM,
|
| 20 |
+
PeftModelForSeq2SeqLM,
|
| 21 |
+
PeftModelForSequenceClassification,
|
| 22 |
+
PeftModelForTokenClassification,
|
| 23 |
+
)
|
| 24 |
+
from .rotation import RotationConfig, RotationTuner
|
| 25 |
+
from .utils import PromptLearningConfig
|
| 26 |
+
|
| 27 |
+
from transformers import PreTrainedModel
|
| 28 |
+
|
| 29 |
+
MODEL_TYPE_TO_PEFT_MODEL_MAPPING = {
|
| 30 |
+
"SEQ_CLS": PeftModelForSequenceClassification,
|
| 31 |
+
"SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM,
|
| 32 |
+
"CAUSAL_LM": PeftModelForCausalLM,
|
| 33 |
+
"TOKEN_CLS": PeftModelForTokenClassification,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
PEFT_TYPE_TO_CONFIG_MAPPING: dict = {
|
| 37 |
+
"ROTATION": RotationConfig,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
PEFT_TYPE_TO_TUNER_MAPPING: dict = {
|
| 41 |
+
"ROTATION": RotationTuner
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = {
|
| 45 |
+
"t5": ["q", "v"],
|
| 46 |
+
"mt5": ["q", "v"],
|
| 47 |
+
"bart": ["q_proj", "v_proj"],
|
| 48 |
+
"gpt2": ["c_attn"],
|
| 49 |
+
"bloom": ["query_key_value"],
|
| 50 |
+
"blip-2": ["q", "v", "q_proj", "v_proj"],
|
| 51 |
+
"opt": ["q_proj", "v_proj"],
|
| 52 |
+
"gptj": ["q_proj", "v_proj"],
|
| 53 |
+
"gpt_neox": ["query_key_value"],
|
| 54 |
+
"gpt_neo": ["q_proj", "v_proj"],
|
| 55 |
+
"bert": ["query", "value"],
|
| 56 |
+
"roberta": ["query", "value"],
|
| 57 |
+
"xlm-roberta": ["query", "value"],
|
| 58 |
+
"electra": ["query", "value"],
|
| 59 |
+
"deberta-v2": ["query_proj", "value_proj"],
|
| 60 |
+
"deberta": ["in_proj"],
|
| 61 |
+
"layoutlm": ["query", "value"],
|
| 62 |
+
"llama": ["q_proj", "v_proj"],
|
| 63 |
+
"chatglm": ["query_key_value"],
|
| 64 |
+
"gpt_bigcode": ["c_attn"],
|
| 65 |
+
"mpt": ["Wqkv"],
|
| 66 |
+
"RefinedWebModel": ["query_key_value"],
|
| 67 |
+
"RefinedWeb": ["query_key_value"],
|
| 68 |
+
"falcon": ["query_key_value"],
|
| 69 |
+
"btlm": ["c_proj", "c_attn"],
|
| 70 |
+
"codegen": ["qkv_proj"],
|
| 71 |
+
"mistral": ["q_proj", "v_proj"],
|
| 72 |
+
"mixtral": ["q_proj", "v_proj"],
|
| 73 |
+
"stablelm": ["q_proj", "v_proj"],
|
| 74 |
+
"phi": ["q_proj", "v_proj", "fc1", "fc2"],
|
| 75 |
+
"gemma": ["q_proj", "v_proj"],
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_peft_config(config_dict):
|
| 81 |
+
"""
|
| 82 |
+
Returns a Peft config object from a dictionary.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
config_dict (`Dict[str, Any]`): Dictionary containing the configuration parameters.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _prepare_prompt_learning_config(peft_config, model_config):
|
| 92 |
+
if peft_config.num_layers is None:
|
| 93 |
+
if "num_hidden_layers" in model_config:
|
| 94 |
+
num_layers = model_config["num_hidden_layers"]
|
| 95 |
+
elif "num_layers" in model_config:
|
| 96 |
+
num_layers = model_config["num_layers"]
|
| 97 |
+
elif "n_layer" in model_config:
|
| 98 |
+
num_layers = model_config["n_layer"]
|
| 99 |
+
else:
|
| 100 |
+
raise ValueError("Please specify `num_layers` in `peft_config`")
|
| 101 |
+
peft_config.num_layers = num_layers
|
| 102 |
+
|
| 103 |
+
if peft_config.token_dim is None:
|
| 104 |
+
if "hidden_size" in model_config:
|
| 105 |
+
token_dim = model_config["hidden_size"]
|
| 106 |
+
elif "n_embd" in model_config:
|
| 107 |
+
token_dim = model_config["n_embd"]
|
| 108 |
+
elif "d_model" in model_config:
|
| 109 |
+
token_dim = model_config["d_model"]
|
| 110 |
+
else:
|
| 111 |
+
raise ValueError("Please specify `token_dim` in `peft_config`")
|
| 112 |
+
peft_config.token_dim = token_dim
|
| 113 |
+
|
| 114 |
+
if peft_config.num_attention_heads is None:
|
| 115 |
+
if "num_attention_heads" in model_config:
|
| 116 |
+
num_attention_heads = model_config["num_attention_heads"]
|
| 117 |
+
elif "n_head" in model_config:
|
| 118 |
+
num_attention_heads = model_config["n_head"]
|
| 119 |
+
elif "num_heads" in model_config:
|
| 120 |
+
num_attention_heads = model_config["num_heads"]
|
| 121 |
+
elif "encoder_attention_heads" in model_config:
|
| 122 |
+
num_attention_heads = model_config["encoder_attention_heads"]
|
| 123 |
+
else:
|
| 124 |
+
raise ValueError("Please specify `num_attention_heads` in `peft_config`")
|
| 125 |
+
peft_config.num_attention_heads = num_attention_heads
|
| 126 |
+
|
| 127 |
+
if getattr(peft_config, "encoder_hidden_size", None) is None:
|
| 128 |
+
setattr(peft_config, "encoder_hidden_size", token_dim)
|
| 129 |
+
|
| 130 |
+
return peft_config
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _prepare_lora_config(peft_config, model_config):
|
| 134 |
+
if peft_config.target_modules is None:
|
| 135 |
+
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING:
|
| 136 |
+
raise ValueError("Please specify `target_modules` in `peft_config`")
|
| 137 |
+
peft_config.target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]]
|
| 138 |
+
if len(peft_config.target_modules) == 1:
|
| 139 |
+
peft_config.fan_in_fan_out = True
|
| 140 |
+
peft_config.enable_lora = [True, False, True]
|
| 141 |
+
if peft_config.inference_mode:
|
| 142 |
+
peft_config.merge_weights = True
|
| 143 |
+
return peft_config
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def get_peft_model(model, peft_config,
|
| 148 |
+
adapter_name: str = "default"):
|
| 149 |
+
"""
|
| 150 |
+
Returns a Peft model object from a model and a config.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
model ([`transformers.PreTrainedModel`]): Model to be wrapped.
|
| 154 |
+
peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model.
|
| 155 |
+
"""
|
| 156 |
+
model_config = model.config.to_dict()
|
| 157 |
+
peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)
|
| 158 |
+
if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():
|
| 159 |
+
if peft_config.peft_type == "LORA" or "QUANTA":
|
| 160 |
+
peft_config = _prepare_lora_config(peft_config, model_config)
|
| 161 |
+
return PeftModel(model, peft_config)
|
| 162 |
+
if not isinstance(peft_config, PromptLearningConfig):
|
| 163 |
+
if peft_config.peft_type == "LORA" or "QUANTA":
|
| 164 |
+
peft_config = _prepare_lora_config(peft_config, model_config)
|
| 165 |
+
else:
|
| 166 |
+
peft_config = _prepare_prompt_learning_config(peft_config, model_config)
|
| 167 |
+
# assert False
|
| 168 |
+
return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](
|
| 169 |
+
model,
|
| 170 |
+
peft_config,
|
| 171 |
+
adapter_name=adapter_name,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# def get_peft_model(
|
| 176 |
+
# model: PreTrainedModel,
|
| 177 |
+
# peft_config,
|
| 178 |
+
# adapter_name: str = "default",
|
| 179 |
+
# mixed: bool = False,
|
| 180 |
+
# autocast_adapter_dtype: bool = True,
|
| 181 |
+
# revision: Optional[str] = None,
|
| 182 |
+
# low_cpu_mem_usage: bool = False,
|
| 183 |
+
# ) -> PeftModel | PeftMixedModel:
|
| 184 |
+
# """
|
| 185 |
+
# Returns a Peft model object from a model and a config, where the model will be modified in-place.
|
| 186 |
+
|
| 187 |
+
# Args:
|
| 188 |
+
# model ([`transformers.PreTrainedModel`]):
|
| 189 |
+
# Model to be wrapped.
|
| 190 |
+
# peft_config ([`PeftConfig`]):
|
| 191 |
+
# Configuration object containing the parameters of the Peft model.
|
| 192 |
+
# adapter_name (`str`, `optional`, defaults to `"default"`):
|
| 193 |
+
# The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
|
| 194 |
+
# mixed (`bool`, `optional`, defaults to `False`):
|
| 195 |
+
# Whether to allow mixing different (compatible) adapter types.
|
| 196 |
+
# autocast_adapter_dtype (`bool`, *optional*):
|
| 197 |
+
# Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights
|
| 198 |
+
# using float16 or bfloat16 to float32, as this is typically required for stable training, and only affect
|
| 199 |
+
# select PEFT tuners.
|
| 200 |
+
# revision (`str`, `optional`, defaults to `main`):
|
| 201 |
+
# The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for
|
| 202 |
+
# the base model
|
| 203 |
+
# low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
|
| 204 |
+
# Create empty adapter weights on meta device. Useful to speed up the loading process. Leave this setting as
|
| 205 |
+
# False if you intend on training the model, unless the adapter weights will be replaced by different weights
|
| 206 |
+
# before training starts.
|
| 207 |
+
# """
|
| 208 |
+
# model_config = BaseTuner.get_model_config(model)
|
| 209 |
+
# old_name = peft_config.base_model_name_or_path
|
| 210 |
+
# new_name = model.__dict__.get("name_or_path", None)
|
| 211 |
+
# peft_config.base_model_name_or_path = new_name
|
| 212 |
+
|
| 213 |
+
# # Especially in notebook environments there could be a case that a user wants to experiment with different
|
| 214 |
+
# # configuration values. However, it is likely that there won't be any changes for new configs on an already
|
| 215 |
+
# # initialized PEFT model. The best we can do is warn the user about it.
|
| 216 |
+
# if any(isinstance(module, BaseTunerLayer) for module in model.modules()):
|
| 217 |
+
# warnings.warn(
|
| 218 |
+
# "You are trying to modify a model with PEFT for a second time. If you want to reload the model with a "
|
| 219 |
+
# "different config, make sure to call `.unload()` before."
|
| 220 |
+
# )
|
| 221 |
+
|
| 222 |
+
# if (old_name is not None) and (old_name != new_name):
|
| 223 |
+
# warnings.warn(
|
| 224 |
+
# f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. "
|
| 225 |
+
# "Please ensure that the correct base model is loaded when loading this checkpoint."
|
| 226 |
+
# )
|
| 227 |
+
|
| 228 |
+
# if revision is not None:
|
| 229 |
+
# if peft_config.revision is not None and peft_config.revision != revision:
|
| 230 |
+
# warnings.warn(
|
| 231 |
+
# f"peft config has already set base model revision to {peft_config.revision}, overwriting with revision {revision}"
|
| 232 |
+
# )
|
| 233 |
+
# peft_config.revision = revision
|
| 234 |
+
|
| 235 |
+
# if (
|
| 236 |
+
# (isinstance(peft_config, PEFT_TYPE_TO_CONFIG_MAPPING["LORA"]))
|
| 237 |
+
# and (peft_config.init_lora_weights == "eva")
|
| 238 |
+
# and not low_cpu_mem_usage
|
| 239 |
+
# ):
|
| 240 |
+
# warnings.warn(
|
| 241 |
+
# "lora with eva initialization used with low_cpu_mem_usage=False. "
|
| 242 |
+
# "Setting low_cpu_mem_usage=True can improve the maximum batch size possible for eva initialization."
|
| 243 |
+
# )
|
| 244 |
+
|
| 245 |
+
# prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(peft_config.peft_type)
|
| 246 |
+
# if prefix and adapter_name in prefix:
|
| 247 |
+
# warnings.warn(
|
| 248 |
+
# f"Adapter name '{adapter_name}' should not be contained in the prefix '{prefix}'. "
|
| 249 |
+
# "This may lead to reinitialization of the adapter weights during loading."
|
| 250 |
+
# )
|
| 251 |
+
|
| 252 |
+
# if mixed:
|
| 253 |
+
# # note: PeftMixedModel does not support autocast_adapter_dtype, so don't pass it
|
| 254 |
+
# return PeftMixedModel(model, peft_config, adapter_name=adapter_name)
|
| 255 |
+
|
| 256 |
+
# # We explicitly exclude prompt learning here since prompt learning is specific to the task and needs special
|
| 257 |
+
# # handling in the PEFT model's forward method.
|
| 258 |
+
# if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning:
|
| 259 |
+
# return PeftModel(
|
| 260 |
+
# model,
|
| 261 |
+
# peft_config,
|
| 262 |
+
# adapter_name=adapter_name,
|
| 263 |
+
# autocast_adapter_dtype=autocast_adapter_dtype,
|
| 264 |
+
# low_cpu_mem_usage=low_cpu_mem_usage,
|
| 265 |
+
# )
|
| 266 |
+
|
| 267 |
+
# return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](
|
| 268 |
+
# model,
|
| 269 |
+
# peft_config,
|
| 270 |
+
# adapter_name=adapter_name,
|
| 271 |
+
# autocast_adapter_dtype=autocast_adapter_dtype,
|
| 272 |
+
# low_cpu_mem_usage=low_cpu_mem_usage,
|
| 273 |
+
# )
|
nl_tasks/rpeft/peft_model.py
ADDED
|
@@ -0,0 +1,922 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Original License:
|
| 3 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import inspect
|
| 18 |
+
import os
|
| 19 |
+
import warnings
|
| 20 |
+
from contextlib import contextmanager
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from accelerate import dispatch_model, infer_auto_device_map
|
| 24 |
+
from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules
|
| 25 |
+
from accelerate.utils import get_balanced_memory
|
| 26 |
+
from huggingface_hub import hf_hub_download
|
| 27 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 28 |
+
from transformers import PreTrainedModel
|
| 29 |
+
from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput
|
| 30 |
+
from transformers.utils import PushToHubMixin
|
| 31 |
+
|
| 32 |
+
import packaging.version
|
| 33 |
+
import transformers
|
| 34 |
+
from typing import Any, Literal, Optional, Union
|
| 35 |
+
|
| 36 |
+
from .rotation import RotationTuner
|
| 37 |
+
from .utils import (
|
| 38 |
+
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
|
| 39 |
+
WEIGHTS_NAME,
|
| 40 |
+
PeftConfig,
|
| 41 |
+
PeftType,
|
| 42 |
+
PromptLearningConfig,
|
| 43 |
+
TaskType,
|
| 44 |
+
_set_trainable,
|
| 45 |
+
get_peft_model_state_dict,
|
| 46 |
+
set_peft_model_state_dict,
|
| 47 |
+
shift_tokens_right,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
class PeftModel(PushToHubMixin, torch.nn.Module):
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, model, peft_config: PeftConfig, adapter_name: str = "default"):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.peft_config = peft_config
|
| 58 |
+
self.base_model = model
|
| 59 |
+
self.config = self.base_model.config
|
| 60 |
+
self.modules_to_save = None
|
| 61 |
+
self.active_adapter = adapter_name
|
| 62 |
+
##### Diff do nothing with active_adapter
|
| 63 |
+
if isinstance(self.peft_config, PromptLearningConfig):
|
| 64 |
+
self._setup_prompt_encoder()
|
| 65 |
+
else:
|
| 66 |
+
if self.peft_config.peft_type == PeftType.ROTATION:
|
| 67 |
+
self.base_model = RotationTuner(model, {adapter_name: peft_config}, adapter_name)
|
| 68 |
+
if getattr(self.peft_config, "modules_to_save", None) is not None:
|
| 69 |
+
self.modules_to_save = self.peft_config.modules_to_save
|
| 70 |
+
_set_trainable(self)
|
| 71 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 72 |
+
self.base_model_torch_dtype = getattr(model, "dtype", None)
|
| 73 |
+
|
| 74 |
+
def save_pretrained(self, save_directory, **kwargs):
|
| 75 |
+
r"""
|
| 76 |
+
Args:
|
| 77 |
+
This function saves the adapter model and the adapter configuration files to a directory, so that it can be
|
| 78 |
+
re-loaded using the `LoraModel.from_pretrained` class method, and also used by the `LoraModel.push_to_hub`
|
| 79 |
+
method.
|
| 80 |
+
save_directory (`str`):
|
| 81 |
+
Directory where the adapter model and configuration files will be saved (will be created if it does not
|
| 82 |
+
exist).
|
| 83 |
+
**kwargs:
|
| 84 |
+
Additional keyword arguments passed along to the `push_to_hub` method.
|
| 85 |
+
"""
|
| 86 |
+
if os.path.isfile(save_directory):
|
| 87 |
+
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
|
| 88 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 89 |
+
|
| 90 |
+
# save only the trainable weights
|
| 91 |
+
output_state_dict = get_peft_model_state_dict(self, kwargs.get("state_dict", None))
|
| 92 |
+
torch.save(output_state_dict, os.path.join(save_directory, WEIGHTS_NAME))
|
| 93 |
+
|
| 94 |
+
# save the config and change the inference mode to `True`
|
| 95 |
+
if self.peft_config.base_model_name_or_path is None:
|
| 96 |
+
self.peft_config.base_model_name_or_path = (
|
| 97 |
+
self.base_model.__dict__.get("name_or_path", None)
|
| 98 |
+
if isinstance(self.peft_config, PromptLearningConfig)
|
| 99 |
+
else self.base_model.model.__dict__.get("name_or_path", None)
|
| 100 |
+
)
|
| 101 |
+
inference_mode = self.peft_config.inference_mode
|
| 102 |
+
self.peft_config.inference_mode = True
|
| 103 |
+
self.peft_config.save_pretrained(save_directory)
|
| 104 |
+
self.peft_config.inference_mode = inference_mode
|
| 105 |
+
|
| 106 |
+
@classmethod
|
| 107 |
+
def from_pretrained(cls, model, model_id, is_trainable = False, **kwargs):
|
| 108 |
+
r"""
|
| 109 |
+
Args:
|
| 110 |
+
Instantiate a `LoraModel` from a pretrained Lora configuration and weights.
|
| 111 |
+
model (`transformers.PreTrainedModel`):
|
| 112 |
+
The model to be adapted. The model should be initialized with the `from_pretrained` method. from
|
| 113 |
+
`transformers` library.
|
| 114 |
+
model_id (`str`):
|
| 115 |
+
The name of the Lora configuration to use. Can be either:
|
| 116 |
+
- A string, the `model id` of a Lora configuration hosted inside a model repo on
|
| 117 |
+
huggingface Hub
|
| 118 |
+
- A path to a directory containing a Lora configuration file saved using the
|
| 119 |
+
`save_pretrained` method, e.g., ``./my_lora_config_directory/``.
|
| 120 |
+
"""
|
| 121 |
+
from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING
|
| 122 |
+
# load the config
|
| 123 |
+
config = PEFT_TYPE_TO_CONFIG_MAPPING[PeftConfig.from_pretrained(model_id).peft_type].from_pretrained(model_id)
|
| 124 |
+
config.inference_mode = not is_trainable
|
| 125 |
+
|
| 126 |
+
if getattr(model, "hf_device_map", None) is not None:
|
| 127 |
+
remove_hook_from_submodules(model)
|
| 128 |
+
|
| 129 |
+
if config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():
|
| 130 |
+
model = cls(model, config)
|
| 131 |
+
else:
|
| 132 |
+
model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config)
|
| 133 |
+
|
| 134 |
+
# load weights if any
|
| 135 |
+
if os.path.exists(os.path.join(model_id, WEIGHTS_NAME)):
|
| 136 |
+
filename = os.path.join(model_id, WEIGHTS_NAME)
|
| 137 |
+
else:
|
| 138 |
+
try:
|
| 139 |
+
filename = hf_hub_download(model_id, WEIGHTS_NAME)
|
| 140 |
+
except: # noqa
|
| 141 |
+
raise ValueError(
|
| 142 |
+
f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. "
|
| 143 |
+
f"Please check that the file {WEIGHTS_NAME} is present at {model_id}."
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
adapters_weights = torch.load(
|
| 147 |
+
filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 148 |
+
)
|
| 149 |
+
# load the weights into the model
|
| 150 |
+
model = set_peft_model_state_dict(model, adapters_weights)
|
| 151 |
+
if getattr(model, "hf_device_map", None) is not None:
|
| 152 |
+
device_map = kwargs.get("device_map", "auto")
|
| 153 |
+
max_memory = kwargs.get("max_memory", None)
|
| 154 |
+
no_split_module_classes = model._no_split_modules
|
| 155 |
+
if device_map != "sequential":
|
| 156 |
+
max_memory = get_balanced_memory(
|
| 157 |
+
model,
|
| 158 |
+
max_memory=max_memory,
|
| 159 |
+
no_split_module_classes=no_split_module_classes,
|
| 160 |
+
low_zero=(device_map == "balanced_low_0"),
|
| 161 |
+
)
|
| 162 |
+
if isinstance(device_map, str):
|
| 163 |
+
device_map = infer_auto_device_map(
|
| 164 |
+
model, max_memory=max_memory, no_split_module_classes=no_split_module_classes
|
| 165 |
+
)
|
| 166 |
+
model = dispatch_model(model, device_map=device_map)
|
| 167 |
+
hook = AlignDevicesHook(io_same_device=True)
|
| 168 |
+
if model.peft_config.peft_type == PeftType.LORA or model.peft_config.peft_type == PeftType.BOTTLENECK \
|
| 169 |
+
or model.peft_config.peft_type == "ROTATION":
|
| 170 |
+
add_hook_to_module(model.base_model.model, hook)
|
| 171 |
+
else:
|
| 172 |
+
remove_hook_from_submodules(model.prompt_encoder)
|
| 173 |
+
add_hook_to_module(model.base_model, hook)
|
| 174 |
+
# if model.peft_config.is_prompt_learning:
|
| 175 |
+
# remove_hook_from_submodules(model.prompt_encoder)
|
| 176 |
+
# add_hook_to_module(model.base_model, hook)
|
| 177 |
+
return model
|
| 178 |
+
|
| 179 |
+
def _setup_prompt_encoder(self):
|
| 180 |
+
transformer_backbone = None
|
| 181 |
+
for name, module in self.base_model.named_children():
|
| 182 |
+
for param in module.parameters():
|
| 183 |
+
param.requires_grad = False
|
| 184 |
+
if isinstance(module, PreTrainedModel):
|
| 185 |
+
# Make sure to freeze Tranformers model
|
| 186 |
+
if transformer_backbone is None:
|
| 187 |
+
transformer_backbone = module
|
| 188 |
+
self.transformer_backbone_name = name
|
| 189 |
+
|
| 190 |
+
if self.peft_config.num_transformer_submodules is None:
|
| 191 |
+
self.peft_config.num_transformer_submodules = (
|
| 192 |
+
2 if self.peft_config.task_type == TaskType.SEQ_2_SEQ_LM else 1
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
for named_param, value in list(transformer_backbone.named_parameters()):
|
| 196 |
+
if value.shape[0] == self.base_model.config.vocab_size:
|
| 197 |
+
self.word_embeddings = transformer_backbone.get_submodule(named_param.replace(".weight", ""))
|
| 198 |
+
break
|
| 199 |
+
|
| 200 |
+
if self.peft_config.peft_type == PeftType.PROMPT_TUNING:
|
| 201 |
+
prompt_encoder = PromptEmbedding(self.peft_config, self.word_embeddings)
|
| 202 |
+
elif self.peft_config.peft_type == PeftType.P_TUNING:
|
| 203 |
+
prompt_encoder = PromptEncoder(self.peft_config)
|
| 204 |
+
elif self.peft_config.peft_type == PeftType.PREFIX_TUNING:
|
| 205 |
+
prompt_encoder = PrefixEncoder(self.peft_config)
|
| 206 |
+
else:
|
| 207 |
+
raise ValueError("Not supported")
|
| 208 |
+
self.prompt_encoder = prompt_encoder
|
| 209 |
+
self.prompt_tokens = torch.arange(
|
| 210 |
+
self.peft_config.num_virtual_tokens * self.peft_config.num_transformer_submodules
|
| 211 |
+
).long()
|
| 212 |
+
|
| 213 |
+
def get_prompt_embedding_to_save(self):
|
| 214 |
+
"""
|
| 215 |
+
Returns the prompt embedding to save when saving the model. Only applicable when `peft_config.peft_type !=
|
| 216 |
+
PeftType.LORA`.
|
| 217 |
+
"""
|
| 218 |
+
prompt_tokens = self.prompt_tokens.unsqueeze(0).expand(1, -1).to(self.device)
|
| 219 |
+
if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
|
| 220 |
+
prompt_tokens = prompt_tokens[:, : self.peft_config.num_virtual_tokens]
|
| 221 |
+
prompt_embeddings = self.prompt_encoder(prompt_tokens)
|
| 222 |
+
return prompt_embeddings[0].detach().cpu()
|
| 223 |
+
|
| 224 |
+
def get_prompt(self, batch_size):
|
| 225 |
+
"""
|
| 226 |
+
Returns the virtual prompts to use for Peft. Only applicable when `peft_config.peft_type != PeftType.LORA`.
|
| 227 |
+
"""
|
| 228 |
+
prompt_tokens = self.prompt_tokens.unsqueeze(0).expand(batch_size, -1).to(self.device)
|
| 229 |
+
if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
|
| 230 |
+
prompt_tokens = prompt_tokens[:, : self.peft_config.num_virtual_tokens]
|
| 231 |
+
if self.peft_config.inference_mode:
|
| 232 |
+
past_key_values = self.prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
|
| 233 |
+
else:
|
| 234 |
+
past_key_values = self.prompt_encoder(prompt_tokens)
|
| 235 |
+
past_key_values = past_key_values.view(
|
| 236 |
+
batch_size,
|
| 237 |
+
self.peft_config.num_virtual_tokens,
|
| 238 |
+
self.peft_config.num_layers * 2,
|
| 239 |
+
self.peft_config.num_attention_heads,
|
| 240 |
+
self.peft_config.token_dim // self.peft_config.num_attention_heads,
|
| 241 |
+
)
|
| 242 |
+
if self.peft_config.num_transformer_submodules == 2:
|
| 243 |
+
past_key_values = torch.cat([past_key_values, past_key_values], dim=2)
|
| 244 |
+
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(
|
| 245 |
+
self.peft_config.num_transformer_submodules * 2
|
| 246 |
+
)
|
| 247 |
+
if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None:
|
| 248 |
+
post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type]
|
| 249 |
+
past_key_values = post_process_fn(past_key_values)
|
| 250 |
+
return past_key_values
|
| 251 |
+
else:
|
| 252 |
+
if self.peft_config.inference_mode:
|
| 253 |
+
prompts = self.prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
|
| 254 |
+
else:
|
| 255 |
+
prompts = self.prompt_encoder(prompt_tokens)
|
| 256 |
+
return prompts
|
| 257 |
+
|
| 258 |
+
def print_trainable_parameters(self):
|
| 259 |
+
"""
|
| 260 |
+
Prints the number of trainable parameters in the model.
|
| 261 |
+
"""
|
| 262 |
+
trainable_params = 0
|
| 263 |
+
all_param = 0
|
| 264 |
+
for _, param in self.named_parameters():
|
| 265 |
+
num_params = param.numel()
|
| 266 |
+
# if using DS Zero 3 and the weights are initialized empty
|
| 267 |
+
if num_params == 0 and hasattr(param, "ds_numel"):
|
| 268 |
+
num_params = param.ds_numel
|
| 269 |
+
|
| 270 |
+
all_param += num_params
|
| 271 |
+
if param.requires_grad:
|
| 272 |
+
trainable_params += num_params
|
| 273 |
+
print(
|
| 274 |
+
f"trainable params: {trainable_params:,} || all params: {all_param:,} || trainable: {100 * trainable_params / all_param}%"
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def __getattr__(self, name: str):
|
| 278 |
+
"""Forward missing attributes to the wrapped module."""
|
| 279 |
+
try:
|
| 280 |
+
return super().__getattr__(name) # defer to nn.Module's logic
|
| 281 |
+
except AttributeError:
|
| 282 |
+
return getattr(self.base_model, name)
|
| 283 |
+
|
| 284 |
+
def forward(self, *args, **kwargs):
|
| 285 |
+
"""
|
| 286 |
+
Forward pass of the model.
|
| 287 |
+
"""
|
| 288 |
+
return self.get_base_model()(*args, **kwargs)
|
| 289 |
+
|
| 290 |
+
@contextmanager
|
| 291 |
+
def disable_adapter(self):
|
| 292 |
+
"""
|
| 293 |
+
Disables the adapter module.
|
| 294 |
+
"""
|
| 295 |
+
if isinstance(self.peft_config, PromptLearningConfig):
|
| 296 |
+
old_forward = self.forward
|
| 297 |
+
self.forward = self.base_model.forward
|
| 298 |
+
else:
|
| 299 |
+
self.base_model.disable_adapter_layers()
|
| 300 |
+
yield
|
| 301 |
+
if isinstance(self.peft_config, PromptLearningConfig):
|
| 302 |
+
self.forward = old_forward
|
| 303 |
+
else:
|
| 304 |
+
self.base_model.enable_adapter_layers()
|
| 305 |
+
|
| 306 |
+
def get_base_model(self):
|
| 307 |
+
"""
|
| 308 |
+
Returns the base model.
|
| 309 |
+
"""
|
| 310 |
+
return self.base_model if isinstance(self.peft_config, PromptLearningConfig) else self.base_model.model
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class PeftModelForSequenceClassification(PeftModel):
|
| 315 |
+
"""
|
| 316 |
+
"""
|
| 317 |
+
|
| 318 |
+
def __init__(self, model, peft_config: PeftConfig, adapter_name: str = "default"):
|
| 319 |
+
super().__init__(model, peft_config, adapter_name)
|
| 320 |
+
self.modules_to_save = ["classifier", "score", "pooler"]
|
| 321 |
+
|
| 322 |
+
# for name, _ in self.base_model.named_children():
|
| 323 |
+
# if any(module_name in name for module_name in self.modules_to_save):
|
| 324 |
+
# self.cls_layer_name = name
|
| 325 |
+
# break
|
| 326 |
+
user_modules = getattr(peft_config, "modules_to_save", None) or []
|
| 327 |
+
default_modules = ["classifier", "score"]
|
| 328 |
+
self.modules_to_save = list(set(user_modules + default_modules))
|
| 329 |
+
|
| 330 |
+
#from .rotation import RotationTuner # Import để check type
|
| 331 |
+
if isinstance(self.base_model, RotationTuner):
|
| 332 |
+
real_model = self.base_model.model
|
| 333 |
+
else:
|
| 334 |
+
real_model = self.base_model
|
| 335 |
+
|
| 336 |
+
# 3. Tìm tên layer thực tế
|
| 337 |
+
for name, _ in real_model.named_children():
|
| 338 |
+
if any(module_name in name for module_name in self.modules_to_save):
|
| 339 |
+
self.cls_layer_name = name
|
| 340 |
+
|
| 341 |
+
# # to make sure classifier layer is trainable
|
| 342 |
+
_set_trainable(self)
|
| 343 |
+
|
| 344 |
+
def forward(
|
| 345 |
+
self,
|
| 346 |
+
input_ids=None,
|
| 347 |
+
attention_mask=None,
|
| 348 |
+
inputs_embeds=None,
|
| 349 |
+
labels=None,
|
| 350 |
+
output_attentions=None,
|
| 351 |
+
output_hidden_states=None,
|
| 352 |
+
return_dict=None,
|
| 353 |
+
**kwargs,
|
| 354 |
+
):
|
| 355 |
+
if "num_items_in_batch" in kwargs:
|
| 356 |
+
kwargs.pop("num_items_in_batch")
|
| 357 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 358 |
+
|
| 359 |
+
if not isinstance(self.peft_config, PromptLearningConfig):
|
| 360 |
+
return self.base_model(
|
| 361 |
+
input_ids=input_ids,
|
| 362 |
+
attention_mask=attention_mask,
|
| 363 |
+
inputs_embeds=inputs_embeds,
|
| 364 |
+
labels=labels,
|
| 365 |
+
output_attentions=output_attentions,
|
| 366 |
+
output_hidden_states=output_hidden_states,
|
| 367 |
+
return_dict=return_dict,
|
| 368 |
+
**kwargs,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
batch_size = input_ids.shape[0]
|
| 372 |
+
if attention_mask is not None:
|
| 373 |
+
# concat prompt attention mask
|
| 374 |
+
prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
|
| 375 |
+
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
|
| 376 |
+
if kwargs.get("position_ids", None) is not None:
|
| 377 |
+
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
|
| 378 |
+
kwargs["position_ids"] = None
|
| 379 |
+
kwargs.update(
|
| 380 |
+
{
|
| 381 |
+
"attention_mask": attention_mask,
|
| 382 |
+
"labels": labels,
|
| 383 |
+
"output_attentions": output_attentions,
|
| 384 |
+
"output_hidden_states": output_hidden_states,
|
| 385 |
+
"return_dict": return_dict,
|
| 386 |
+
}
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
|
| 390 |
+
return self._prefix_tuning_forward(input_ids=input_ids, **kwargs)
|
| 391 |
+
else:
|
| 392 |
+
if kwargs.get("token_type_ids", None) is not None:
|
| 393 |
+
kwargs["token_type_ids"] = torch.cat(
|
| 394 |
+
(
|
| 395 |
+
torch.zeros(batch_size, self.peft_config.num_virtual_tokens).to(self.device),
|
| 396 |
+
kwargs["token_type_ids"],
|
| 397 |
+
),
|
| 398 |
+
dim=1,
|
| 399 |
+
).long()
|
| 400 |
+
if inputs_embeds is None:
|
| 401 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 402 |
+
prompts = self.get_prompt(batch_size=batch_size)
|
| 403 |
+
prompts = prompts.to(inputs_embeds.dtype)
|
| 404 |
+
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
|
| 405 |
+
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
|
| 406 |
+
|
| 407 |
+
def _prefix_tuning_forward(
|
| 408 |
+
self,
|
| 409 |
+
input_ids=None,
|
| 410 |
+
attention_mask=None,
|
| 411 |
+
inputs_embeds=None,
|
| 412 |
+
labels=None,
|
| 413 |
+
output_attentions=None,
|
| 414 |
+
output_hidden_states=None,
|
| 415 |
+
return_dict=None,
|
| 416 |
+
**kwargs,
|
| 417 |
+
):
|
| 418 |
+
batch_size = input_ids.shape[0]
|
| 419 |
+
past_key_values = self.get_prompt(batch_size)
|
| 420 |
+
fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys())
|
| 421 |
+
kwargs.update(
|
| 422 |
+
{
|
| 423 |
+
"input_ids": input_ids,
|
| 424 |
+
"attention_mask": attention_mask,
|
| 425 |
+
"inputs_embeds": inputs_embeds,
|
| 426 |
+
"output_attentions": output_attentions,
|
| 427 |
+
"output_hidden_states": output_hidden_states,
|
| 428 |
+
"return_dict": return_dict,
|
| 429 |
+
"past_key_values": past_key_values,
|
| 430 |
+
}
|
| 431 |
+
)
|
| 432 |
+
if "past_key_values" in fwd_params:
|
| 433 |
+
return self.base_model(labels=labels, **kwargs)
|
| 434 |
+
else:
|
| 435 |
+
transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name)
|
| 436 |
+
fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys())
|
| 437 |
+
if "past_key_values" not in fwd_params:
|
| 438 |
+
raise ValueError("Model does not support past key values which are required for prefix tuning.")
|
| 439 |
+
outputs = transformer_backbone_name(**kwargs)
|
| 440 |
+
pooled_output = outputs[1] if len(outputs) > 1 else outputs[0]
|
| 441 |
+
if "dropout" in [name for name, _ in list(self.base_model.named_children())]:
|
| 442 |
+
pooled_output = self.base_model.dropout(pooled_output)
|
| 443 |
+
logits = self.base_model.get_submodule(self.cls_layer_name)(pooled_output)
|
| 444 |
+
|
| 445 |
+
loss = None
|
| 446 |
+
if labels is not None:
|
| 447 |
+
if self.config.problem_type is None:
|
| 448 |
+
if self.base_model.num_labels == 1:
|
| 449 |
+
self.config.problem_type = "regression"
|
| 450 |
+
elif self.base_model.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 451 |
+
self.config.problem_type = "single_label_classification"
|
| 452 |
+
else:
|
| 453 |
+
self.config.problem_type = "multi_label_classification"
|
| 454 |
+
|
| 455 |
+
if self.config.problem_type == "regression":
|
| 456 |
+
loss_fct = MSELoss()
|
| 457 |
+
if self.base_model.num_labels == 1:
|
| 458 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 459 |
+
else:
|
| 460 |
+
loss = loss_fct(logits, labels)
|
| 461 |
+
elif self.config.problem_type == "single_label_classification":
|
| 462 |
+
loss_fct = CrossEntropyLoss()
|
| 463 |
+
loss = loss_fct(logits.view(-1, self.base_model.num_labels), labels.view(-1))
|
| 464 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 465 |
+
loss_fct = BCEWithLogitsLoss()
|
| 466 |
+
loss = loss_fct(logits, labels)
|
| 467 |
+
if not return_dict:
|
| 468 |
+
output = (logits,) + outputs[2:]
|
| 469 |
+
return ((loss,) + output) if loss is not None else output
|
| 470 |
+
|
| 471 |
+
return SequenceClassifierOutput(
|
| 472 |
+
loss=loss,
|
| 473 |
+
logits=logits,
|
| 474 |
+
hidden_states=outputs.hidden_states,
|
| 475 |
+
attentions=outputs.attentions,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
class PeftModelForCausalLM(PeftModel):
|
| 480 |
+
"""
|
| 481 |
+
Peft model for Causal LM
|
| 482 |
+
|
| 483 |
+
Args:
|
| 484 |
+
model ([`PreTrainedModel`]): Base transformer model
|
| 485 |
+
peft_config ([`PeftConfig`]): Peft config.
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
Example::
|
| 489 |
+
|
| 490 |
+
>>> from transformers import AutoModelForCausalLM >>> from peft_local_tensor import PeftModelForCausalLM, get_peft_config
|
| 491 |
+
>>> config = {
|
| 492 |
+
'peft_type': 'PREFIX_TUNING', 'task_type': 'CAUSAL_LM', 'inference_mode': False, 'num_virtual_tokens':
|
| 493 |
+
20, 'token_dim': 1280, 'num_transformer_submodules': 1, 'num_attention_heads': 20, 'num_layers': 36,
|
| 494 |
+
'encoder_hidden_size': 1280, 'prefix_projection': False, 'postprocess_past_key_value_function': None
|
| 495 |
+
}
|
| 496 |
+
>>> peft_config = get_peft_config(config) >>> model = AutoModelForCausalLM.from_pretrained("gpt2-large") >>>
|
| 497 |
+
peft_model = PeftModelForCausalLM(model, peft_config) >>> peft_model.print_trainable_parameters() trainable
|
| 498 |
+
params: 1843200 || all params: 775873280 || trainable%: 0.23756456724479544
|
| 499 |
+
"""
|
| 500 |
+
|
| 501 |
+
def __init__(self, model, peft_config: PeftConfig, adapter_name: str = "default"):
|
| 502 |
+
self.prompt_encoder = None #### don't know why
|
| 503 |
+
self.modules_to_save = None
|
| 504 |
+
super().__init__(model, peft_config, adapter_name)
|
| 505 |
+
self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
|
| 506 |
+
|
| 507 |
+
def forward(
|
| 508 |
+
self,
|
| 509 |
+
input_ids=None,
|
| 510 |
+
attention_mask=None,
|
| 511 |
+
inputs_embeds=None,
|
| 512 |
+
labels=None,
|
| 513 |
+
output_attentions=None,
|
| 514 |
+
output_hidden_states=None,
|
| 515 |
+
return_dict=None,
|
| 516 |
+
**kwargs,
|
| 517 |
+
):
|
| 518 |
+
if not isinstance(self.peft_config, PromptLearningConfig):
|
| 519 |
+
return self.base_model(
|
| 520 |
+
input_ids=input_ids,
|
| 521 |
+
attention_mask=attention_mask,
|
| 522 |
+
inputs_embeds=inputs_embeds,
|
| 523 |
+
labels=labels,
|
| 524 |
+
output_attentions=output_attentions,
|
| 525 |
+
output_hidden_states=output_hidden_states,
|
| 526 |
+
return_dict=return_dict,
|
| 527 |
+
**kwargs,
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
batch_size = input_ids.shape[0]
|
| 531 |
+
if attention_mask is not None:
|
| 532 |
+
# concat prompt attention mask
|
| 533 |
+
prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
|
| 534 |
+
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
|
| 535 |
+
|
| 536 |
+
if kwargs.get("position_ids", None) is not None:
|
| 537 |
+
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
|
| 538 |
+
kwargs["position_ids"] = None
|
| 539 |
+
if kwargs.get("token_type_ids", None) is not None:
|
| 540 |
+
warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
|
| 541 |
+
kwargs["token_type_ids"] = None
|
| 542 |
+
kwargs.update(
|
| 543 |
+
{
|
| 544 |
+
"attention_mask": attention_mask,
|
| 545 |
+
"labels": labels,
|
| 546 |
+
"output_attentions": output_attentions,
|
| 547 |
+
"output_hidden_states": output_hidden_states,
|
| 548 |
+
"return_dict": return_dict,
|
| 549 |
+
}
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
|
| 553 |
+
past_key_values = self.get_prompt(batch_size)
|
| 554 |
+
return self.base_model(input_ids=input_ids, past_key_values=past_key_values, **kwargs)
|
| 555 |
+
else:
|
| 556 |
+
if inputs_embeds is None:
|
| 557 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 558 |
+
# concat prompt labels
|
| 559 |
+
if labels is not None:
|
| 560 |
+
prefix_labels = torch.full((batch_size, self.peft_config.num_virtual_tokens), -100).to(self.device)
|
| 561 |
+
kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
|
| 562 |
+
prompts = self.get_prompt(batch_size=batch_size)
|
| 563 |
+
prompts = prompts.to(inputs_embeds.dtype)
|
| 564 |
+
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
|
| 565 |
+
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
|
| 566 |
+
|
| 567 |
+
def generate(self, **kwargs):
|
| 568 |
+
self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
|
| 569 |
+
try:
|
| 570 |
+
if not isinstance(self.peft_config, PromptLearningConfig):
|
| 571 |
+
outputs = self.base_model.generate(**kwargs)
|
| 572 |
+
else:
|
| 573 |
+
if "input_ids" not in kwargs:
|
| 574 |
+
raise ValueError("input_ids must be provided for Peft model generation")
|
| 575 |
+
if kwargs.get("attention_mask", None) is not None:
|
| 576 |
+
# concat prompt attention mask
|
| 577 |
+
prefix_attention_mask = torch.ones(
|
| 578 |
+
kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens
|
| 579 |
+
).to(kwargs["input_ids"].device)
|
| 580 |
+
kwargs["attention_mask"] = torch.cat((prefix_attention_mask, kwargs["attention_mask"]), dim=1)
|
| 581 |
+
|
| 582 |
+
if kwargs.get("position_ids", None) is not None:
|
| 583 |
+
warnings.warn(
|
| 584 |
+
"Position ids are not supported for parameter efficient tuning. Ignoring position ids."
|
| 585 |
+
)
|
| 586 |
+
kwargs["position_ids"] = None
|
| 587 |
+
if kwargs.get("token_type_ids", None) is not None:
|
| 588 |
+
warnings.warn(
|
| 589 |
+
"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
|
| 590 |
+
)
|
| 591 |
+
kwargs["token_type_ids"] = None
|
| 592 |
+
|
| 593 |
+
outputs = self.base_model.generate(**kwargs)
|
| 594 |
+
except:
|
| 595 |
+
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
|
| 596 |
+
raise
|
| 597 |
+
else:
|
| 598 |
+
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
|
| 599 |
+
return outputs
|
| 600 |
+
|
| 601 |
+
def prepare_inputs_for_generation(self, *args, **kwargs):
|
| 602 |
+
model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
|
| 603 |
+
if isinstance(self.peft_config, PromptLearningConfig):
|
| 604 |
+
if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
|
| 605 |
+
prefix_attention_mask = torch.ones(
|
| 606 |
+
model_kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens
|
| 607 |
+
).to(model_kwargs["input_ids"].device)
|
| 608 |
+
model_kwargs["attention_mask"] = torch.cat(
|
| 609 |
+
(prefix_attention_mask, model_kwargs["attention_mask"]), dim=1
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
if model_kwargs["past_key_values"] is None and self.peft_config.peft_type == PeftType.PREFIX_TUNING:
|
| 613 |
+
past_key_values = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0])
|
| 614 |
+
if self.base_model_torch_dtype is not None:
|
| 615 |
+
# handle the case for Bloom where it outputs tuple of tuples
|
| 616 |
+
if isinstance(past_key_values[0], tuple):
|
| 617 |
+
past_key_values = tuple(
|
| 618 |
+
tuple(
|
| 619 |
+
past_key_value.to(self.base_model_torch_dtype)
|
| 620 |
+
for past_key_value in past_key_value_tuple
|
| 621 |
+
)
|
| 622 |
+
for past_key_value_tuple in past_key_values
|
| 623 |
+
)
|
| 624 |
+
else:
|
| 625 |
+
past_key_values = tuple(
|
| 626 |
+
past_key_value.to(self.base_model_torch_dtype) for past_key_value in past_key_values
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
model_kwargs["past_key_values"] = past_key_values
|
| 630 |
+
else:
|
| 631 |
+
if model_kwargs["past_key_values"] is None:
|
| 632 |
+
inputs_embeds = self.word_embeddings(model_kwargs["input_ids"])
|
| 633 |
+
prompts = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0])
|
| 634 |
+
prompts = prompts.to(inputs_embeds.dtype)
|
| 635 |
+
model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1)
|
| 636 |
+
model_kwargs["input_ids"] = None
|
| 637 |
+
|
| 638 |
+
return model_kwargs
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
class PeftModelForSeq2SeqLM(PeftModel):
|
| 642 |
+
"""
|
| 643 |
+
|
| 644 |
+
"""
|
| 645 |
+
|
| 646 |
+
def __init__(self, model, peft_config: PeftConfig):
|
| 647 |
+
super().__init__(model, peft_config)
|
| 648 |
+
self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
|
| 649 |
+
self.base_model_prepare_encoder_decoder_kwargs_for_generation = (
|
| 650 |
+
self.base_model._prepare_encoder_decoder_kwargs_for_generation
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
def forward(
|
| 654 |
+
self,
|
| 655 |
+
input_ids=None,
|
| 656 |
+
attention_mask=None,
|
| 657 |
+
inputs_embeds=None,
|
| 658 |
+
decoder_input_ids=None,
|
| 659 |
+
decoder_attention_mask=None,
|
| 660 |
+
decoder_inputs_embeds=None,
|
| 661 |
+
labels=None,
|
| 662 |
+
output_attentions=None,
|
| 663 |
+
output_hidden_states=None,
|
| 664 |
+
return_dict=None,
|
| 665 |
+
**kwargs,
|
| 666 |
+
):
|
| 667 |
+
if not isinstance(self.peft_config, PromptLearningConfig):
|
| 668 |
+
return self.base_model(
|
| 669 |
+
input_ids=input_ids,
|
| 670 |
+
attention_mask=attention_mask,
|
| 671 |
+
inputs_embeds=inputs_embeds,
|
| 672 |
+
decoder_input_ids=decoder_input_ids,
|
| 673 |
+
decoder_attention_mask=decoder_attention_mask,
|
| 674 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
| 675 |
+
labels=labels,
|
| 676 |
+
output_attentions=output_attentions,
|
| 677 |
+
output_hidden_states=output_hidden_states,
|
| 678 |
+
return_dict=return_dict,
|
| 679 |
+
**kwargs,
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
batch_size = input_ids.shape[0]
|
| 683 |
+
if decoder_attention_mask is not None:
|
| 684 |
+
# concat prompt attention mask
|
| 685 |
+
prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
|
| 686 |
+
decoder_attention_mask = torch.cat((prefix_attention_mask, decoder_attention_mask), dim=1)
|
| 687 |
+
|
| 688 |
+
if kwargs.get("position_ids", None) is not None:
|
| 689 |
+
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
|
| 690 |
+
kwargs["position_ids"] = None
|
| 691 |
+
if kwargs.get("token_type_ids", None) is not None:
|
| 692 |
+
warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
|
| 693 |
+
kwargs["token_type_ids"] = None
|
| 694 |
+
kwargs.update(
|
| 695 |
+
{
|
| 696 |
+
"attention_mask": attention_mask,
|
| 697 |
+
"decoder_attention_mask": decoder_attention_mask,
|
| 698 |
+
"labels": labels,
|
| 699 |
+
"output_attentions": output_attentions,
|
| 700 |
+
"output_hidden_states": output_hidden_states,
|
| 701 |
+
"return_dict": return_dict,
|
| 702 |
+
}
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
|
| 706 |
+
past_key_values = self.get_prompt(batch_size)
|
| 707 |
+
return self.base_model(
|
| 708 |
+
input_ids=input_ids, decoder_input_ids=decoder_input_ids, past_key_values=past_key_values, **kwargs
|
| 709 |
+
)
|
| 710 |
+
else:
|
| 711 |
+
if inputs_embeds is None:
|
| 712 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 713 |
+
if decoder_inputs_embeds is None and decoder_input_ids is None:
|
| 714 |
+
decoder_input_ids = shift_tokens_right(
|
| 715 |
+
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
| 716 |
+
)
|
| 717 |
+
decoder_inputs_embeds = self.word_embeddings(decoder_input_ids)
|
| 718 |
+
|
| 719 |
+
if attention_mask is not None:
|
| 720 |
+
# concat prompt attention mask
|
| 721 |
+
prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
|
| 722 |
+
kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)
|
| 723 |
+
# concat prompt labels
|
| 724 |
+
if labels is not None:
|
| 725 |
+
if self.peft_config.num_transformer_submodules == 1:
|
| 726 |
+
kwargs["labels"] = labels
|
| 727 |
+
elif self.peft_config.num_transformer_submodules == 2:
|
| 728 |
+
prefix_labels = torch.full((batch_size, self.peft_config.num_virtual_tokens), -100).to(self.device)
|
| 729 |
+
kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
|
| 730 |
+
prompts = self.get_prompt(batch_size=batch_size)
|
| 731 |
+
prompts = prompts.to(inputs_embeds.dtype)
|
| 732 |
+
inputs_embeds = torch.cat((prompts[:, : self.peft_config.num_virtual_tokens], inputs_embeds), dim=1)
|
| 733 |
+
if self.peft_config.num_transformer_submodules == 1:
|
| 734 |
+
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
|
| 735 |
+
elif self.peft_config.num_transformer_submodules == 2:
|
| 736 |
+
decoder_inputs_embeds = torch.cat(
|
| 737 |
+
(prompts[:, self.peft_config.num_virtual_tokens :], decoder_inputs_embeds), dim=1
|
| 738 |
+
)
|
| 739 |
+
return self.base_model(
|
| 740 |
+
inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, **kwargs
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
def generate(self, **kwargs):
|
| 744 |
+
self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
|
| 745 |
+
self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
|
| 746 |
+
self._prepare_encoder_decoder_kwargs_for_generation
|
| 747 |
+
)
|
| 748 |
+
try:
|
| 749 |
+
if not isinstance(self.peft_config, PromptLearningConfig):
|
| 750 |
+
outputs = self.base_model.generate(**kwargs)
|
| 751 |
+
else:
|
| 752 |
+
if "input_ids" not in kwargs:
|
| 753 |
+
raise ValueError("input_ids must be provided for Peft model generation")
|
| 754 |
+
if kwargs.get("position_ids", None) is not None:
|
| 755 |
+
warnings.warn(
|
| 756 |
+
"Position ids are not supported for parameter efficient tuning. Ignoring position ids."
|
| 757 |
+
)
|
| 758 |
+
kwargs["position_ids"] = None
|
| 759 |
+
if kwargs.get("token_type_ids", None) is not None:
|
| 760 |
+
warnings.warn(
|
| 761 |
+
"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
|
| 762 |
+
)
|
| 763 |
+
kwargs["token_type_ids"] = None
|
| 764 |
+
|
| 765 |
+
if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
|
| 766 |
+
outputs = self.base_model.generate(**kwargs)
|
| 767 |
+
else:
|
| 768 |
+
raise NotImplementedError
|
| 769 |
+
except:
|
| 770 |
+
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
|
| 771 |
+
self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
|
| 772 |
+
self.base_model_prepare_encoder_decoder_kwargs_for_generation
|
| 773 |
+
)
|
| 774 |
+
raise
|
| 775 |
+
else:
|
| 776 |
+
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
|
| 777 |
+
self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
|
| 778 |
+
self.base_model_prepare_encoder_decoder_kwargs_for_generation
|
| 779 |
+
)
|
| 780 |
+
return outputs
|
| 781 |
+
|
| 782 |
+
def prepare_inputs_for_generation(self, *args, **kwargs):
|
| 783 |
+
model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
|
| 784 |
+
if model_kwargs["past_key_values"] is None and self.peft_config.peft_type == PeftType.PREFIX_TUNING:
|
| 785 |
+
batch_size = model_kwargs["decoder_input_ids"].shape[0]
|
| 786 |
+
past_key_values = self.get_prompt(batch_size)
|
| 787 |
+
model_kwargs["past_key_values"] = past_key_values
|
| 788 |
+
return model_kwargs
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
class PeftModelForTokenClassification(PeftModel):
|
| 792 |
+
"""
|
| 793 |
+
|
| 794 |
+
"""
|
| 795 |
+
|
| 796 |
+
def __init__(self, model, peft_config: PeftConfig):
|
| 797 |
+
super().__init__(model, peft_config)
|
| 798 |
+
self.modules_to_save = ["classifier", "score"]
|
| 799 |
+
|
| 800 |
+
for name, _ in self.base_model.named_children():
|
| 801 |
+
if any(module_name in name for module_name in self.modules_to_save):
|
| 802 |
+
self.cls_layer_name = name
|
| 803 |
+
break
|
| 804 |
+
|
| 805 |
+
# to make sure classifier layer is trainable
|
| 806 |
+
_set_trainable(self)
|
| 807 |
+
|
| 808 |
+
def forward(
|
| 809 |
+
self,
|
| 810 |
+
input_ids=None,
|
| 811 |
+
attention_mask=None,
|
| 812 |
+
inputs_embeds=None,
|
| 813 |
+
labels=None,
|
| 814 |
+
output_attentions=None,
|
| 815 |
+
output_hidden_states=None,
|
| 816 |
+
return_dict=None,
|
| 817 |
+
**kwargs,
|
| 818 |
+
):
|
| 819 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 820 |
+
|
| 821 |
+
if not isinstance(self.peft_config, PromptLearningConfig):
|
| 822 |
+
return self.base_model(
|
| 823 |
+
input_ids=input_ids,
|
| 824 |
+
attention_mask=attention_mask,
|
| 825 |
+
inputs_embeds=inputs_embeds,
|
| 826 |
+
labels=labels,
|
| 827 |
+
output_attentions=output_attentions,
|
| 828 |
+
output_hidden_states=output_hidden_states,
|
| 829 |
+
return_dict=return_dict,
|
| 830 |
+
**kwargs,
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
batch_size = input_ids.shape[0]
|
| 834 |
+
if attention_mask is not None:
|
| 835 |
+
# concat prompt attention mask
|
| 836 |
+
prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
|
| 837 |
+
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
|
| 838 |
+
if kwargs.get("position_ids", None) is not None:
|
| 839 |
+
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
|
| 840 |
+
kwargs["position_ids"] = None
|
| 841 |
+
kwargs.update(
|
| 842 |
+
{
|
| 843 |
+
"attention_mask": attention_mask,
|
| 844 |
+
"labels": labels,
|
| 845 |
+
"output_attentions": output_attentions,
|
| 846 |
+
"output_hidden_states": output_hidden_states,
|
| 847 |
+
"return_dict": return_dict,
|
| 848 |
+
}
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
|
| 852 |
+
return self._prefix_tuning_forward(input_ids=input_ids, **kwargs)
|
| 853 |
+
else:
|
| 854 |
+
if kwargs.get("token_type_ids", None) is not None:
|
| 855 |
+
kwargs["token_type_ids"] = torch.cat(
|
| 856 |
+
(
|
| 857 |
+
torch.zeros(batch_size, self.peft_config.num_virtual_tokens).to(self.device),
|
| 858 |
+
kwargs["token_type_ids"],
|
| 859 |
+
),
|
| 860 |
+
dim=1,
|
| 861 |
+
).long()
|
| 862 |
+
if inputs_embeds is None:
|
| 863 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 864 |
+
prompts = self.get_prompt(batch_size=batch_size)
|
| 865 |
+
prompts = prompts.to(inputs_embeds.dtype)
|
| 866 |
+
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
|
| 867 |
+
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
|
| 868 |
+
|
| 869 |
+
def _prefix_tuning_forward(
|
| 870 |
+
self,
|
| 871 |
+
input_ids=None,
|
| 872 |
+
attention_mask=None,
|
| 873 |
+
inputs_embeds=None,
|
| 874 |
+
labels=None,
|
| 875 |
+
output_attentions=None,
|
| 876 |
+
output_hidden_states=None,
|
| 877 |
+
return_dict=None,
|
| 878 |
+
**kwargs,
|
| 879 |
+
):
|
| 880 |
+
batch_size = input_ids.shape[0]
|
| 881 |
+
past_key_values = self.get_prompt(batch_size)
|
| 882 |
+
fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys())
|
| 883 |
+
kwargs.update(
|
| 884 |
+
{
|
| 885 |
+
"input_ids": input_ids,
|
| 886 |
+
"attention_mask": attention_mask,
|
| 887 |
+
"inputs_embeds": inputs_embeds,
|
| 888 |
+
"output_attentions": output_attentions,
|
| 889 |
+
"output_hidden_states": output_hidden_states,
|
| 890 |
+
"return_dict": return_dict,
|
| 891 |
+
"past_key_values": past_key_values,
|
| 892 |
+
}
|
| 893 |
+
)
|
| 894 |
+
if "past_key_values" in fwd_params:
|
| 895 |
+
return self.base_model(labels=labels, **kwargs)
|
| 896 |
+
else:
|
| 897 |
+
transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name)
|
| 898 |
+
fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys())
|
| 899 |
+
if "past_key_values" not in fwd_params:
|
| 900 |
+
raise ValueError("Model does not support past key values which are required for prefix tuning.")
|
| 901 |
+
outputs = transformer_backbone_name(**kwargs)
|
| 902 |
+
sequence_output = outputs[0]
|
| 903 |
+
if "dropout" in [name for name, _ in list(self.base_model.named_children())]:
|
| 904 |
+
sequence_output = self.base_model.dropout(sequence_output)
|
| 905 |
+
logits = self.base_model.get_submodule(self.cls_layer_name)(sequence_output)
|
| 906 |
+
|
| 907 |
+
loss = None
|
| 908 |
+
loss = None
|
| 909 |
+
if labels is not None:
|
| 910 |
+
loss_fct = CrossEntropyLoss()
|
| 911 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 912 |
+
|
| 913 |
+
if not return_dict:
|
| 914 |
+
output = (logits,) + outputs[2:]
|
| 915 |
+
return ((loss,) + output) if loss is not None else output
|
| 916 |
+
|
| 917 |
+
return TokenClassifierOutput(
|
| 918 |
+
loss=loss,
|
| 919 |
+
logits=logits,
|
| 920 |
+
hidden_states=outputs.hidden_states,
|
| 921 |
+
attentions=outputs.attentions,
|
| 922 |
+
)
|
nl_tasks/rpeft/rotation/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .rotation_config import RotationConfig
|
| 2 |
+
from .layer import RotationLayer
|
| 3 |
+
from .model import RotationTuner
|
nl_tasks/rpeft/rotation/layer.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import Optional, Set
|
| 4 |
+
|
| 5 |
+
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
def inverse_2x2(matrices):
|
| 9 |
+
|
| 10 |
+
# Extract matrix elements
|
| 11 |
+
# matrices[..., 0, 0] corresponds to 'a' in [[a, b], [c, d]]
|
| 12 |
+
a = matrices[..., 0, 0]
|
| 13 |
+
b = matrices[..., 0, 1]
|
| 14 |
+
c = matrices[..., 1, 0]
|
| 15 |
+
d = matrices[..., 1, 1]
|
| 16 |
+
|
| 17 |
+
# Compute determinant
|
| 18 |
+
det = a * d - b * c
|
| 19 |
+
|
| 20 |
+
# Compute inverse using the formula:
|
| 21 |
+
# inv = (1/det) * [[d, -b], [-c, a]]
|
| 22 |
+
inv_det = 1.0 / det
|
| 23 |
+
|
| 24 |
+
# Create output tensor
|
| 25 |
+
inv_matrices = torch.empty_like(matrices)
|
| 26 |
+
inv_matrices[..., 0, 0] = d * inv_det
|
| 27 |
+
inv_matrices[..., 0, 1] = -b * inv_det
|
| 28 |
+
inv_matrices[..., 1, 0] = -c * inv_det
|
| 29 |
+
inv_matrices[..., 1, 1] = a * inv_det
|
| 30 |
+
|
| 31 |
+
return inv_matrices
|
| 32 |
+
|
| 33 |
+
class Rotation(nn.Module):
|
| 34 |
+
"""
|
| 35 |
+
Rotation layer based on Cayley transformation for parameter-efficient fine-tuning.
|
| 36 |
+
|
| 37 |
+
This layer implements orthogonal fine-tuning through Cayley transformation:
|
| 38 |
+
h(x) = (I - A)^{-1} (I + A) x
|
| 39 |
+
|
| 40 |
+
where A = XY^T with X = [U; -V] and Y = [V; U]
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, r, dim, T=1.0, num_rotations=4, drop_out=0.1):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.r = r
|
| 46 |
+
self.T = T
|
| 47 |
+
self.U = nn.Parameter(torch.randn(num_rotations, r, dim) * 0.002, requires_grad=True)
|
| 48 |
+
self.V = nn.Parameter(torch.randn(num_rotations, r, dim) * 0.0, requires_grad=True)
|
| 49 |
+
# self.U = nn.Parameter(torch.empty(num_rotations, r, dim), requires_grad=True)
|
| 50 |
+
# self.V = nn.Parameter(torch.empty(num_rotations, r, dim), requires_grad=True)
|
| 51 |
+
self.num_rotations = num_rotations
|
| 52 |
+
self.dropout = nn.Dropout(drop_out) if drop_out > 0 else nn.Identity()
|
| 53 |
+
self.dim = dim
|
| 54 |
+
|
| 55 |
+
# elf._post_init()
|
| 56 |
+
# @property
|
| 57 |
+
# def U(self):
|
| 58 |
+
# # Calculate U = [a, b] whenever self.U is accessed
|
| 59 |
+
# # This function acts as the 'getter' for self.U
|
| 60 |
+
# return torch.cat([self.a, self.b], dim=-1)
|
| 61 |
+
|
| 62 |
+
# @property
|
| 63 |
+
# def V(self):
|
| 64 |
+
# # Calculate V = [b, a] whenever self.V is accessed
|
| 65 |
+
# # This function acts as the 'getter' for self.V
|
| 66 |
+
# return torch.cat([self.b, self.a], dim=-1)
|
| 67 |
+
|
| 68 |
+
# def _post_init(self):
|
| 69 |
+
# import torch.nn.init as init
|
| 70 |
+
# import math
|
| 71 |
+
# # init.kaiming_uniform_(self.U, a=math.sqrt(1), mode='fan_out')
|
| 72 |
+
# init.normal_(self.U, 0, 1e-2)
|
| 73 |
+
# with torch.no_grad():
|
| 74 |
+
# self.V.data.copy_(self.U.data)
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
"""
|
| 78 |
+
Apply Cayley transformation to input x.
|
| 79 |
+
|
| 80 |
+
A = XY^T where X = [U; -V], Y = [V; U]
|
| 81 |
+
Cayley transformation: h(x) = (I - A)^{-1} (I + A) x
|
| 82 |
+
|
| 83 |
+
Uses Woodbury identity for efficient computation:
|
| 84 |
+
(I - XY^T)^{-1} = I + X (I - Y^T X)^{-1} Y^T
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
x: Input tensor of shape (..., dim)
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
Transformed tensor of shape (..., dim)
|
| 91 |
+
"""
|
| 92 |
+
x_dtype = x.dtype
|
| 93 |
+
|
| 94 |
+
x = self.dropout(x) # NLU tasks do not use dropout
|
| 95 |
+
X = torch.cat([self.U, -self.V], dim=1) # Shape: (num_rotations, 2r, dim)
|
| 96 |
+
Y = torch.cat([self.V, self.U], dim=1) * self.T # Shape: (num_rotations, 2r, dim)
|
| 97 |
+
|
| 98 |
+
Y_T_X = torch.matmul(Y, X.transpose(1, 2)) # Shape: (num_rotations, 2r, 2r)
|
| 99 |
+
# I_2r = torch.eye(2 * self.r, device=x.device, dtype=x.dtype).repeat(self.num_rotations, 1, 1)
|
| 100 |
+
I_2r = torch.eye(2 * self.r, device=x.device, dtype=x.dtype).unsqueeze(0)
|
| 101 |
+
I_minus_YX = I_2r - Y_T_X
|
| 102 |
+
|
| 103 |
+
if self.r == 1:
|
| 104 |
+
I_minus_YX_inv = inverse_2x2(I_minus_YX)
|
| 105 |
+
else:
|
| 106 |
+
# make it float32
|
| 107 |
+
I_minus_YX = I_minus_YX.to(torch.float32)
|
| 108 |
+
I_minus_YX_inv = torch.linalg.inv(I_minus_YX) # Shape: (num_rotations, 2r, 2r)
|
| 109 |
+
I_minus_YX_inv = I_minus_YX_inv.to(x_dtype)
|
| 110 |
+
|
| 111 |
+
# Yx = torch.einsum("...d,nrd->...nr", x, Y) # Shape: (batch*seq_len, num_rotations, 2r)
|
| 112 |
+
|
| 113 |
+
input_shape = x.shape
|
| 114 |
+
x_flat = x.reshape(-1, self.dim) # Shape: (B, dim)
|
| 115 |
+
Y_reshape = Y.reshape(-1, self.dim) # Shape: (nr, d)
|
| 116 |
+
Yx2_flat = F.linear(x_flat, Y_reshape)
|
| 117 |
+
Yx2 = Yx2_flat.view(-1, self.num_rotations, 2*self.r)
|
| 118 |
+
# is_close = torch.allclose(Yx.view(-1, self.num_rotations, 2*self.r), Yx2, atol=1e-5, rtol=1e-4)
|
| 119 |
+
# if is_close:
|
| 120 |
+
# pass
|
| 121 |
+
# # print("✅ SUCCESS 11: The optimized code produces identical results!")
|
| 122 |
+
# else:
|
| 123 |
+
# print("❌ FAILURE: The results diverge.")
|
| 124 |
+
|
| 125 |
+
###
|
| 126 |
+
# n of (r,r) @ n of (r,1) -> n of r
|
| 127 |
+
# I_minus_YX_inv_Yx = torch.einsum("nrr,...nr->...nr", I_minus_YX_inv, Yx)
|
| 128 |
+
# I_minus_YX_inv_Yx = torch.einsum("...qr,...r->...q", I_minus_YX_inv, Yx)
|
| 129 |
+
|
| 130 |
+
Yx2_expanded = Yx2.unsqueeze(-1)
|
| 131 |
+
I_minus_YX_inv_ex = I_minus_YX_inv.unsqueeze(0)
|
| 132 |
+
I_minus_YX_inv_Yx2 = I_minus_YX_inv_ex @ Yx2_expanded
|
| 133 |
+
I_minus_YX_inv_Yx2 = I_minus_YX_inv_Yx2.squeeze(-1)
|
| 134 |
+
# is_close = torch.allclose(I_minus_YX_inv_Yx.view(-1, self.num_rotations, 2*self.r), I_minus_YX_inv_Yx2, atol=1e-4, rtol=1e-3)
|
| 135 |
+
# if is_close:
|
| 136 |
+
# pass
|
| 137 |
+
# #print("✅ SUCCESS 22: The optimized code produces identical results!")
|
| 138 |
+
# else:
|
| 139 |
+
# print("❌ FAILURE: The results diverge.")
|
| 140 |
+
# exit()
|
| 141 |
+
|
| 142 |
+
# n of (r,) @ n of (r,d)
|
| 143 |
+
# second_term = torch.einsum("...nr,nrd->...nd", I_minus_YX_inv_Yx, X) # Shape: (batch*seq_len, num_rotations, dim)
|
| 144 |
+
# second_term = torch.einsum("...r, ...rd->...d", I_minus_YX_inv_Yx, X)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# I_minus_YX_inv_Yx_ex = I_minus_YX_inv_Yx2.unsqueeze(-2)
|
| 148 |
+
# X_ex = X.unsqueeze(0)
|
| 149 |
+
# second_term2 = I_minus_YX_inv_Yx_ex @ X_ex
|
| 150 |
+
# second_term2 = second_term2.squeeze(-2)
|
| 151 |
+
# is_close = torch.allclose(second_term, second_term2, atol=1e-5, rtol=1e-4)
|
| 152 |
+
# if is_close:
|
| 153 |
+
# pass
|
| 154 |
+
# # print("✅ SUCCESS 33: The optimized code produces identical results!")
|
| 155 |
+
# else:
|
| 156 |
+
# print("❌ FAILURE: The results diverge.")
|
| 157 |
+
|
| 158 |
+
# second_term = second_term.sum(dim=-2) # Sum over rotations
|
| 159 |
+
|
| 160 |
+
coeffs_flat = I_minus_YX_inv_Yx2.reshape(-1, self.num_rotations * 2 * self.r) # (batch*len, 2n r)
|
| 161 |
+
X_flat = X.reshape(-1, self.dim) #(N*2r, dim)
|
| 162 |
+
second_term3 = torch.matmul(coeffs_flat, X_flat)
|
| 163 |
+
# is_close = torch.allclose(second_term.view(-1, self.dim), second_term3, atol=1e-5, rtol=1e-4)
|
| 164 |
+
# if is_close:
|
| 165 |
+
# pass
|
| 166 |
+
# # print("✅ SUCCESS 44: The optimized code produces identical results!")
|
| 167 |
+
# else:
|
| 168 |
+
# print("❌ FAILURE: The results diverge.")
|
| 169 |
+
|
| 170 |
+
# output = x + 2 * second_term # Shape: (batch*seq_len, dim)
|
| 171 |
+
output = x_flat + 2 * second_term3 # (batch*seq_len, dim)
|
| 172 |
+
|
| 173 |
+
# return output.to(x_dtype)
|
| 174 |
+
return output.view(*input_shape)
|
| 175 |
+
|
| 176 |
+
def get_delta_weight(self):
|
| 177 |
+
"""
|
| 178 |
+
Compute the delta weight matrix induced by the rotation layer.
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
Delta weight matrix of shape (dim, dim)
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
X = torch.cat([self.U, -self.V], dim=1) # Shape: (num_rotations, 2r, dim)
|
| 185 |
+
Y = torch.cat([self.V, self.U], dim=1) * self.T # Shape: (num_rotations, 2r, dim)
|
| 186 |
+
|
| 187 |
+
Y_T_X = torch.matmul(Y, X.transpose(1, 2)) # Shape: (num_rotations, 2r, 2r)
|
| 188 |
+
I_2r = torch.eye(2 * self.r, device=X.device, dtype=X.dtype).repeat(self.num_rotations, 1, 1)
|
| 189 |
+
I_minus_YX = I_2r - Y_T_X
|
| 190 |
+
|
| 191 |
+
if self.r == 1:
|
| 192 |
+
I_minus_YX_inv = inverse_2x2(I_minus_YX)
|
| 193 |
+
I_minus_YX_inv_Y = torch.einsum("nRr,nrd->nRd", I_minus_YX_inv, Y) # Shape: (num_rotations, 2r, dim)
|
| 194 |
+
# I_minus_YX_inv_Y = torch.einsum("...rr,...rd->...rd", I_minus_YX_inv, Y) ## reproduce
|
| 195 |
+
else:
|
| 196 |
+
I_minus_YX_inv_Y = torch.linalg.solve(I_minus_YX.to(torch.float32), Y.to(torch.float32)) # Shape: (num_rotations, 2r, dim)
|
| 197 |
+
|
| 198 |
+
# I_minus_YX_inv = torch.linalg.inv(I_minus_YX)
|
| 199 |
+
# I_minus_YX_inv_Y = torch.einsum("...rr,...rd->...rd", I_minus_YX_inv, Y) ## reproduce
|
| 200 |
+
|
| 201 |
+
I_minus_YX_inv_Y = I_minus_YX_inv_Y.to(X.dtype)
|
| 202 |
+
|
| 203 |
+
# I_minus_YX_float = I_minus_YX.float()
|
| 204 |
+
# I_minus_YX_inv = torch.linalg.inv(I_minus_YX_float) # Shape: (num_rotations, 2r, 2r)
|
| 205 |
+
# I_minus_YX_inv = I_minus_YX_inv.to(X.dtype)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# I_minus_YX_inv_Y = torch.einsum("nRr,nrd->nRd", I_minus_YX_inv, Y) # Shape: (num_rotations, 2r, dim)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
second_term = torch.einsum("nrd,nrD->ndD", X, I_minus_YX_inv_Y) # Shape: (num_rotations, dim, dim)
|
| 212 |
+
second_term = second_term.sum(dim=0)
|
| 213 |
+
total_delta_weight = 2 * second_term
|
| 214 |
+
|
| 215 |
+
return total_delta_weight
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class RotationLayer(BaseTunerLayer):
|
| 219 |
+
"""
|
| 220 |
+
Adapter-like wrapper that attaches Rotation modules to a base linear layer.
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
adapter_layer_names: tuple[str, ...] = ("rotation",)
|
| 224 |
+
other_param_names: tuple[str, ...] = ("r", "T", "num_rotations", "scaling")
|
| 225 |
+
|
| 226 |
+
def __init__(self, base_layer: nn.Module, **kwargs):
|
| 227 |
+
# Let BaseTunerLayer do its init (it usually subclasses nn.Module)
|
| 228 |
+
super().__init__()
|
| 229 |
+
# store base layer and adapter containers
|
| 230 |
+
self.base_layer = base_layer
|
| 231 |
+
self.rotation = nn.ModuleDict() # mapping adapter_name -> Rotation module
|
| 232 |
+
self.scaling={} # default scaling per adapter
|
| 233 |
+
self._adapter_config = {} # store r, T, num_rotations per adapter
|
| 234 |
+
|
| 235 |
+
# flags (exposed in a simple way)
|
| 236 |
+
self._disable_adapters = False
|
| 237 |
+
self.merged_adapters: list[str] = []
|
| 238 |
+
self._cast_input_dtype_enabled = True
|
| 239 |
+
self.kwargs = kwargs
|
| 240 |
+
|
| 241 |
+
if isinstance(base_layer, nn.Linear):
|
| 242 |
+
self.in_features = base_layer.in_features
|
| 243 |
+
self.out_features = base_layer.out_features
|
| 244 |
+
else:
|
| 245 |
+
raise NotImplementedError("RotationLayer only supports nn.Linear base layers for now.")
|
| 246 |
+
|
| 247 |
+
@property
|
| 248 |
+
def _available_adapters(self) -> set[str]:
|
| 249 |
+
return set(self.rotation.keys())
|
| 250 |
+
|
| 251 |
+
@property
|
| 252 |
+
def disable_adapters(self) -> bool:
|
| 253 |
+
return self._disable_adapters
|
| 254 |
+
|
| 255 |
+
@property
|
| 256 |
+
def merged(self) -> bool:
|
| 257 |
+
return bool(self.merged_adapters)
|
| 258 |
+
|
| 259 |
+
@property
|
| 260 |
+
def active_adapters(self) -> list[str]:
|
| 261 |
+
# If some external mechanism sets active adapters, prefer it; else use all added adapters.
|
| 262 |
+
return getattr(self, "_active_adapters", list(self.rotation.keys()))
|
| 263 |
+
|
| 264 |
+
def get_base_layer(self) -> nn.Module:
|
| 265 |
+
return self.base_layer
|
| 266 |
+
|
| 267 |
+
def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
| 268 |
+
if not self._cast_input_dtype_enabled:
|
| 269 |
+
return x
|
| 270 |
+
return x.to(dtype)
|
| 271 |
+
|
| 272 |
+
def update_layer(
|
| 273 |
+
self,
|
| 274 |
+
adapter_name: str,
|
| 275 |
+
r: int,
|
| 276 |
+
T: float,
|
| 277 |
+
num_rotations: int,
|
| 278 |
+
drop_out: float,
|
| 279 |
+
**kwargs,
|
| 280 |
+
):
|
| 281 |
+
"""
|
| 282 |
+
Add / update a rotation adapter for this layer.
|
| 283 |
+
"""
|
| 284 |
+
|
| 285 |
+
if r <= 0:
|
| 286 |
+
raise ValueError(f"r must be positive, got {r}")
|
| 287 |
+
if num_rotations <= 0:
|
| 288 |
+
raise ValueError(f"num_rotations must be positive, got {num_rotations}")
|
| 289 |
+
|
| 290 |
+
rot = Rotation(r=r, dim=self.in_features, T=T, num_rotations=num_rotations, drop_out=drop_out)
|
| 291 |
+
self.rotation[adapter_name] = rot
|
| 292 |
+
self.scaling[adapter_name] = 1.0
|
| 293 |
+
self._adapter_config[adapter_name] = {"r": r, "T": T, "num_rotations": num_rotations}
|
| 294 |
+
|
| 295 |
+
# (optional) helper to set currently active adapters externally
|
| 296 |
+
def set_active_adapters(self, adapters: Optional[list[str]]):
|
| 297 |
+
if adapters is None:
|
| 298 |
+
if hasattr(self, "_active_adapters"):
|
| 299 |
+
delattr(self, "_active_adapters")
|
| 300 |
+
else:
|
| 301 |
+
self._active_adapters = adapters
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class Linear(nn.Module, RotationLayer):
|
| 305 |
+
"""
|
| 306 |
+
A linear layer with an integrated rotation layer for parameter-efficient fine-tuning.
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
def __init__(self,
|
| 310 |
+
base_layer: nn.Linear,
|
| 311 |
+
adapter_name: str,
|
| 312 |
+
r: int,
|
| 313 |
+
T: float,
|
| 314 |
+
num_rotations: int,
|
| 315 |
+
drop_out: float,
|
| 316 |
+
**kwargs):
|
| 317 |
+
|
| 318 |
+
super().__init__()
|
| 319 |
+
RotationLayer.__init__(self, base_layer=base_layer, **kwargs)
|
| 320 |
+
|
| 321 |
+
self._active_adapter = adapter_name
|
| 322 |
+
|
| 323 |
+
self.update_layer(
|
| 324 |
+
adapter_name=adapter_name,
|
| 325 |
+
r=r,
|
| 326 |
+
T=T,
|
| 327 |
+
num_rotations=num_rotations,
|
| 328 |
+
drop_out=drop_out,
|
| 329 |
+
**kwargs,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
def merge(self, safe_merge: bool = False, adapter_names: Optional[str] = None):
|
| 333 |
+
"""
|
| 334 |
+
Merge the adapter effect into the base layer weights:
|
| 335 |
+
W_merged = W @ R
|
| 336 |
+
where R = I + delta (delta returned by get_delta_weight()).
|
| 337 |
+
"""
|
| 338 |
+
adapter_names = check_adapters_to_merge(self, adapter_names)
|
| 339 |
+
|
| 340 |
+
if not adapter_names:
|
| 341 |
+
return
|
| 342 |
+
|
| 343 |
+
base_layer = self.get_base_layer()
|
| 344 |
+
orig_dtype = base_layer.weight.dtype
|
| 345 |
+
# base_layer.weight shape: (out_features, in_features)
|
| 346 |
+
W = base_layer.weight.data # (out, in)
|
| 347 |
+
|
| 348 |
+
for active_adapter in adapter_names:
|
| 349 |
+
|
| 350 |
+
if active_adapter not in self._available_adapters:
|
| 351 |
+
continue
|
| 352 |
+
|
| 353 |
+
delta_R = self.rotation[active_adapter].get_delta_weight() # (in, in)
|
| 354 |
+
R = torch.eye(delta_R.size(0), device=delta_R.device, dtype=delta_R.dtype) + delta_R # (in, in)
|
| 355 |
+
# merged W = W @ R
|
| 356 |
+
merged_W = W.to(R.dtype) @ R
|
| 357 |
+
|
| 358 |
+
if safe_merge and not torch.isfinite(merged_W).all():
|
| 359 |
+
raise ValueError("Merging resulted in non-finite weights. Aborting merge.")
|
| 360 |
+
|
| 361 |
+
base_layer.weight.data = merged_W.contiguous().to(orig_dtype)
|
| 362 |
+
|
| 363 |
+
# mark merged (so unmerge can restore by inverse)
|
| 364 |
+
self.merged_adapters.append(active_adapter)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def unmerge(self):
|
| 368 |
+
"""
|
| 369 |
+
Reverse merges in LIFO order (pop merged adapters and invert R).
|
| 370 |
+
"""
|
| 371 |
+
base_layer = self.get_base_layer()
|
| 372 |
+
orig_dtype = base_layer.weight.dtype
|
| 373 |
+
|
| 374 |
+
while self.merged_adapters:
|
| 375 |
+
active_adapter = self.merged_adapters.pop()
|
| 376 |
+
if active_adapter not in self._available_adapters:
|
| 377 |
+
continue
|
| 378 |
+
delta_R = self.rotation[active_adapter].get_delta_weight() # (in, in)
|
| 379 |
+
R = torch.eye(delta_R.size(0), device=delta_R.device, dtype=delta_R.dtype) + delta_R
|
| 380 |
+
R_inv = torch.linalg.inv(R)
|
| 381 |
+
merged_W = base_layer.weight.data.to(R.dtype)
|
| 382 |
+
unmerged_W = merged_W @ R_inv
|
| 383 |
+
base_layer.weight.data = unmerged_W.contiguous().to(orig_dtype)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 387 |
+
x_dtype = x.dtype
|
| 388 |
+
base_layer = self.get_base_layer()
|
| 389 |
+
|
| 390 |
+
if self.disable_adapters:
|
| 391 |
+
# if merged, unmerge to ensure base_layer produces original behavior
|
| 392 |
+
if self.merged:
|
| 393 |
+
self.unmerge()
|
| 394 |
+
return base_layer(x, *args, **kwargs).to(x_dtype)
|
| 395 |
+
|
| 396 |
+
if self.merged:
|
| 397 |
+
# if merged into base layer, just forward
|
| 398 |
+
return base_layer(x, *args, **kwargs).to(x_dtype)
|
| 399 |
+
|
| 400 |
+
# otherwise apply active adapters (transform inputs) then call base layer
|
| 401 |
+
for active_adapter in self.active_adapters:
|
| 402 |
+
if active_adapter not in self.rotation:
|
| 403 |
+
continue
|
| 404 |
+
rotation = self.rotation[active_adapter]
|
| 405 |
+
x = self._cast_input_dtype(x, rotation.U.dtype)
|
| 406 |
+
x = rotation(x)
|
| 407 |
+
|
| 408 |
+
return base_layer(x, *args, **kwargs).to(x_dtype)
|
| 409 |
+
|
| 410 |
+
def __repr__(self):
|
| 411 |
+
return f"rotation.{super().__repr__()}"
|
| 412 |
+
|
nl_tasks/rpeft/rotation/layer_test.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from omini.rotation.layer import Linear, Rotation
|
| 4 |
+
|
| 5 |
+
def test_rotation_merge():
|
| 6 |
+
"""
|
| 7 |
+
Test that merging rotation adapter produces the same output as the unmerged version.
|
| 8 |
+
"""
|
| 9 |
+
print("="*60)
|
| 10 |
+
print("Testing Rotation Layer Merge")
|
| 11 |
+
print("="*60)
|
| 12 |
+
|
| 13 |
+
# Set random seed for reproducibility
|
| 14 |
+
torch.manual_seed(42)
|
| 15 |
+
|
| 16 |
+
# Configuration
|
| 17 |
+
in_features = 512
|
| 18 |
+
out_features = 1024
|
| 19 |
+
r = 4
|
| 20 |
+
num_rotations = 4
|
| 21 |
+
T = 1.0
|
| 22 |
+
batch_size = 8
|
| 23 |
+
seq_len = 16
|
| 24 |
+
|
| 25 |
+
# Create base linear layer
|
| 26 |
+
base_layer = nn.Linear(in_features, out_features, bias=True)
|
| 27 |
+
|
| 28 |
+
# Create rotation layer
|
| 29 |
+
rotation_layer = Linear(
|
| 30 |
+
base_layer=base_layer,
|
| 31 |
+
adapter_name="default",
|
| 32 |
+
r=r,
|
| 33 |
+
T=T,
|
| 34 |
+
num_rotations=num_rotations
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Create random input
|
| 38 |
+
x = torch.randn(batch_size, seq_len, in_features)
|
| 39 |
+
|
| 40 |
+
# Test 1: Forward pass before merge
|
| 41 |
+
print("\n" + "-"*60)
|
| 42 |
+
print("Test 1: Computing output BEFORE merge")
|
| 43 |
+
print("-"*60)
|
| 44 |
+
rotation_layer.eval()
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
output_before = rotation_layer(x)
|
| 47 |
+
|
| 48 |
+
print(f"Output shape: {output_before.shape}")
|
| 49 |
+
print(f"Output mean: {output_before.mean().item():.6f}")
|
| 50 |
+
print(f"Output std: {output_before.std().item():.6f}")
|
| 51 |
+
print(f"Output min: {output_before.min().item():.6f}")
|
| 52 |
+
print(f"Output max: {output_before.max().item():.6f}")
|
| 53 |
+
|
| 54 |
+
# Save original weight for verification
|
| 55 |
+
original_weight = base_layer.weight.data.clone()
|
| 56 |
+
|
| 57 |
+
# Test 2: Merge adapter
|
| 58 |
+
print("\n" + "-"*60)
|
| 59 |
+
print("Test 2: Merging adapter")
|
| 60 |
+
print("-"*60)
|
| 61 |
+
rotation_layer.merge(safe_merge=True, adapter_names=["default"])
|
| 62 |
+
print(f"✓ Adapter merged successfully")
|
| 63 |
+
print(f"✓ Merged adapters: {rotation_layer.merged_adapters}")
|
| 64 |
+
|
| 65 |
+
# Check that weights have changed
|
| 66 |
+
weight_diff = (base_layer.weight.data - original_weight).abs().max().item()
|
| 67 |
+
print(f"Max weight change: {weight_diff:.6e}")
|
| 68 |
+
|
| 69 |
+
# Test 3: Forward pass after merge
|
| 70 |
+
print("\n" + "-"*60)
|
| 71 |
+
print("Test 3: Computing output AFTER merge")
|
| 72 |
+
print("-"*60)
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
output_after = rotation_layer(x)
|
| 75 |
+
|
| 76 |
+
print(f"Output shape: {output_after.shape}")
|
| 77 |
+
print(f"Output mean: {output_after.mean().item():.6f}")
|
| 78 |
+
print(f"Output std: {output_after.std().item():.6f}")
|
| 79 |
+
print(f"Output min: {output_after.min().item():.6f}")
|
| 80 |
+
print(f"Output max: {output_after.max().item():.6f}")
|
| 81 |
+
|
| 82 |
+
# Test 4: Compare outputs
|
| 83 |
+
print("\n" + "-"*60)
|
| 84 |
+
print("Test 4: Comparing outputs")
|
| 85 |
+
print("-"*60)
|
| 86 |
+
|
| 87 |
+
# Compute differences
|
| 88 |
+
abs_diff = (output_after - output_before).abs()
|
| 89 |
+
rel_diff = abs_diff / (output_before.abs() + 1e-8)
|
| 90 |
+
|
| 91 |
+
max_abs_diff = abs_diff.max().item()
|
| 92 |
+
mean_abs_diff = abs_diff.mean().item()
|
| 93 |
+
max_rel_diff = rel_diff.max().item()
|
| 94 |
+
mean_rel_diff = rel_diff.mean().item()
|
| 95 |
+
|
| 96 |
+
print(f"Max absolute difference: {max_abs_diff:.6e}")
|
| 97 |
+
print(f"Mean absolute difference: {mean_abs_diff:.6e}")
|
| 98 |
+
print(f"Max relative difference: {max_rel_diff:.6e}")
|
| 99 |
+
print(f"Mean relative difference: {mean_rel_diff:.6e}")
|
| 100 |
+
|
| 101 |
+
# Check if outputs are close
|
| 102 |
+
atol = 1e-4 # Absolute tolerance
|
| 103 |
+
rtol = 1e-3 # Relative tolerance
|
| 104 |
+
|
| 105 |
+
are_close = torch.allclose(output_before, output_after, atol=atol, rtol=rtol)
|
| 106 |
+
|
| 107 |
+
if are_close:
|
| 108 |
+
print(f"\n✅ PASS: Outputs are identical (within atol={atol}, rtol={rtol})")
|
| 109 |
+
else:
|
| 110 |
+
print(f"\n❌ FAIL: Outputs differ significantly")
|
| 111 |
+
print(f" Expected: atol < {atol}, rtol < {rtol}")
|
| 112 |
+
print(f" Got: max_abs_diff = {max_abs_diff:.6e}, max_rel_diff = {max_rel_diff:.6e}")
|
| 113 |
+
|
| 114 |
+
# Test 5: Unmerge and verify
|
| 115 |
+
print("\n" + "-"*60)
|
| 116 |
+
print("Test 5: Testing unmerge")
|
| 117 |
+
print("-"*60)
|
| 118 |
+
rotation_layer.unmerge()
|
| 119 |
+
print(f"✓ Adapter unmerged")
|
| 120 |
+
print(f"✓ Merged adapters: {rotation_layer.merged_adapters}")
|
| 121 |
+
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
output_unmerged = rotation_layer(x)
|
| 124 |
+
|
| 125 |
+
unmerge_diff = (output_unmerged - output_before).abs().max().item()
|
| 126 |
+
print(f"Max difference after unmerge: {unmerge_diff:.6e}")
|
| 127 |
+
|
| 128 |
+
unmerge_close = torch.allclose(output_before, output_unmerged, atol=atol, rtol=rtol)
|
| 129 |
+
if unmerge_close:
|
| 130 |
+
print(f"✅ PASS: Unmerge restored original behavior")
|
| 131 |
+
else:
|
| 132 |
+
print(f"❌ FAIL: Unmerge did not restore original behavior")
|
| 133 |
+
|
| 134 |
+
# Test 6: Verify weight restoration
|
| 135 |
+
weight_restored_diff = (base_layer.weight.data - original_weight).abs().max().item()
|
| 136 |
+
print(f"Max weight difference after unmerge: {weight_restored_diff:.6e}")
|
| 137 |
+
|
| 138 |
+
weight_restored = torch.allclose(base_layer.weight.data, original_weight, atol=1e-5)
|
| 139 |
+
if weight_restored:
|
| 140 |
+
print(f"✅ PASS: Original weights restored")
|
| 141 |
+
else:
|
| 142 |
+
print(f"❌ FAIL: Original weights not fully restored")
|
| 143 |
+
|
| 144 |
+
print("\n" + "="*60)
|
| 145 |
+
print("Test Summary")
|
| 146 |
+
print("="*60)
|
| 147 |
+
return are_close and unmerge_close and weight_restored
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def test_multiple_merges():
|
| 151 |
+
"""
|
| 152 |
+
Test merging and unmerging multiple times.
|
| 153 |
+
"""
|
| 154 |
+
print("\n" + "="*60)
|
| 155 |
+
print("Testing Multiple Merge/Unmerge Cycles")
|
| 156 |
+
print("="*60)
|
| 157 |
+
|
| 158 |
+
torch.manual_seed(42)
|
| 159 |
+
|
| 160 |
+
in_features = 256
|
| 161 |
+
out_features = 512
|
| 162 |
+
r = 4
|
| 163 |
+
num_rotations = 4
|
| 164 |
+
|
| 165 |
+
base_layer = nn.Linear(in_features, out_features, bias=True)
|
| 166 |
+
rotation_layer = Linear(
|
| 167 |
+
base_layer=base_layer,
|
| 168 |
+
adapter_name="default",
|
| 169 |
+
r=r,
|
| 170 |
+
T=1.0,
|
| 171 |
+
num_rotations=num_rotations
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
x = torch.randn(4, 8, in_features)
|
| 175 |
+
rotation_layer.eval()
|
| 176 |
+
|
| 177 |
+
# Get original output
|
| 178 |
+
with torch.no_grad():
|
| 179 |
+
original_output = rotation_layer(x)
|
| 180 |
+
|
| 181 |
+
# Test multiple cycles
|
| 182 |
+
all_passed = True
|
| 183 |
+
for cycle in range(3):
|
| 184 |
+
print(f"\nCycle {cycle + 1}:")
|
| 185 |
+
|
| 186 |
+
# Merge
|
| 187 |
+
rotation_layer.merge(safe_merge=True)
|
| 188 |
+
with torch.no_grad():
|
| 189 |
+
merged_output = rotation_layer(x)
|
| 190 |
+
|
| 191 |
+
merge_close = torch.allclose(original_output, merged_output, atol=1e-4, rtol=1e-3)
|
| 192 |
+
print(f" Merge: {'✅ PASS' if merge_close else '❌ FAIL'}")
|
| 193 |
+
|
| 194 |
+
# Unmerge
|
| 195 |
+
rotation_layer.unmerge()
|
| 196 |
+
with torch.no_grad():
|
| 197 |
+
unmerged_output = rotation_layer(x)
|
| 198 |
+
|
| 199 |
+
unmerge_close = torch.allclose(original_output, unmerged_output, atol=1e-4, rtol=1e-3)
|
| 200 |
+
print(f" Unmerge: {'✅ PASS' if unmerge_close else '❌ FAIL'}")
|
| 201 |
+
|
| 202 |
+
all_passed = all_passed and merge_close and unmerge_close
|
| 203 |
+
|
| 204 |
+
return all_passed
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def test_with_different_dtypes():
|
| 208 |
+
"""
|
| 209 |
+
Test merging with different data types.
|
| 210 |
+
"""
|
| 211 |
+
print("\n" + "="*60)
|
| 212 |
+
print("Testing Different Data Types")
|
| 213 |
+
print("="*60)
|
| 214 |
+
|
| 215 |
+
torch.manual_seed(42)
|
| 216 |
+
|
| 217 |
+
dtypes = [torch.float32, torch.float16, torch.bfloat16]
|
| 218 |
+
all_passed = True
|
| 219 |
+
|
| 220 |
+
for dtype in dtypes:
|
| 221 |
+
print(f"\nTesting with dtype: {dtype}")
|
| 222 |
+
|
| 223 |
+
in_features = 256
|
| 224 |
+
out_features = 512
|
| 225 |
+
r = 4
|
| 226 |
+
num_rotations = 4
|
| 227 |
+
|
| 228 |
+
base_layer = nn.Linear(in_features, out_features, bias=True)
|
| 229 |
+
base_layer = base_layer.to(dtype)
|
| 230 |
+
|
| 231 |
+
rotation_layer = Linear(
|
| 232 |
+
base_layer=base_layer,
|
| 233 |
+
adapter_name="default",
|
| 234 |
+
r=r,
|
| 235 |
+
T=1.0,
|
| 236 |
+
num_rotations=num_rotations
|
| 237 |
+
)
|
| 238 |
+
rotation_layer = rotation_layer.to(dtype)
|
| 239 |
+
|
| 240 |
+
x = torch.randn(4, 8, in_features, dtype=dtype)
|
| 241 |
+
rotation_layer.eval()
|
| 242 |
+
|
| 243 |
+
with torch.no_grad():
|
| 244 |
+
output_before = rotation_layer(x)
|
| 245 |
+
rotation_layer.merge(safe_merge=True)
|
| 246 |
+
output_after = rotation_layer(x)
|
| 247 |
+
|
| 248 |
+
# Adjust tolerances based on dtype
|
| 249 |
+
if dtype == torch.float32:
|
| 250 |
+
atol, rtol = 1e-5, 1e-4
|
| 251 |
+
elif dtype == torch.float16:
|
| 252 |
+
atol, rtol = 1e-2, 1e-2
|
| 253 |
+
else: # bfloat16
|
| 254 |
+
atol, rtol = 1e-2, 1e-2
|
| 255 |
+
|
| 256 |
+
are_close = torch.allclose(output_before, output_after, atol=atol, rtol=rtol)
|
| 257 |
+
|
| 258 |
+
if are_close:
|
| 259 |
+
print(f" ✅ PASS")
|
| 260 |
+
else:
|
| 261 |
+
max_diff = (output_after - output_before).abs().max().item()
|
| 262 |
+
print(f" ❌ FAIL (max diff: {max_diff:.6e})")
|
| 263 |
+
|
| 264 |
+
all_passed = all_passed and are_close
|
| 265 |
+
|
| 266 |
+
return all_passed
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
if __name__ == "__main__":
|
| 270 |
+
print("\n" + "="*60)
|
| 271 |
+
print("ROTATION LAYER MERGE TEST SUITE")
|
| 272 |
+
print("="*60)
|
| 273 |
+
|
| 274 |
+
results = {}
|
| 275 |
+
|
| 276 |
+
# Run all tests
|
| 277 |
+
results["basic_merge"] = test_rotation_merge()
|
| 278 |
+
results["multiple_cycles"] = test_multiple_merges()
|
| 279 |
+
results["different_dtypes"] = test_with_different_dtypes()
|
| 280 |
+
|
| 281 |
+
# Print summary
|
| 282 |
+
print("\n" + "="*60)
|
| 283 |
+
print("FINAL SUMMARY")
|
| 284 |
+
print("="*60)
|
| 285 |
+
|
| 286 |
+
for test_name, passed in results.items():
|
| 287 |
+
status = "✅ PASS" if passed else "❌ FAIL"
|
| 288 |
+
print(f"{test_name}: {status}")
|
| 289 |
+
|
| 290 |
+
all_passed = all(results.values())
|
| 291 |
+
print("\n" + "="*60)
|
| 292 |
+
if all_passed:
|
| 293 |
+
print("🎉 ALL TESTS PASSED!")
|
| 294 |
+
else:
|
| 295 |
+
print("⚠️ SOME TESTS FAILED")
|
| 296 |
+
print("="*60)
|
nl_tasks/rpeft/rotation/model.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from enum import Enum
|
| 6 |
+
from dataclasses import asdict
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists, onload_layer
|
| 11 |
+
|
| 12 |
+
from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, ModulesToSaveWrapper, _get_submodules
|
| 13 |
+
|
| 14 |
+
from .layer import RotationLayer, Linear
|
| 15 |
+
|
| 16 |
+
TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
|
| 17 |
+
|
| 18 |
+
class RotationTuner(BaseTuner):
|
| 19 |
+
|
| 20 |
+
prefix: str = "rotation"
|
| 21 |
+
tuner_layer_class = RotationLayer
|
| 22 |
+
target_module_mapping = TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def _check_target_module_exists(rotation_config, key: str) -> bool:
|
| 27 |
+
return check_target_module_exists(rotation_config, key)
|
| 28 |
+
|
| 29 |
+
def _create_and_replace(
|
| 30 |
+
self,
|
| 31 |
+
rotation_config,
|
| 32 |
+
adapter_name: str,
|
| 33 |
+
target: nn.Module,
|
| 34 |
+
target_name: str,
|
| 35 |
+
parent: nn.Module,
|
| 36 |
+
current_key: str,
|
| 37 |
+
**optional_kwargs,
|
| 38 |
+
) -> None:
|
| 39 |
+
"""
|
| 40 |
+
Create and replace a target module with a rotation-augmented version.
|
| 41 |
+
|
| 42 |
+
This method is called when an existing module is already a RotationLayer
|
| 43 |
+
and needs to have a new adapter added to it.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
rotation_config: Configuration for the rotation adapter
|
| 47 |
+
adapter_name: Name of the adapter to add
|
| 48 |
+
target: The target module to augment
|
| 49 |
+
target_name: Name of the target module
|
| 50 |
+
parent: Parent module containing the target
|
| 51 |
+
current_key: Full key path to the current module
|
| 52 |
+
**optional_kwargs: Additional optional arguments
|
| 53 |
+
|
| 54 |
+
Raises:
|
| 55 |
+
ValueError: If current_key is not provided
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
if current_key is None:
|
| 59 |
+
raise ValueError("current_key must be provided to create Rotation layer")
|
| 60 |
+
|
| 61 |
+
# Check if target is already a RotationLayer
|
| 62 |
+
if isinstance(target, RotationLayer):
|
| 63 |
+
target.update_layer(
|
| 64 |
+
adapter_name=adapter_name,
|
| 65 |
+
r=rotation_config.r,
|
| 66 |
+
T=rotation_config.T,
|
| 67 |
+
num_rotations=rotation_config.num_rotations,
|
| 68 |
+
)
|
| 69 |
+
else:
|
| 70 |
+
# Create new rotation layer
|
| 71 |
+
new_module = self._create_new_module(
|
| 72 |
+
rotation_config=rotation_config,
|
| 73 |
+
adapter_name=adapter_name,
|
| 74 |
+
target=target,
|
| 75 |
+
**optional_kwargs,
|
| 76 |
+
)
|
| 77 |
+
if new_module is not None:
|
| 78 |
+
self._replace_module(parent, target_name, new_module, target)
|
| 79 |
+
|
| 80 |
+
def _replace_module(self, parent, child_name, new_module, child):
|
| 81 |
+
|
| 82 |
+
setattr(parent, child_name, new_module)
|
| 83 |
+
|
| 84 |
+
# child layer wraps the original module, unpack it
|
| 85 |
+
if hasattr(child, "base_layer"):
|
| 86 |
+
child = child.base_layer
|
| 87 |
+
|
| 88 |
+
meta = torch.device("meta")
|
| 89 |
+
# dispatch to correct device
|
| 90 |
+
for name, module in new_module.named_modules():
|
| 91 |
+
if (self.prefix in name) or ("ranknum" in name):
|
| 92 |
+
if hasattr(child, "qweight"):
|
| 93 |
+
weight = child.qweight
|
| 94 |
+
elif hasattr(child, "W_q"):
|
| 95 |
+
weight = child.W_q
|
| 96 |
+
elif hasattr(child, "weight"):
|
| 97 |
+
weight = child.weight
|
| 98 |
+
elif getattr(child, "in_proj_weight", None) is not None: # MHA
|
| 99 |
+
weight = child.in_proj_weight
|
| 100 |
+
else:
|
| 101 |
+
weight = next(child.parameters())
|
| 102 |
+
if not any(p.device == meta for p in module.parameters()):
|
| 103 |
+
module.to(weight.device)
|
| 104 |
+
|
| 105 |
+
def _mark_only_adapters_as_trainable(self, model):
|
| 106 |
+
|
| 107 |
+
# First, freeze all parameters
|
| 108 |
+
for n, p in model.named_parameters():
|
| 109 |
+
# print(f'{n}, np {p.requires_grad}')
|
| 110 |
+
if self.prefix not in n:
|
| 111 |
+
p.requires_grad = False
|
| 112 |
+
else:
|
| 113 |
+
p.requires_grad = True
|
| 114 |
+
|
| 115 |
+
# Handle bias parameters based on config
|
| 116 |
+
for active_adapter in self.active_adapters:
|
| 117 |
+
bias_config = self.peft_config[active_adapter].bias
|
| 118 |
+
|
| 119 |
+
if bias_config == "none":
|
| 120 |
+
continue
|
| 121 |
+
elif bias_config == "all":
|
| 122 |
+
# Enable all bias parameters
|
| 123 |
+
for n, p in model.named_parameters():
|
| 124 |
+
if "bias" in n:
|
| 125 |
+
p.requires_grad = True
|
| 126 |
+
elif bias_config == "rotation_only":
|
| 127 |
+
# Enable only bias in rotation layers
|
| 128 |
+
for name, m in model.named_modules():
|
| 129 |
+
if isinstance(m, RotationLayer):
|
| 130 |
+
if hasattr(m, "bias") and m.bias is not None:
|
| 131 |
+
m.bias.requires_grad = True
|
| 132 |
+
else:
|
| 133 |
+
raise NotImplementedError(
|
| 134 |
+
f"Requested bias configuration '{bias_config}' is not implemented. "
|
| 135 |
+
f"Supported values: 'none', 'all', 'rotation_only'"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
@staticmethod
|
| 139 |
+
def _create_new_module(
|
| 140 |
+
rotation_config,
|
| 141 |
+
adapter_name: str,
|
| 142 |
+
target: nn.Module,
|
| 143 |
+
**kwargs,
|
| 144 |
+
) -> Optional[nn.Module]:
|
| 145 |
+
"""
|
| 146 |
+
Create a new rotation-augmented module.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
rotation_config: Configuration for the rotation adapter
|
| 150 |
+
adapter_name: Name of the adapter
|
| 151 |
+
target: Base module to augment
|
| 152 |
+
**kwargs: Additional arguments
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
New RotationLayer module wrapping the target, or None if unsupported
|
| 156 |
+
"""
|
| 157 |
+
if isinstance(target, nn.Linear):
|
| 158 |
+
return Linear(
|
| 159 |
+
base_layer=target,
|
| 160 |
+
adapter_name=adapter_name,
|
| 161 |
+
r=rotation_config.r,
|
| 162 |
+
T=rotation_config.T,
|
| 163 |
+
num_rotations=rotation_config.num_rotations,
|
| 164 |
+
drop_out=rotation_config.drop_out,
|
| 165 |
+
**kwargs,
|
| 166 |
+
)
|
| 167 |
+
else:
|
| 168 |
+
# Unsupported layer type
|
| 169 |
+
print(
|
| 170 |
+
f"Rotation layer does not support {type(target).__name__} yet. "
|
| 171 |
+
f"Skipping this module."
|
| 172 |
+
)
|
| 173 |
+
return None
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def __getattr__(self, name: str):
|
| 177 |
+
"""Forward missing attributes to the wrapped module."""
|
| 178 |
+
try:
|
| 179 |
+
return super().__getattr__(name) # defer to nn.Module's logic
|
| 180 |
+
except AttributeError:
|
| 181 |
+
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
|
| 182 |
+
raise
|
| 183 |
+
return getattr(self.model, name)
|
| 184 |
+
|
| 185 |
+
def get_peft_config_as_dict(self, inference: bool = False):
|
| 186 |
+
config_dict = {}
|
| 187 |
+
for key, value in self.peft_config.items():
|
| 188 |
+
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()}
|
| 189 |
+
if inference:
|
| 190 |
+
config["inference_mode"] = True
|
| 191 |
+
config_dict[key] = config
|
| 192 |
+
return config
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _set_adapter_layers(self, enabled=True):
|
| 196 |
+
for module in self.model.modules():
|
| 197 |
+
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
|
| 198 |
+
module.enable_adapters(enabled)
|
| 199 |
+
|
| 200 |
+
def enable_adapter_layers(self) -> None:
|
| 201 |
+
"""Enable all adapters.
|
| 202 |
+
|
| 203 |
+
Call this if you have previously disabled all adapters and want to re-enable them.
|
| 204 |
+
"""
|
| 205 |
+
self._set_adapter_layers(enabled=True)
|
| 206 |
+
|
| 207 |
+
def disable_adapter_layers(self):
|
| 208 |
+
for active_adapter in self.active_adapters:
|
| 209 |
+
val = self.peft_config[active_adapter].bias
|
| 210 |
+
if val != "none":
|
| 211 |
+
msg = (
|
| 212 |
+
f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same "
|
| 213 |
+
"output as the base model would without adaption."
|
| 214 |
+
)
|
| 215 |
+
print(msg)
|
| 216 |
+
self._set_adapter_layers(enabled=False)
|
| 217 |
+
|
| 218 |
+
def set_adapter(self, adapter_name, inference_mode):
|
| 219 |
+
"""Set the active adapter(s).
|
| 220 |
+
|
| 221 |
+
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
|
| 222 |
+
not desired, use the following code.
|
| 223 |
+
|
| 224 |
+
```py
|
| 225 |
+
>>> for name, param in model_peft.named_parameters():
|
| 226 |
+
... if ...: # some check on name (ex. if 'lora' in name)
|
| 227 |
+
... param.requires_grad = False
|
| 228 |
+
```
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated.
|
| 232 |
+
"""
|
| 233 |
+
for module in self.model.modules():
|
| 234 |
+
if isinstance(module, RotationLayer):
|
| 235 |
+
if module.merged:
|
| 236 |
+
print("Adapter cannot be set when the model is merged. Unmerging the model first.")
|
| 237 |
+
module.unmerge()
|
| 238 |
+
module.set_adapter(adapter_name, inference_mode)
|
| 239 |
+
self.active_adapter = adapter_name
|
| 240 |
+
|
| 241 |
+
def merge_adapter(self, adapter_names: Optional[list[str]] = None) -> None:
|
| 242 |
+
"""
|
| 243 |
+
Merge adapter weights into the base model weights.
|
| 244 |
+
|
| 245 |
+
This can speed up inference by eliminating the need for runtime
|
| 246 |
+
rotation computations.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
adapter_names: List of adapter names to merge. If None, merges all
|
| 250 |
+
active adapters.
|
| 251 |
+
"""
|
| 252 |
+
for module in self.model.modules():
|
| 253 |
+
if isinstance(module, RotationLayer):
|
| 254 |
+
module.merge(safe_merge=False, adapter_names=adapter_names)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def unmerge_adapter(self) -> None:
|
| 258 |
+
"""
|
| 259 |
+
Unmerge adapter weights from the base model weights.
|
| 260 |
+
|
| 261 |
+
This reverses the merge operation, restoring dynamic adapter behavior.
|
| 262 |
+
"""
|
| 263 |
+
for module in self.model.modules():
|
| 264 |
+
if isinstance(module, RotationLayer):
|
| 265 |
+
module.unmerge()
|
| 266 |
+
|
| 267 |
+
@staticmethod
|
| 268 |
+
def _prepare_adapter_config(peft_config, model_config):
|
| 269 |
+
|
| 270 |
+
if peft_config.target_modules is None:
|
| 271 |
+
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING:
|
| 272 |
+
raise ValueError("Please specify `target_modules` in `peft_config`")
|
| 273 |
+
peft_config.target_modules = set(
|
| 274 |
+
TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING[model_config["model_type"]]
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
return peft_config
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def _check_new_adapter_config(self, config) -> None:
|
| 281 |
+
"""
|
| 282 |
+
Check the validity of a new adapter configuration.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
config: Configuration to validate
|
| 286 |
+
|
| 287 |
+
Raises:
|
| 288 |
+
ValueError: If configuration is invalid
|
| 289 |
+
"""
|
| 290 |
+
# Validate rank
|
| 291 |
+
if config.r <= 0:
|
| 292 |
+
raise ValueError(f"r must be positive, got {config.r}")
|
| 293 |
+
|
| 294 |
+
# Validate num_rotations
|
| 295 |
+
if config.num_rotations <= 0:
|
| 296 |
+
raise ValueError(
|
| 297 |
+
f"num_rotations must be positive, got {config.num_rotations}"
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
# Validate bias configuration
|
| 302 |
+
valid_bias_configs = ["none", "all", "rotation_only"]
|
| 303 |
+
if hasattr(config, "bias") and config.bias not in valid_bias_configs:
|
| 304 |
+
raise ValueError(
|
| 305 |
+
f"Invalid bias configuration '{config.bias}'. "
|
| 306 |
+
f"Must be one of {valid_bias_configs}"
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def _unload_and_optionally_merge(
|
| 311 |
+
self,
|
| 312 |
+
merge=True,
|
| 313 |
+
progressbar: bool = False,
|
| 314 |
+
safe_merge: bool = False,
|
| 315 |
+
adapter_names: Optional[list[str]] = None,
|
| 316 |
+
):
|
| 317 |
+
if merge:
|
| 318 |
+
self._check_merge_allowed()
|
| 319 |
+
|
| 320 |
+
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
|
| 321 |
+
desc = "Unloading " + ("and merging " if merge else "") + "model"
|
| 322 |
+
for key in tqdm(key_list, disable=not progressbar, desc=desc):
|
| 323 |
+
try:
|
| 324 |
+
parent, target, target_name = _get_submodules(self.model, key)
|
| 325 |
+
except AttributeError:
|
| 326 |
+
continue
|
| 327 |
+
with onload_layer(target):
|
| 328 |
+
if hasattr(target, "unload_and_optionally_merge_module"):
|
| 329 |
+
# if layers have special unloading method, like MultiheadAttention, use that
|
| 330 |
+
unloaded_module = target.unload_and_optionally_merge_module(
|
| 331 |
+
merge=merge, safe_merge=safe_merge, adapter_names=adapter_names
|
| 332 |
+
)
|
| 333 |
+
self._replace_module(parent, target_name, unloaded_module, target)
|
| 334 |
+
elif hasattr(target, "base_layer"):
|
| 335 |
+
if merge:
|
| 336 |
+
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
|
| 337 |
+
self._replace_module(parent, target_name, target.get_base_layer(), target)
|
| 338 |
+
|
| 339 |
+
return self.model
|
| 340 |
+
|
| 341 |
+
def delete_adapter(self, adapter_name: str) -> None:
|
| 342 |
+
"""
|
| 343 |
+
Deletes an existing adapter.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
adapter_name (str): Name of the adapter to be deleted.
|
| 347 |
+
"""
|
| 348 |
+
if adapter_name not in list(self.peft_config.keys()):
|
| 349 |
+
raise ValueError(f"Adapter {adapter_name} does not exist")
|
| 350 |
+
del self.peft_config[adapter_name]
|
| 351 |
+
|
| 352 |
+
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
|
| 353 |
+
new_adapter = None
|
| 354 |
+
for key in key_list:
|
| 355 |
+
_, target, _ = _get_submodules(self.model, key)
|
| 356 |
+
if isinstance(target, RotationLayer):
|
| 357 |
+
target.delete_adapter(adapter_name)
|
| 358 |
+
if new_adapter is None:
|
| 359 |
+
new_adapter = target.active_adapters[:]
|
| 360 |
+
|
| 361 |
+
self.active_adapter = new_adapter or []
|
| 362 |
+
self._delete_auxiliary_adapter(adapter_name, new_active_adapters=new_adapter)
|
| 363 |
+
|
| 364 |
+
def merge_and_unload(
|
| 365 |
+
self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
|
| 366 |
+
) -> torch.nn.Module:
|
| 367 |
+
r"""
|
| 368 |
+
This method merges the OFT layers into the base model. This is needed if someone wants to use the base model as
|
| 369 |
+
a standalone model.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
progressbar (`bool`):
|
| 373 |
+
whether to show a progressbar indicating the unload and merge process
|
| 374 |
+
safe_merge (`bool`):
|
| 375 |
+
whether to activate the safe merging check to check if there is any potential Nan in the adapter
|
| 376 |
+
weights
|
| 377 |
+
adapter_names (`List[str]`, *optional*):
|
| 378 |
+
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
| 379 |
+
to `None`.
|
| 380 |
+
|
| 381 |
+
"""
|
| 382 |
+
return self._unload_and_optionally_merge(
|
| 383 |
+
progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
def unload(self) -> torch.nn.Module:
|
| 387 |
+
"""
|
| 388 |
+
Gets back the base model by removing all the oft modules without merging. This gives back the original base
|
| 389 |
+
model.
|
| 390 |
+
"""
|
| 391 |
+
return self._unload_and_optionally_merge(merge=False)
|
| 392 |
+
|
nl_tasks/rpeft/rotation/rotation_config.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
# from peft.config import PeftConfig
|
| 4 |
+
from rpeft.utils import PeftConfig
|
| 5 |
+
from rpeft.utils import PeftType
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class RotationConfig(PeftConfig):
|
| 9 |
+
"""
|
| 10 |
+
Configuration class for Rotation-based Parameter-Efficient Fine-Tuning.
|
| 11 |
+
|
| 12 |
+
This configuration stores all parameters needed to apply the Rotation method
|
| 13 |
+
(based on Cayley transformation) to a model's linear layers.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
r (`int`):
|
| 17 |
+
The rank parameter for the low-rank approximation in rotation matrices.
|
| 18 |
+
T (`float`, *optional*, defaults to 1.0):
|
| 19 |
+
Temperature parameter for the transformation.
|
| 20 |
+
num_rotations (`int`, *optional*, defaults to 4):
|
| 21 |
+
Number of rotation matrices to use in parallel.
|
| 22 |
+
target_modules (`Union[List[str], str]`):
|
| 23 |
+
Module names to apply rotation to (e.g., ["q_proj", "v_proj"]).
|
| 24 |
+
target_modules_to_skip (`Union[List[str], str]`, *optional*):
|
| 25 |
+
Module names to skip when applying rotation.
|
| 26 |
+
modules_to_save (`Union[List[str], str]`, *optional*):
|
| 27 |
+
Modules to save in addition to rotation parameters.
|
| 28 |
+
layers_to_transform (`Union[List[int], int]`, *optional*):
|
| 29 |
+
Layers to transform. If None, all layers matching target_modules are transformed.
|
| 30 |
+
apply_before (`bool`, *optional*, defaults to False):
|
| 31 |
+
If True, apply rotation before the base linear layer. If False, apply after.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
peft_type: str = field(default="ROTATION", init=False)
|
| 35 |
+
target_modules: Optional[List[str]] = field(
|
| 36 |
+
default=None,
|
| 37 |
+
metadata={
|
| 38 |
+
"help": "List of module names to apply rotation to (e.g., ['q_proj', 'v_proj', 'linear'])"
|
| 39 |
+
},
|
| 40 |
+
)
|
| 41 |
+
target_modules_to_skip: Optional[List[str]] = field(
|
| 42 |
+
default=None,
|
| 43 |
+
metadata={"help": "List of module names to skip when applying rotation"},
|
| 44 |
+
)
|
| 45 |
+
modules_to_save: Optional[List[str]] = field(
|
| 46 |
+
default=None,
|
| 47 |
+
metadata={"help": "List of modules to save in addition to rotation parameters"},
|
| 48 |
+
)
|
| 49 |
+
r: int = field(
|
| 50 |
+
default=8,
|
| 51 |
+
metadata={"help": "Rank parameter for low-rank approximation"},
|
| 52 |
+
)
|
| 53 |
+
T: float = field(
|
| 54 |
+
default=1.0,
|
| 55 |
+
metadata={"help": "Temperature parameter for Cayley transformation"},
|
| 56 |
+
)
|
| 57 |
+
num_rotations: int = field(
|
| 58 |
+
default=4,
|
| 59 |
+
metadata={"help": "Number of rotation matrices to use in parallel"},
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
bias: str = field(
|
| 63 |
+
default="none",
|
| 64 |
+
metadata={
|
| 65 |
+
"help": "Bias training configuration. Options: 'none', 'all', 'rotation_only'"
|
| 66 |
+
}
|
| 67 |
+
)
|
| 68 |
+
layers_to_transform: Optional[List[int]] = field(
|
| 69 |
+
default=None,
|
| 70 |
+
metadata={"help": "Layers to transform. If None, all matching layers are transformed"},
|
| 71 |
+
)
|
| 72 |
+
drop_out: float = field(
|
| 73 |
+
default=0.0,
|
| 74 |
+
metadata={
|
| 75 |
+
'help': 'intput drop out rate'
|
| 76 |
+
}
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def __post_init__(self):
|
| 80 |
+
##### Diff
|
| 81 |
+
self.peft_type = PeftType.ROTATION
|
| 82 |
+
self.target_modules = (
|
| 83 |
+
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
|
| 84 |
+
)
|
| 85 |
+
self.target_modules_to_skip = (
|
| 86 |
+
set(self.target_modules_to_skip)
|
| 87 |
+
if isinstance(self.target_modules_to_skip, list)
|
| 88 |
+
else self.target_modules_to_skip
|
| 89 |
+
)
|
nl_tasks/rpeft/utils/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
| 3 |
+
# module, but to preserve other warnings. So, don't check this module at all
|
| 4 |
+
|
| 5 |
+
# coding=utf-8
|
| 6 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
from .adapters_utils import CONFIG_NAME, WEIGHTS_NAME
|
| 20 |
+
from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType
|
| 21 |
+
from .other import (
|
| 22 |
+
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
|
| 23 |
+
_set_trainable,
|
| 24 |
+
bloom_model_postprocess_past_key_value,
|
| 25 |
+
prepare_model_for_int8_training,
|
| 26 |
+
shift_tokens_right,
|
| 27 |
+
transpose,
|
| 28 |
+
)
|
| 29 |
+
from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict
|
nl_tasks/rpeft/utils/adapters_utils.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Original License:
|
| 3 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
WEIGHTS_NAME = "adapter_model.bin"
|
| 17 |
+
CONFIG_NAME = "adapter_config.json"
|
| 18 |
+
|
| 19 |
+
# TODO: add automapping and superclass here?
|
nl_tasks/rpeft/utils/config.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Original License:
|
| 3 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
import enum
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
from dataclasses import asdict, dataclass, field
|
| 20 |
+
from typing import Optional, Union
|
| 21 |
+
|
| 22 |
+
from huggingface_hub import hf_hub_download
|
| 23 |
+
from transformers.utils import PushToHubMixin, http_user_agent
|
| 24 |
+
|
| 25 |
+
from .adapters_utils import CONFIG_NAME
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PeftType(str, enum.Enum):
|
| 29 |
+
PROMPT_TUNING = "PROMPT_TUNING"
|
| 30 |
+
P_TUNING = "P_TUNING"
|
| 31 |
+
PREFIX_TUNING = "PREFIX_TUNING"
|
| 32 |
+
LORA = "LORA"
|
| 33 |
+
BOTTLENECK = "BOTTLENECK"
|
| 34 |
+
QUANTA = "QUANTA"
|
| 35 |
+
|
| 36 |
+
ROTATION = "ROTATION"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class TaskType(str, enum.Enum):
|
| 41 |
+
SEQ_CLS = "SEQ_CLS"
|
| 42 |
+
SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM"
|
| 43 |
+
CAUSAL_LM = "CAUSAL_LM"
|
| 44 |
+
TOKEN_CLS = "TOKEN_CLS"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class PeftConfigMixin(PushToHubMixin):
|
| 49 |
+
r"""
|
| 50 |
+
This is the base configuration class for PEFT adapter models. It contains all the methods that are common to all
|
| 51 |
+
PEFT adapter models. This class inherits from `transformers.utils.PushToHubMixin` which contains the methods to
|
| 52 |
+
push your model to the Hub. The method `save_pretrained` will save the configuration of your adapter model in a
|
| 53 |
+
directory. The method `from_pretrained` will load the configuration of your adapter model from a directory.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
peft_type (Union[[`~peft_local_tensor.utils.config.PeftType`], `str`]): The type of Peft method to use.
|
| 57 |
+
"""
|
| 58 |
+
peft_type: Optional[PeftType] = field(default=None, metadata={"help": "The type of PEFT model."})
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def __dict__(self):
|
| 62 |
+
return asdict(self)
|
| 63 |
+
|
| 64 |
+
def to_dict(self):
|
| 65 |
+
return self.__dict__
|
| 66 |
+
|
| 67 |
+
@classmethod
|
| 68 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional[str] = None, **kwargs):
|
| 69 |
+
r"""
|
| 70 |
+
This method loads the configuration of your adapter model from a directory.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
pretrained_model_name_or_path (`str`):
|
| 74 |
+
The directory or the Hub repository id where the configuration is saved.
|
| 75 |
+
kwargs (additional keyword arguments, *optional*):
|
| 76 |
+
Additional keyword arguments passed along to the child class initialization.
|
| 77 |
+
"""
|
| 78 |
+
path = (
|
| 79 |
+
os.path.join(pretrained_model_name_or_path, subfolder)
|
| 80 |
+
if subfolder is not None
|
| 81 |
+
else pretrained_model_name_or_path
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
hf_hub_download_kwargs, class_kwargs, _ = cls._split_kwargs(kwargs)
|
| 85 |
+
if "user_agent" not in hf_hub_download_kwargs:
|
| 86 |
+
hf_hub_download_kwargs["user_agent"] = http_user_agent()
|
| 87 |
+
|
| 88 |
+
if os.path.isfile(os.path.join(path, CONFIG_NAME)):
|
| 89 |
+
config_file = os.path.join(path, CONFIG_NAME)
|
| 90 |
+
else:
|
| 91 |
+
try:
|
| 92 |
+
config_file = hf_hub_download(
|
| 93 |
+
pretrained_model_name_or_path, CONFIG_NAME, subfolder=subfolder, **hf_hub_download_kwargs
|
| 94 |
+
)
|
| 95 |
+
except Exception as exc:
|
| 96 |
+
raise ValueError(f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'") from exc
|
| 97 |
+
|
| 98 |
+
loaded_attributes = cls.from_json_file(config_file)
|
| 99 |
+
kwargs = {**class_kwargs, **loaded_attributes}
|
| 100 |
+
kwargs = cls.check_kwargs(**kwargs)
|
| 101 |
+
return cls.from_peft_type(**kwargs)
|
| 102 |
+
|
| 103 |
+
def save_pretrained(self, save_directory, **kwargs):
|
| 104 |
+
r"""
|
| 105 |
+
This method saves the configuration of your adapter model in a directory.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
save_directory (`str`):
|
| 109 |
+
The directory where the configuration will be saved.
|
| 110 |
+
**kwargs:
|
| 111 |
+
Additional keyword arguments passed along to the `transformers.utils.PushToHubMixin.push_to_hub`
|
| 112 |
+
method.
|
| 113 |
+
"""
|
| 114 |
+
if os.path.isfile(save_directory):
|
| 115 |
+
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
| 116 |
+
|
| 117 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 118 |
+
auto_mapping_dict = kwargs.pop("auto_mapping_dict", None)
|
| 119 |
+
|
| 120 |
+
output_dict = self.to_dict()
|
| 121 |
+
# converting set type to list
|
| 122 |
+
for key, value in output_dict.items():
|
| 123 |
+
if isinstance(value, set):
|
| 124 |
+
output_dict[key] = list(value)
|
| 125 |
+
|
| 126 |
+
output_path = os.path.join(save_directory, CONFIG_NAME)
|
| 127 |
+
|
| 128 |
+
# Add auto mapping details for custom models.
|
| 129 |
+
if auto_mapping_dict is not None:
|
| 130 |
+
output_dict["auto_mapping"] = auto_mapping_dict
|
| 131 |
+
|
| 132 |
+
# save it
|
| 133 |
+
with open(output_path, "w") as writer:
|
| 134 |
+
writer.write(json.dumps(output_dict, indent=2, sort_keys=True))
|
| 135 |
+
|
| 136 |
+
@classmethod
|
| 137 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
| 138 |
+
r"""
|
| 139 |
+
This method loads the configuration of your adapter model from a directory.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
pretrained_model_name_or_path (`str`):
|
| 143 |
+
The directory or the hub-id where the configuration is saved.
|
| 144 |
+
**kwargs:
|
| 145 |
+
Additional keyword arguments passed along to the child class initialization.
|
| 146 |
+
"""
|
| 147 |
+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)):
|
| 148 |
+
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
| 149 |
+
else:
|
| 150 |
+
try:
|
| 151 |
+
config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME)
|
| 152 |
+
except Exception:
|
| 153 |
+
raise ValueError(f"Can't find config.json at '{pretrained_model_name_or_path}'")
|
| 154 |
+
|
| 155 |
+
loaded_attributes = cls.from_json_file(config_file)
|
| 156 |
+
|
| 157 |
+
config = cls(**kwargs)
|
| 158 |
+
|
| 159 |
+
for key, value in loaded_attributes.items():
|
| 160 |
+
if hasattr(config, key):
|
| 161 |
+
setattr(config, key, value)
|
| 162 |
+
|
| 163 |
+
return config
|
| 164 |
+
|
| 165 |
+
@classmethod
|
| 166 |
+
def from_json_file(cls, path_json_file, **kwargs):
|
| 167 |
+
r"""
|
| 168 |
+
Loads a configuration file from a json file.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
path_json_file (`str`):
|
| 172 |
+
The path to the json file.
|
| 173 |
+
"""
|
| 174 |
+
with open(path_json_file, "r") as file:
|
| 175 |
+
json_object = json.load(file)
|
| 176 |
+
|
| 177 |
+
return json_object
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
@dataclass
|
| 181 |
+
class PeftConfig(PeftConfigMixin):
|
| 182 |
+
"""
|
| 183 |
+
This is the base configuration class to store the configuration of a :class:`~peft_local_tensor.PeftModel`.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
peft_type (Union[[`~peft_local_tensor.utils.config.PeftType`], `str`]): The type of Peft method to use.
|
| 187 |
+
task_type (Union[[`~peft_local_tensor.utils.config.TaskType`], `str`]): The type of task to perform.
|
| 188 |
+
inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode.
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
base_model_name_or_path: str = field(default=None, metadata={"help": "The name of the base model to use."})
|
| 192 |
+
revision: Optional[str] = field(default=None, metadata={"help": "The specific base model version to use."})
|
| 193 |
+
peft_type: Union[str, PeftType] = field(default=None, metadata={"help": "Peft type"})
|
| 194 |
+
task_type: Union[str, TaskType] = field(default=None, metadata={"help": "Task type"})
|
| 195 |
+
inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"})
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@dataclass
|
| 199 |
+
class PromptLearningConfig(PeftConfig):
|
| 200 |
+
"""
|
| 201 |
+
This is the base configuration class to store the configuration of a Union[[`~peft_local_tensor.PrefixTuning`],
|
| 202 |
+
[`~peft_local_tensor.PromptEncoder`], [`~peft_local_tensor.PromptTuning`]].
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
num_virtual_tokens (`int`): The number of virtual tokens to use.
|
| 206 |
+
token_dim (`int`): The hidden embedding dimension of the base transformer model.
|
| 207 |
+
num_transformer_submodules (`int`): The number of transformer submodules in the base transformer model.
|
| 208 |
+
num_attention_heads (`int`): The number of attention heads in the base transformer model.
|
| 209 |
+
num_layers (`int`): The number of layers in the base transformer model.
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
num_virtual_tokens: int = field(default=None, metadata={"help": "Number of virtual tokens"})
|
| 213 |
+
token_dim: int = field(
|
| 214 |
+
default=None, metadata={"help": "The hidden embedding dimension of the base transformer model"}
|
| 215 |
+
)
|
| 216 |
+
num_transformer_submodules: Optional[int] = field(
|
| 217 |
+
default=None, metadata={"help": "Number of transformer submodules"}
|
| 218 |
+
)
|
| 219 |
+
num_attention_heads: Optional[int] = field(default=None, metadata={"help": "Number of attention heads"})
|
| 220 |
+
num_layers: Optional[int] = field(default=None, metadata={"help": "Number of transformer layers"})
|
nl_tasks/rpeft/utils/other.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Original License:
|
| 3 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# needed for prefix-tuning of bloom model
|
| 21 |
+
def bloom_model_postprocess_past_key_value(past_key_values):
|
| 22 |
+
past_key_values = torch.cat(past_key_values)
|
| 23 |
+
total_layers, batch_size, num_attention_heads, num_virtual_tokens, head_dim = past_key_values.shape
|
| 24 |
+
keys = past_key_values[: total_layers // 2]
|
| 25 |
+
keys = keys.transpose(2, 3).reshape(
|
| 26 |
+
total_layers // 2, batch_size * num_attention_heads, head_dim, num_virtual_tokens
|
| 27 |
+
)
|
| 28 |
+
values = past_key_values[total_layers // 2 :]
|
| 29 |
+
values = values.reshape(total_layers // 2, batch_size * num_attention_heads, num_virtual_tokens, head_dim)
|
| 30 |
+
|
| 31 |
+
return tuple(zip(keys, values))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def prepare_model_for_int8_training(
|
| 35 |
+
model, output_embedding_layer_name="lm_head", use_gradient_checkpointing=True, layer_norm_names=["layer_norm"]
|
| 36 |
+
):
|
| 37 |
+
r"""
|
| 38 |
+
This method wrapps the entire protocol for preparing a model before running a training. This includes:
|
| 39 |
+
1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
|
| 40 |
+
head to fp32
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
model, (`transformers.PreTrainedModel`):
|
| 44 |
+
The loaded model from `transformers`
|
| 45 |
+
"""
|
| 46 |
+
loaded_in_8bit = getattr(model, "is_loaded_in_8bit", False)
|
| 47 |
+
|
| 48 |
+
for name, param in model.named_parameters():
|
| 49 |
+
# freeze base model's layers
|
| 50 |
+
param.requires_grad = False
|
| 51 |
+
|
| 52 |
+
if loaded_in_8bit:
|
| 53 |
+
# cast layer norm in fp32 for stability for 8bit models
|
| 54 |
+
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
| 55 |
+
param.data = param.data.to(torch.float32)
|
| 56 |
+
|
| 57 |
+
if loaded_in_8bit and use_gradient_checkpointing:
|
| 58 |
+
# For backward compatibility
|
| 59 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 60 |
+
model.enable_input_require_grads()
|
| 61 |
+
else:
|
| 62 |
+
|
| 63 |
+
def make_inputs_require_grad(module, input, output):
|
| 64 |
+
output.requires_grad_(True)
|
| 65 |
+
|
| 66 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 67 |
+
|
| 68 |
+
# enable gradient checkpointing for memory efficiency
|
| 69 |
+
model.gradient_checkpointing_enable()
|
| 70 |
+
|
| 71 |
+
if hasattr(model, output_embedding_layer_name):
|
| 72 |
+
output_embedding_layer = getattr(model, output_embedding_layer_name)
|
| 73 |
+
input_dtype = output_embedding_layer.weight.dtype
|
| 74 |
+
|
| 75 |
+
class CastOutputToFloat(torch.nn.Sequential):
|
| 76 |
+
r"""
|
| 77 |
+
Manually cast to the expected dtype of the lm_head as sometimes there is a final layer norm that is casted
|
| 78 |
+
in fp32
|
| 79 |
+
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
return super().forward(x.to(input_dtype)).to(torch.float32)
|
| 84 |
+
|
| 85 |
+
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
|
| 86 |
+
|
| 87 |
+
return model
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = {
|
| 91 |
+
"bloom": bloom_model_postprocess_past_key_value,
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# copied from transformers.models.bart.modeling_bart
|
| 96 |
+
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
| 97 |
+
"""
|
| 98 |
+
Shift input ids one token to the right.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids
|
| 102 |
+
pad_token_id (`int`): The id of the `padding` token.
|
| 103 |
+
decoder_start_token_id (`int`): The id of the `start` token.
|
| 104 |
+
"""
|
| 105 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
| 106 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
| 107 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
| 108 |
+
|
| 109 |
+
if pad_token_id is None:
|
| 110 |
+
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
| 111 |
+
# replace possible -100 values in labels by `pad_token_id`
|
| 112 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
| 113 |
+
|
| 114 |
+
return shifted_input_ids
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _set_trainable(model):
|
| 118 |
+
if model.modules_to_save is not None:
|
| 119 |
+
for name, param in model.named_parameters():
|
| 120 |
+
if any(module_name in name for module_name in model.modules_to_save):
|
| 121 |
+
param.requires_grad = True
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def fsdp_auto_wrap_policy(model):
|
| 125 |
+
import functools
|
| 126 |
+
import os
|
| 127 |
+
|
| 128 |
+
from accelerate import FullyShardedDataParallelPlugin
|
| 129 |
+
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
|
| 130 |
+
|
| 131 |
+
from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
|
| 132 |
+
|
| 133 |
+
def lambda_policy_fn(module):
|
| 134 |
+
if (
|
| 135 |
+
len(list(module.named_children())) == 0
|
| 136 |
+
and getattr(module, "weight", None) is not None
|
| 137 |
+
and module.weight.requires_grad
|
| 138 |
+
):
|
| 139 |
+
return True
|
| 140 |
+
return False
|
| 141 |
+
|
| 142 |
+
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
|
| 143 |
+
transformer_wrap_policy = functools.partial(
|
| 144 |
+
transformer_auto_wrap_policy,
|
| 145 |
+
transformer_layer_cls=(
|
| 146 |
+
PrefixEncoder,
|
| 147 |
+
PromptEncoder,
|
| 148 |
+
PromptEmbedding,
|
| 149 |
+
FullyShardedDataParallelPlugin.get_module_class_from_name(
|
| 150 |
+
model, os.environ.get("FSDP_TRANSFORMER_CLS_TO_WRAP", "")
|
| 151 |
+
),
|
| 152 |
+
),
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
|
| 156 |
+
return auto_wrap_policy
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def transpose(weight, fan_in_fan_out):
|
| 160 |
+
return weight.T if fan_in_fan_out else weight
|
nl_tasks/rpeft/utils/save_and_load.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Original License:
|
| 3 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from .config import PeftType
|
| 18 |
+
import warnings
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
def _find_mismatched_keys(
|
| 22 |
+
model: torch.nn.Module, peft_model_state_dict: dict[str, torch.Tensor], ignore_mismatched_sizes: bool = False
|
| 23 |
+
) -> tuple[dict[str, torch.Tensor], list[tuple[str, tuple[int, ...], tuple[int, ...]]]]:
|
| 24 |
+
if not ignore_mismatched_sizes:
|
| 25 |
+
return peft_model_state_dict, []
|
| 26 |
+
|
| 27 |
+
mismatched = []
|
| 28 |
+
state_dict = model.state_dict()
|
| 29 |
+
for key, tensor in peft_model_state_dict.items():
|
| 30 |
+
if key not in state_dict:
|
| 31 |
+
continue
|
| 32 |
+
|
| 33 |
+
# see https://github.com/huggingface/transformers/blob/09f9f566de83eef1f13ee83b5a1bbeebde5c80c1/src/transformers/modeling_utils.py#L3858-L3864
|
| 34 |
+
if (state_dict[key].shape[-1] == 1) and (state_dict[key].numel() * 2 == tensor.numel()):
|
| 35 |
+
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size
|
| 36 |
+
# differences. Without matching with module type or parameter type it seems like a practical way to detect
|
| 37 |
+
# valid 4bit weights.
|
| 38 |
+
continue
|
| 39 |
+
|
| 40 |
+
if state_dict[key].shape != tensor.shape:
|
| 41 |
+
mismatched.append((key, tensor.shape, state_dict[key].shape))
|
| 42 |
+
|
| 43 |
+
for key, _, _ in mismatched:
|
| 44 |
+
del peft_model_state_dict[key]
|
| 45 |
+
|
| 46 |
+
return peft_model_state_dict, mismatched
|
| 47 |
+
|
| 48 |
+
def get_peft_model_state_dict(model, state_dict=None):
|
| 49 |
+
"""
|
| 50 |
+
Get the state dict of the Peft model.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP,
|
| 54 |
+
the model should be the underlying model/unwrapped model (i.e. model.module).
|
| 55 |
+
state_dict (`dict`, *optional*, defaults to `None`):
|
| 56 |
+
The state dict of the model. If not provided, the state dict of the model
|
| 57 |
+
will be used.
|
| 58 |
+
"""
|
| 59 |
+
if state_dict is None:
|
| 60 |
+
state_dict = model.state_dict()
|
| 61 |
+
if model.peft_config.peft_type == PeftType.LORA:
|
| 62 |
+
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
|
| 63 |
+
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
|
| 64 |
+
# to directly with the state dict which is necessary when using DeepSpeed or FSDP
|
| 65 |
+
bias = model.peft_config.bias
|
| 66 |
+
if bias == "none":
|
| 67 |
+
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k}
|
| 68 |
+
elif bias == "all":
|
| 69 |
+
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k}
|
| 70 |
+
elif bias == "lora_only":
|
| 71 |
+
to_return = {}
|
| 72 |
+
for k in state_dict:
|
| 73 |
+
if "lora_" in k:
|
| 74 |
+
to_return[k] = state_dict[k]
|
| 75 |
+
bias_name = k.split("lora_")[0] + "bias"
|
| 76 |
+
if bias_name in state_dict:
|
| 77 |
+
to_return[bias_name] = state_dict[bias_name]
|
| 78 |
+
else:
|
| 79 |
+
raise NotImplementedError
|
| 80 |
+
elif model.peft_config.peft_type == PeftType.BOTTLENECK:
|
| 81 |
+
# return the state dict of the model with Bottleneck adapters
|
| 82 |
+
bias = model.peft_config.bias
|
| 83 |
+
if bias == "none":
|
| 84 |
+
to_return = {k: state_dict[k] for k in state_dict if "adapter_" in k}
|
| 85 |
+
elif bias == "all":
|
| 86 |
+
to_return = {k: state_dict[k] for k in state_dict if "adapter_" in k or "bias" in k}
|
| 87 |
+
elif bias == "adapter_only":
|
| 88 |
+
to_return = {}
|
| 89 |
+
for k in state_dict:
|
| 90 |
+
if "adapter_" in k:
|
| 91 |
+
to_return[k] = state_dict[k]
|
| 92 |
+
bias_name = k.split("adapter_")[0] + "bias"
|
| 93 |
+
if bias_name in state_dict:
|
| 94 |
+
to_return[bias_name] = state_dict[bias_name]
|
| 95 |
+
else:
|
| 96 |
+
raise NotImplementedError
|
| 97 |
+
|
| 98 |
+
elif model.peft_config.peft_type == PeftType.ROTATION:
|
| 99 |
+
bias = model.peft_config.bias
|
| 100 |
+
if bias == "none":
|
| 101 |
+
to_return = {k: state_dict[k] for k in state_dict if "rotation" in k}
|
| 102 |
+
elif bias == "all":
|
| 103 |
+
to_return = {k: state_dict[k] for k in state_dict if "rotation" in k or "bias" in k}
|
| 104 |
+
elif bias == "rotation_only":
|
| 105 |
+
to_return = {}
|
| 106 |
+
for k in state_dict:
|
| 107 |
+
if "rotation" in k:
|
| 108 |
+
to_return[k] = state_dict[k]
|
| 109 |
+
bias_name = k.split("rotation")[0] + "bias"
|
| 110 |
+
if bias_name in state_dict:
|
| 111 |
+
to_return[bias_name] = state_dict[bias_name]
|
| 112 |
+
else:
|
| 113 |
+
raise NotImplementedError
|
| 114 |
+
|
| 115 |
+
elif model.peft_config.is_prompt_learning:
|
| 116 |
+
to_return = {}
|
| 117 |
+
if model.peft_config.inference_mode:
|
| 118 |
+
prompt_embeddings = model.prompt_encoder.embedding.weight
|
| 119 |
+
else:
|
| 120 |
+
prompt_embeddings = model.get_prompt_embedding_to_save()
|
| 121 |
+
to_return["prompt_embeddings"] = prompt_embeddings
|
| 122 |
+
else:
|
| 123 |
+
raise NotImplementedError
|
| 124 |
+
if model.modules_to_save is not None:
|
| 125 |
+
for key, value in state_dict.items():
|
| 126 |
+
if any(module_name in key for module_name in model.modules_to_save):
|
| 127 |
+
to_return[key] = value
|
| 128 |
+
return to_return
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def set_peft_model_state_dict(model, peft_model_state_dict,
|
| 132 |
+
adapter_name="default",
|
| 133 |
+
ignore_mismatched_sizes: bool = False):
|
| 134 |
+
"""
|
| 135 |
+
Set the state dict of the Peft model.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
model ([`PeftModel`]): The Peft model.
|
| 139 |
+
peft_model_state_dict (`dict`): The state dict of the Peft model.
|
| 140 |
+
adapter_name (`str`, *optional*, defaults to `"default"`):
|
| 141 |
+
The name of the adapter whose state dict should be set.
|
| 142 |
+
"""
|
| 143 |
+
peft_model_state_dict, mismatched_keys = _find_mismatched_keys(
|
| 144 |
+
model, peft_model_state_dict, ignore_mismatched_sizes=ignore_mismatched_sizes
|
| 145 |
+
)
|
| 146 |
+
if mismatched_keys:
|
| 147 |
+
# see https://github.com/huggingface/transformers/blob/09f9f566de83eef1f13ee83b5a1bbeebde5c80c1/src/transformers/modeling_utils.py#L4039
|
| 148 |
+
mismatched_warning = "\n".join(
|
| 149 |
+
[
|
| 150 |
+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
| 151 |
+
for key, shape1, shape2 in mismatched_keys
|
| 152 |
+
]
|
| 153 |
+
)
|
| 154 |
+
msg = (
|
| 155 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint "
|
| 156 |
+
f"and are being ignored because you passed `ignore_mismatched_sizes=True`: {mismatched_warning}."
|
| 157 |
+
)
|
| 158 |
+
warnings.warn(msg)
|
| 159 |
+
|
| 160 |
+
model.load_state_dict(peft_model_state_dict, strict=False)
|
| 161 |
+
if model.peft_config.peft_type != PeftType.LORA and model.peft_config.peft_type != PeftType.BOTTLENECK \
|
| 162 |
+
and model.peft_config.peft_type != PeftType.ROTATION:
|
| 163 |
+
model.prompt_encoder.embedding.load_state_dict(
|
| 164 |
+
{"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True
|
| 165 |
+
)
|
| 166 |
+
return model
|
nl_tasks/scripts/.nfs80e7f26e00566c630000664a
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# export OMINI_CONFIG=./config/commonsense.yaml
|
| 2 |
+
export OMINI_CONFIG=./config/commonsense.yaml
|
| 3 |
+
#echo $OMINI_CONFIG
|
| 4 |
+
export TOKENIZERS_PARALLELISM=true
|
| 5 |
+
|
| 6 |
+
# CUDA Include (/cuda.h)
|
| 7 |
+
CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
|
| 8 |
+
|
| 9 |
+
# 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
|
| 10 |
+
export CPATH=$CPATH:$CUDA_INCLUDE_PATH
|
| 11 |
+
export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
|
| 12 |
+
# echo "CPATH is set to: $CPATH"
|
| 13 |
+
# echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
|
| 14 |
+
|
| 15 |
+
#./exps/run_ex15_3ep
|
| 16 |
+
# ./exp_init/run_ex02/ft2
|
| 17 |
+
# ADAPTER = "--model.merge_adapter_path "./exps/run_ex12/ft2" --model.merge_output_path "./exps/run_ex12/merged"
|
| 18 |
+
# export ADAPTER = "--model.merge_adapter_path ./exp395/run_ex01/ft2 --model.merge_output_path ./exp395/run_ex01/merged"
|
| 19 |
+
# accelerate launch --main_process_port 41353 -m src.merge \
|
| 20 |
+
# --config_path $OMINI_CONFIG \
|
| 21 |
+
# --model.merge_adapter_path "./exps/run_ex19_2ep/ft2" --model.merge_output_path "./exps/run_ex19_2ep/merged"
|
| 22 |
+
|
| 23 |
+
# OUTPUT="./exps/run_ex19_2ep/merged"
|
| 24 |
+
|
| 25 |
+
# date +"%F %T"
|
| 26 |
+
# python inference/MATH_infer.py --model $OUTPUT
|
| 27 |
+
# date +"%F %T"
|
| 28 |
+
# python inference/gsm8k_infer.py --model $OUTPUT
|
| 29 |
+
# date +"%F %T"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# MERGE_DIR="./exps/run_ex24"
|
| 33 |
+
# accelerate launch --main_process_port 41353 -m src.merge \
|
| 34 |
+
# --config_path $OMINI_CONFIG \
|
| 35 |
+
# --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 36 |
+
|
| 37 |
+
# date +"%F %T"
|
| 38 |
+
# python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 39 |
+
# date +"%F %T"
|
| 40 |
+
# python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 41 |
+
# date +"%F %T"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# MERGE_DIR="./exps/run_ex25"
|
| 45 |
+
# accelerate launch --main_process_port 41353 -m src.merge \
|
| 46 |
+
# --config_path $OMINI_CONFIG \
|
| 47 |
+
# --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 48 |
+
|
| 49 |
+
# date +"%F %T"
|
| 50 |
+
# python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 51 |
+
# date +"%F %T"
|
| 52 |
+
# python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 53 |
+
# date +"%F %T"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# MERGE_DIR="./exps/run_ex26"
|
| 57 |
+
# accelerate launch --main_process_port 41353 -m src.merge \
|
| 58 |
+
# --config_path $OMINI_CONFIG \
|
| 59 |
+
# --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 60 |
+
|
| 61 |
+
# date +"%F %T"
|
| 62 |
+
# python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 63 |
+
# date +"%F %T"
|
| 64 |
+
# python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 65 |
+
# date +"%F %T"
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# MERGE_DIR="./exps/run_ex27"
|
| 69 |
+
# accelerate launch --main_process_port 41353 -m src.merge \
|
| 70 |
+
# --config_path $OMINI_CONFIG \
|
| 71 |
+
# --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 72 |
+
|
| 73 |
+
# date +"%F %T"
|
| 74 |
+
# python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 75 |
+
# date +"%F %T"
|
| 76 |
+
# python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 77 |
+
# date +"%F %T"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# MERGE_DIR="./exps/run_ex28"
|
| 81 |
+
# accelerate launch --main_process_port 41353 -m src.merge \
|
| 82 |
+
# --config_path $OMINI_CONFIG \
|
| 83 |
+
# --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 84 |
+
|
| 85 |
+
# date +"%F %T"
|
| 86 |
+
# python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 87 |
+
# date +"%F %T"
|
| 88 |
+
# python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 89 |
+
# date +"%F %T"
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
MERGE_DIR="./exps/run_ex33"
|
| 93 |
+
# accelerate launch --main_process_port 41353 -m src.merge \
|
| 94 |
+
# --config_path $OMINI_CONFIG \
|
| 95 |
+
# --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 96 |
+
|
| 97 |
+
# date +"%F %T"
|
| 98 |
+
# python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 99 |
+
# date +"%F %T"
|
| 100 |
+
# python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 101 |
+
# date +"%F %T"
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
MERGE_DIR="./exps/run_ex34"
|
| 105 |
+
accelerate launch --main_process_port 41353 -m src.merge \
|
| 106 |
+
--config_path $OMINI_CONFIG \
|
| 107 |
+
--model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 108 |
+
|
| 109 |
+
# date +"%F %T"
|
| 110 |
+
# python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 111 |
+
date +"%F %T"
|
| 112 |
+
python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 113 |
+
date +"%F %T"
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
nl_tasks/scripts/.nfs80e7f26e0132942e00006649
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export OMINI_CONFIG=./config/commonsense.yaml
|
| 2 |
+
|
| 3 |
+
#echo $OMINI_CONFIG
|
| 4 |
+
export TOKENIZERS_PARALLELISM=true
|
| 5 |
+
|
| 6 |
+
# CUDA Include (/cuda.h)
|
| 7 |
+
CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
|
| 8 |
+
|
| 9 |
+
# 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
|
| 10 |
+
export CPATH=$CPATH:$CUDA_INCLUDE_PATH
|
| 11 |
+
export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
|
| 12 |
+
# echo "CPATH is set to: $CPATH"
|
| 13 |
+
# echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
|
| 14 |
+
|
| 15 |
+
export WANDB_PROJECT="Llama2_7B_FT_Math40k_2"
|
| 16 |
+
|
| 17 |
+
export OMP_NUM_THREADS=1
|
| 18 |
+
export MKL_NUM_THREADS=1
|
| 19 |
+
export OPENBLAS_NUM_THREADS=1
|
| 20 |
+
export NUMEXPR_NUM_THREADS=1
|
| 21 |
+
|
| 22 |
+
date +"%F %T"
|
| 23 |
+
|
| 24 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 25 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp_init/run_ex01" --trainer_args.learning_rate=1e-3 \
|
| 26 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text "init|kaim_out_u=v"
|
| 27 |
+
|
| 28 |
+
# sleep 5
|
| 29 |
+
# echo "1st exp finishes"
|
| 30 |
+
# date +"%F %T"
|
| 31 |
+
# wandb sync wandb/latest-run
|
| 32 |
+
|
| 33 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 34 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp_init/run_ex02" --trainer_args.learning_rate=1e-3 \
|
| 35 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text "init|kaim_out_u=v(ratio)"
|
| 36 |
+
|
| 37 |
+
# sleep 5
|
| 38 |
+
# echo "2nd exp finishes"
|
| 39 |
+
# date +"%F %T"
|
| 40 |
+
# wandb sync wandb/latest-run
|
| 41 |
+
|
| 42 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 43 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex03" --trainer_args.learning_rate=1e-3 \
|
| 44 |
+
# --rotation_adapter_config.num_rotations 2 --rotation_adapter_config.r 8
|
| 45 |
+
|
| 46 |
+
# sleep 5
|
| 47 |
+
# echo "3rd exp finishes"
|
| 48 |
+
# date +"%F %T"
|
| 49 |
+
# wandb sync wandb/latest-run
|
| 50 |
+
|
| 51 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 52 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex04" --trainer_args.learning_rate=2e-3 \
|
| 53 |
+
# --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
|
| 54 |
+
# sleep 5
|
| 55 |
+
# echo "4th exp finishes"
|
| 56 |
+
# date +"%F %T"
|
| 57 |
+
# wandb sync wandb/latest-run
|
| 58 |
+
|
| 59 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 60 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex05" --trainer_args.learning_rate=2e-3 \
|
| 61 |
+
# --rotation_adapter_config.num_rotations 2 --rotation_adapter_config.r 8
|
| 62 |
+
# sleep 5
|
| 63 |
+
# echo "5th exp finishes"
|
| 64 |
+
# date +"%F %T"
|
| 65 |
+
# wandb sync wandb/latest-run
|
| 66 |
+
|
| 67 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 68 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex06" --trainer_args.learning_rate=1e-3 \
|
| 69 |
+
# --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
|
| 70 |
+
# sleep 5
|
| 71 |
+
# echo "6th exp finishes"
|
| 72 |
+
# date +"%F %T"
|
| 73 |
+
# wandb sync wandb/latest-run
|
| 74 |
+
|
| 75 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 76 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_exps7" --trainer_args.learning_rate=1e-3 \
|
| 77 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
|
| 78 |
+
# sleep 5
|
| 79 |
+
# echo "7th exp finishes"
|
| 80 |
+
# date +"%F %T"
|
| 81 |
+
# wandb sync wandb/latest-run
|
| 82 |
+
|
| 83 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 84 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex08" --trainer_args.learning_rate=2e-3 \
|
| 85 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
|
| 86 |
+
# sleep 5
|
| 87 |
+
# echo "8th exp finishes"
|
| 88 |
+
# date +"%F %T"
|
| 89 |
+
# wandb sync wandb/latest-run
|
| 90 |
+
|
| 91 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 92 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex09" --trainer_args.learning_rate=2e-3 \
|
| 93 |
+
# --rotation_adapter_config.num_rotations 16 --rotation_adapter_config.r 1
|
| 94 |
+
|
| 95 |
+
# sleep 5
|
| 96 |
+
# echo "9th exp finishes"
|
| 97 |
+
# date +"%F %T"
|
| 98 |
+
# wandb sync wandb/latest-run
|
| 99 |
+
|
| 100 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 101 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex10" --trainer_args.learning_rate=2e-3 \
|
| 102 |
+
# --rotation_adapter_config.num_rotations 8 --rotation_adapter_config.r 2
|
| 103 |
+
|
| 104 |
+
# sleep 5
|
| 105 |
+
# echo "10 exp finishes"
|
| 106 |
+
# date +"%F %T"
|
| 107 |
+
# wandb sync wandb/latest-run
|
| 108 |
+
|
| 109 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 110 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex11" --trainer_args.learning_rate=1e-2 \
|
| 111 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
|
| 112 |
+
|
| 113 |
+
# sleep 5
|
| 114 |
+
# echo "11 exp finishes"
|
| 115 |
+
# date +"%F %T"
|
| 116 |
+
# wandb sync wandb/latest-run
|
| 117 |
+
|
| 118 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 119 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex12" --trainer_args.learning_rate=1e-2 \
|
| 120 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'u=v,def'
|
| 121 |
+
|
| 122 |
+
# sleep 5
|
| 123 |
+
# echo "12 exp finishes"
|
| 124 |
+
# date +"%F %T"
|
| 125 |
+
# wandb sync wandb/latest-run
|
| 126 |
+
|
| 127 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 128 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex13" --trainer_args.learning_rate=1e-3 \
|
| 129 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'u=vkaim'
|
| 130 |
+
|
| 131 |
+
# sleep 5
|
| 132 |
+
# echo "13 exp finishes"
|
| 133 |
+
# date +"%F %T"
|
| 134 |
+
# wandb sync wandb/latest-run
|
| 135 |
+
|
| 136 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 137 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex14" --trainer_args.learning_rate=2e-3 \
|
| 138 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'a,b,def'
|
| 139 |
+
|
| 140 |
+
# sleep 5
|
| 141 |
+
# echo "14 exp finishes"
|
| 142 |
+
# date +"%F %T"
|
| 143 |
+
# wandb sync wandb/latest-run
|
| 144 |
+
|
| 145 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 146 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex15" --trainer_args.learning_rate=1e-3 \
|
| 147 |
+
# --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
|
| 148 |
+
|
| 149 |
+
# sleep 5
|
| 150 |
+
# echo "15 exp finishes"
|
| 151 |
+
# date +"%F %T"
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 155 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex17" --trainer_args.learning_rate=1e-3 \
|
| 156 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 157 |
+
# --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
|
| 158 |
+
# --run_text "dropout|fix_token"
|
| 159 |
+
# sleep 5
|
| 160 |
+
# echo "15 exp finishes"
|
| 161 |
+
# date +"%F %T"
|
| 162 |
+
# wandb sync wandb/latest-run
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 166 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex18" --trainer_args.learning_rate=1e-3 \
|
| 167 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 168 |
+
# --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
|
| 169 |
+
# --run_text "dropout|fix_token"
|
| 170 |
+
# sleep 5
|
| 171 |
+
# echo "158exp finishes"
|
| 172 |
+
# date +"%F %T"
|
| 173 |
+
# wandb sync wandb/latest-run
|
| 174 |
+
|
| 175 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 176 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex19" --trainer_args.learning_rate=2e-3 \
|
| 177 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 178 |
+
# --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
|
| 179 |
+
# --run_text "dropout|fix_token"
|
| 180 |
+
# sleep 5
|
| 181 |
+
# echo "19 exp finishes"
|
| 182 |
+
# date +"%F %T"
|
| 183 |
+
# wandb sync wandb/latest-run
|
| 184 |
+
|
| 185 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 186 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex20" --trainer_args.learning_rate=8e-4 \
|
| 187 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 188 |
+
# --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
|
| 189 |
+
# --run_text "dropout|fix_token"
|
| 190 |
+
# sleep 5
|
| 191 |
+
# echo "20 exp finishes"
|
| 192 |
+
# date +"%F %T"
|
| 193 |
+
# wandb sync wandb/latest-run
|
| 194 |
+
|
| 195 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 196 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex21" --trainer_args.learning_rate=1e-3 \
|
| 197 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 198 |
+
# --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
|
| 199 |
+
# --run_text "dropout|2ep|1e3"
|
| 200 |
+
# sleep 5
|
| 201 |
+
# echo "21 exp finishes"
|
| 202 |
+
# date +"%F %T"
|
| 203 |
+
# wandb sync wandb/latest-run
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# back to official 40k
|
| 207 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 208 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex22" --trainer_args.learning_rate=1e-3 \
|
| 209 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 210 |
+
# --trainer_args.num_train_epochs 2.0 --data.dataset_split train \
|
| 211 |
+
# --run_text "drop0.1|2ep|1e3|40k"
|
| 212 |
+
# sleep 5
|
| 213 |
+
# echo "21 exp finishes"
|
| 214 |
+
# date +"%F %T"
|
| 215 |
+
# wandb sync wandb/latest-run
|
| 216 |
+
|
| 217 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 218 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex23" --trainer_args.learning_rate=1e-3 \
|
| 219 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 220 |
+
# --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 221 |
+
# --run_text "drop0.1|2ep|1e3|40k"
|
| 222 |
+
# sleep 5
|
| 223 |
+
# echo "21 exp finishes"
|
| 224 |
+
# date +"%F %T"
|
| 225 |
+
# wandb sync wandb/latest-run
|
| 226 |
+
|
| 227 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 228 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex24" --trainer_args.learning_rate=1e-2 \
|
| 229 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 230 |
+
# --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 231 |
+
# --run_text "drop0.1|2ep|1e2|40k"
|
| 232 |
+
# sleep 5
|
| 233 |
+
# echo "21 exp finishes"
|
| 234 |
+
# date +"%F %T"
|
| 235 |
+
# wandb sync wandb/latest-run
|
| 236 |
+
|
| 237 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 238 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex25" --trainer_args.learning_rate=2e-3 \
|
| 239 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 240 |
+
# --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 241 |
+
# --run_text "drop0.1|2ep|2e3|40k"
|
| 242 |
+
# sleep 5
|
| 243 |
+
# echo "21 exp finishes"
|
| 244 |
+
# date +"%F %T"
|
| 245 |
+
# wandb sync wandb/latest-run
|
| 246 |
+
|
| 247 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 248 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex26" --trainer_args.learning_rate=5e-3 \
|
| 249 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 250 |
+
# --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 251 |
+
# --run_text "drop0.1|2ep|5e3|40k"
|
| 252 |
+
# sleep 5
|
| 253 |
+
# echo "21 exp finishes"
|
| 254 |
+
# date +"%F %T"
|
| 255 |
+
# wandb sync wandb/latest-run
|
| 256 |
+
|
| 257 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 258 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex27" --trainer_args.learning_rate=8e-3 \
|
| 259 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 260 |
+
# --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 261 |
+
# --run_text "drop0.1|2ep|8e3|40k"
|
| 262 |
+
# sleep 5
|
| 263 |
+
# echo "21 exp finishes"
|
| 264 |
+
# date +"%F %T"
|
| 265 |
+
# wandb sync wandb/latest-run
|
| 266 |
+
|
| 267 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 268 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex28" --trainer_args.learning_rate=2e-2 \
|
| 269 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 270 |
+
# --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 271 |
+
# --run_text "drop0.1|2ep|2e2|40k"
|
| 272 |
+
# sleep 5
|
| 273 |
+
# echo "21 exp finishes"
|
| 274 |
+
# date +"%F %T"
|
| 275 |
+
# wandb sync wandb/latest-run
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 279 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex29" --trainer_args.learning_rate=5e-3 \
|
| 280 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 281 |
+
# --trainer_args.num_train_epochs 2.0 --data.dataset_split train \
|
| 282 |
+
# --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
|
| 283 |
+
# sleep 5
|
| 284 |
+
# echo "29 exp finishes"
|
| 285 |
+
# date +"%F %T"
|
| 286 |
+
# wandb sync wandb/latest-run
|
| 287 |
+
|
| 288 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 289 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex30" --trainer_args.learning_rate=1e-3 \
|
| 290 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 291 |
+
# --trainer_args.num_train_epochs 2.0 --data.dataset_split train \
|
| 292 |
+
# --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
|
| 293 |
+
# sleep 5
|
| 294 |
+
# echo "29 exp finishes"
|
| 295 |
+
# date +"%F %T"
|
| 296 |
+
# wandb sync wandb/latest-run
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 300 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex31" --trainer_args.learning_rate=5e-3 \
|
| 301 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 302 |
+
# --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 303 |
+
# --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
|
| 304 |
+
# sleep 5
|
| 305 |
+
# echo "29 exp finishes"
|
| 306 |
+
# date +"%F %T"
|
| 307 |
+
# wandb sync wandb/latest-run
|
| 308 |
+
|
| 309 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 310 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex32" --trainer_args.learning_rate=1e-3 \
|
| 311 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 312 |
+
# --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 313 |
+
# --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
|
| 314 |
+
# sleep 5
|
| 315 |
+
# echo "29 exp finishes"
|
| 316 |
+
# date +"%F %T"
|
| 317 |
+
# wandb sync wandb/latest-run
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 321 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex33" --trainer_args.learning_rate=1e-2 \
|
| 322 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 323 |
+
# --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 324 |
+
# --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
|
| 325 |
+
# sleep 5
|
| 326 |
+
# echo "29 exp finishes"
|
| 327 |
+
# date +"%F %T"
|
| 328 |
+
# wandb sync wandb/latest-run
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 332 |
+
--config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex34" --trainer_args.learning_rate=2e-2 \
|
| 333 |
+
--rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 334 |
+
--trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 335 |
+
--run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
|
| 336 |
+
sleep 5
|
| 337 |
+
echo "29 exp finishes"
|
| 338 |
+
date +"%F %T"
|
| 339 |
+
wandb sync wandb/latest-run
|
| 340 |
+
|
| 341 |
+
bash scripts/merge.sh
|
nl_tasks/scripts/copy train_cms_reasoning.sh
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export OMINI_CONFIG=./config/commonsense.yaml
|
| 2 |
+
|
| 3 |
+
#echo $OMINI_CONFIG
|
| 4 |
+
export TOKENIZERS_PARALLELISM=true
|
| 5 |
+
|
| 6 |
+
# CUDA Include (/cuda.h)
|
| 7 |
+
CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
|
| 8 |
+
|
| 9 |
+
# 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
|
| 10 |
+
export CPATH=$CPATH:$CUDA_INCLUDE_PATH
|
| 11 |
+
export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
|
| 12 |
+
# echo "CPATH is set to: $CPATH"
|
| 13 |
+
# echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
|
| 14 |
+
|
| 15 |
+
export WANDB_PROJECT="Llama2_7B_FT_Math40k"
|
| 16 |
+
|
| 17 |
+
export OMP_NUM_THREADS=1
|
| 18 |
+
export MKL_NUM_THREADS=1
|
| 19 |
+
export OPENBLAS_NUM_THREADS=1
|
| 20 |
+
export NUMEXPR_NUM_THREADS=1
|
| 21 |
+
|
| 22 |
+
date +"%F %T"
|
| 23 |
+
|
| 24 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 25 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exps1" --trainer_args.learning_rate=1e-5
|
| 26 |
+
|
| 27 |
+
# sleep 5
|
| 28 |
+
# echo "1st exp finishes"
|
| 29 |
+
# date +"%F %T"
|
| 30 |
+
|
| 31 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 32 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exps2" --trainer_args.learning_rate=2e-5
|
| 33 |
+
|
| 34 |
+
# sleep 5
|
| 35 |
+
# echo "2nd exp finishes"
|
| 36 |
+
# date +"%F %T"
|
| 37 |
+
|
| 38 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 39 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exps3" --trainer_args.learning_rate=5e-5
|
| 40 |
+
|
| 41 |
+
# sleep 5
|
| 42 |
+
# echo "3rd exp finishes"
|
| 43 |
+
# date +"%F %T"
|
| 44 |
+
|
| 45 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 46 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exps4" --trainer_args.learning_rate=1e-4
|
| 47 |
+
|
| 48 |
+
# sleep 5
|
| 49 |
+
# echo "4th exp finishes"
|
| 50 |
+
# date +"%F %T"
|
| 51 |
+
|
| 52 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 53 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exps5" --trainer_args.learning_rate=2e-4
|
| 54 |
+
|
| 55 |
+
# sleep 5
|
| 56 |
+
# echo "5th exp finishes"
|
| 57 |
+
# date +"%F %T"
|
| 58 |
+
|
| 59 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 60 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exps6" --trainer_args.learning_rate=5e-4
|
| 61 |
+
|
| 62 |
+
# sleep 5
|
| 63 |
+
# echo "6th exp finishes"
|
| 64 |
+
# date +"%F %T"
|
| 65 |
+
|
| 66 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 67 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exps7" --trainer_args.learning_rate=8e-4
|
| 68 |
+
|
| 69 |
+
# sleep 5
|
| 70 |
+
# echo "7th exp finishes"
|
| 71 |
+
# date +"%F %T"
|
| 72 |
+
|
| 73 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 74 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exps8" --trainer_args.learning_rate=1e-3
|
| 75 |
+
|
| 76 |
+
# sleep 5
|
| 77 |
+
# echo "8th exp finishes"
|
| 78 |
+
# date +"%F %T"
|
| 79 |
+
|
| 80 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 81 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exps9" --trainer_args.learning_rate=2e-3
|
| 82 |
+
|
| 83 |
+
# sleep 5
|
| 84 |
+
# echo "9th exp finishes"
|
| 85 |
+
# date +"%F %T"
|
| 86 |
+
|
| 87 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 88 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exnr10" --trainer_args.learning_rate=1e-3 \
|
| 89 |
+
# --rotation_adapter_config.num_rotations 8 --rotation_adapter_config.r 2
|
| 90 |
+
|
| 91 |
+
# sleep 5
|
| 92 |
+
# echo "10 exp finishes"
|
| 93 |
+
# date +"%F %T"
|
| 94 |
+
|
| 95 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 96 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exnr11" --trainer_args.learning_rate=1e-3 \
|
| 97 |
+
# --rotation_adapter_config.num_rotations 2 --rotation_adapter_config.r 8
|
| 98 |
+
|
| 99 |
+
# sleep 5
|
| 100 |
+
# echo "11 exp finishes"
|
| 101 |
+
# date +"%F %T"
|
| 102 |
+
|
| 103 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 104 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exnr12" --trainer_args.learning_rate=1e-3 \
|
| 105 |
+
# --rotation_adapter_config.num_rotations 16 --rotation_adapter_config.r 1
|
| 106 |
+
|
| 107 |
+
# sleep 5
|
| 108 |
+
# echo "12 exp finishes"
|
| 109 |
+
# date +"%F %T"
|
| 110 |
+
|
| 111 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 112 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_exnr13" --trainer_args.learning_rate=1e-3 \
|
| 113 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
|
| 114 |
+
|
| 115 |
+
# sleep 5
|
| 116 |
+
# echo "13 exp finishes"
|
| 117 |
+
# date +"%F %T"
|
| 118 |
+
|
| 119 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 120 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./run_all/exnr14" --trainer_args.learning_rate=2e-3 \
|
| 121 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
|
| 122 |
+
|
| 123 |
+
# sleep 5
|
| 124 |
+
# echo "14 exp finishes"
|
| 125 |
+
# date +"%F %T"
|
| 126 |
+
|
| 127 |
+
accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 128 |
+
--config_path $OMINI_CONFIG --trainer_args.output_dir "./run_all/exnr15" --trainer_args.learning_rate=1e-3 \
|
| 129 |
+
--rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
|
| 130 |
+
|
| 131 |
+
# sleep 5
|
| 132 |
+
echo "15 exp finishes"
|
| 133 |
+
date +"%F %T"
|
nl_tasks/scripts/down_math_train.sh
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
DATASET_ID="meta-math/MetaMathQA"
|
| 4 |
+
LOCAL_DIR="./data/MetaMathQA"
|
| 5 |
+
|
| 6 |
+
# echo "Starting download for dataset: $DATASET_ID..."
|
| 7 |
+
huggingface-cli download $DATASET_ID \
|
| 8 |
+
--repo-type dataset \
|
| 9 |
+
--local-dir $LOCAL_DIR \
|
| 10 |
+
--local-dir-use-symlinks False \
|
| 11 |
+
--resume-download \
|
| 12 |
+
--include "*.json"
|
| 13 |
+
|
| 14 |
+
# echo "Download completed. Data is located at: $LOCAL_DIR"
|
nl_tasks/scripts/inference.sh
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OUTPUT="./exps/run_ex12/merged"
|
| 2 |
+
# OUTPUT="./exp395/run_ex02/merged"
|
| 3 |
+
# OUTPUT="./exp_init/run_ex02/merged"
|
| 4 |
+
OUTPUT="./exps/run_ex15_3ep/merged"
|
| 5 |
+
|
| 6 |
+
date +"%F %T"
|
| 7 |
+
|
| 8 |
+
echo 'test math'
|
| 9 |
+
|
| 10 |
+
date +"%F %T"
|
| 11 |
+
python inference/gsm8k_infer.py --model $OUTPUT
|
| 12 |
+
date +"%F %T"
|
| 13 |
+
python inference/MATH_infer.py --model $OUTPUT
|
| 14 |
+
date +"%F %T"
|
nl_tasks/scripts/merge.sh
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# export OMINI_CONFIG=./config/commonsense.yaml
|
| 2 |
+
export OMINI_CONFIG=./config/commonsense.yaml
|
| 3 |
+
#echo $OMINI_CONFIG
|
| 4 |
+
export TOKENIZERS_PARALLELISM=true
|
| 5 |
+
|
| 6 |
+
# CUDA Include (/cuda.h)
|
| 7 |
+
CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
|
| 8 |
+
|
| 9 |
+
# 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
|
| 10 |
+
export CPATH=$CPATH:$CUDA_INCLUDE_PATH
|
| 11 |
+
export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
|
| 12 |
+
# echo "CPATH is set to: $CPATH"
|
| 13 |
+
# echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
|
| 14 |
+
|
| 15 |
+
#./exps/run_ex15_3ep
|
| 16 |
+
# ./exp_init/run_ex02/ft2
|
| 17 |
+
# ADAPTER = "--model.merge_adapter_path "./exps/run_ex12/ft2" --model.merge_output_path "./exps/run_ex12/merged"
|
| 18 |
+
# export ADAPTER = "--model.merge_adapter_path ./exp395/run_ex01/ft2 --model.merge_output_path ./exp395/run_ex01/merged"
|
| 19 |
+
# accelerate launch --main_process_port 41353 -m src.merge \
|
| 20 |
+
# --config_path $OMINI_CONFIG \
|
| 21 |
+
# --model.merge_adapter_path "./exps/run_ex19_2ep/ft2" --model.merge_output_path "./exps/run_ex19_2ep/merged"
|
| 22 |
+
|
| 23 |
+
# OUTPUT="./exps/run_ex19_2ep/merged"
|
| 24 |
+
|
| 25 |
+
# date +"%F %T"
|
| 26 |
+
# python inference/MATH_infer.py --model $OUTPUT
|
| 27 |
+
# date +"%F %T"
|
| 28 |
+
# python inference/gsm8k_infer.py --model $OUTPUT
|
| 29 |
+
# date +"%F %T"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# MERGE_DIR="./exps/run_ex24"
|
| 33 |
+
# accelerate launch --main_process_port 41353 -m src.merge \
|
| 34 |
+
# --config_path $OMINI_CONFIG \
|
| 35 |
+
# --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 36 |
+
|
| 37 |
+
# date +"%F %T"
|
| 38 |
+
# python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 39 |
+
# date +"%F %T"
|
| 40 |
+
# python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 41 |
+
# date +"%F %T"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# MERGE_DIR="./exps/run_ex25"
|
| 45 |
+
# accelerate launch --main_process_port 41353 -m src.merge \
|
| 46 |
+
# --config_path $OMINI_CONFIG \
|
| 47 |
+
# --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 48 |
+
|
| 49 |
+
# date +"%F %T"
|
| 50 |
+
# python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 51 |
+
# date +"%F %T"
|
| 52 |
+
# python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 53 |
+
# date +"%F %T"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# MERGE_DIR="./exps/run_ex26"
|
| 57 |
+
# accelerate launch --main_process_port 41353 -m src.merge \
|
| 58 |
+
# --config_path $OMINI_CONFIG \
|
| 59 |
+
# --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 60 |
+
|
| 61 |
+
# date +"%F %T"
|
| 62 |
+
# python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 63 |
+
# date +"%F %T"
|
| 64 |
+
# python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 65 |
+
# date +"%F %T"
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# MERGE_DIR="./exps/run_ex27"
|
| 69 |
+
# accelerate launch --main_process_port 41353 -m src.merge \
|
| 70 |
+
# --config_path $OMINI_CONFIG \
|
| 71 |
+
# --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 72 |
+
|
| 73 |
+
# date +"%F %T"
|
| 74 |
+
# python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 75 |
+
# date +"%F %T"
|
| 76 |
+
# python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 77 |
+
# date +"%F %T"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# MERGE_DIR="./exps/run_ex28"
|
| 81 |
+
# accelerate launch --main_process_port 41353 -m src.merge \
|
| 82 |
+
# --config_path $OMINI_CONFIG \
|
| 83 |
+
# --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 84 |
+
|
| 85 |
+
# date +"%F %T"
|
| 86 |
+
# python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 87 |
+
# date +"%F %T"
|
| 88 |
+
# python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 89 |
+
# date +"%F %T"
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# MERGE_DIR="./exps/run_ex33"
|
| 93 |
+
# accelerate launch --main_process_port 41353 -m src.merge \
|
| 94 |
+
# --config_path $OMINI_CONFIG \
|
| 95 |
+
# --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 96 |
+
|
| 97 |
+
# date +"%F %T"
|
| 98 |
+
# python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 99 |
+
# date +"%F %T"
|
| 100 |
+
# python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 101 |
+
# date +"%F %T"
|
| 102 |
+
|
| 103 |
+
# 140126
|
| 104 |
+
MERGE_DIR="./exprep/run_ex30"
|
| 105 |
+
accelerate launch --main_process_port 41353 -m src.merge \
|
| 106 |
+
--config_path $OMINI_CONFIG \
|
| 107 |
+
--model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 108 |
+
|
| 109 |
+
date +"%F %T"
|
| 110 |
+
python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 111 |
+
date +"%F %T"
|
| 112 |
+
python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 113 |
+
date +"%F %T"
|
| 114 |
+
|
| 115 |
+
MERGE_DIR="./exprep/run_ex31"
|
| 116 |
+
accelerate launch --main_process_port 41353 -m src.merge \
|
| 117 |
+
--config_path $OMINI_CONFIG \
|
| 118 |
+
--model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 119 |
+
|
| 120 |
+
date +"%F %T"
|
| 121 |
+
python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 122 |
+
date +"%F %T"
|
| 123 |
+
python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 124 |
+
date +"%F %T"
|
| 125 |
+
|
| 126 |
+
MERGE_DIR="./exprep/run_ex32"
|
| 127 |
+
accelerate launch --main_process_port 41353 -m src.merge \
|
| 128 |
+
--config_path $OMINI_CONFIG \
|
| 129 |
+
--model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 130 |
+
|
| 131 |
+
date +"%F %T"
|
| 132 |
+
python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 133 |
+
date +"%F %T"
|
| 134 |
+
python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 135 |
+
date +"%F %T"
|
| 136 |
+
|
| 137 |
+
|
nl_tasks/scripts/merge_100k.sh
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# export OMINI_CONFIG=./config/commonsense.yaml
|
| 2 |
+
export OMINI_CONFIG=./config/math395.yaml
|
| 3 |
+
#echo $OMINI_CONFIG
|
| 4 |
+
export TOKENIZERS_PARALLELISM=true
|
| 5 |
+
|
| 6 |
+
# CUDA Include (/cuda.h)
|
| 7 |
+
CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
|
| 8 |
+
|
| 9 |
+
# 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
|
| 10 |
+
export CPATH=$CPATH:$CUDA_INCLUDE_PATH
|
| 11 |
+
export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
|
| 12 |
+
# echo "CPATH is set to: $CPATH"
|
| 13 |
+
# echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
|
| 14 |
+
|
| 15 |
+
#./exps/run_ex15_3ep
|
| 16 |
+
# ./exp_init/run_ex02/ft2
|
| 17 |
+
# ADAPTER = "--model.merge_adapter_path "./exps/run_ex12/ft2" --model.merge_output_path "./exps/run_ex12/merged"
|
| 18 |
+
# export ADAPTER = "--model.merge_adapter_path ./exp395/run_ex01/ft2 --model.merge_output_path ./exp395/run_ex01/merged"
|
| 19 |
+
|
| 20 |
+
MERGE_DIR="./exp100/run_ex06"
|
| 21 |
+
accelerate launch --main_process_port 41353 -m src.merge \
|
| 22 |
+
--config_path $OMINI_CONFIG \
|
| 23 |
+
--model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 24 |
+
|
| 25 |
+
date +"%F %T"
|
| 26 |
+
python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 27 |
+
date +"%F %T"
|
| 28 |
+
python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 29 |
+
date +"%F %T"
|
| 30 |
+
|
| 31 |
+
MERGE_DIR="./exp100/run_ex07"
|
| 32 |
+
accelerate launch --main_process_port 41353 -m src.merge \
|
| 33 |
+
--config_path $OMINI_CONFIG \
|
| 34 |
+
--model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 35 |
+
|
| 36 |
+
date +"%F %T"
|
| 37 |
+
python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 38 |
+
date +"%F %T"
|
| 39 |
+
python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 40 |
+
date +"%F %T"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
MERGE_DIR="./exp100/run_ex08"
|
| 44 |
+
accelerate launch --main_process_port 41353 -m src.merge \
|
| 45 |
+
--config_path $OMINI_CONFIG \
|
| 46 |
+
--model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 47 |
+
|
| 48 |
+
date +"%F %T"
|
| 49 |
+
python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 50 |
+
date +"%F %T"
|
| 51 |
+
python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 52 |
+
date +"%F %T"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
MERGE_DIR="./exp100/run_ex09"
|
| 56 |
+
accelerate launch --main_process_port 41353 -m src.merge \
|
| 57 |
+
--config_path $OMINI_CONFIG \
|
| 58 |
+
--model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 59 |
+
|
| 60 |
+
date +"%F %T"
|
| 61 |
+
python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 62 |
+
date +"%F %T"
|
| 63 |
+
python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 64 |
+
date +"%F %T"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
MERGE_DIR="./exp100/run_ex10"
|
| 68 |
+
accelerate launch --main_process_port 41353 -m src.merge \
|
| 69 |
+
--config_path $OMINI_CONFIG \
|
| 70 |
+
--model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 71 |
+
|
| 72 |
+
date +"%F %T"
|
| 73 |
+
python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 74 |
+
date +"%F %T"
|
| 75 |
+
python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 76 |
+
date +"%F %T"
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
MERGE_DIR="./exp100/run_ex11"
|
| 80 |
+
accelerate launch --main_process_port 41353 -m src.merge \
|
| 81 |
+
--config_path $OMINI_CONFIG \
|
| 82 |
+
--model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 83 |
+
|
| 84 |
+
date +"%F %T"
|
| 85 |
+
python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 86 |
+
date +"%F %T"
|
| 87 |
+
python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 88 |
+
date +"%F %T"
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
MERGE_DIR="./exp100/run_ex12"
|
| 92 |
+
accelerate launch --main_process_port 41353 -m src.merge \
|
| 93 |
+
--config_path $OMINI_CONFIG \
|
| 94 |
+
--model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 95 |
+
|
| 96 |
+
date +"%F %T"
|
| 97 |
+
python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 98 |
+
date +"%F %T"
|
| 99 |
+
python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 100 |
+
date +"%F %T"
|
nl_tasks/scripts/merge_math.sh
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# export OMINI_CONFIG=./config/commonsense.yaml
|
| 2 |
+
export OMINI_CONFIG=./config/math395.yaml
|
| 3 |
+
#echo $OMINI_CONFIG
|
| 4 |
+
export TOKENIZERS_PARALLELISM=true
|
| 5 |
+
|
| 6 |
+
# CUDA Include (/cuda.h)
|
| 7 |
+
CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
|
| 8 |
+
|
| 9 |
+
# 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
|
| 10 |
+
export CPATH=$CPATH:$CUDA_INCLUDE_PATH
|
| 11 |
+
export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
|
| 12 |
+
# echo "CPATH is set to: $CPATH"
|
| 13 |
+
# echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
|
| 14 |
+
|
| 15 |
+
#./exps/run_ex15_3ep
|
| 16 |
+
# ./exp_init/run_ex02/ft2
|
| 17 |
+
# ADAPTER = "--model.merge_adapter_path "./exps/run_ex12/ft2" --model.merge_output_path "./exps/run_ex12/merged"
|
| 18 |
+
# export ADAPTER = "--model.merge_adapter_path ./exp395/run_ex01/ft2 --model.merge_output_path ./exp395/run_ex01/merged"
|
| 19 |
+
|
| 20 |
+
MERGE_DIR="./exp395/run_ex10"
|
| 21 |
+
# accelerate launch --main_process_port 41353 -m src.merge \
|
| 22 |
+
# --config_path $OMINI_CONFIG \
|
| 23 |
+
# --model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 24 |
+
|
| 25 |
+
# # OUTPUT="./exp395/run_ex09/merged"
|
| 26 |
+
|
| 27 |
+
# date +"%F %T"
|
| 28 |
+
# python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 29 |
+
date +"%F %T"
|
| 30 |
+
python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 31 |
+
date +"%F %T"
|
nl_tasks/scripts/peft_merge.sh
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# export OMINI_CONFIG=./config/commonsense.yaml
|
| 2 |
+
export OMINI_CONFIG=./config/commonsense.yaml
|
| 3 |
+
#echo $OMINI_CONFIG
|
| 4 |
+
export TOKENIZERS_PARALLELISM=true
|
| 5 |
+
|
| 6 |
+
# CUDA Include (/cuda.h)
|
| 7 |
+
CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
|
| 8 |
+
|
| 9 |
+
# 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
|
| 10 |
+
export CPATH=$CPATH:$CUDA_INCLUDE_PATH
|
| 11 |
+
export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
|
| 12 |
+
# echo "CPATH is set to: $CPATH"
|
| 13 |
+
# echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
MERGE_DIR="./expsBOFT/seed43"
|
| 18 |
+
accelerate launch --main_process_port 41353 -m src.peft_merge \
|
| 19 |
+
--config_path $OMINI_CONFIG \
|
| 20 |
+
--model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 21 |
+
|
| 22 |
+
date +"%F %T"
|
| 23 |
+
python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 24 |
+
date +"%F %T"
|
| 25 |
+
python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 26 |
+
date +"%F %T"
|
| 27 |
+
|
| 28 |
+
MERGE_DIR="./expsBOFT/seed44"
|
| 29 |
+
accelerate launch --main_process_port 41353 -m src.peft_merge \
|
| 30 |
+
--config_path $OMINI_CONFIG \
|
| 31 |
+
--model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 32 |
+
|
| 33 |
+
date +"%F %T"
|
| 34 |
+
python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 35 |
+
date +"%F %T"
|
| 36 |
+
python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 37 |
+
date +"%F %T"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
MERGE_DIR="./expsOFT/seed43"
|
| 41 |
+
accelerate launch --main_process_port 41353 -m src.peft_merge \
|
| 42 |
+
--config_path $OMINI_CONFIG \
|
| 43 |
+
--model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 44 |
+
|
| 45 |
+
date +"%F %T"
|
| 46 |
+
python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 47 |
+
date +"%F %T"
|
| 48 |
+
python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 49 |
+
date +"%F %T"
|
| 50 |
+
|
| 51 |
+
MERGE_DIR="./expsOFT/seed44"
|
| 52 |
+
accelerate launch --main_process_port 41353 -m src.peft_merge \
|
| 53 |
+
--config_path $OMINI_CONFIG \
|
| 54 |
+
--model.merge_adapter_path $MERGE_DIR/ft2/ --model.merge_output_path $MERGE_DIR/merged/
|
| 55 |
+
date +"%F %T"
|
| 56 |
+
python inference/MATH_infer.py --model $MERGE_DIR/merged/
|
| 57 |
+
date +"%F %T"
|
| 58 |
+
python inference/gsm8k_infer.py --model $MERGE_DIR/merged/
|
| 59 |
+
date +"%F %T"
|
| 60 |
+
|
nl_tasks/scripts/train_100math.sh
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export OMINI_CONFIG=./config/math395.yaml
|
| 2 |
+
|
| 3 |
+
#echo $OMINI_CONFIG
|
| 4 |
+
export TOKENIZERS_PARALLELISM=true
|
| 5 |
+
|
| 6 |
+
# CUDA Include (/cuda.h)
|
| 7 |
+
CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
|
| 8 |
+
|
| 9 |
+
# 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
|
| 10 |
+
export CPATH=$CPATH:$CUDA_INCLUDE_PATH
|
| 11 |
+
export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
|
| 12 |
+
# echo "CPATH is set to: $CPATH"
|
| 13 |
+
# echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
|
| 14 |
+
|
| 15 |
+
export WANDB_PROJECT="Llama2_7B_FT_Math_395k"
|
| 16 |
+
|
| 17 |
+
export OMP_NUM_THREADS=1
|
| 18 |
+
export MKL_NUM_THREADS=1
|
| 19 |
+
export OPENBLAS_NUM_THREADS=1
|
| 20 |
+
export NUMEXPR_NUM_THREADS=1
|
| 21 |
+
|
| 22 |
+
date +"%F %T"
|
| 23 |
+
|
| 24 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 25 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex01" --trainer_args.learning_rate=5e-3 \
|
| 26 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 27 |
+
# --trainer_args.num_train_epochs 2.0 --data.dataset_split train[:100000] \
|
| 28 |
+
# --run_text 'def|o100k'
|
| 29 |
+
|
| 30 |
+
# sleep 5
|
| 31 |
+
# echo "1st exp finishes"
|
| 32 |
+
# date +"%F %T"
|
| 33 |
+
# wandb sync wandb/latest-run
|
| 34 |
+
|
| 35 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 36 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex02" --trainer_args.learning_rate=2e-2 \
|
| 37 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 38 |
+
# --trainer_args.num_train_epochs 2.0 --data.dataset_split train[:100000] \
|
| 39 |
+
# --run_text 'def|o100k'
|
| 40 |
+
|
| 41 |
+
# sleep 5
|
| 42 |
+
# echo "2nd exp finishes"
|
| 43 |
+
# date +"%F %T"
|
| 44 |
+
# wandb sync wandb/latest-run
|
| 45 |
+
|
| 46 |
+
# bash scripts/merge_100k.sh
|
| 47 |
+
|
| 48 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 49 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex03" --trainer_args.learning_rate=1e-2 \
|
| 50 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 51 |
+
# --trainer_args.num_train_epochs 2.0 --data.dataset_split train[:100000] \
|
| 52 |
+
# --run_text 'def|o100k'
|
| 53 |
+
|
| 54 |
+
# sleep 5
|
| 55 |
+
# echo "3rd exp finishes"
|
| 56 |
+
# date +"%F %T"
|
| 57 |
+
# wandb sync wandb/latest-run
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 61 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex04" --trainer_args.learning_rate=5e-2 \
|
| 62 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 63 |
+
# --trainer_args.num_train_epochs 2.0 --data.dataset_split train[:100000] \
|
| 64 |
+
# --run_text 'def|o100k'
|
| 65 |
+
|
| 66 |
+
# sleep 5
|
| 67 |
+
# echo "4th exp finishes"
|
| 68 |
+
# date +"%F %T"
|
| 69 |
+
# wandb sync wandb/latest-run
|
| 70 |
+
|
| 71 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 72 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex05" --trainer_args.learning_rate=1e-2 \
|
| 73 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 74 |
+
# --trainer_args.num_train_epochs 3.0 --data.dataset_split train[:100000] \
|
| 75 |
+
# --run_text 'def|o100k'
|
| 76 |
+
# sleep 5
|
| 77 |
+
# echo "5th exp finishes"
|
| 78 |
+
# date +"%F %T"
|
| 79 |
+
# wandb sync wandb/latest-run
|
| 80 |
+
# bash scripts/merge_100math.sh
|
| 81 |
+
|
| 82 |
+
accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 83 |
+
--config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex06" --trainer_args.learning_rate=1e-2 \
|
| 84 |
+
--rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 85 |
+
--trainer_args.num_train_epochs 2.0 --data.dataset_split train[:100000] \
|
| 86 |
+
--run_text 'def|o100k|b48' --trainer_args.per_device_train_batch_size 48
|
| 87 |
+
sleep 5
|
| 88 |
+
echo "6th exp finishes"
|
| 89 |
+
date +"%F %T"
|
| 90 |
+
wandb sync wandb/latest-run
|
| 91 |
+
|
| 92 |
+
accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 93 |
+
--config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex07" --trainer_args.learning_rate=1e-2 \
|
| 94 |
+
--rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 95 |
+
--trainer_args.num_train_epochs 3.0 --data.dataset_split train[:100000] \
|
| 96 |
+
--run_text 'def|o100k|b48' --trainer_args.per_device_train_batch_size 48
|
| 97 |
+
sleep 5
|
| 98 |
+
echo "6th exp finishes"
|
| 99 |
+
date +"%F %T"
|
| 100 |
+
wandb sync wandb/latest-run
|
| 101 |
+
|
| 102 |
+
accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 103 |
+
--config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex08" --trainer_args.learning_rate=2e-2 \
|
| 104 |
+
--rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 105 |
+
--trainer_args.num_train_epochs 2.0 --data.dataset_split train[:100000] \
|
| 106 |
+
--run_text 'def|o100k|b48' --trainer_args.per_device_train_batch_size 48
|
| 107 |
+
sleep 5
|
| 108 |
+
echo "8th exp finishes"
|
| 109 |
+
date +"%F %T"
|
| 110 |
+
wandb sync wandb/latest-run
|
| 111 |
+
|
| 112 |
+
accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 113 |
+
--config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex09" --trainer_args.learning_rate=2e-2 \
|
| 114 |
+
--rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 115 |
+
--trainer_args.num_train_epochs 3.0 --data.dataset_split train[:100000] \
|
| 116 |
+
--run_text 'def|o100k|b48' --trainer_args.per_device_train_batch_size 48
|
| 117 |
+
sleep 5
|
| 118 |
+
echo "9th exp finishes"
|
| 119 |
+
date +"%F %T"
|
| 120 |
+
wandb sync wandb/latest-run
|
| 121 |
+
|
| 122 |
+
accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 123 |
+
--config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex10" --trainer_args.learning_rate=3e-2 \
|
| 124 |
+
--rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 125 |
+
--trainer_args.num_train_epochs 2.0 --data.dataset_split train[:100000] \
|
| 126 |
+
--run_text 'def|o100k|b48' --trainer_args.per_device_train_batch_size 48
|
| 127 |
+
sleep 5
|
| 128 |
+
echo "10th exp finishes"
|
| 129 |
+
date +"%F %T"
|
| 130 |
+
wandb sync wandb/latest-run
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 135 |
+
--config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex11" --trainer_args.learning_rate=8e-3 \
|
| 136 |
+
--rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 137 |
+
--trainer_args.num_train_epochs 2.0 --data.dataset_split train[:100000] \
|
| 138 |
+
--run_text 'def|o100k|b48' --trainer_args.per_device_train_batch_size 48
|
| 139 |
+
|
| 140 |
+
sleep 5
|
| 141 |
+
echo "11 exp finishes"
|
| 142 |
+
date +"%F %T"
|
| 143 |
+
wandb sync wandb/latest-run
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 147 |
+
--config_path $OMINI_CONFIG --trainer_args.output_dir "./exp100/run_ex12" --trainer_args.learning_rate=8e-3 \
|
| 148 |
+
--rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 149 |
+
--trainer_args.num_train_epochs 3.0 --data.dataset_split train[:100000] \
|
| 150 |
+
--run_text 'def|o100k|b48' --trainer_args.per_device_train_batch_size 48
|
| 151 |
+
|
| 152 |
+
sleep 5
|
| 153 |
+
echo "12 exp finishes"
|
| 154 |
+
date +"%F %T"
|
| 155 |
+
wandb sync wandb/latest-run
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
bash ./scripts/merge_100k.sh
|
| 159 |
+
|
| 160 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 161 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex13" --trainer_args.learning_rate=1e-3 \
|
| 162 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'u=vkaim'
|
| 163 |
+
|
| 164 |
+
# sleep 5
|
| 165 |
+
# echo "13 exp finishes"
|
| 166 |
+
# date +"%F %T"
|
| 167 |
+
# wandb sync wandb/latest-run
|
| 168 |
+
|
| 169 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 170 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex14" --trainer_args.learning_rate=2e-3 \
|
| 171 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'a,b,def'
|
| 172 |
+
|
| 173 |
+
# sleep 5
|
| 174 |
+
# echo "14 exp finishes"
|
| 175 |
+
# date +"%F %T"
|
| 176 |
+
# wandb sync wandb/latest-run
|
| 177 |
+
|
| 178 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 179 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex15" --trainer_args.learning_rate=1e-3 \
|
| 180 |
+
# --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
|
| 181 |
+
|
| 182 |
+
# sleep 5
|
| 183 |
+
# echo "15 exp finishes"
|
| 184 |
+
# date +"%F %T"
|
nl_tasks/scripts/train_cms_reasoning.sh
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export OMINI_CONFIG=./config/commonsense.yaml
|
| 2 |
+
|
| 3 |
+
#echo $OMINI_CONFIG
|
| 4 |
+
export TOKENIZERS_PARALLELISM=true
|
| 5 |
+
|
| 6 |
+
# CUDA Include (/cuda.h)
|
| 7 |
+
CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
|
| 8 |
+
|
| 9 |
+
# 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
|
| 10 |
+
export CPATH=$CPATH:$CUDA_INCLUDE_PATH
|
| 11 |
+
export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
|
| 12 |
+
# echo "CPATH is set to: $CPATH"
|
| 13 |
+
# echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
|
| 14 |
+
|
| 15 |
+
export WANDB_PROJECT="Llama2_7B_FT_Math40k_2"
|
| 16 |
+
|
| 17 |
+
export OMP_NUM_THREADS=1
|
| 18 |
+
export MKL_NUM_THREADS=1
|
| 19 |
+
export OPENBLAS_NUM_THREADS=1
|
| 20 |
+
export NUMEXPR_NUM_THREADS=1
|
| 21 |
+
|
| 22 |
+
date +"%F %T"
|
| 23 |
+
|
| 24 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 25 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex01" --trainer_args.learning_rate=5e-5 \
|
| 26 |
+
# --rotation_adapter_config.num_rotations 2 --rotation_adapter_config.r 8
|
| 27 |
+
|
| 28 |
+
# sleep 5
|
| 29 |
+
# echo "1st exp finishes"
|
| 30 |
+
# date +"%F %T"
|
| 31 |
+
|
| 32 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 33 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex02" --trainer_args.learning_rate=5e-4 \
|
| 34 |
+
# --rotation_adapter_config.num_rotations 2 --rotation_adapter_config.r 8
|
| 35 |
+
|
| 36 |
+
# sleep 5
|
| 37 |
+
# echo "2nd exp finishes"
|
| 38 |
+
# date +"%F %T"
|
| 39 |
+
|
| 40 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 41 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex03" --trainer_args.learning_rate=1e-3 \
|
| 42 |
+
# --rotation_adapter_config.num_rotations 2 --rotation_adapter_config.r 8
|
| 43 |
+
|
| 44 |
+
# sleep 5
|
| 45 |
+
# echo "3rd exp finishes"
|
| 46 |
+
# date +"%F %T"
|
| 47 |
+
# wandb sync wandb/latest-run
|
| 48 |
+
|
| 49 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 50 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex04" --trainer_args.learning_rate=2e-3 \
|
| 51 |
+
# --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
|
| 52 |
+
# sleep 5
|
| 53 |
+
# echo "4th exp finishes"
|
| 54 |
+
# date +"%F %T"
|
| 55 |
+
# wandb sync wandb/latest-run
|
| 56 |
+
|
| 57 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 58 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex05" --trainer_args.learning_rate=2e-3 \
|
| 59 |
+
# --rotation_adapter_config.num_rotations 2 --rotation_adapter_config.r 8
|
| 60 |
+
# sleep 5
|
| 61 |
+
# echo "5th exp finishes"
|
| 62 |
+
# date +"%F %T"
|
| 63 |
+
# wandb sync wandb/latest-run
|
| 64 |
+
|
| 65 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 66 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex06" --trainer_args.learning_rate=1e-3 \
|
| 67 |
+
# --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
|
| 68 |
+
# sleep 5
|
| 69 |
+
# echo "6th exp finishes"
|
| 70 |
+
# date +"%F %T"
|
| 71 |
+
# wandb sync wandb/latest-run
|
| 72 |
+
|
| 73 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 74 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_exps7" --trainer_args.learning_rate=1e-3 \
|
| 75 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
|
| 76 |
+
# sleep 5
|
| 77 |
+
# echo "7th exp finishes"
|
| 78 |
+
# date +"%F %T"
|
| 79 |
+
# wandb sync wandb/latest-run
|
| 80 |
+
|
| 81 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 82 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex08" --trainer_args.learning_rate=2e-3 \
|
| 83 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
|
| 84 |
+
# sleep 5
|
| 85 |
+
# echo "8th exp finishes"
|
| 86 |
+
# date +"%F %T"
|
| 87 |
+
# wandb sync wandb/latest-run
|
| 88 |
+
|
| 89 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 90 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex09" --trainer_args.learning_rate=2e-3 \
|
| 91 |
+
# --rotation_adapter_config.num_rotations 16 --rotation_adapter_config.r 1
|
| 92 |
+
|
| 93 |
+
# sleep 5
|
| 94 |
+
# echo "9th exp finishes"
|
| 95 |
+
# date +"%F %T"
|
| 96 |
+
# wandb sync wandb/latest-run
|
| 97 |
+
|
| 98 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 99 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex10" --trainer_args.learning_rate=2e-3 \
|
| 100 |
+
# --rotation_adapter_config.num_rotations 8 --rotation_adapter_config.r 2
|
| 101 |
+
|
| 102 |
+
# sleep 5
|
| 103 |
+
# echo "10 exp finishes"
|
| 104 |
+
# date +"%F %T"
|
| 105 |
+
# wandb sync wandb/latest-run
|
| 106 |
+
|
| 107 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 108 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex11" --trainer_args.learning_rate=1e-2 \
|
| 109 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
|
| 110 |
+
|
| 111 |
+
# sleep 5
|
| 112 |
+
# echo "11 exp finishes"
|
| 113 |
+
# date +"%F %T"
|
| 114 |
+
# wandb sync wandb/latest-run
|
| 115 |
+
|
| 116 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 117 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex12" --trainer_args.learning_rate=1e-2 \
|
| 118 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'u=v,def'
|
| 119 |
+
|
| 120 |
+
# sleep 5
|
| 121 |
+
# echo "12 exp finishes"
|
| 122 |
+
# date +"%F %T"
|
| 123 |
+
# wandb sync wandb/latest-run
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
### continue with 40k
|
| 127 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 128 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex13_3ep" --trainer_args.learning_rate=1e-3 \
|
| 129 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def' \
|
| 130 |
+
# --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 200
|
| 131 |
+
|
| 132 |
+
# sleep 5
|
| 133 |
+
# echo "13 exp finishes"
|
| 134 |
+
# date +"%F %T"
|
| 135 |
+
# wandb sync wandb/latest-run
|
| 136 |
+
|
| 137 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 138 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex14_3ep" --trainer_args.learning_rate=2e-4 \
|
| 139 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def' \
|
| 140 |
+
# --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 200
|
| 141 |
+
|
| 142 |
+
# sleep 5
|
| 143 |
+
# echo "14 exp finishes"
|
| 144 |
+
# date +"%F %T"
|
| 145 |
+
# wandb sync wandb/latest-run
|
| 146 |
+
|
| 147 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 148 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex15_3ep" --trainer_args.learning_rate=5e-4 \
|
| 149 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def' \
|
| 150 |
+
# --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 200
|
| 151 |
+
|
| 152 |
+
# sleep 5
|
| 153 |
+
# echo "15 exp finishes"
|
| 154 |
+
# date +"%F %T"
|
| 155 |
+
# wandb sync wandb/latest-run
|
| 156 |
+
|
| 157 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 158 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex16_3ep" --trainer_args.learning_rate=1e-3 \
|
| 159 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.05' \
|
| 160 |
+
# --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 200
|
| 161 |
+
|
| 162 |
+
# sleep 5
|
| 163 |
+
# echo "15 exp finishes"
|
| 164 |
+
# date +"%F %T"
|
| 165 |
+
# wandb sync wandb/latest-run
|
| 166 |
+
|
| 167 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 168 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex17_3ep" --trainer_args.learning_rate=2e-3 \
|
| 169 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.05' \
|
| 170 |
+
# --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 200
|
| 171 |
+
|
| 172 |
+
# sleep 5
|
| 173 |
+
# echo "15 exp finishes"
|
| 174 |
+
# date +"%F %T"
|
| 175 |
+
# wandb sync wandb/latest-run
|
| 176 |
+
|
| 177 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 178 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex18_2ep" --trainer_args.learning_rate=1e-3 \
|
| 179 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
|
| 180 |
+
# --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 200
|
| 181 |
+
|
| 182 |
+
# sleep 5
|
| 183 |
+
# echo "15 exp finishes"
|
| 184 |
+
# date +"%F %T"
|
| 185 |
+
# wandb sync wandb/latest-run
|
| 186 |
+
|
| 187 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 188 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex19_2ep" --trainer_args.learning_rate=5e-3 \
|
| 189 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
|
| 190 |
+
# --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 200
|
| 191 |
+
|
| 192 |
+
# sleep 5
|
| 193 |
+
# echo "19 exp finishes"
|
| 194 |
+
# date +"%F %T"
|
| 195 |
+
# wandb sync wandb/latest-run
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# 140126
|
| 200 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 201 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex20_2ep" --trainer_args.learning_rate=1e-3 \
|
| 202 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
|
| 203 |
+
# --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 11
|
| 204 |
+
|
| 205 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 206 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex21_2ep" --trainer_args.learning_rate=1e-3 \
|
| 207 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
|
| 208 |
+
# --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 10
|
| 209 |
+
|
| 210 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 211 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex24_3ep" --trainer_args.learning_rate=1e-3 \
|
| 212 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
|
| 213 |
+
# --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 100 --seed 10
|
| 214 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 215 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex25_3ep" --trainer_args.learning_rate=1e-3 \
|
| 216 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
|
| 217 |
+
# --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 100 --seed 12
|
| 218 |
+
|
| 219 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 220 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex22_2ep" --trainer_args.learning_rate=1e-3 \
|
| 221 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
|
| 222 |
+
# --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 12
|
| 223 |
+
|
| 224 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 225 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex23_3ep" --trainer_args.learning_rate=1e-3 \
|
| 226 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
|
| 227 |
+
# --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 100 --seed 11
|
| 228 |
+
|
| 229 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 230 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex26_2ep" --trainer_args.learning_rate=8e-4 \
|
| 231 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
|
| 232 |
+
# --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 11
|
| 233 |
+
|
| 234 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 235 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex27_2ep" --trainer_args.learning_rate=8e-4 \
|
| 236 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
|
| 237 |
+
# --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 10
|
| 238 |
+
|
| 239 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 240 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex28_2ep" --trainer_args.learning_rate=2e-3 \
|
| 241 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
|
| 242 |
+
# --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 11
|
| 243 |
+
|
| 244 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 245 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex29_2ep" --trainer_args.learning_rate=2e-3 \
|
| 246 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
|
| 247 |
+
# --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 10
|
| 248 |
+
|
| 249 |
+
accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 250 |
+
--config_path $OMINI_CONFIG --trainer_args.output_dir "./exprep/run_ex30" --trainer_args.learning_rate=8e-4 \
|
| 251 |
+
--rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
|
| 252 |
+
--trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 20
|
| 253 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 254 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exprep/run_ex31" --trainer_args.learning_rate=8e-4 \
|
| 255 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
|
| 256 |
+
# --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 21
|
| 257 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 258 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exprep/run_ex32" --trainer_args.learning_rate=8e-4 \
|
| 259 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'init=def|dr0.10' \
|
| 260 |
+
# --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 100 --seed 22
|
nl_tasks/scripts/train_initn40k.sh
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export OMINI_CONFIG=./config/commonsense.yaml
|
| 2 |
+
|
| 3 |
+
#echo $OMINI_CONFIG
|
| 4 |
+
export TOKENIZERS_PARALLELISM=true
|
| 5 |
+
|
| 6 |
+
# CUDA Include (/cuda.h)
|
| 7 |
+
CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
|
| 8 |
+
|
| 9 |
+
# 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
|
| 10 |
+
export CPATH=$CPATH:$CUDA_INCLUDE_PATH
|
| 11 |
+
export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
|
| 12 |
+
# echo "CPATH is set to: $CPATH"
|
| 13 |
+
# echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
|
| 14 |
+
|
| 15 |
+
export WANDB_PROJECT="Llama2_7B_FT_Math40k_2"
|
| 16 |
+
|
| 17 |
+
export OMP_NUM_THREADS=1
|
| 18 |
+
export MKL_NUM_THREADS=1
|
| 19 |
+
export OPENBLAS_NUM_THREADS=1
|
| 20 |
+
export NUMEXPR_NUM_THREADS=1
|
| 21 |
+
|
| 22 |
+
date +"%F %T"
|
| 23 |
+
|
| 24 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 25 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp_init/run_ex01" --trainer_args.learning_rate=1e-3 \
|
| 26 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text "init|kaim_out_u=v"
|
| 27 |
+
|
| 28 |
+
# sleep 5
|
| 29 |
+
# echo "1st exp finishes"
|
| 30 |
+
# date +"%F %T"
|
| 31 |
+
# wandb sync wandb/latest-run
|
| 32 |
+
|
| 33 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 34 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp_init/run_ex02" --trainer_args.learning_rate=1e-3 \
|
| 35 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text "init|kaim_out_u=v(ratio)"
|
| 36 |
+
|
| 37 |
+
# sleep 5
|
| 38 |
+
# echo "2nd exp finishes"
|
| 39 |
+
# date +"%F %T"
|
| 40 |
+
# wandb sync wandb/latest-run
|
| 41 |
+
|
| 42 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 43 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex03" --trainer_args.learning_rate=1e-3 \
|
| 44 |
+
# --rotation_adapter_config.num_rotations 2 --rotation_adapter_config.r 8
|
| 45 |
+
|
| 46 |
+
# sleep 5
|
| 47 |
+
# echo "3rd exp finishes"
|
| 48 |
+
# date +"%F %T"
|
| 49 |
+
# wandb sync wandb/latest-run
|
| 50 |
+
|
| 51 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 52 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex04" --trainer_args.learning_rate=2e-3 \
|
| 53 |
+
# --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
|
| 54 |
+
# sleep 5
|
| 55 |
+
# echo "4th exp finishes"
|
| 56 |
+
# date +"%F %T"
|
| 57 |
+
# wandb sync wandb/latest-run
|
| 58 |
+
|
| 59 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 60 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex05" --trainer_args.learning_rate=2e-3 \
|
| 61 |
+
# --rotation_adapter_config.num_rotations 2 --rotation_adapter_config.r 8
|
| 62 |
+
# sleep 5
|
| 63 |
+
# echo "5th exp finishes"
|
| 64 |
+
# date +"%F %T"
|
| 65 |
+
# wandb sync wandb/latest-run
|
| 66 |
+
|
| 67 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 68 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex06" --trainer_args.learning_rate=1e-3 \
|
| 69 |
+
# --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
|
| 70 |
+
# sleep 5
|
| 71 |
+
# echo "6th exp finishes"
|
| 72 |
+
# date +"%F %T"
|
| 73 |
+
# wandb sync wandb/latest-run
|
| 74 |
+
|
| 75 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 76 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_exps7" --trainer_args.learning_rate=1e-3 \
|
| 77 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
|
| 78 |
+
# sleep 5
|
| 79 |
+
# echo "7th exp finishes"
|
| 80 |
+
# date +"%F %T"
|
| 81 |
+
# wandb sync wandb/latest-run
|
| 82 |
+
|
| 83 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 84 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex08" --trainer_args.learning_rate=2e-3 \
|
| 85 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
|
| 86 |
+
# sleep 5
|
| 87 |
+
# echo "8th exp finishes"
|
| 88 |
+
# date +"%F %T"
|
| 89 |
+
# wandb sync wandb/latest-run
|
| 90 |
+
|
| 91 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 92 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex09" --trainer_args.learning_rate=2e-3 \
|
| 93 |
+
# --rotation_adapter_config.num_rotations 16 --rotation_adapter_config.r 1
|
| 94 |
+
|
| 95 |
+
# sleep 5
|
| 96 |
+
# echo "9th exp finishes"
|
| 97 |
+
# date +"%F %T"
|
| 98 |
+
# wandb sync wandb/latest-run
|
| 99 |
+
|
| 100 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 101 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex10" --trainer_args.learning_rate=2e-3 \
|
| 102 |
+
# --rotation_adapter_config.num_rotations 8 --rotation_adapter_config.r 2
|
| 103 |
+
|
| 104 |
+
# sleep 5
|
| 105 |
+
# echo "10 exp finishes"
|
| 106 |
+
# date +"%F %T"
|
| 107 |
+
# wandb sync wandb/latest-run
|
| 108 |
+
|
| 109 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 110 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex11" --trainer_args.learning_rate=1e-2 \
|
| 111 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
|
| 112 |
+
|
| 113 |
+
# sleep 5
|
| 114 |
+
# echo "11 exp finishes"
|
| 115 |
+
# date +"%F %T"
|
| 116 |
+
# wandb sync wandb/latest-run
|
| 117 |
+
|
| 118 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 119 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex12" --trainer_args.learning_rate=1e-2 \
|
| 120 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'u=v,def'
|
| 121 |
+
|
| 122 |
+
# sleep 5
|
| 123 |
+
# echo "12 exp finishes"
|
| 124 |
+
# date +"%F %T"
|
| 125 |
+
# wandb sync wandb/latest-run
|
| 126 |
+
|
| 127 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 128 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex13" --trainer_args.learning_rate=1e-3 \
|
| 129 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'u=vkaim'
|
| 130 |
+
|
| 131 |
+
# sleep 5
|
| 132 |
+
# echo "13 exp finishes"
|
| 133 |
+
# date +"%F %T"
|
| 134 |
+
# wandb sync wandb/latest-run
|
| 135 |
+
|
| 136 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 137 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex14" --trainer_args.learning_rate=2e-3 \
|
| 138 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'a,b,def'
|
| 139 |
+
|
| 140 |
+
# sleep 5
|
| 141 |
+
# echo "14 exp finishes"
|
| 142 |
+
# date +"%F %T"
|
| 143 |
+
# wandb sync wandb/latest-run
|
| 144 |
+
|
| 145 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 146 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex15" --trainer_args.learning_rate=1e-3 \
|
| 147 |
+
# --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
|
| 148 |
+
|
| 149 |
+
# sleep 5
|
| 150 |
+
# echo "15 exp finishes"
|
| 151 |
+
# date +"%F %T"
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 155 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex17" --trainer_args.learning_rate=1e-3 \
|
| 156 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 157 |
+
# --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
|
| 158 |
+
# --run_text "dropout|fix_token"
|
| 159 |
+
# sleep 5
|
| 160 |
+
# echo "15 exp finishes"
|
| 161 |
+
# date +"%F %T"
|
| 162 |
+
# wandb sync wandb/latest-run
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 166 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex18" --trainer_args.learning_rate=1e-3 \
|
| 167 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 168 |
+
# --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
|
| 169 |
+
# --run_text "dropout|fix_token"
|
| 170 |
+
# sleep 5
|
| 171 |
+
# echo "158exp finishes"
|
| 172 |
+
# date +"%F %T"
|
| 173 |
+
# wandb sync wandb/latest-run
|
| 174 |
+
|
| 175 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 176 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex19" --trainer_args.learning_rate=2e-3 \
|
| 177 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 178 |
+
# --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
|
| 179 |
+
# --run_text "dropout|fix_token"
|
| 180 |
+
# sleep 5
|
| 181 |
+
# echo "19 exp finishes"
|
| 182 |
+
# date +"%F %T"
|
| 183 |
+
# wandb sync wandb/latest-run
|
| 184 |
+
|
| 185 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 186 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex20" --trainer_args.learning_rate=8e-4 \
|
| 187 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 188 |
+
# --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
|
| 189 |
+
# --run_text "dropout|fix_token"
|
| 190 |
+
# sleep 5
|
| 191 |
+
# echo "20 exp finishes"
|
| 192 |
+
# date +"%F %T"
|
| 193 |
+
# wandb sync wandb/latest-run
|
| 194 |
+
|
| 195 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 196 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex21" --trainer_args.learning_rate=1e-3 \
|
| 197 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 198 |
+
# --trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 500 --data.dataset_split train[:41023] --data.split_ratio 0.02493 \
|
| 199 |
+
# --run_text "dropout|2ep|1e3"
|
| 200 |
+
# sleep 5
|
| 201 |
+
# echo "21 exp finishes"
|
| 202 |
+
# date +"%F %T"
|
| 203 |
+
# wandb sync wandb/latest-run
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# back to official 40k
|
| 207 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 208 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex22" --trainer_args.learning_rate=1e-3 \
|
| 209 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 210 |
+
# --trainer_args.num_train_epochs 2.0 --data.dataset_split train \
|
| 211 |
+
# --run_text "drop0.1|2ep|1e3|40k"
|
| 212 |
+
# sleep 5
|
| 213 |
+
# echo "21 exp finishes"
|
| 214 |
+
# date +"%F %T"
|
| 215 |
+
# wandb sync wandb/latest-run
|
| 216 |
+
|
| 217 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 218 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex23" --trainer_args.learning_rate=1e-3 \
|
| 219 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 220 |
+
# --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 221 |
+
# --run_text "drop0.1|2ep|1e3|40k"
|
| 222 |
+
# sleep 5
|
| 223 |
+
# echo "21 exp finishes"
|
| 224 |
+
# date +"%F %T"
|
| 225 |
+
# wandb sync wandb/latest-run
|
| 226 |
+
|
| 227 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 228 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex24" --trainer_args.learning_rate=1e-2 \
|
| 229 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 230 |
+
# --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 231 |
+
# --run_text "drop0.1|2ep|1e2|40k"
|
| 232 |
+
# sleep 5
|
| 233 |
+
# echo "21 exp finishes"
|
| 234 |
+
# date +"%F %T"
|
| 235 |
+
# wandb sync wandb/latest-run
|
| 236 |
+
|
| 237 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 238 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex25" --trainer_args.learning_rate=2e-3 \
|
| 239 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 240 |
+
# --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 241 |
+
# --run_text "drop0.1|2ep|2e3|40k"
|
| 242 |
+
# sleep 5
|
| 243 |
+
# echo "21 exp finishes"
|
| 244 |
+
# date +"%F %T"
|
| 245 |
+
# wandb sync wandb/latest-run
|
| 246 |
+
|
| 247 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 248 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex26" --trainer_args.learning_rate=5e-3 \
|
| 249 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 250 |
+
# --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 251 |
+
# --run_text "drop0.1|2ep|5e3|40k"
|
| 252 |
+
# sleep 5
|
| 253 |
+
# echo "21 exp finishes"
|
| 254 |
+
# date +"%F %T"
|
| 255 |
+
# wandb sync wandb/latest-run
|
| 256 |
+
|
| 257 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 258 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex27" --trainer_args.learning_rate=8e-3 \
|
| 259 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 260 |
+
# --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 261 |
+
# --run_text "drop0.1|2ep|8e3|40k"
|
| 262 |
+
# sleep 5
|
| 263 |
+
# echo "21 exp finishes"
|
| 264 |
+
# date +"%F %T"
|
| 265 |
+
# wandb sync wandb/latest-run
|
| 266 |
+
|
| 267 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 268 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex28" --trainer_args.learning_rate=2e-2 \
|
| 269 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 270 |
+
# --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 271 |
+
# --run_text "drop0.1|2ep|2e2|40k"
|
| 272 |
+
# sleep 5
|
| 273 |
+
# echo "21 exp finishes"
|
| 274 |
+
# date +"%F %T"
|
| 275 |
+
# wandb sync wandb/latest-run
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 279 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex29" --trainer_args.learning_rate=5e-3 \
|
| 280 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 281 |
+
# --trainer_args.num_train_epochs 2.0 --data.dataset_split train \
|
| 282 |
+
# --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
|
| 283 |
+
# sleep 5
|
| 284 |
+
# echo "29 exp finishes"
|
| 285 |
+
# date +"%F %T"
|
| 286 |
+
# wandb sync wandb/latest-run
|
| 287 |
+
|
| 288 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 289 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex30" --trainer_args.learning_rate=1e-3 \
|
| 290 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 291 |
+
# --trainer_args.num_train_epochs 2.0 --data.dataset_split train \
|
| 292 |
+
# --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
|
| 293 |
+
# sleep 5
|
| 294 |
+
# echo "29 exp finishes"
|
| 295 |
+
# date +"%F %T"
|
| 296 |
+
# wandb sync wandb/latest-run
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 300 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex31" --trainer_args.learning_rate=5e-3 \
|
| 301 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 302 |
+
# --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 303 |
+
# --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
|
| 304 |
+
# sleep 5
|
| 305 |
+
# echo "29 exp finishes"
|
| 306 |
+
# date +"%F %T"
|
| 307 |
+
# wandb sync wandb/latest-run
|
| 308 |
+
|
| 309 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 310 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex32" --trainer_args.learning_rate=1e-3 \
|
| 311 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 312 |
+
# --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 313 |
+
# --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
|
| 314 |
+
# sleep 5
|
| 315 |
+
# echo "29 exp finishes"
|
| 316 |
+
# date +"%F %T"
|
| 317 |
+
# wandb sync wandb/latest-run
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 321 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex33" --trainer_args.learning_rate=1e-2 \
|
| 322 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 323 |
+
# --trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 324 |
+
# --run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
|
| 325 |
+
# sleep 5
|
| 326 |
+
# echo "29 exp finishes"
|
| 327 |
+
# date +"%F %T"
|
| 328 |
+
# wandb sync wandb/latest-run
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
accelerate launch --main_process_port 41353 -m src.ft_mathQ \
|
| 332 |
+
--config_path $OMINI_CONFIG --trainer_args.output_dir "./exps/run_ex34" --trainer_args.learning_rate=2e-2 \
|
| 333 |
+
--rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 334 |
+
--trainer_args.num_train_epochs 3.0 --data.dataset_split train \
|
| 335 |
+
--run_text "drop0.1|2ep|initu=v=0.01|40k" --trainer_args.per_device_train_batch_size 48
|
| 336 |
+
sleep 5
|
| 337 |
+
echo "29 exp finishes"
|
| 338 |
+
date +"%F %T"
|
| 339 |
+
wandb sync wandb/latest-run
|
| 340 |
+
|
| 341 |
+
bash scripts/merge.sh
|
nl_tasks/scripts/train_math.sh
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export OMINI_CONFIG=./config/math395.yaml
|
| 2 |
+
|
| 3 |
+
#echo $OMINI_CONFIG
|
| 4 |
+
export TOKENIZERS_PARALLELISM=true
|
| 5 |
+
|
| 6 |
+
# CUDA Include (/cuda.h)
|
| 7 |
+
CUDA_INCLUDE_PATH="/home/work/miniconda3/envs/allm/include"
|
| 8 |
+
|
| 9 |
+
# 3. Add into CPATH & CPLUS_INCLUDE_PATH (C/C++ compiler)
|
| 10 |
+
export CPATH=$CPATH:$CUDA_INCLUDE_PATH
|
| 11 |
+
export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$CUDA_INCLUDE_PATH
|
| 12 |
+
# echo "CPATH is set to: $CPATH"
|
| 13 |
+
# echo "CPLUS_INCLUDE_PATH is set to: $CPLUS_INCLUDE_PATH"
|
| 14 |
+
|
| 15 |
+
export WANDB_PROJECT="Llama2_7B_FT_Math_395k"
|
| 16 |
+
|
| 17 |
+
export OMP_NUM_THREADS=1
|
| 18 |
+
export MKL_NUM_THREADS=1
|
| 19 |
+
export OPENBLAS_NUM_THREADS=1
|
| 20 |
+
export NUMEXPR_NUM_THREADS=1
|
| 21 |
+
|
| 22 |
+
date +"%F %T"
|
| 23 |
+
|
| 24 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 25 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex01" --trainer_args.learning_rate=1e-3 \
|
| 26 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'def'
|
| 27 |
+
|
| 28 |
+
# sleep 5
|
| 29 |
+
# echo "1st exp finishes"
|
| 30 |
+
# date +"%F %T"
|
| 31 |
+
# wandb sync wandb/latest-run
|
| 32 |
+
|
| 33 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 34 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex02" --trainer_args.learning_rate=5e-3 \
|
| 35 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
|
| 36 |
+
|
| 37 |
+
# sleep 5
|
| 38 |
+
# echo "2nd exp finishes"
|
| 39 |
+
# date +"%F %T"
|
| 40 |
+
# wandb sync wandb/latest-run
|
| 41 |
+
|
| 42 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 43 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex03" --trainer_args.learning_rate=2e-4 \
|
| 44 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
|
| 45 |
+
|
| 46 |
+
# sleep 5
|
| 47 |
+
# echo "3rd exp finishes"
|
| 48 |
+
# date +"%F %T"
|
| 49 |
+
# wandb sync wandb/latest-run
|
| 50 |
+
|
| 51 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 52 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex04" --trainer_args.learning_rate=1e-3 \
|
| 53 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
|
| 54 |
+
|
| 55 |
+
# sleep 5
|
| 56 |
+
# echo "4rd exp finishes"
|
| 57 |
+
# date +"%F %T"
|
| 58 |
+
# wandb sync wandb/latest-run
|
| 59 |
+
|
| 60 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 61 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex05" --trainer_args.learning_rate=1e-3 \
|
| 62 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 63 |
+
# --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:101011] --data.split_ratio 0.01
|
| 64 |
+
# sleep 5
|
| 65 |
+
# echo "5th exp finishes"
|
| 66 |
+
# date +"%F %T"
|
| 67 |
+
# wandb sync wandb/latest-run
|
| 68 |
+
|
| 69 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 70 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex06" --trainer_args.learning_rate=2e-3 \
|
| 71 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 72 |
+
# --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:101011] --data.split_ratio 0.01
|
| 73 |
+
# sleep 5
|
| 74 |
+
# echo "6th exp finishes"
|
| 75 |
+
# date +"%F %T"
|
| 76 |
+
# wandb sync wandb/latest-run
|
| 77 |
+
|
| 78 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 79 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex0s7" --trainer_args.learning_rate=5e-3 \
|
| 80 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 81 |
+
# --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:101011] --data.split_ratio 0.01
|
| 82 |
+
# sleep 5
|
| 83 |
+
# echo "7th exp finishes"
|
| 84 |
+
# date +"%F %T"
|
| 85 |
+
# wandb sync wandb/latest-run
|
| 86 |
+
|
| 87 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 88 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex08" --trainer_args.learning_rate=1e-4 \
|
| 89 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 90 |
+
# --trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:101011] --data.split_ratio 0.01 \
|
| 91 |
+
# --trainer_args.per_device_train_batch_size 32 --run_text 'u2e2,def'
|
| 92 |
+
# sleep 5
|
| 93 |
+
# echo "8th exp finishes"
|
| 94 |
+
# date +"%F %T"
|
| 95 |
+
# wandb sync wandb/latest-run
|
| 96 |
+
|
| 97 |
+
accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 98 |
+
--config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex09" --trainer_args.learning_rate=2e-3 \
|
| 99 |
+
--rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 100 |
+
--trainer_args.num_train_epochs 3.0 --trainer_args.eval_steps 500 --data.dataset_split train[:101011] --data.split_ratio 0.01 \
|
| 101 |
+
--trainer_args.per_device_train_batch_size 32 --run_text 'init=def|fix_token'
|
| 102 |
+
|
| 103 |
+
sleep 5
|
| 104 |
+
echo "9th exp finishes"
|
| 105 |
+
date +"%F %T"
|
| 106 |
+
wandb sync wandb/latest-run
|
| 107 |
+
|
| 108 |
+
accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 109 |
+
--config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex10" --trainer_args.learning_rate=2e-3 \
|
| 110 |
+
--rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 \
|
| 111 |
+
--trainer_args.num_train_epochs 2.0 --trainer_args.eval_steps 500 --data.dataset_split train[:101011] --data.split_ratio 0.01 \
|
| 112 |
+
--trainer_args.per_device_train_batch_size 32 --run_text "init=def|fix_token"
|
| 113 |
+
|
| 114 |
+
sleep 5
|
| 115 |
+
echo "10 exp finishes"
|
| 116 |
+
date +"%F %T"
|
| 117 |
+
wandb sync wandb/latest-run
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 121 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex11" --trainer_args.learning_rate=1e-2 \
|
| 122 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16
|
| 123 |
+
|
| 124 |
+
# sleep 5
|
| 125 |
+
# echo "11 exp finishes"
|
| 126 |
+
# date +"%F %T"
|
| 127 |
+
# wandb sync wandb/latest-run
|
| 128 |
+
|
| 129 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 130 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex12" --trainer_args.learning_rate=1e-2 \
|
| 131 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'u=v,def'
|
| 132 |
+
|
| 133 |
+
# sleep 5
|
| 134 |
+
# echo "12 exp finishes"
|
| 135 |
+
# date +"%F %T"
|
| 136 |
+
# wandb sync wandb/latest-run
|
| 137 |
+
|
| 138 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 139 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex13" --trainer_args.learning_rate=1e-3 \
|
| 140 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'u=vkaim'
|
| 141 |
+
|
| 142 |
+
# sleep 5
|
| 143 |
+
# echo "13 exp finishes"
|
| 144 |
+
# date +"%F %T"
|
| 145 |
+
# wandb sync wandb/latest-run
|
| 146 |
+
|
| 147 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 148 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex14" --trainer_args.learning_rate=2e-3 \
|
| 149 |
+
# --rotation_adapter_config.num_rotations 1 --rotation_adapter_config.r 16 --run_text 'a,b,def'
|
| 150 |
+
|
| 151 |
+
# sleep 5
|
| 152 |
+
# echo "14 exp finishes"
|
| 153 |
+
# date +"%F %T"
|
| 154 |
+
# wandb sync wandb/latest-run
|
| 155 |
+
|
| 156 |
+
# accelerate launch --main_process_port 41353 -m src.ft_mathR \
|
| 157 |
+
# --config_path $OMINI_CONFIG --trainer_args.output_dir "./exp395/run_ex15" --trainer_args.learning_rate=1e-3 \
|
| 158 |
+
# --rotation_adapter_config.num_rotations 4 --rotation_adapter_config.r 4
|
| 159 |
+
|
| 160 |
+
# sleep 5
|
| 161 |
+
# echo "15 exp finishes"
|
| 162 |
+
# date +"%F %T"
|
nl_tasks/setup.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright [2024] [Zhuo Chen]
|
| 2 |
+
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import setuptools
|
| 17 |
+
|
| 18 |
+
setuptools.setup(
|
| 19 |
+
name="rpeft",
|
| 20 |
+
version="0.0.2",
|
| 21 |
+
author="SDML",
|
| 22 |
+
packages=setuptools.find_packages(),
|
| 23 |
+
install_requires=[
|
| 24 |
+
'transformers>=4.0.0',
|
| 25 |
+
'torch>=2.0.0'
|
| 26 |
+
],
|
| 27 |
+
python_requires='>=3.9',
|
| 28 |
+
)
|
nl_tasks/src/bb.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
nl_tasks/src/cc.ipynb
ADDED
|
File without changes
|
nl_tasks/src/config.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field, fields, asdict
|
| 2 |
+
from typing import Optional, List, Literal, Dict, Any, Union
|
| 3 |
+
from transformers import TrainingArguments, Trainer
|
| 4 |
+
from omegaconf import OmegaConf
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class ModelConfig:
|
| 10 |
+
model_name: str = ""
|
| 11 |
+
dropout: float = 0.0
|
| 12 |
+
model_max_seq_length: int = field(default=512)
|
| 13 |
+
data_collator_mode: str=field(default='fixed', metadata={"help": "fixed or dynamic padding in DataCollator"})
|
| 14 |
+
lambda_reg: float = field(default=1e-4, metadata={"help": "The control strength of regularity"})
|
| 15 |
+
adapter_path: Optional[str] = field(default=None)
|
| 16 |
+
|
| 17 |
+
merge_adapter_path: Optional[str] = field(default=None)
|
| 18 |
+
merge_output_path: Optional[str] = field(default=None)
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class RotationConfig:
|
| 22 |
+
r: int = field(default=4)
|
| 23 |
+
num_rotations: int = field(default=4)
|
| 24 |
+
task_type: str = "CAUSAL_LM"
|
| 25 |
+
target_modules: List[str] = field(default_factory=lambda: ["q_proj",])
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class DataConfig:
|
| 29 |
+
dataset_name: str = 'math'
|
| 30 |
+
split_ratio: float = field(default=0.01)
|
| 31 |
+
path: str = "./nl_tasks/data/MetaMathQA-40K/MetaMathQA-40K.json"
|
| 32 |
+
dataset_split: str = field(default="train[:1000]", metadata={"help": "(`['train', 'test', 'eval']`):"})
|
| 33 |
+
adapter_names: List[Optional[str]] = field(default_factory=lambda: ["default"]) ###
|
| 34 |
+
dataset_field: List[str] = field(default_factory=list, metadata={"help": "Fields of dataset input and output."})
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class TrainingOverride:
|
| 39 |
+
optim: str=field(default="adamw_torch") ##
|
| 40 |
+
eval_strategy: str=field(default='no')
|
| 41 |
+
per_device_train_batch_size: int=field(default=8) ##
|
| 42 |
+
per_device_eval_batch_size: int=field(default=8) ##
|
| 43 |
+
|
| 44 |
+
learning_rate: float = field(default=1e-05)
|
| 45 |
+
lr_scheduler_type: str = field(default='cosine')
|
| 46 |
+
# warmup_ratio: float = field(default=0.1)
|
| 47 |
+
warmup_steps: int = field(default=0)
|
| 48 |
+
|
| 49 |
+
gradient_checkpointing: bool = field(default=False)
|
| 50 |
+
gradient_accumulation_steps: int=field(default=1)
|
| 51 |
+
output_dir: str = field(default="runs")
|
| 52 |
+
save_steps: float = field(default=0)
|
| 53 |
+
save_strategy: str =field(default='no')
|
| 54 |
+
# save_total_limit: int=field(default=1) No need any more
|
| 55 |
+
bf16: bool=field(default=False)
|
| 56 |
+
bf16_full_eval: bool=field(default=False)
|
| 57 |
+
save_safetensors: bool=field(default=False)
|
| 58 |
+
|
| 59 |
+
report_to: Union[None, str, list[str]]=field(default="none")
|
| 60 |
+
logging_steps: int=field(default=25) # we use int only
|
| 61 |
+
# logging_first_step: bool=field(default=False)
|
| 62 |
+
eval_steps: Union[None,int]=field(default=None) # we use int only f
|
| 63 |
+
|
| 64 |
+
dataloader_num_workers: int = field(default=1)
|
| 65 |
+
dataloader_pin_memory: bool = field(default=True) ###
|
| 66 |
+
dataloader_persistent_workers: bool=field(default=True) ###
|
| 67 |
+
dataloader_prefetch_factor: int = field(default=1) ###
|
| 68 |
+
|
| 69 |
+
num_train_epochs: float = field(default=1.0)
|
| 70 |
+
max_steps: int=field(default=-1)
|
| 71 |
+
load_best_model_at_end: bool = field(default=True)
|
| 72 |
+
|
| 73 |
+
@dataclass
|
| 74 |
+
class GlueConfig:
|
| 75 |
+
task_name: str = field(default='mnli')
|
| 76 |
+
pad_to_max_length: bool = field(default=True)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@dataclass
|
| 80 |
+
class MainConfig:
|
| 81 |
+
model: ModelConfig = field(default_factory=ModelConfig)
|
| 82 |
+
rotation_adapter_config: RotationConfig = field(default_factory=RotationConfig)
|
| 83 |
+
data: DataConfig = field(default_factory=DataConfig)
|
| 84 |
+
trainer_args: TrainingOverride = field(default_factory=TrainingOverride)
|
| 85 |
+
|
| 86 |
+
glue: GlueConfig = field(default_factory=GlueConfig)
|
| 87 |
+
project_name: str = "llm_rotation"
|
| 88 |
+
seed: int = 42
|
| 89 |
+
run_text: str=field(default='def')
|
| 90 |
+
# device: str = field(default='cpu')
|
| 91 |
+
|
| 92 |
+
@dataclass
|
| 93 |
+
class HFTrainingArguments(TrainingArguments):
|
| 94 |
+
extension: Optional[Dict[str, Any]] = field(
|
| 95 |
+
default=None,
|
| 96 |
+
metadata={"help": "Serialized MainConfig excluding training args"}
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def convert_to_trainer_args(main_cfg: MainConfig) -> HFTrainingArguments:
|
| 100 |
+
"""
|
| 101 |
+
Maps MainConfig to MyTrainingArguments.
|
| 102 |
+
Logic:
|
| 103 |
+
1. Extract 'training' fields -> Pass to TrainingArguments constructor.
|
| 104 |
+
2. Pack 'model', 'data', etc. -> Put into 'extension'.
|
| 105 |
+
"""
|
| 106 |
+
KEY = "trainer_args"
|
| 107 |
+
# 1. Convert OmegaConf/Dataclass to pure Python dict
|
| 108 |
+
# resolve=True ensures variables like ${model.name} are interpolated
|
| 109 |
+
full_dict = asdict(main_cfg)
|
| 110 |
+
|
| 111 |
+
# 2. Extract Training Arguments
|
| 112 |
+
# These will be unpack **kwargs to initialize the parent TrainingArguments
|
| 113 |
+
train_args_dict = full_dict.pop(KEY)
|
| 114 |
+
|
| 115 |
+
# 3. The rest (model, data, seed) goes into extension
|
| 116 |
+
extension_payload = full_dict
|
| 117 |
+
|
| 118 |
+
# 4. Initialize MyTrainingArguments
|
| 119 |
+
# Note: We must ensure train_args_dict keys match TrainingArguments fields.
|
| 120 |
+
try:
|
| 121 |
+
args = HFTrainingArguments(**train_args_dict)
|
| 122 |
+
except TypeError as e:
|
| 123 |
+
print(f"Error: Your 'training' config contains keys unknown to HF TrainingArguments: {e}")
|
| 124 |
+
sys.exit(1)
|
| 125 |
+
|
| 126 |
+
# 5. Attach the extension
|
| 127 |
+
args.extension = extension_payload
|
| 128 |
+
|
| 129 |
+
return args
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@dataclass
|
| 135 |
+
class Training:
|
| 136 |
+
model_name_or_path: Optional[str] = field(default="huggyllama/llama-7b")
|
| 137 |
+
adapter_name_or_path: Optional[str] = field(default=None)
|
| 138 |
+
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
|
| 139 |
+
dataset_split: str = field(
|
| 140 |
+
default="train[:100000]", metadata={"help": "(`['train', 'test', 'eval']`):"}
|
| 141 |
+
)
|
| 142 |
+
dataset_field: List[str] = field(
|
| 143 |
+
default=None, metadata={"help": "Fields of dataset input and output."}
|
| 144 |
+
)
|
| 145 |
+
optim: str = field(default="adamw_torch")
|
| 146 |
+
model_max_length: int = field(default=512, metadata={
|
| 147 |
+
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, )
|
| 148 |
+
hrft_r: int = field(default=8, metadata={
|
| 149 |
+
"help": "The rank of the adapter. When passing `None` and `adapter_name_or_path` is also `None`, full fine-tuning is used."})
|
| 150 |
+
init_a: float = field(default=1e-4, metadata={"help": "The initial weights"})
|
| 151 |
+
eps: float = field(default=1e-4, metadata={"help": "The control strength of COFT. The freedom of rotation."})
|
| 152 |
+
lamda: float = field(default=1e-4, metadata={"help": "The control strength of regularity"})
|
| 153 |
+
add_orth: str = field(default='none', metadata={"help": ""})
|
| 154 |
+
init_weights: Literal[True, "pissa"] = field(
|
| 155 |
+
default=True,
|
| 156 |
+
metadata={
|
| 157 |
+
"help": (
|
| 158 |
+
"Passing True (default) results in the LoRA initialization."
|
| 159 |
+
"Passing `pissa` results in PiSSA initialization."
|
| 160 |
+
),
|
| 161 |
+
},
|
| 162 |
+
)
|
| 163 |
+
extension: Optional[Dict[str, Any]] = field(
|
| 164 |
+
default=None,
|
| 165 |
+
metadata={"help": "Serialized MainConfig excluding training args"}
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# target_modules: str = (
|
| 169 |
+
# "(.*x_embedder"
|
| 170 |
+
# "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear"
|
| 171 |
+
# "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k"
|
| 172 |
+
# "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q"
|
| 173 |
+
# "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v"
|
| 174 |
+
# "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0"
|
| 175 |
+
# "|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2"
|
| 176 |
+
# "|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear"
|
| 177 |
+
# "|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp"
|
| 178 |
+
# "|.*single_transformer_blocks\\.[0-9]+\\.proj_out"
|
| 179 |
+
# "|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k"
|
| 180 |
+
# "|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q"
|
| 181 |
+
# "|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v"
|
| 182 |
+
# "|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
|
| 183 |
+
# )
|
nl_tasks/src/ft_mathQ.py
ADDED
|
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
import sys
|
| 3 |
+
#print('sys.path: ___ ', sys.path)
|
| 4 |
+
#print(f"Current Python Executable: {sys.executable}")
|
| 5 |
+
|
| 6 |
+
### dynamo warning
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
# Ignore FutureWarning: prims_common.check, Online Softmax
|
| 10 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module='torch._inductor.lowering')
|
| 11 |
+
warnings.filterwarnings("ignore", message=".*Online softmax is disabled on the fly.*", category=UserWarning)
|
| 12 |
+
|
| 13 |
+
warnings.filterwarnings("ignore", message=".*Our suggested max number of worker in current system is 1.*", category=UserWarning)
|
| 14 |
+
warnings.filterwarnings("ignore", message=".*will be initialized from a multivariate normal distribution.*")
|
| 15 |
+
warnings.filterwarnings("ignore", message=".*that differ from the model config and generation config.*", category=UserWarning)
|
| 16 |
+
warnings.filterwarnings("ignore", message=".*torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch..*", category=UserWarning)
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
torch.backends.cuda.matmul.fp32_precision = 'tf32'
|
| 20 |
+
# import wandb
|
| 21 |
+
import os
|
| 22 |
+
torch.set_num_threads(1)
|
| 23 |
+
os.environ["OMP_NUM_THREADS"]="1"
|
| 24 |
+
os.environ["MKL_NUM_THREADS"]="1"
|
| 25 |
+
import torch
|
| 26 |
+
print(f"PyTorch version: {torch.__version__}")
|
| 27 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 28 |
+
print(f"PyTorch built with CUDA version: {torch.version.cuda}")
|
| 29 |
+
|
| 30 |
+
import yaml
|
| 31 |
+
#from peft import LoraConfig, get_peft_model_state_dict
|
| 32 |
+
from torch.utils.data import DataLoader
|
| 33 |
+
import time
|
| 34 |
+
from datetime import datetime
|
| 35 |
+
import math
|
| 36 |
+
|
| 37 |
+
from typing import List, Tuple
|
| 38 |
+
|
| 39 |
+
# import prodigyopt
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
###
|
| 43 |
+
import copy
|
| 44 |
+
from dataclasses import field, dataclass, asdict
|
| 45 |
+
from typing import Sequence, Literal, Dict
|
| 46 |
+
|
| 47 |
+
import transformers
|
| 48 |
+
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
|
| 49 |
+
from transformers import Trainer
|
| 50 |
+
from transformers.modeling_utils import *
|
| 51 |
+
from transformers.trainer import _is_peft_model
|
| 52 |
+
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 53 |
+
from transformers.data.data_collator import DataCollator
|
| 54 |
+
|
| 55 |
+
from transformers.training_args import TrainingArguments
|
| 56 |
+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
| 57 |
+
from transformers.trainer_callback import TrainerCallback
|
| 58 |
+
from transformers.trainer_utils import EvalPrediction
|
| 59 |
+
from torch.utils.data import Dataset, IterableDataset
|
| 60 |
+
from datasets import load_dataset
|
| 61 |
+
##
|
| 62 |
+
#from ..pipeline.flux_omini import transformer_forward, encode_images
|
| 63 |
+
# from ...omini.rotation import RotationTuner, RotationConfig
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
from rpeft.rotation import RotationTuner, RotationConfig
|
| 67 |
+
from rpeft import get_peft_model, PeftModel
|
| 68 |
+
from .config import MainConfig, convert_to_trainer_args
|
| 69 |
+
import pyrallis
|
| 70 |
+
from omegaconf import OmegaConf
|
| 71 |
+
import torch.optim as optim
|
| 72 |
+
import wandb
|
| 73 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 74 |
+
|
| 75 |
+
IGNORE_INDEX = -100
|
| 76 |
+
PROMPT = (
|
| 77 |
+
"Below is an instruction that describes a task. "
|
| 78 |
+
"Write a response that appropriately completes the request.\n\n"
|
| 79 |
+
"### Instruction:\n{instruction}\n\n### Response:"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def get_rank():
|
| 83 |
+
try:
|
| 84 |
+
rank = int(os.environ.get("LOCAL_RANK"))
|
| 85 |
+
except:
|
| 86 |
+
rank = 0
|
| 87 |
+
return rank
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_config():
|
| 91 |
+
config_path = os.environ.get("OMINI_CONFIG")
|
| 92 |
+
assert config_path is not None, "Please set the OMINI_CONFIG environment variable"
|
| 93 |
+
with open(config_path, "r") as f:
|
| 94 |
+
config = yaml.safe_load(f)
|
| 95 |
+
return config
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def init_wandb(wandb_config, run_name):
|
| 99 |
+
import wandb
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
assert os.environ.get("WANDB_API_KEY") is not None
|
| 103 |
+
wandb.init(
|
| 104 |
+
project=wandb_config["project"],
|
| 105 |
+
name=run_name,
|
| 106 |
+
config={},
|
| 107 |
+
)
|
| 108 |
+
except Exception as e:
|
| 109 |
+
print("Failed to initialize WanDB:", e)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
| 114 |
+
"""Collects the state dict and dump to disk."""
|
| 115 |
+
state_dict = trainer.model.state_dict()
|
| 116 |
+
if trainer.args.should_save:
|
| 117 |
+
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
| 118 |
+
del state_dict
|
| 119 |
+
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def smart_tokenizer_and_embedding_resize(
|
| 123 |
+
special_tokens_dict: Dict,
|
| 124 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 125 |
+
model: transformers.PreTrainedModel,
|
| 126 |
+
):
|
| 127 |
+
"""Resize tokenizer and embedding.
|
| 128 |
+
|
| 129 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
| 130 |
+
"""
|
| 131 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
| 132 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 133 |
+
|
| 134 |
+
if num_new_tokens > 0:
|
| 135 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
| 136 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
| 137 |
+
|
| 138 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
| 139 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
| 140 |
+
|
| 141 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
| 142 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
| 146 |
+
"""Tokenize a list of strings."""
|
| 147 |
+
tokenized_list = [
|
| 148 |
+
tokenizer(
|
| 149 |
+
text,
|
| 150 |
+
return_tensors="pt",
|
| 151 |
+
padding="longest",
|
| 152 |
+
max_length=tokenizer.model_max_length,
|
| 153 |
+
truncation=True,
|
| 154 |
+
)
|
| 155 |
+
for text in strings
|
| 156 |
+
]
|
| 157 |
+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
| 158 |
+
input_ids_lens = labels_lens = [
|
| 159 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
| 160 |
+
]
|
| 161 |
+
return dict(
|
| 162 |
+
input_ids=input_ids,
|
| 163 |
+
labels=labels,
|
| 164 |
+
input_ids_lens=input_ids_lens,
|
| 165 |
+
labels_lens=labels_lens,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def preprocess(
|
| 169 |
+
sources: Sequence[str],
|
| 170 |
+
targets: Sequence[str],
|
| 171 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 172 |
+
) -> Dict:
|
| 173 |
+
"""Preprocess the data by tokenizing."""
|
| 174 |
+
examples = [s + t for s, t in zip(sources, targets)]
|
| 175 |
+
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
| 176 |
+
input_ids = examples_tokenized["input_ids"]
|
| 177 |
+
labels = copy.deepcopy(input_ids)
|
| 178 |
+
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
| 179 |
+
label[:source_len] = IGNORE_INDEX
|
| 180 |
+
return dict(input_ids=input_ids, labels=labels)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# @dataclass
|
| 184 |
+
# class DataCollatorForSupervisedDataset():
|
| 185 |
+
# """Collate examples for supervised fine-tuning."""
|
| 186 |
+
|
| 187 |
+
# tokenizer: transformers.PreTrainedTokenizer
|
| 188 |
+
# max_length: int = field(default=512)
|
| 189 |
+
# mode: str = field(default="fixed") # dynamic -> dynamo
|
| 190 |
+
|
| 191 |
+
# def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 192 |
+
# if self.mode == 'dynamic':
|
| 193 |
+
# input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
| 194 |
+
# input_ids = [torch.tensor(x) for x in input_ids]
|
| 195 |
+
# input_ids = torch.nn.utils.rnn.pad_sequence(
|
| 196 |
+
# input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
| 197 |
+
# )
|
| 198 |
+
# labels = [torch.tensor(x) for x in labels]
|
| 199 |
+
# labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
| 200 |
+
# return dict(
|
| 201 |
+
# input_ids=input_ids,
|
| 202 |
+
# labels=labels,
|
| 203 |
+
# attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
| 204 |
+
# )
|
| 205 |
+
# elif self.mode == 'fixed':
|
| 206 |
+
# input_ids = [torch.tensor(x["input_ids"][:self.max_length]) for x in instances]
|
| 207 |
+
# input_ids = torch.stack([
|
| 208 |
+
# torch.nn.functional.pad(x, (0, self.max_length - x.size(0)), value=self.tokenizer.pad_token_id)
|
| 209 |
+
# for x in input_ids
|
| 210 |
+
# ])
|
| 211 |
+
|
| 212 |
+
# # Labels
|
| 213 |
+
# labels = [torch.tensor(x["labels"][:self.max_length]) for x in instances]
|
| 214 |
+
# labels = torch.stack([
|
| 215 |
+
# torch.nn.functional.pad(x, (0, self.max_length - x.size(0)), value=IGNORE_INDEX)
|
| 216 |
+
# for x in labels
|
| 217 |
+
# ])
|
| 218 |
+
|
| 219 |
+
# return dict(
|
| 220 |
+
# input_ids=input_ids,
|
| 221 |
+
# labels=labels,
|
| 222 |
+
# attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
| 223 |
+
# )
|
| 224 |
+
# else:
|
| 225 |
+
# raise NotImplementedError
|
| 226 |
+
|
| 227 |
+
# @dataclass
|
| 228 |
+
# class DataCollatorForSupervisedDataset(object):
|
| 229 |
+
# tokenizer: transformers.PreTrainedTokenizer
|
| 230 |
+
# max_length: int = field(default=512)
|
| 231 |
+
# mode: str = field(default="fixed") # "dynamic" or "fixed"
|
| 232 |
+
|
| 233 |
+
# def _pad_to_length(self, tensors: Sequence[torch.Tensor], pad_value: int, target_len: int):
|
| 234 |
+
# """Pad a list of 1D tensors to target_len (int) and stack -> (B, target_len)."""
|
| 235 |
+
# batch_size = len(tensors)
|
| 236 |
+
# out = torch.full((batch_size, target_len), pad_value, dtype=tensors[0].dtype)
|
| 237 |
+
# for i, t in enumerate(tensors):
|
| 238 |
+
# L = min(t.size(0), target_len)
|
| 239 |
+
# out[i, :L] = t[:L]
|
| 240 |
+
# return out
|
| 241 |
+
|
| 242 |
+
# def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 243 |
+
# # Collect raw sequences (lists or tensors)
|
| 244 |
+
# input_seqs = [torch.tensor(x["input_ids"], dtype=torch.long) for x in instances]
|
| 245 |
+
# label_seqs = [torch.tensor(x["labels"], dtype=torch.long) for x in instances]
|
| 246 |
+
|
| 247 |
+
# if self.mode == "dynamic":
|
| 248 |
+
# # pad to the max length present in this batch (<= self.max_length)
|
| 249 |
+
# batch_max_len = min(max([s.size(0) for s in input_seqs]), self.max_length)
|
| 250 |
+
# input_ids = self._pad_to_length(input_seqs, pad_value=self.tokenizer.pad_token_id, target_len=batch_max_len)
|
| 251 |
+
# labels = self._pad_to_length(label_seqs, pad_value=IGNORE_INDEX, target_len=batch_max_len)
|
| 252 |
+
# elif self.mode == "fixed":
|
| 253 |
+
# # always pad/truncate to self.max_length
|
| 254 |
+
# input_ids = self._pad_to_length(input_seqs, pad_value=self.tokenizer.pad_token_id, target_len=self.max_length)
|
| 255 |
+
# labels = self._pad_to_length(label_seqs, pad_value=IGNORE_INDEX, target_len=self.max_length)
|
| 256 |
+
# else:
|
| 257 |
+
# raise NotImplementedError(f"Unknown mode: {self.mode}")
|
| 258 |
+
|
| 259 |
+
# attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
|
| 260 |
+
|
| 261 |
+
# return {
|
| 262 |
+
# "input_ids": input_ids,
|
| 263 |
+
# "labels": labels,
|
| 264 |
+
# "attention_mask": attention_mask
|
| 265 |
+
# }
|
| 266 |
+
|
| 267 |
+
@dataclass
|
| 268 |
+
class DataCollatorForSupervisedDataset():
|
| 269 |
+
tokenizer: transformers.PreTrainedTokenizer
|
| 270 |
+
max_length: int = field(default=512)
|
| 271 |
+
mode: str = field(default="fixed") # "dynamic" or "fixed"
|
| 272 |
+
|
| 273 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 274 |
+
# Extract inputs and labels
|
| 275 |
+
# Assuming instances is a list of dicts like {'input_ids': [...], 'labels': [...]}
|
| 276 |
+
input_ids_list = [torch.tensor(x["input_ids"], dtype=torch.long) for x in instances]
|
| 277 |
+
labels_list = [torch.tensor(x["labels"], dtype=torch.long) for x in instances]
|
| 278 |
+
|
| 279 |
+
# 1. Determine padding logic
|
| 280 |
+
if self.mode == "dynamic":
|
| 281 |
+
# Dynamic padding: pad to the longest sequence in the batch
|
| 282 |
+
# But cap it at self.max_length to prevent OOM
|
| 283 |
+
batch_max_len = max([len(x) for x in input_ids_list])
|
| 284 |
+
target_len = min(batch_max_len, self.max_length)
|
| 285 |
+
else:
|
| 286 |
+
# Fixed padding: always pad to max_length
|
| 287 |
+
target_len = self.max_length
|
| 288 |
+
|
| 289 |
+
# 2. Helper to pad and truncate
|
| 290 |
+
def pad_and_truncate(tensors, padding_value):
|
| 291 |
+
# First, pad everything using PyTorch's optimized utility (batch_first=True)
|
| 292 |
+
padded = pad_sequence(tensors, batch_first=True, padding_value=padding_value)
|
| 293 |
+
|
| 294 |
+
# Handle truncation/extending to exact target_len
|
| 295 |
+
curr_len = padded.shape[1]
|
| 296 |
+
if curr_len > target_len:
|
| 297 |
+
# Truncate if too long (rare if filtered beforehand)
|
| 298 |
+
return padded[:, :target_len]
|
| 299 |
+
elif curr_len < target_len:
|
| 300 |
+
# Pad more if shorter than target_len (happens in fixed mode)
|
| 301 |
+
diff = target_len - curr_len
|
| 302 |
+
padding = torch.full((padded.shape[0], diff), padding_value, dtype=padded.dtype)
|
| 303 |
+
return torch.cat([padded, padding], dim=1)
|
| 304 |
+
else:
|
| 305 |
+
return padded
|
| 306 |
+
|
| 307 |
+
# 3. Apply padding
|
| 308 |
+
# Critical: tokenizer.pad_token_id must NOT be None here
|
| 309 |
+
if self.tokenizer.pad_token_id is None:
|
| 310 |
+
raise ValueError("Tokenizer.pad_token_id is None. Please set it to eos_token_id or unk_token_id.")
|
| 311 |
+
|
| 312 |
+
input_ids = pad_and_truncate(input_ids_list, self.tokenizer.pad_token_id)
|
| 313 |
+
labels = pad_and_truncate(labels_list, IGNORE_INDEX)
|
| 314 |
+
|
| 315 |
+
# 4. Create Attention Mask explicitly
|
| 316 |
+
# .ne() creates Bools, .long() casts to 0s and 1s for compatibility
|
| 317 |
+
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
|
| 318 |
+
|
| 319 |
+
return {
|
| 320 |
+
"input_ids": input_ids,
|
| 321 |
+
"labels": labels,
|
| 322 |
+
"attention_mask": attention_mask
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
def train_tokenize_function(examples, tokenizer, query, response):
|
| 326 |
+
sources = [PROMPT.format_map(dict(instruction=instruction)) for instruction in examples[query]]
|
| 327 |
+
targets = [f"{output}{tokenizer.eos_token}" for output in examples[response]]
|
| 328 |
+
data_dict = preprocess(sources, targets, tokenizer)
|
| 329 |
+
return data_dict
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
### Trainer
|
| 334 |
+
def default_worker_init_fn(worker_id):
|
| 335 |
+
# mỗi worker chỉ 1 thread cho BLAS
|
| 336 |
+
try:
|
| 337 |
+
import numpy as _np
|
| 338 |
+
except Exception:
|
| 339 |
+
_np = None
|
| 340 |
+
torch.set_num_threads(1)
|
| 341 |
+
os.environ.setdefault("OMP_NUM_THREADS", "1")
|
| 342 |
+
os.environ.setdefault("MKL_NUM_THREADS", "1")
|
| 343 |
+
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
|
| 344 |
+
# Optional: bind CPU affinity per worker to avoid contention (NUMA-aware)
|
| 345 |
+
try:
|
| 346 |
+
cpu_count = os.cpu_count() or 1
|
| 347 |
+
# chia đều CPU cho workers
|
| 348 |
+
num_workers = getattr(torch.utils.data, "_num_workers", None)
|
| 349 |
+
# fallback: if not available, compute from environment variable or pass externally
|
| 350 |
+
# We'll do a simple round-robin assignment using worker_id
|
| 351 |
+
# assign a small mask of cores to this worker (e.g., chunk size 4)
|
| 352 |
+
chunk = max(1, cpu_count // max(1, min(64, cpu_count)))
|
| 353 |
+
start = (worker_id * chunk) % cpu_count
|
| 354 |
+
end = start + chunk
|
| 355 |
+
mask = set(range(start, min(end, cpu_count)))
|
| 356 |
+
try:
|
| 357 |
+
os.sched_setaffinity(0, mask)
|
| 358 |
+
except Exception:
|
| 359 |
+
pass
|
| 360 |
+
except Exception:
|
| 361 |
+
pass
|
| 362 |
+
|
| 363 |
+
def set_seed(seed: int):
|
| 364 |
+
# random.seed(seed)
|
| 365 |
+
# np.random.seed(seed)
|
| 366 |
+
torch.manual_seed(seed)
|
| 367 |
+
torch.cuda.manual_seed_all(seed)
|
| 368 |
+
transformers.set_seed(seed)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
@pyrallis.wrap()
|
| 372 |
+
def main(mainCfg: MainConfig):
|
| 373 |
+
#mainCfg = get_config()
|
| 374 |
+
#print(mainCfg)
|
| 375 |
+
print('='*120)
|
| 376 |
+
# print(OmegaConf.to_yaml(mainCfg))
|
| 377 |
+
# print('-'*40)
|
| 378 |
+
#
|
| 379 |
+
# print((training_args))
|
| 380 |
+
set_seed(mainCfg.seed)
|
| 381 |
+
training_args = convert_to_trainer_args(mainCfg)
|
| 382 |
+
|
| 383 |
+
# wandb
|
| 384 |
+
ENTITY = "nvan-13-korea-university"
|
| 385 |
+
PROJECT = os.environ.get("WANDB_PROJECT")
|
| 386 |
+
api = wandb.Api()
|
| 387 |
+
try:
|
| 388 |
+
runs_list = api.runs(f"{ENTITY}/{PROJECT}")
|
| 389 |
+
next_run_num = len(runs_list) + 1
|
| 390 |
+
except Exception as e:
|
| 391 |
+
next_run_num = 1
|
| 392 |
+
|
| 393 |
+
training_args.run_name = f'[{next_run_num}]lr={mainCfg.trainer_args.learning_rate:.1e},b={mainCfg.trainer_args.per_device_train_batch_size},'\
|
| 394 |
+
f'n={mainCfg.rotation_adapter_config.num_rotations},r={mainCfg.rotation_adapter_config.r},'\
|
| 395 |
+
f'init={mainCfg.run_text}'
|
| 396 |
+
# training_args.project = f'Rotation-Llama2-{mainCfg.data.dataset_name}'
|
| 397 |
+
|
| 398 |
+
# print('-'*40)
|
| 399 |
+
# print(training_args.to_json_string())
|
| 400 |
+
# exit()
|
| 401 |
+
|
| 402 |
+
model = AutoModelForCausalLM.from_pretrained(mainCfg.model.model_name,
|
| 403 |
+
device_map="auto", low_cpu_mem_usage=True,
|
| 404 |
+
dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
|
| 405 |
+
attn_implementation="sdpa",
|
| 406 |
+
)
|
| 407 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 408 |
+
print("DEVICE", DEVICE)
|
| 409 |
+
# for name, param in model.named_parameters():
|
| 410 |
+
# if 'q_proj' in name and 'layers.5' in name:
|
| 411 |
+
# print(f"Name: {name} | {param.shape} ")
|
| 412 |
+
# print(f"Name (pretrained): {name} | {param.shape} | {param.data[0:5,0:5]}")
|
| 413 |
+
# print('model', model)
|
| 414 |
+
# exit()
|
| 415 |
+
|
| 416 |
+
total_params_now = sum(p.numel() for p in model.parameters())
|
| 417 |
+
print(f'#params of the pretrained model, {total_params_now:,}')
|
| 418 |
+
# print(model)
|
| 419 |
+
if mainCfg.model.adapter_path is not None:
|
| 420 |
+
print('___ Loading from: ', mainCfg.model.adapter_path)
|
| 421 |
+
model = PeftModel.from_pretrained(model, mainCfg.model.adapter_path, is_trainable = True)
|
| 422 |
+
elif mainCfg.rotation_adapter_config.r is not None:
|
| 423 |
+
rotation_adapter_config = asdict(mainCfg.rotation_adapter_config)
|
| 424 |
+
# rotation_adapter_config[peft_type]
|
| 425 |
+
|
| 426 |
+
for adapter_name in mainCfg.data.adapter_names:
|
| 427 |
+
rotation_config = RotationConfig(**rotation_adapter_config)
|
| 428 |
+
model = get_peft_model(model, rotation_config, adapter_name=adapter_name)
|
| 429 |
+
# model.set_adapter(adapter_name)
|
| 430 |
+
|
| 431 |
+
# import peft
|
| 432 |
+
# from peft import OFTConfig
|
| 433 |
+
# oft_config = OFTConfig(
|
| 434 |
+
# # r=16,
|
| 435 |
+
# oft_block_size=4*mainCfg.rotation_adapter_config.r,
|
| 436 |
+
# use_cayley_neumann=True,
|
| 437 |
+
# target_modules=["q_proj", "v_proj",],
|
| 438 |
+
# module_dropout=0.05, # mainCfg.rotation_adapter_config.drop_out,
|
| 439 |
+
# # task_type="CAUSAL_LM",
|
| 440 |
+
# bias="none",
|
| 441 |
+
# )
|
| 442 |
+
|
| 443 |
+
# for adapter_name in mainCfg.data.adapter_names:
|
| 444 |
+
# model = peft.get_peft_model(model, oft_config, adapter_name=adapter_name)
|
| 445 |
+
else:
|
| 446 |
+
print("Full Parameter Fine-Tuning")
|
| 447 |
+
model = model.to(DEVICE)
|
| 448 |
+
|
| 449 |
+
# print('model', model)
|
| 450 |
+
model.print_trainable_parameters()
|
| 451 |
+
exit()
|
| 452 |
+
# print("Program starts")
|
| 453 |
+
# time.sleep(300)
|
| 454 |
+
# exit()
|
| 455 |
+
|
| 456 |
+
# for name, param in model.named_parameters():
|
| 457 |
+
# if 'q_proj' in name and 'rotation' in name and 'layers.5' in name:
|
| 458 |
+
# print(f"Name: {name} | {param.shape} ")
|
| 459 |
+
# print(f"Name (pretrained): {name} | {param.shape} ")
|
| 460 |
+
# X = param.data
|
| 461 |
+
# print('model', type(model), X.shape)
|
| 462 |
+
# visualize_value_distribution(X)
|
| 463 |
+
# exit()
|
| 464 |
+
|
| 465 |
+
rotation_layers = filter(
|
| 466 |
+
lambda p: p.requires_grad, model.parameters()
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 470 |
+
mainCfg.model.model_name,
|
| 471 |
+
model_max_length=mainCfg.model.model_max_seq_length,
|
| 472 |
+
padding_side="right",
|
| 473 |
+
use_fast=True,
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
if tokenizer.pad_token is None:
|
| 477 |
+
if tokenizer.unk_token_id is not None:
|
| 478 |
+
tokenizer.pad_token_id = tokenizer.unk_token_id
|
| 479 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 480 |
+
print("Set PAD token to UNK token.")
|
| 481 |
+
elif tokenizer.eos_token_id is not None:
|
| 482 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 483 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 484 |
+
print("Set PAD token to EOS token.")
|
| 485 |
+
|
| 486 |
+
if model is not None:
|
| 487 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 488 |
+
if model.config.pad_token_id != tokenizer.pad_token_id:
|
| 489 |
+
raise ValueError("Failed to sync pad_token_id between tokenizer and model config")
|
| 490 |
+
|
| 491 |
+
# local MetaMathQA-40K
|
| 492 |
+
raw_datasets = load_dataset("json", data_files=mainCfg.data.path, split=mainCfg.data.dataset_split)
|
| 493 |
+
#raw_train_datasets = load_dataset("MetaMathQA-40K", split=mainCfg.data.dataset_split)
|
| 494 |
+
# print('raw', type(raw_train_datasets), len(raw_train_datasets))
|
| 495 |
+
|
| 496 |
+
# split a single set
|
| 497 |
+
# split_ratio = mainCfg.data.split_ratio
|
| 498 |
+
# split_data = raw_datasets.train_test_split(test_size=split_ratio, seed=42)
|
| 499 |
+
# raw_train_datasets = split_data['train']
|
| 500 |
+
# raw_valid_datasets = split_data['test']
|
| 501 |
+
|
| 502 |
+
train_dataset = raw_datasets.map(
|
| 503 |
+
train_tokenize_function,
|
| 504 |
+
batched=True,
|
| 505 |
+
batch_size=30000,
|
| 506 |
+
num_proc=32,
|
| 507 |
+
remove_columns=raw_datasets.column_names,
|
| 508 |
+
load_from_cache_file=True,
|
| 509 |
+
desc="Running tokenizer on train dataset",
|
| 510 |
+
fn_kwargs={"tokenizer": tokenizer, "query": mainCfg.data.dataset_field[0],
|
| 511 |
+
"response": mainCfg.data.dataset_field[1]}
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# valid_dataset = raw_valid_datasets.map(
|
| 515 |
+
# train_tokenize_function,
|
| 516 |
+
# batched=True,
|
| 517 |
+
# batch_size=30000,
|
| 518 |
+
# num_proc=32,
|
| 519 |
+
# remove_columns=raw_train_datasets.column_names,
|
| 520 |
+
# load_from_cache_file=True,
|
| 521 |
+
# desc="Running tokenizer on train dataset",
|
| 522 |
+
# fn_kwargs={"tokenizer": tokenizer, "query": mainCfg.data.dataset_field[0],
|
| 523 |
+
# "response": mainCfg.data.dataset_field[1]}
|
| 524 |
+
# )
|
| 525 |
+
print('- dataset size: ', len(train_dataset))
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
# print('dataset', type(train_dataset))
|
| 529 |
+
# print('process', len(train_dataset))
|
| 530 |
+
# print(f"Sample features: {train_dataset.column_names}, {train_dataset.num_rows}")
|
| 531 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=mainCfg.model.model_max_seq_length,
|
| 532 |
+
#mode=mainCfg.model.data_collator_mode,
|
| 533 |
+
)
|
| 534 |
+
data_module = dict(train_dataset=train_dataset, data_collator=data_collator)
|
| 535 |
+
|
| 536 |
+
optimizer = optim.AdamW(
|
| 537 |
+
rotation_layers,
|
| 538 |
+
lr=mainCfg.trainer_args.learning_rate, #
|
| 539 |
+
eps=1e-8
|
| 540 |
+
)
|
| 541 |
+
# print('model x', model)
|
| 542 |
+
start_time = datetime.now()
|
| 543 |
+
print('start time: ', start_time.strftime("%Y-%m-%d %H:%M:%S"))
|
| 544 |
+
trainer = MyTrainer(model=model, processing_class=tokenizer,
|
| 545 |
+
lamda=mainCfg.model.lambda_reg,
|
| 546 |
+
optimizers=(optimizer, None),
|
| 547 |
+
args=training_args, **data_module)
|
| 548 |
+
model.config.use_cache = False
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
# now = time.time()
|
| 552 |
+
# for i in range(20):
|
| 553 |
+
# next(iter(trainer.get_train_dataloader()))
|
| 554 |
+
# print('time', time.time()-now)
|
| 555 |
+
# now = time.time()
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
# dl = trainer.get_train_dataloader()
|
| 559 |
+
# t0 = time.time()
|
| 560 |
+
# for i, batch in enumerate(dl):
|
| 561 |
+
# if i==20: break
|
| 562 |
+
# print("time / 20 batches =", time.time() - t0)
|
| 563 |
+
# exit()
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
# model2 = model.merge_and_unload()
|
| 567 |
+
# results2 = trainer2.evaluate()
|
| 568 |
+
# print('results2: ', results2)
|
| 569 |
+
# exit()
|
| 570 |
+
|
| 571 |
+
trainer.train()
|
| 572 |
+
|
| 573 |
+
end_time = datetime.now()
|
| 574 |
+
print('end time: ', end_time.strftime("%Y-%m-%d %H:%M:%S"), '| duration: ', end_time - start_time)
|
| 575 |
+
# Save Model (Includes Adapter weights & Config)
|
| 576 |
+
# trainer.save_model(os.path.join(training_args.output_dir, 'ft'))
|
| 577 |
+
# Save Tokenizer
|
| 578 |
+
tokenizer.save_pretrained(os.path.join(training_args.output_dir, 'ft'))
|
| 579 |
+
# Save Training State (Metrics & Logs)
|
| 580 |
+
trainer.save_state()
|
| 581 |
+
|
| 582 |
+
# save peft_config. Or model.base_model.peft_config['default']
|
| 583 |
+
model.peft_config.save_pretrained(os.path.join(training_args.output_dir, 'ft'))
|
| 584 |
+
|
| 585 |
+
# the easiest way
|
| 586 |
+
model.save_pretrained(os.path.join(training_args.output_dir, 'ft2'))
|
| 587 |
+
return
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
class MyTrainer(Trainer):
|
| 592 |
+
|
| 593 |
+
def __init__(
|
| 594 |
+
self,
|
| 595 |
+
model: Union[PreTrainedModel, nn.Module] = None,
|
| 596 |
+
args: TrainingArguments = None,
|
| 597 |
+
data_collator: Optional[DataCollator] = None,
|
| 598 |
+
train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
|
| 599 |
+
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None,
|
| 600 |
+
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
| 601 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 602 |
+
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
| 603 |
+
callbacks: Optional[List[TrainerCallback]] = None,
|
| 604 |
+
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 605 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 606 |
+
#run_name: Optional[str] = None,
|
| 607 |
+
#report_to: Optional[Union[str, list[str]]] = None,
|
| 608 |
+
# project
|
| 609 |
+
lamda: float = 1e-4
|
| 610 |
+
):
|
| 611 |
+
super().__init__(model=model, args=args, data_collator=data_collator,
|
| 612 |
+
train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class,
|
| 613 |
+
model_init=model_init, compute_metrics=compute_metrics, callbacks=callbacks,
|
| 614 |
+
optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 615 |
+
#run_name=run_name, report_to=report_to
|
| 616 |
+
)
|
| 617 |
+
self.lamda = lamda
|
| 618 |
+
|
| 619 |
+
# def compute_loss(self, model, inputs, return_outputs=False,
|
| 620 |
+
# num_items_in_batch: Optional[torch.Tensor] = None,):
|
| 621 |
+
# """
|
| 622 |
+
# How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
| 623 |
+
|
| 624 |
+
# Subclass and override for custom behavior.
|
| 625 |
+
# """
|
| 626 |
+
# if self.label_smoother is not None and "labels" in inputs:
|
| 627 |
+
# labels = inputs.pop("labels")
|
| 628 |
+
# else:
|
| 629 |
+
# labels = None
|
| 630 |
+
# if self.model_accepts_loss_kwargs:
|
| 631 |
+
# kwargs = {}
|
| 632 |
+
# if num_items_in_batch is not None:
|
| 633 |
+
# kwargs["num_items_in_batch"] = num_items_in_batch
|
| 634 |
+
# inputs = {**inputs, **kwargs}
|
| 635 |
+
# outputs = model(**inputs)
|
| 636 |
+
# # Save past state if it exists
|
| 637 |
+
# # TODO: this needs to be fixed and made cleaner later.
|
| 638 |
+
# if self.args.past_index >= 0:
|
| 639 |
+
# self._past = outputs[self.args.past_index]
|
| 640 |
+
|
| 641 |
+
# if labels is not None:
|
| 642 |
+
# unwrapped_model = unwrap_model(model)
|
| 643 |
+
# if _is_peft_model(unwrapped_model):
|
| 644 |
+
# model_name = unwrapped_model.base_model.model._get_name()
|
| 645 |
+
# else:
|
| 646 |
+
# model_name = unwrapped_model._get_name()
|
| 647 |
+
# if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
| 648 |
+
# loss = self.label_smoother(outputs, labels, shift_labels=True)
|
| 649 |
+
# else:
|
| 650 |
+
# loss = self.label_smoother(outputs, labels)
|
| 651 |
+
# else:
|
| 652 |
+
# if isinstance(outputs, dict) and "loss" not in outputs:
|
| 653 |
+
# raise ValueError(
|
| 654 |
+
# "The model did not return a loss from the inputs, only the following keys: "
|
| 655 |
+
# f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
|
| 656 |
+
# )
|
| 657 |
+
# # We don't use .loss here since the model may return tuples instead of ModelOutput.
|
| 658 |
+
# loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
| 659 |
+
# # ------------------------------------------------------------------------------
|
| 660 |
+
|
| 661 |
+
# # for name, param in model.named_parameters():
|
| 662 |
+
# # if 'oft_r' in name:
|
| 663 |
+
# # device = param.device
|
| 664 |
+
# # householder_U_norm = param / param.norm(dim=0)
|
| 665 |
+
# # orth_loss = torch.norm(
|
| 666 |
+
# # torch.eye(householder_U_norm.size(1), device=device) - householder_U_norm.t() @ householder_U_norm)
|
| 667 |
+
# # print(self.lamda)
|
| 668 |
+
# # loss = loss + self.lamda * orth_loss.to(loss.device)
|
| 669 |
+
|
| 670 |
+
# # ------------------------------------------------------------------------------
|
| 671 |
+
|
| 672 |
+
# return (loss, outputs) if return_outputs else loss
|
| 673 |
+
|
| 674 |
+
def get_train_dataloader(self):
|
| 675 |
+
# get dataset & sampler from super
|
| 676 |
+
train_dataset = self.train_dataset
|
| 677 |
+
sampler = self._get_train_sampler()
|
| 678 |
+
|
| 679 |
+
# compute effective batch size per step (HF has some routines; we use per_device_train_batch_size)
|
| 680 |
+
batch_size = self.args.train_batch_size if hasattr(self.args, "train_batch_size") else self.args.per_device_train_batch_size
|
| 681 |
+
|
| 682 |
+
# recommended num_workers: start moderate (16), you can tune upward
|
| 683 |
+
num_workers = getattr(self.args, "dataloader_num_workers", 16)
|
| 684 |
+
pin_memory = getattr(self.args, "dataloader_pin_memory", True)
|
| 685 |
+
prefetch_factor = getattr(self.args, "dataloader_prefetch_factor", 2)
|
| 686 |
+
persistent_workers = getattr(self.args, "dataloader_persistent_workers", True)
|
| 687 |
+
|
| 688 |
+
return DataLoader(
|
| 689 |
+
train_dataset,
|
| 690 |
+
batch_size=batch_size,
|
| 691 |
+
sampler=sampler,
|
| 692 |
+
collate_fn=self.data_collator,
|
| 693 |
+
drop_last=self.args.dataloader_drop_last if hasattr(self.args, "dataloader_drop_last") else False,
|
| 694 |
+
num_workers=num_workers,
|
| 695 |
+
pin_memory=pin_memory,
|
| 696 |
+
persistent_workers=persistent_workers,
|
| 697 |
+
prefetch_factor=prefetch_factor,
|
| 698 |
+
worker_init_fn=default_worker_init_fn,
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
if __name__ == "__main__":
|
| 702 |
+
main()
|
nl_tasks/src/ft_mathR.py
ADDED
|
@@ -0,0 +1,689 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
import sys
|
| 3 |
+
#print('sys.path: ___ ', sys.path)
|
| 4 |
+
#print(f"Current Python Executable: {sys.executable}")
|
| 5 |
+
|
| 6 |
+
### dynamo warning
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
# Ignore FutureWarning: prims_common.check, Online Softmax
|
| 10 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module='torch._inductor.lowering')
|
| 11 |
+
warnings.filterwarnings("ignore", message=".*Online softmax is disabled on the fly.*", category=UserWarning)
|
| 12 |
+
|
| 13 |
+
warnings.filterwarnings("ignore", message=".*Our suggested max number of worker in current system is 1.*", category=UserWarning)
|
| 14 |
+
warnings.filterwarnings("ignore", message=".*will be initialized from a multivariate normal distribution.*")
|
| 15 |
+
warnings.filterwarnings("ignore", message=".*that differ from the model config and generation config.*", category=UserWarning)
|
| 16 |
+
warnings.filterwarnings("ignore", message=".*torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch..*", category=UserWarning)
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
torch.backends.cuda.matmul.fp32_precision = 'tf32'
|
| 20 |
+
# import wandb
|
| 21 |
+
import os
|
| 22 |
+
torch.set_num_threads(1)
|
| 23 |
+
os.environ["OMP_NUM_THREADS"]="1"
|
| 24 |
+
os.environ["MKL_NUM_THREADS"]="1"
|
| 25 |
+
import torch
|
| 26 |
+
print(f"PyTorch version: {torch.__version__}")
|
| 27 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 28 |
+
print(f"PyTorch built with CUDA version: {torch.version.cuda}")
|
| 29 |
+
|
| 30 |
+
import yaml
|
| 31 |
+
#from peft import LoraConfig, get_peft_model_state_dict
|
| 32 |
+
from torch.utils.data import DataLoader
|
| 33 |
+
import time
|
| 34 |
+
from datetime import datetime
|
| 35 |
+
import math
|
| 36 |
+
|
| 37 |
+
from typing import List, Tuple
|
| 38 |
+
|
| 39 |
+
# import prodigyopt
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
###
|
| 43 |
+
import copy
|
| 44 |
+
from dataclasses import field, dataclass, asdict
|
| 45 |
+
from typing import Sequence, Literal, Dict
|
| 46 |
+
|
| 47 |
+
import transformers
|
| 48 |
+
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
|
| 49 |
+
from transformers import Trainer
|
| 50 |
+
from transformers.modeling_utils import *
|
| 51 |
+
from transformers.trainer import _is_peft_model
|
| 52 |
+
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 53 |
+
from transformers.data.data_collator import DataCollator
|
| 54 |
+
|
| 55 |
+
from transformers.training_args import TrainingArguments
|
| 56 |
+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
| 57 |
+
from transformers.trainer_callback import TrainerCallback
|
| 58 |
+
from transformers.trainer_utils import EvalPrediction
|
| 59 |
+
from torch.utils.data import Dataset, IterableDataset
|
| 60 |
+
from datasets import load_dataset
|
| 61 |
+
##
|
| 62 |
+
#from ..pipeline.flux_omini import transformer_forward, encode_images
|
| 63 |
+
# from ...omini.rotation import RotationTuner, RotationConfig
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
from rpeft.rotation import RotationTuner, RotationConfig
|
| 67 |
+
from rpeft import get_peft_model, PeftModel
|
| 68 |
+
from .config import MainConfig, convert_to_trainer_args
|
| 69 |
+
import pyrallis
|
| 70 |
+
from omegaconf import OmegaConf
|
| 71 |
+
import torch.optim as optim
|
| 72 |
+
import wandb
|
| 73 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 74 |
+
|
| 75 |
+
IGNORE_INDEX = -100
|
| 76 |
+
PROMPT = (
|
| 77 |
+
"Below is an instruction that describes a task. "
|
| 78 |
+
"Write a response that appropriately completes the request.\n\n"
|
| 79 |
+
"### Instruction:\n{instruction}\n\n### Response:"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def get_rank():
|
| 83 |
+
try:
|
| 84 |
+
rank = int(os.environ.get("LOCAL_RANK"))
|
| 85 |
+
except:
|
| 86 |
+
rank = 0
|
| 87 |
+
return rank
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_config():
|
| 91 |
+
config_path = os.environ.get("OMINI_CONFIG")
|
| 92 |
+
assert config_path is not None, "Please set the OMINI_CONFIG environment variable"
|
| 93 |
+
with open(config_path, "r") as f:
|
| 94 |
+
config = yaml.safe_load(f)
|
| 95 |
+
return config
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def init_wandb(wandb_config, run_name):
|
| 99 |
+
import wandb
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
assert os.environ.get("WANDB_API_KEY") is not None
|
| 103 |
+
wandb.init(
|
| 104 |
+
project=wandb_config["project"],
|
| 105 |
+
name=run_name,
|
| 106 |
+
config={},
|
| 107 |
+
)
|
| 108 |
+
except Exception as e:
|
| 109 |
+
print("Failed to initialize WanDB:", e)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
| 114 |
+
"""Collects the state dict and dump to disk."""
|
| 115 |
+
state_dict = trainer.model.state_dict()
|
| 116 |
+
if trainer.args.should_save:
|
| 117 |
+
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
| 118 |
+
del state_dict
|
| 119 |
+
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def smart_tokenizer_and_embedding_resize(
|
| 123 |
+
special_tokens_dict: Dict,
|
| 124 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 125 |
+
model: transformers.PreTrainedModel,
|
| 126 |
+
):
|
| 127 |
+
"""Resize tokenizer and embedding.
|
| 128 |
+
|
| 129 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
| 130 |
+
"""
|
| 131 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
| 132 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 133 |
+
|
| 134 |
+
if num_new_tokens > 0:
|
| 135 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
| 136 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
| 137 |
+
|
| 138 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
| 139 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
| 140 |
+
|
| 141 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
| 142 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
| 146 |
+
"""Tokenize a list of strings."""
|
| 147 |
+
tokenized_list = [
|
| 148 |
+
tokenizer(
|
| 149 |
+
text,
|
| 150 |
+
return_tensors="pt",
|
| 151 |
+
padding="longest",
|
| 152 |
+
max_length=tokenizer.model_max_length,
|
| 153 |
+
truncation=True,
|
| 154 |
+
)
|
| 155 |
+
for text in strings
|
| 156 |
+
]
|
| 157 |
+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
| 158 |
+
input_ids_lens = labels_lens = [
|
| 159 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
| 160 |
+
]
|
| 161 |
+
return dict(
|
| 162 |
+
input_ids=input_ids,
|
| 163 |
+
labels=labels,
|
| 164 |
+
input_ids_lens=input_ids_lens,
|
| 165 |
+
labels_lens=labels_lens,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def preprocess(
|
| 169 |
+
sources: Sequence[str],
|
| 170 |
+
targets: Sequence[str],
|
| 171 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 172 |
+
) -> Dict:
|
| 173 |
+
"""Preprocess the data by tokenizing."""
|
| 174 |
+
examples = [s + t for s, t in zip(sources, targets)]
|
| 175 |
+
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
| 176 |
+
input_ids = examples_tokenized["input_ids"]
|
| 177 |
+
labels = copy.deepcopy(input_ids)
|
| 178 |
+
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
| 179 |
+
label[:source_len] = IGNORE_INDEX
|
| 180 |
+
return dict(input_ids=input_ids, labels=labels)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# @dataclass
|
| 184 |
+
# class DataCollatorForSupervisedDataset():
|
| 185 |
+
# """Collate examples for supervised fine-tuning."""
|
| 186 |
+
|
| 187 |
+
# tokenizer: transformers.PreTrainedTokenizer
|
| 188 |
+
# max_length: int = field(default=512)
|
| 189 |
+
# mode: str = field(default="fixed") # dynamic -> dynamo
|
| 190 |
+
|
| 191 |
+
# def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 192 |
+
# if self.mode == 'dynamic':
|
| 193 |
+
# input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
| 194 |
+
# input_ids = [torch.tensor(x) for x in input_ids]
|
| 195 |
+
# input_ids = torch.nn.utils.rnn.pad_sequence(
|
| 196 |
+
# input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
| 197 |
+
# )
|
| 198 |
+
# labels = [torch.tensor(x) for x in labels]
|
| 199 |
+
# labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
| 200 |
+
# return dict(
|
| 201 |
+
# input_ids=input_ids,
|
| 202 |
+
# labels=labels,
|
| 203 |
+
# attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
| 204 |
+
# )
|
| 205 |
+
# elif self.mode == 'fixed':
|
| 206 |
+
# input_ids = [torch.tensor(x["input_ids"][:self.max_length]) for x in instances]
|
| 207 |
+
# input_ids = torch.stack([
|
| 208 |
+
# torch.nn.functional.pad(x, (0, self.max_length - x.size(0)), value=self.tokenizer.pad_token_id)
|
| 209 |
+
# for x in input_ids
|
| 210 |
+
# ])
|
| 211 |
+
|
| 212 |
+
# # Labels
|
| 213 |
+
# labels = [torch.tensor(x["labels"][:self.max_length]) for x in instances]
|
| 214 |
+
# labels = torch.stack([
|
| 215 |
+
# torch.nn.functional.pad(x, (0, self.max_length - x.size(0)), value=IGNORE_INDEX)
|
| 216 |
+
# for x in labels
|
| 217 |
+
# ])
|
| 218 |
+
|
| 219 |
+
# return dict(
|
| 220 |
+
# input_ids=input_ids,
|
| 221 |
+
# labels=labels,
|
| 222 |
+
# attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
| 223 |
+
# )
|
| 224 |
+
# else:
|
| 225 |
+
# raise NotImplementedError
|
| 226 |
+
|
| 227 |
+
# @dataclass
|
| 228 |
+
# class DataCollatorForSupervisedDataset(object):
|
| 229 |
+
# tokenizer: transformers.PreTrainedTokenizer
|
| 230 |
+
# max_length: int = field(default=512)
|
| 231 |
+
# mode: str = field(default="fixed") # "dynamic" or "fixed"
|
| 232 |
+
|
| 233 |
+
# def _pad_to_length(self, tensors: Sequence[torch.Tensor], pad_value: int, target_len: int):
|
| 234 |
+
# """Pad a list of 1D tensors to target_len (int) and stack -> (B, target_len)."""
|
| 235 |
+
# batch_size = len(tensors)
|
| 236 |
+
# out = torch.full((batch_size, target_len), pad_value, dtype=tensors[0].dtype)
|
| 237 |
+
# for i, t in enumerate(tensors):
|
| 238 |
+
# L = min(t.size(0), target_len)
|
| 239 |
+
# out[i, :L] = t[:L]
|
| 240 |
+
# return out
|
| 241 |
+
|
| 242 |
+
# def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 243 |
+
# # Collect raw sequences (lists or tensors)
|
| 244 |
+
# input_seqs = [torch.tensor(x["input_ids"], dtype=torch.long) for x in instances]
|
| 245 |
+
# label_seqs = [torch.tensor(x["labels"], dtype=torch.long) for x in instances]
|
| 246 |
+
|
| 247 |
+
# if self.mode == "dynamic":
|
| 248 |
+
# # pad to the max length present in this batch (<= self.max_length)
|
| 249 |
+
# batch_max_len = min(max([s.size(0) for s in input_seqs]), self.max_length)
|
| 250 |
+
# input_ids = self._pad_to_length(input_seqs, pad_value=self.tokenizer.pad_token_id, target_len=batch_max_len)
|
| 251 |
+
# labels = self._pad_to_length(label_seqs, pad_value=IGNORE_INDEX, target_len=batch_max_len)
|
| 252 |
+
# elif self.mode == "fixed":
|
| 253 |
+
# # always pad/truncate to self.max_length
|
| 254 |
+
# input_ids = self._pad_to_length(input_seqs, pad_value=self.tokenizer.pad_token_id, target_len=self.max_length)
|
| 255 |
+
# labels = self._pad_to_length(label_seqs, pad_value=IGNORE_INDEX, target_len=self.max_length)
|
| 256 |
+
# else:
|
| 257 |
+
# raise NotImplementedError(f"Unknown mode: {self.mode}")
|
| 258 |
+
|
| 259 |
+
# attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
|
| 260 |
+
|
| 261 |
+
# return {
|
| 262 |
+
# "input_ids": input_ids,
|
| 263 |
+
# "labels": labels,
|
| 264 |
+
# "attention_mask": attention_mask
|
| 265 |
+
# }
|
| 266 |
+
|
| 267 |
+
@dataclass
|
| 268 |
+
class DataCollatorForSupervisedDataset():
|
| 269 |
+
tokenizer: transformers.PreTrainedTokenizer
|
| 270 |
+
max_length: int = field(default=512)
|
| 271 |
+
mode: str = field(default="fixed") # "dynamic" or "fixed"
|
| 272 |
+
|
| 273 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 274 |
+
# Extract inputs and labels
|
| 275 |
+
# Assuming instances is a list of dicts like {'input_ids': [...], 'labels': [...]}
|
| 276 |
+
input_ids_list = [torch.tensor(x["input_ids"], dtype=torch.long) for x in instances]
|
| 277 |
+
labels_list = [torch.tensor(x["labels"], dtype=torch.long) for x in instances]
|
| 278 |
+
|
| 279 |
+
# 1. Determine padding logic
|
| 280 |
+
if self.mode == "dynamic":
|
| 281 |
+
# Dynamic padding: pad to the longest sequence in the batch
|
| 282 |
+
# But cap it at self.max_length to prevent OOM
|
| 283 |
+
batch_max_len = max([len(x) for x in input_ids_list])
|
| 284 |
+
target_len = min(batch_max_len, self.max_length)
|
| 285 |
+
else:
|
| 286 |
+
# Fixed padding: always pad to max_length
|
| 287 |
+
target_len = self.max_length
|
| 288 |
+
|
| 289 |
+
# 2. Helper to pad and truncate
|
| 290 |
+
def pad_and_truncate(tensors, padding_value):
|
| 291 |
+
# First, pad everything using PyTorch's optimized utility (batch_first=True)
|
| 292 |
+
padded = pad_sequence(tensors, batch_first=True, padding_value=padding_value)
|
| 293 |
+
|
| 294 |
+
# Handle truncation/extending to exact target_len
|
| 295 |
+
curr_len = padded.shape[1]
|
| 296 |
+
if curr_len > target_len:
|
| 297 |
+
# Truncate if too long (rare if filtered beforehand)
|
| 298 |
+
return padded[:, :target_len]
|
| 299 |
+
elif curr_len < target_len:
|
| 300 |
+
# Pad more if shorter than target_len (happens in fixed mode)
|
| 301 |
+
diff = target_len - curr_len
|
| 302 |
+
padding = torch.full((padded.shape[0], diff), padding_value, dtype=padded.dtype)
|
| 303 |
+
return torch.cat([padded, padding], dim=1)
|
| 304 |
+
else:
|
| 305 |
+
return padded
|
| 306 |
+
|
| 307 |
+
# 3. Apply padding
|
| 308 |
+
# Critical: tokenizer.pad_token_id must NOT be None here
|
| 309 |
+
if self.tokenizer.pad_token_id is None:
|
| 310 |
+
raise ValueError("Tokenizer.pad_token_id is None. Please set it to eos_token_id or unk_token_id.")
|
| 311 |
+
|
| 312 |
+
input_ids = pad_and_truncate(input_ids_list, self.tokenizer.pad_token_id)
|
| 313 |
+
labels = pad_and_truncate(labels_list, IGNORE_INDEX)
|
| 314 |
+
|
| 315 |
+
# 4. Create Attention Mask explicitly
|
| 316 |
+
# .ne() creates Bools, .long() casts to 0s and 1s for compatibility
|
| 317 |
+
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
|
| 318 |
+
|
| 319 |
+
return {
|
| 320 |
+
"input_ids": input_ids,
|
| 321 |
+
"labels": labels,
|
| 322 |
+
"attention_mask": attention_mask
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
def train_tokenize_function(examples, tokenizer, query, response):
|
| 326 |
+
sources = [PROMPT.format_map(dict(instruction=instruction)) for instruction in examples[query]]
|
| 327 |
+
targets = [f"{output}{tokenizer.eos_token}" for output in examples[response]]
|
| 328 |
+
data_dict = preprocess(sources, targets, tokenizer)
|
| 329 |
+
return data_dict
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
### Trainer
|
| 334 |
+
def default_worker_init_fn(worker_id):
|
| 335 |
+
# mỗi worker chỉ 1 thread cho BLAS
|
| 336 |
+
try:
|
| 337 |
+
import numpy as _np
|
| 338 |
+
except Exception:
|
| 339 |
+
_np = None
|
| 340 |
+
torch.set_num_threads(1)
|
| 341 |
+
os.environ.setdefault("OMP_NUM_THREADS", "1")
|
| 342 |
+
os.environ.setdefault("MKL_NUM_THREADS", "1")
|
| 343 |
+
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
|
| 344 |
+
# Optional: bind CPU affinity per worker to avoid contention (NUMA-aware)
|
| 345 |
+
try:
|
| 346 |
+
cpu_count = os.cpu_count() or 1
|
| 347 |
+
# chia đều CPU cho workers
|
| 348 |
+
num_workers = getattr(torch.utils.data, "_num_workers", None)
|
| 349 |
+
# fallback: if not available, compute from environment variable or pass externally
|
| 350 |
+
# We'll do a simple round-robin assignment using worker_id
|
| 351 |
+
# assign a small mask of cores to this worker (e.g., chunk size 4)
|
| 352 |
+
chunk = max(1, cpu_count // max(1, min(64, cpu_count)))
|
| 353 |
+
start = (worker_id * chunk) % cpu_count
|
| 354 |
+
end = start + chunk
|
| 355 |
+
mask = set(range(start, min(end, cpu_count)))
|
| 356 |
+
try:
|
| 357 |
+
os.sched_setaffinity(0, mask)
|
| 358 |
+
except Exception:
|
| 359 |
+
pass
|
| 360 |
+
except Exception:
|
| 361 |
+
pass
|
| 362 |
+
|
| 363 |
+
def set_seed(seed: int):
|
| 364 |
+
# random.seed(seed)
|
| 365 |
+
# np.random.seed(seed)
|
| 366 |
+
torch.manual_seed(seed)
|
| 367 |
+
torch.cuda.manual_seed_all(seed)
|
| 368 |
+
transformers.set_seed(seed)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
@pyrallis.wrap()
|
| 372 |
+
def main(mainCfg: MainConfig):
|
| 373 |
+
#mainCfg = get_config()
|
| 374 |
+
#print(mainCfg)
|
| 375 |
+
print('='*120)
|
| 376 |
+
# print(OmegaConf.to_yaml(mainCfg))
|
| 377 |
+
# print('-'*40)
|
| 378 |
+
#
|
| 379 |
+
# print((training_args))
|
| 380 |
+
set_seed(mainCfg.seed)
|
| 381 |
+
training_args = convert_to_trainer_args(mainCfg)
|
| 382 |
+
|
| 383 |
+
# wandb
|
| 384 |
+
ENTITY = "nvan-13-korea-university"
|
| 385 |
+
PROJECT = os.environ.get("WANDB_PROJECT")
|
| 386 |
+
api = wandb.Api()
|
| 387 |
+
try:
|
| 388 |
+
runs_list = api.runs(f"{ENTITY}/{PROJECT}")
|
| 389 |
+
next_run_num = len(runs_list) + 1
|
| 390 |
+
except Exception as e:
|
| 391 |
+
next_run_num = 1
|
| 392 |
+
|
| 393 |
+
training_args.run_name = f'[{next_run_num}]lr={mainCfg.trainer_args.learning_rate:.1e},b={mainCfg.trainer_args.per_device_train_batch_size},'\
|
| 394 |
+
f'n={mainCfg.rotation_adapter_config.num_rotations},r={mainCfg.rotation_adapter_config.r},'\
|
| 395 |
+
f'init={mainCfg.run_text}'
|
| 396 |
+
# training_args.project = f'Rotation-Llama2-{mainCfg.data.dataset_name}'
|
| 397 |
+
|
| 398 |
+
# print('-'*40)
|
| 399 |
+
# print(training_args.to_json_string())
|
| 400 |
+
# exit()
|
| 401 |
+
|
| 402 |
+
model = AutoModelForCausalLM.from_pretrained(mainCfg.model.model_name,
|
| 403 |
+
device_map="auto", low_cpu_mem_usage=True,
|
| 404 |
+
dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
|
| 405 |
+
attn_implementation="sdpa",
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 409 |
+
print("DEVICE", DEVICE)
|
| 410 |
+
# for name, param in model.named_parameters():
|
| 411 |
+
# if 'q_proj' in name and 'layers.5' in name:
|
| 412 |
+
# print(f"Name: {name} | {param.shape} ")
|
| 413 |
+
# print(f"Name (pretrained): {name} | {param.shape} | {param.data[0:5,0:5]}")
|
| 414 |
+
# print('model', model)
|
| 415 |
+
# exit()
|
| 416 |
+
|
| 417 |
+
total_params_now = sum(p.numel() for p in model.parameters())
|
| 418 |
+
print(f'#params of the pretrained model, {total_params_now:,}')
|
| 419 |
+
# print(model)
|
| 420 |
+
if mainCfg.model.adapter_path is not None:
|
| 421 |
+
print('___ Loading from: ', mainCfg.model.adapter_path)
|
| 422 |
+
model = PeftModel.from_pretrained(model, mainCfg.model.adapter_path, is_trainable = True)
|
| 423 |
+
elif mainCfg.rotation_adapter_config.r is not None:
|
| 424 |
+
rotation_adapter_config = asdict(mainCfg.rotation_adapter_config)
|
| 425 |
+
# rotation_adapter_config[peft_type]
|
| 426 |
+
|
| 427 |
+
for adapter_name in mainCfg.data.adapter_names:
|
| 428 |
+
rotation_config = RotationConfig(**rotation_adapter_config)
|
| 429 |
+
model = get_peft_model(model, rotation_config, adapter_name=adapter_name)
|
| 430 |
+
# model.set_adapter(adapter_name)
|
| 431 |
+
|
| 432 |
+
else:
|
| 433 |
+
print("Full Parameter Fine-Tuning")
|
| 434 |
+
model = model.to(DEVICE)
|
| 435 |
+
|
| 436 |
+
# print('model', model)
|
| 437 |
+
model.print_trainable_parameters()
|
| 438 |
+
# print("Program starts")
|
| 439 |
+
# time.sleep(300)
|
| 440 |
+
# exit()
|
| 441 |
+
|
| 442 |
+
# for name, param in model.named_parameters():
|
| 443 |
+
# if 'q_proj' in name and 'rotation' in name and 'layers.5' in name:
|
| 444 |
+
# print(f"Name: {name} | {param.shape} ")
|
| 445 |
+
# print(f"Name (pretrained): {name} | {param.shape} ")
|
| 446 |
+
# X = param.data
|
| 447 |
+
# print('model', type(model), X.shape)
|
| 448 |
+
# visualize_value_distribution(X)
|
| 449 |
+
# exit()
|
| 450 |
+
|
| 451 |
+
rotation_layers = filter(
|
| 452 |
+
lambda p: p.requires_grad, model.parameters()
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 456 |
+
mainCfg.model.model_name,
|
| 457 |
+
model_max_length=mainCfg.model.model_max_seq_length,
|
| 458 |
+
padding_side="right",
|
| 459 |
+
use_fast=True,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
if tokenizer.pad_token is None:
|
| 463 |
+
if tokenizer.unk_token_id is not None:
|
| 464 |
+
tokenizer.pad_token_id = tokenizer.unk_token_id
|
| 465 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 466 |
+
print("Set PAD token to UNK token.")
|
| 467 |
+
elif tokenizer.eos_token_id is not None:
|
| 468 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 469 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 470 |
+
print("Set PAD token to EOS token.")
|
| 471 |
+
|
| 472 |
+
if model is not None:
|
| 473 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 474 |
+
if model.config.pad_token_id != tokenizer.pad_token_id:
|
| 475 |
+
raise ValueError("Failed to sync pad_token_id between tokenizer and model config")
|
| 476 |
+
|
| 477 |
+
# local MetaMathQA-40K
|
| 478 |
+
raw_datasets = load_dataset("json", data_files=mainCfg.data.path, split=mainCfg.data.dataset_split)
|
| 479 |
+
#raw_train_datasets = load_dataset("MetaMathQA-40K", split=mainCfg.data.dataset_split)
|
| 480 |
+
# print('raw', type(raw_train_datasets), len(raw_train_datasets))
|
| 481 |
+
|
| 482 |
+
# split a single set
|
| 483 |
+
split_ratio = mainCfg.data.split_ratio
|
| 484 |
+
split_data = raw_datasets.train_test_split(test_size=split_ratio, seed=42)
|
| 485 |
+
raw_train_datasets = split_data['train']
|
| 486 |
+
raw_valid_datasets = split_data['test']
|
| 487 |
+
|
| 488 |
+
train_dataset = raw_train_datasets.map(
|
| 489 |
+
train_tokenize_function,
|
| 490 |
+
batched=True,
|
| 491 |
+
batch_size=30000,
|
| 492 |
+
num_proc=32,
|
| 493 |
+
remove_columns=raw_train_datasets.column_names,
|
| 494 |
+
load_from_cache_file=True,
|
| 495 |
+
desc="Running tokenizer on train dataset",
|
| 496 |
+
fn_kwargs={"tokenizer": tokenizer, "query": mainCfg.data.dataset_field[0],
|
| 497 |
+
"response": mainCfg.data.dataset_field[1]}
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
valid_dataset = raw_valid_datasets.map(
|
| 501 |
+
train_tokenize_function,
|
| 502 |
+
batched=True,
|
| 503 |
+
batch_size=30000,
|
| 504 |
+
num_proc=32,
|
| 505 |
+
remove_columns=raw_train_datasets.column_names,
|
| 506 |
+
load_from_cache_file=True,
|
| 507 |
+
desc="Running tokenizer on train dataset",
|
| 508 |
+
fn_kwargs={"tokenizer": tokenizer, "query": mainCfg.data.dataset_field[0],
|
| 509 |
+
"response": mainCfg.data.dataset_field[1]}
|
| 510 |
+
)
|
| 511 |
+
print('- dataset size: ', len(valid_dataset), len(train_dataset))
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
# print('dataset', type(train_dataset))
|
| 515 |
+
# print('process', len(train_dataset))
|
| 516 |
+
# print(f"Sample features: {train_dataset.column_names}, {train_dataset.num_rows}")
|
| 517 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=mainCfg.model.model_max_seq_length,
|
| 518 |
+
#mode=mainCfg.model.data_collator_mode,
|
| 519 |
+
)
|
| 520 |
+
data_module = dict(train_dataset=train_dataset, data_collator=data_collator, eval_dataset=valid_dataset)
|
| 521 |
+
|
| 522 |
+
optimizer = optim.AdamW(
|
| 523 |
+
rotation_layers,
|
| 524 |
+
lr=mainCfg.trainer_args.learning_rate, #
|
| 525 |
+
eps=1e-8
|
| 526 |
+
)
|
| 527 |
+
# print('model x', model)
|
| 528 |
+
start_time = datetime.now()
|
| 529 |
+
print('start time: ', start_time.strftime("%Y-%m-%d %H:%M:%S"))
|
| 530 |
+
trainer = MyTrainer(model=model, processing_class=tokenizer,
|
| 531 |
+
lamda=mainCfg.model.lambda_reg,
|
| 532 |
+
optimizers=(optimizer, None),
|
| 533 |
+
args=training_args, **data_module)
|
| 534 |
+
model.config.use_cache = False
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
# now = time.time()
|
| 538 |
+
# for i in range(20):
|
| 539 |
+
# next(iter(trainer.get_train_dataloader()))
|
| 540 |
+
# print('time', time.time()-now)
|
| 541 |
+
# now = time.time()
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
# dl = trainer.get_train_dataloader()
|
| 545 |
+
# t0 = time.time()
|
| 546 |
+
# for i, batch in enumerate(dl):
|
| 547 |
+
# if i==20: break
|
| 548 |
+
# print("time / 20 batches =", time.time() - t0)
|
| 549 |
+
# exit()
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
# model2 = model.merge_and_unload()
|
| 553 |
+
# results2 = trainer2.evaluate()
|
| 554 |
+
# print('results2: ', results2)
|
| 555 |
+
# exit()
|
| 556 |
+
|
| 557 |
+
start_time = datetime.now()
|
| 558 |
+
trainer.train()
|
| 559 |
+
|
| 560 |
+
end_time = datetime.now()
|
| 561 |
+
print('end time: ', end_time.strftime("%Y-%m-%d %H:%M:%S"), '| duration: ', end_time - start_time)
|
| 562 |
+
# Save Model (Includes Adapter weights & Config)
|
| 563 |
+
# trainer.save_model(os.path.join(training_args.output_dir, 'ft'))
|
| 564 |
+
# Save Tokenizer
|
| 565 |
+
tokenizer.save_pretrained(os.path.join(training_args.output_dir, 'ft'))
|
| 566 |
+
# Save Training State (Metrics & Logs)
|
| 567 |
+
trainer.save_state()
|
| 568 |
+
|
| 569 |
+
# save peft_config. Or model.base_model.peft_config['default']
|
| 570 |
+
model.peft_config.save_pretrained(os.path.join(training_args.output_dir, 'ft'))
|
| 571 |
+
|
| 572 |
+
# the easiest way
|
| 573 |
+
model.save_pretrained(os.path.join(training_args.output_dir, 'ft2'))
|
| 574 |
+
return
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
class MyTrainer(Trainer):
|
| 579 |
+
|
| 580 |
+
def __init__(
|
| 581 |
+
self,
|
| 582 |
+
model: Union[PreTrainedModel, nn.Module] = None,
|
| 583 |
+
args: TrainingArguments = None,
|
| 584 |
+
data_collator: Optional[DataCollator] = None,
|
| 585 |
+
train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
|
| 586 |
+
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None,
|
| 587 |
+
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
| 588 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 589 |
+
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
| 590 |
+
callbacks: Optional[List[TrainerCallback]] = None,
|
| 591 |
+
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 592 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 593 |
+
#run_name: Optional[str] = None,
|
| 594 |
+
#report_to: Optional[Union[str, list[str]]] = None,
|
| 595 |
+
# project
|
| 596 |
+
lamda: float = 1e-4
|
| 597 |
+
):
|
| 598 |
+
super().__init__(model=model, args=args, data_collator=data_collator,
|
| 599 |
+
train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class,
|
| 600 |
+
model_init=model_init, compute_metrics=compute_metrics, callbacks=callbacks,
|
| 601 |
+
optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 602 |
+
#run_name=run_name, report_to=report_to
|
| 603 |
+
)
|
| 604 |
+
self.lamda = lamda
|
| 605 |
+
|
| 606 |
+
# def compute_loss(self, model, inputs, return_outputs=False,
|
| 607 |
+
# num_items_in_batch: Optional[torch.Tensor] = None,):
|
| 608 |
+
# """
|
| 609 |
+
# How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
| 610 |
+
|
| 611 |
+
# Subclass and override for custom behavior.
|
| 612 |
+
# """
|
| 613 |
+
# if self.label_smoother is not None and "labels" in inputs:
|
| 614 |
+
# labels = inputs.pop("labels")
|
| 615 |
+
# else:
|
| 616 |
+
# labels = None
|
| 617 |
+
# if self.model_accepts_loss_kwargs:
|
| 618 |
+
# kwargs = {}
|
| 619 |
+
# if num_items_in_batch is not None:
|
| 620 |
+
# kwargs["num_items_in_batch"] = num_items_in_batch
|
| 621 |
+
# inputs = {**inputs, **kwargs}
|
| 622 |
+
# outputs = model(**inputs)
|
| 623 |
+
# # Save past state if it exists
|
| 624 |
+
# # TODO: this needs to be fixed and made cleaner later.
|
| 625 |
+
# if self.args.past_index >= 0:
|
| 626 |
+
# self._past = outputs[self.args.past_index]
|
| 627 |
+
|
| 628 |
+
# if labels is not None:
|
| 629 |
+
# unwrapped_model = unwrap_model(model)
|
| 630 |
+
# if _is_peft_model(unwrapped_model):
|
| 631 |
+
# model_name = unwrapped_model.base_model.model._get_name()
|
| 632 |
+
# else:
|
| 633 |
+
# model_name = unwrapped_model._get_name()
|
| 634 |
+
# if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
| 635 |
+
# loss = self.label_smoother(outputs, labels, shift_labels=True)
|
| 636 |
+
# else:
|
| 637 |
+
# loss = self.label_smoother(outputs, labels)
|
| 638 |
+
# else:
|
| 639 |
+
# if isinstance(outputs, dict) and "loss" not in outputs:
|
| 640 |
+
# raise ValueError(
|
| 641 |
+
# "The model did not return a loss from the inputs, only the following keys: "
|
| 642 |
+
# f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
|
| 643 |
+
# )
|
| 644 |
+
# # We don't use .loss here since the model may return tuples instead of ModelOutput.
|
| 645 |
+
# loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
| 646 |
+
# # ------------------------------------------------------------------------------
|
| 647 |
+
|
| 648 |
+
# # for name, param in model.named_parameters():
|
| 649 |
+
# # if 'oft_r' in name:
|
| 650 |
+
# # device = param.device
|
| 651 |
+
# # householder_U_norm = param / param.norm(dim=0)
|
| 652 |
+
# # orth_loss = torch.norm(
|
| 653 |
+
# # torch.eye(householder_U_norm.size(1), device=device) - householder_U_norm.t() @ householder_U_norm)
|
| 654 |
+
# # print(self.lamda)
|
| 655 |
+
# # loss = loss + self.lamda * orth_loss.to(loss.device)
|
| 656 |
+
|
| 657 |
+
# # ------------------------------------------------------------------------------
|
| 658 |
+
|
| 659 |
+
# return (loss, outputs) if return_outputs else loss
|
| 660 |
+
|
| 661 |
+
def get_train_dataloader(self):
|
| 662 |
+
# get dataset & sampler from super
|
| 663 |
+
train_dataset = self.train_dataset
|
| 664 |
+
sampler = self._get_train_sampler()
|
| 665 |
+
|
| 666 |
+
# compute effective batch size per step (HF has some routines; we use per_device_train_batch_size)
|
| 667 |
+
batch_size = self.args.train_batch_size if hasattr(self.args, "train_batch_size") else self.args.per_device_train_batch_size
|
| 668 |
+
|
| 669 |
+
# recommended num_workers: start moderate (16), you can tune upward
|
| 670 |
+
num_workers = getattr(self.args, "dataloader_num_workers", 16)
|
| 671 |
+
pin_memory = getattr(self.args, "dataloader_pin_memory", True)
|
| 672 |
+
prefetch_factor = getattr(self.args, "dataloader_prefetch_factor", 2)
|
| 673 |
+
persistent_workers = getattr(self.args, "dataloader_persistent_workers", True)
|
| 674 |
+
|
| 675 |
+
return DataLoader(
|
| 676 |
+
train_dataset,
|
| 677 |
+
batch_size=batch_size,
|
| 678 |
+
sampler=sampler,
|
| 679 |
+
collate_fn=self.data_collator,
|
| 680 |
+
drop_last=self.args.dataloader_drop_last if hasattr(self.args, "dataloader_drop_last") else False,
|
| 681 |
+
num_workers=num_workers,
|
| 682 |
+
pin_memory=pin_memory,
|
| 683 |
+
persistent_workers=persistent_workers,
|
| 684 |
+
prefetch_factor=prefetch_factor,
|
| 685 |
+
worker_init_fn=default_worker_init_fn,
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
if __name__ == "__main__":
|
| 689 |
+
main()
|
nl_tasks/src/merge.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
# import wandb
|
| 3 |
+
import os
|
| 4 |
+
import yaml
|
| 5 |
+
from peft import LoraConfig, get_peft_model_state_dict
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
from typing import List, Tuple
|
| 10 |
+
|
| 11 |
+
# import prodigyopt
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
###
|
| 15 |
+
import copy
|
| 16 |
+
from dataclasses import field, dataclass, asdict
|
| 17 |
+
from typing import Sequence, Literal, Dict
|
| 18 |
+
|
| 19 |
+
import transformers
|
| 20 |
+
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
|
| 21 |
+
from transformers import Trainer
|
| 22 |
+
from transformers.modeling_utils import *
|
| 23 |
+
from transformers.trainer import _is_peft_model
|
| 24 |
+
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 25 |
+
from transformers.data.data_collator import DataCollator
|
| 26 |
+
|
| 27 |
+
from transformers.training_args import TrainingArguments
|
| 28 |
+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
| 29 |
+
from transformers.trainer_callback import TrainerCallback
|
| 30 |
+
from transformers.trainer_utils import EvalPrediction
|
| 31 |
+
from torch.utils.data import Dataset, IterableDataset
|
| 32 |
+
from datasets import load_dataset
|
| 33 |
+
##
|
| 34 |
+
#from ..pipeline.flux_omini import transformer_forward, encode_images
|
| 35 |
+
# from ...omini.rotation import RotationTuner, RotationConfig
|
| 36 |
+
from rpeft.rotation import RotationTuner, RotationConfig
|
| 37 |
+
from rpeft import get_peft_model, PeftModel
|
| 38 |
+
from .config import MainConfig, convert_to_trainer_args
|
| 39 |
+
import pyrallis
|
| 40 |
+
from omegaconf import OmegaConf
|
| 41 |
+
|
| 42 |
+
import argparse
|
| 43 |
+
IGNORE_INDEX = -100
|
| 44 |
+
DEFAULT_PAD_TOKEN = "[PAD]"
|
| 45 |
+
DEFAULT_EOS_TOKEN = "</s>"
|
| 46 |
+
DEFAULT_BOS_TOKEN = "</s>"
|
| 47 |
+
DEFAULT_UNK_TOKEN = "</s>"
|
| 48 |
+
PROMPT = (
|
| 49 |
+
"Below is an instruction that describes a task. "
|
| 50 |
+
"Write a response that appropriately completes the request.\n\n"
|
| 51 |
+
"### Instruction:\n{instruction}\n\n### Response:"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# parser = argparse.ArgumentParser(description='Merge Adapter to Base Model')
|
| 55 |
+
# parser.add_argument('--base_mode', type=str)
|
| 56 |
+
# parser.add_argument('--adapter_path', type=str)
|
| 57 |
+
# parser.add_argument('--output_path', type=str)
|
| 58 |
+
# args = parser.parse_args()
|
| 59 |
+
|
| 60 |
+
@pyrallis.wrap()
|
| 61 |
+
def main(mainCfg: MainConfig):
|
| 62 |
+
print('='*120)
|
| 63 |
+
model_name = mainCfg.model.model_name
|
| 64 |
+
# adapter = mainCfg.trainer_args.output_dir + '/ft2'
|
| 65 |
+
# output_path = mainCfg.trainer_args.output_dir + '/merge/'
|
| 66 |
+
adapter = mainCfg.model.merge_adapter_path
|
| 67 |
+
output_path = mainCfg.model.merge_output_path
|
| 68 |
+
|
| 69 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto",)
|
| 70 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map='auto')
|
| 71 |
+
|
| 72 |
+
# config = PeftConfig.from_pretrained(args.adapter)
|
| 73 |
+
model = PeftModel.from_pretrained(model, adapter)
|
| 74 |
+
model = model.merge_and_unload()
|
| 75 |
+
model.save_pretrained(output_path, safe_serialization=False)
|
| 76 |
+
tokenizer.save_pretrained(output_path)
|
| 77 |
+
# print(model)
|
| 78 |
+
print('merge.py ends', adapter, output_path)
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
if __name__ == "__main__":
|
| 82 |
+
main()
|
nl_tasks/src/peft_merge.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
# import wandb
|
| 3 |
+
import os
|
| 4 |
+
import yaml
|
| 5 |
+
from peft import LoraConfig, get_peft_model_state_dict
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
from typing import List, Tuple
|
| 10 |
+
|
| 11 |
+
# import prodigyopt
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
###
|
| 15 |
+
import copy
|
| 16 |
+
from dataclasses import field, dataclass, asdict
|
| 17 |
+
from typing import Sequence, Literal, Dict
|
| 18 |
+
|
| 19 |
+
import transformers
|
| 20 |
+
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
|
| 21 |
+
from transformers import Trainer
|
| 22 |
+
from transformers.modeling_utils import *
|
| 23 |
+
from transformers.trainer import _is_peft_model
|
| 24 |
+
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 25 |
+
from transformers.data.data_collator import DataCollator
|
| 26 |
+
|
| 27 |
+
from transformers.training_args import TrainingArguments
|
| 28 |
+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
| 29 |
+
from transformers.trainer_callback import TrainerCallback
|
| 30 |
+
from transformers.trainer_utils import EvalPrediction
|
| 31 |
+
from torch.utils.data import Dataset, IterableDataset
|
| 32 |
+
from datasets import load_dataset
|
| 33 |
+
##
|
| 34 |
+
#from ..pipeline.flux_omini import transformer_forward, encode_images
|
| 35 |
+
# from ...omini.rotation import RotationTuner, RotationConfig
|
| 36 |
+
from rpeft.rotation import RotationTuner, RotationConfig
|
| 37 |
+
from peft import get_peft_model, PeftModel
|
| 38 |
+
from .config import MainConfig, convert_to_trainer_args
|
| 39 |
+
import pyrallis
|
| 40 |
+
from omegaconf import OmegaConf
|
| 41 |
+
|
| 42 |
+
import argparse
|
| 43 |
+
IGNORE_INDEX = -100
|
| 44 |
+
DEFAULT_PAD_TOKEN = "[PAD]"
|
| 45 |
+
DEFAULT_EOS_TOKEN = "</s>"
|
| 46 |
+
DEFAULT_BOS_TOKEN = "</s>"
|
| 47 |
+
DEFAULT_UNK_TOKEN = "</s>"
|
| 48 |
+
PROMPT = (
|
| 49 |
+
"Below is an instruction that describes a task. "
|
| 50 |
+
"Write a response that appropriately completes the request.\n\n"
|
| 51 |
+
"### Instruction:\n{instruction}\n\n### Response:"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# parser = argparse.ArgumentParser(description='Merge Adapter to Base Model')
|
| 55 |
+
# parser.add_argument('--base_mode', type=str)
|
| 56 |
+
# parser.add_argument('--adapter_path', type=str)
|
| 57 |
+
# parser.add_argument('--output_path', type=str)
|
| 58 |
+
# args = parser.parse_args()
|
| 59 |
+
|
| 60 |
+
@pyrallis.wrap()
|
| 61 |
+
def main(mainCfg: MainConfig):
|
| 62 |
+
print('='*120)
|
| 63 |
+
model_name = mainCfg.model.model_name
|
| 64 |
+
# adapter = mainCfg.trainer_args.output_dir + '/ft2'
|
| 65 |
+
# output_path = mainCfg.trainer_args.output_dir + '/merge/'
|
| 66 |
+
adapter = mainCfg.model.merge_adapter_path
|
| 67 |
+
output_path = mainCfg.model.merge_output_path
|
| 68 |
+
|
| 69 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto",)
|
| 70 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map='auto')
|
| 71 |
+
|
| 72 |
+
# config = PeftConfig.from_pretrained(args.adapter)
|
| 73 |
+
model = PeftModel.from_pretrained(model, adapter)
|
| 74 |
+
model = model.merge_and_unload()
|
| 75 |
+
model.save_pretrained(output_path, safe_serialization=False)
|
| 76 |
+
tokenizer.save_pretrained(output_path)
|
| 77 |
+
# print(model)
|
| 78 |
+
print('peft_merge.py ends', adapter, output_path)
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
if __name__ == "__main__":
|
| 82 |
+
main()
|
nl_tasks/src/testLlama.py
ADDED
|
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
import sys
|
| 3 |
+
#print('sys.path: ___ ', sys.path)
|
| 4 |
+
#print(f"Current Python Executable: {sys.executable}")
|
| 5 |
+
|
| 6 |
+
### dynamo warning
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
# Ignore FutureWarning: prims_common.check, Online Softmax
|
| 10 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module='torch._inductor.lowering')
|
| 11 |
+
warnings.filterwarnings("ignore", message=".*Online softmax is disabled on the fly.*", category=UserWarning)
|
| 12 |
+
|
| 13 |
+
warnings.filterwarnings("ignore", message=".*Our suggested max number of worker in current system is 1.*", category=UserWarning)
|
| 14 |
+
warnings.filterwarnings("ignore", message=".*will be initialized from a multivariate normal distribution.*")
|
| 15 |
+
warnings.filterwarnings("ignore", message=".*that differ from the model config and generation config.*", category=UserWarning)
|
| 16 |
+
warnings.filterwarnings("ignore", message=".*torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch..*", category=UserWarning)
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
torch.backends.cuda.matmul.fp32_precision = 'tf32'
|
| 20 |
+
# import wandb
|
| 21 |
+
import os
|
| 22 |
+
torch.set_num_threads(1)
|
| 23 |
+
os.environ["OMP_NUM_THREADS"]="1"
|
| 24 |
+
os.environ["MKL_NUM_THREADS"]="1"
|
| 25 |
+
import torch
|
| 26 |
+
print(f"PyTorch version: {torch.__version__}")
|
| 27 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 28 |
+
print(f"PyTorch built with CUDA version: {torch.version.cuda}")
|
| 29 |
+
|
| 30 |
+
import yaml
|
| 31 |
+
#from peft import LoraConfig, get_peft_model_state_dict
|
| 32 |
+
from torch.utils.data import DataLoader
|
| 33 |
+
import time
|
| 34 |
+
from datetime import datetime
|
| 35 |
+
import math
|
| 36 |
+
|
| 37 |
+
from typing import List, Tuple
|
| 38 |
+
|
| 39 |
+
# import prodigyopt
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
###
|
| 43 |
+
import copy
|
| 44 |
+
from dataclasses import field, dataclass, asdict
|
| 45 |
+
from typing import Sequence, Literal, Dict
|
| 46 |
+
|
| 47 |
+
import transformers
|
| 48 |
+
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
|
| 49 |
+
from transformers import Trainer
|
| 50 |
+
from transformers.modeling_utils import *
|
| 51 |
+
from transformers.trainer import _is_peft_model
|
| 52 |
+
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 53 |
+
from transformers.data.data_collator import DataCollator
|
| 54 |
+
|
| 55 |
+
from transformers.training_args import TrainingArguments
|
| 56 |
+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
| 57 |
+
from transformers.trainer_callback import TrainerCallback
|
| 58 |
+
from transformers.trainer_utils import EvalPrediction
|
| 59 |
+
from torch.utils.data import Dataset, IterableDataset
|
| 60 |
+
from datasets import load_dataset
|
| 61 |
+
##
|
| 62 |
+
#from ..pipeline.flux_omini import transformer_forward, encode_images
|
| 63 |
+
# from ...omini.rotation import RotationTuner, RotationConfig
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
from rpeft.rotation import RotationTuner, RotationConfig
|
| 67 |
+
from rpeft import get_peft_model, PeftModel
|
| 68 |
+
from .config import MainConfig, convert_to_trainer_args
|
| 69 |
+
import pyrallis
|
| 70 |
+
from omegaconf import OmegaConf
|
| 71 |
+
import torch.optim as optim
|
| 72 |
+
import wandb
|
| 73 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 74 |
+
|
| 75 |
+
IGNORE_INDEX = -100
|
| 76 |
+
PROMPT = (
|
| 77 |
+
"Below is an instruction that describes a task. "
|
| 78 |
+
"Write a response that appropriately completes the request.\n\n"
|
| 79 |
+
"### Instruction:\n{instruction}\n\n### Response:"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
import platform
|
| 84 |
+
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
|
| 85 |
+
class ExperimentMonitorCallback(TrainerCallback):
|
| 86 |
+
"""
|
| 87 |
+
Callback to monitor training performance and log system stats to a JSON file.
|
| 88 |
+
It captures:
|
| 89 |
+
1. Experiment Metadata (GPU info, Batch size, Learning rate, etc.)
|
| 90 |
+
2. Runtime Metrics (Avg time/step, Throughput)
|
| 91 |
+
3. Memory Metrics (Allocated, Reserved, and Peak usage)
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __init__(self, log_file_path: str, run_name: str = "experiment", log_interval: int = 100):
|
| 95 |
+
# English comments as requested
|
| 96 |
+
self.log_file_path = log_file_path
|
| 97 |
+
self.run_name = run_name
|
| 98 |
+
self.log_interval = log_interval
|
| 99 |
+
|
| 100 |
+
# Timing variables
|
| 101 |
+
self.start_time = None
|
| 102 |
+
self.last_log_time = None
|
| 103 |
+
|
| 104 |
+
# Data container to be saved
|
| 105 |
+
self.log_data = {
|
| 106 |
+
"metadata": {},
|
| 107 |
+
"metrics": []
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
def _get_gpu_info(self):
|
| 111 |
+
# Helper to get GPU details if available
|
| 112 |
+
if torch.cuda.is_available():
|
| 113 |
+
return {
|
| 114 |
+
"name": torch.cuda.get_device_name(0),
|
| 115 |
+
"count": torch.cuda.device_count(),
|
| 116 |
+
"capability": torch.cuda.get_device_capability(0)
|
| 117 |
+
}
|
| 118 |
+
return "CPU_ONLY"
|
| 119 |
+
|
| 120 |
+
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
| 121 |
+
# Initialize timing
|
| 122 |
+
self.start_time = time.perf_counter()
|
| 123 |
+
self.last_log_time = self.start_time
|
| 124 |
+
|
| 125 |
+
# Reset peak memory stats to ensure we capture peaks specific to this run
|
| 126 |
+
if torch.cuda.is_available():
|
| 127 |
+
torch.cuda.reset_peak_memory_stats()
|
| 128 |
+
|
| 129 |
+
# Capture experiment metadata
|
| 130 |
+
self.log_data["metadata"] = {
|
| 131 |
+
"run_name": self.run_name,
|
| 132 |
+
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
| 133 |
+
"python_version": platform.python_version(),
|
| 134 |
+
"pytorch_version": torch.__version__,
|
| 135 |
+
"gpu_info": self._get_gpu_info(),
|
| 136 |
+
"configuration": {
|
| 137 |
+
"batch_size_per_device": args.per_device_train_batch_size,
|
| 138 |
+
"learning_rate": args.learning_rate,
|
| 139 |
+
"max_steps": args.max_steps,
|
| 140 |
+
"num_train_epochs": args.num_train_epochs,
|
| 141 |
+
"fp16": args.fp16,
|
| 142 |
+
"bf16": args.bf16,
|
| 143 |
+
"optim": args.optim,
|
| 144 |
+
}
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
# Create/Overwrite the file with initial metadata
|
| 148 |
+
self._save_log()
|
| 149 |
+
# print(f"[{self.run_name}] Experiment started. Logging to {self.log_file_path}")
|
| 150 |
+
|
| 151 |
+
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
| 152 |
+
current_step = state.global_step
|
| 153 |
+
|
| 154 |
+
# Perform logging only at specified intervals
|
| 155 |
+
if current_step > 0 and current_step % self.log_interval == 0:
|
| 156 |
+
current_time = time.perf_counter()
|
| 157 |
+
|
| 158 |
+
# Calculate time elapsed since the last log
|
| 159 |
+
elapsed_since_last = current_time - self.last_log_time
|
| 160 |
+
avg_time_per_step = elapsed_since_last / self.log_interval
|
| 161 |
+
|
| 162 |
+
# Memory Statistics (in GB)
|
| 163 |
+
mem_stats = {}
|
| 164 |
+
if torch.cuda.is_available():
|
| 165 |
+
# Current usage
|
| 166 |
+
mem_stats["allocated_gb"] = torch.cuda.memory_allocated() / 1024**3
|
| 167 |
+
mem_stats["reserved_gb"] = torch.cuda.memory_reserved() / 1024**3
|
| 168 |
+
# Peak usage since start (Long-term peak)
|
| 169 |
+
mem_stats["peak_allocated_gb"] = torch.cuda.max_memory_allocated() / 1024**3
|
| 170 |
+
|
| 171 |
+
# Construct metric entry
|
| 172 |
+
metric_entry = {
|
| 173 |
+
"step": current_step,
|
| 174 |
+
"epoch": state.epoch,
|
| 175 |
+
"timestamp": datetime.now().isoformat(),
|
| 176 |
+
"performance": {
|
| 177 |
+
"avg_time_per_step_s": round(avg_time_per_step, 4),
|
| 178 |
+
"steps_per_second": round(1.0 / avg_time_per_step, 2)
|
| 179 |
+
},
|
| 180 |
+
"memory": mem_stats
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
# Append to internal list and save to file
|
| 184 |
+
self.log_data["metrics"].append(metric_entry)
|
| 185 |
+
self._save_log()
|
| 186 |
+
|
| 187 |
+
# Update last log time
|
| 188 |
+
self.last_log_time = current_time
|
| 189 |
+
|
| 190 |
+
# Optional: Print a brief summary to console
|
| 191 |
+
print(f" -> Step {current_step}: {avg_time_per_step*1000:.1f}s/step |"\
|
| 192 |
+
f"Peak Mem: {mem_stats.get('peak_allocated_gb', 0):.2f} GB |"\
|
| 193 |
+
f"Reserved: {mem_stats.get('reserved_gb', 0):.2f} GB")
|
| 194 |
+
|
| 195 |
+
def _save_log(self):
|
| 196 |
+
# Dump the entire data structure to JSON
|
| 197 |
+
# For very long training runs, appending to a JSONL (lines) file might be more efficient,
|
| 198 |
+
# but standard JSON is easier to read for analysis.
|
| 199 |
+
try:
|
| 200 |
+
with open(self.log_file_path, 'w', encoding='utf-8') as f:
|
| 201 |
+
json.dump(self.log_data, f, indent=4)
|
| 202 |
+
except Exception as e:
|
| 203 |
+
print(f"Error saving experiment log: {e}")
|
| 204 |
+
|
| 205 |
+
def get_rank():
|
| 206 |
+
try:
|
| 207 |
+
rank = int(os.environ.get("LOCAL_RANK"))
|
| 208 |
+
except:
|
| 209 |
+
rank = 0
|
| 210 |
+
return rank
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def get_config():
|
| 214 |
+
config_path = os.environ.get("OMINI_CONFIG")
|
| 215 |
+
assert config_path is not None, "Please set the OMINI_CONFIG environment variable"
|
| 216 |
+
with open(config_path, "r") as f:
|
| 217 |
+
config = yaml.safe_load(f)
|
| 218 |
+
return config
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def init_wandb(wandb_config, run_name):
|
| 222 |
+
import wandb
|
| 223 |
+
|
| 224 |
+
try:
|
| 225 |
+
assert os.environ.get("WANDB_API_KEY") is not None
|
| 226 |
+
wandb.init(
|
| 227 |
+
project=wandb_config["project"],
|
| 228 |
+
name=run_name,
|
| 229 |
+
config={},
|
| 230 |
+
)
|
| 231 |
+
except Exception as e:
|
| 232 |
+
print("Failed to initialize WanDB:", e)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
| 237 |
+
"""Collects the state dict and dump to disk."""
|
| 238 |
+
state_dict = trainer.model.state_dict()
|
| 239 |
+
if trainer.args.should_save:
|
| 240 |
+
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
| 241 |
+
del state_dict
|
| 242 |
+
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def smart_tokenizer_and_embedding_resize(
|
| 246 |
+
special_tokens_dict: Dict,
|
| 247 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 248 |
+
model: transformers.PreTrainedModel,
|
| 249 |
+
):
|
| 250 |
+
"""Resize tokenizer and embedding.
|
| 251 |
+
|
| 252 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
| 253 |
+
"""
|
| 254 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
| 255 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 256 |
+
|
| 257 |
+
if num_new_tokens > 0:
|
| 258 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
| 259 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
| 260 |
+
|
| 261 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
| 262 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
| 263 |
+
|
| 264 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
| 265 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
| 269 |
+
"""Tokenize a list of strings."""
|
| 270 |
+
tokenized_list = [
|
| 271 |
+
tokenizer(
|
| 272 |
+
text,
|
| 273 |
+
return_tensors="pt",
|
| 274 |
+
padding="longest",
|
| 275 |
+
max_length=tokenizer.model_max_length,
|
| 276 |
+
truncation=True,
|
| 277 |
+
)
|
| 278 |
+
for text in strings
|
| 279 |
+
]
|
| 280 |
+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
| 281 |
+
input_ids_lens = labels_lens = [
|
| 282 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
| 283 |
+
]
|
| 284 |
+
return dict(
|
| 285 |
+
input_ids=input_ids,
|
| 286 |
+
labels=labels,
|
| 287 |
+
input_ids_lens=input_ids_lens,
|
| 288 |
+
labels_lens=labels_lens,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
def preprocess(
|
| 292 |
+
sources: Sequence[str],
|
| 293 |
+
targets: Sequence[str],
|
| 294 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 295 |
+
) -> Dict:
|
| 296 |
+
"""Preprocess the data by tokenizing."""
|
| 297 |
+
examples = [s + t for s, t in zip(sources, targets)]
|
| 298 |
+
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
| 299 |
+
input_ids = examples_tokenized["input_ids"]
|
| 300 |
+
labels = copy.deepcopy(input_ids)
|
| 301 |
+
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
| 302 |
+
label[:source_len] = IGNORE_INDEX
|
| 303 |
+
return dict(input_ids=input_ids, labels=labels)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
@dataclass
|
| 308 |
+
class DataCollatorForSupervisedDataset():
|
| 309 |
+
tokenizer: transformers.PreTrainedTokenizer
|
| 310 |
+
max_length: int = field(default=512)
|
| 311 |
+
mode: str = field(default="fixed") # "dynamic" or "fixed"
|
| 312 |
+
|
| 313 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 314 |
+
# Extract inputs and labels
|
| 315 |
+
# Assuming instances is a list of dicts like {'input_ids': [...], 'labels': [...]}
|
| 316 |
+
input_ids_list = [torch.tensor(x["input_ids"], dtype=torch.long) for x in instances]
|
| 317 |
+
labels_list = [torch.tensor(x["labels"], dtype=torch.long) for x in instances]
|
| 318 |
+
|
| 319 |
+
# 1. Determine padding logic
|
| 320 |
+
if self.mode == "dynamic":
|
| 321 |
+
# Dynamic padding: pad to the longest sequence in the batch
|
| 322 |
+
# But cap it at self.max_length to prevent OOM
|
| 323 |
+
batch_max_len = max([len(x) for x in input_ids_list])
|
| 324 |
+
target_len = min(batch_max_len, self.max_length)
|
| 325 |
+
else:
|
| 326 |
+
# Fixed padding: always pad to max_length
|
| 327 |
+
target_len = self.max_length
|
| 328 |
+
|
| 329 |
+
# 2. Helper to pad and truncate
|
| 330 |
+
def pad_and_truncate(tensors, padding_value):
|
| 331 |
+
# First, pad everything using PyTorch's optimized utility (batch_first=True)
|
| 332 |
+
padded = pad_sequence(tensors, batch_first=True, padding_value=padding_value)
|
| 333 |
+
|
| 334 |
+
# Handle truncation/extending to exact target_len
|
| 335 |
+
curr_len = padded.shape[1]
|
| 336 |
+
if curr_len > target_len:
|
| 337 |
+
# Truncate if too long (rare if filtered beforehand)
|
| 338 |
+
return padded[:, :target_len]
|
| 339 |
+
elif curr_len < target_len:
|
| 340 |
+
# Pad more if shorter than target_len (happens in fixed mode)
|
| 341 |
+
diff = target_len - curr_len
|
| 342 |
+
padding = torch.full((padded.shape[0], diff), padding_value, dtype=padded.dtype)
|
| 343 |
+
return torch.cat([padded, padding], dim=1)
|
| 344 |
+
else:
|
| 345 |
+
return padded
|
| 346 |
+
|
| 347 |
+
# 3. Apply padding
|
| 348 |
+
# Critical: tokenizer.pad_token_id must NOT be None here
|
| 349 |
+
if self.tokenizer.pad_token_id is None:
|
| 350 |
+
raise ValueError("Tokenizer.pad_token_id is None. Please set it to eos_token_id or unk_token_id.")
|
| 351 |
+
|
| 352 |
+
input_ids = pad_and_truncate(input_ids_list, self.tokenizer.pad_token_id)
|
| 353 |
+
labels = pad_and_truncate(labels_list, IGNORE_INDEX)
|
| 354 |
+
|
| 355 |
+
# 4. Create Attention Mask explicitly
|
| 356 |
+
# .ne() creates Bools, .long() casts to 0s and 1s for compatibility
|
| 357 |
+
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
|
| 358 |
+
|
| 359 |
+
return {
|
| 360 |
+
"input_ids": input_ids,
|
| 361 |
+
"labels": labels,
|
| 362 |
+
"attention_mask": attention_mask
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
def train_tokenize_function(examples, tokenizer, query, response):
|
| 366 |
+
sources = [PROMPT.format_map(dict(instruction=instruction)) for instruction in examples[query]]
|
| 367 |
+
targets = [f"{output}{tokenizer.eos_token}" for output in examples[response]]
|
| 368 |
+
data_dict = preprocess(sources, targets, tokenizer)
|
| 369 |
+
return data_dict
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
### Trainer
|
| 374 |
+
def default_worker_init_fn(worker_id):
|
| 375 |
+
# mỗi worker chỉ 1 thread cho BLAS
|
| 376 |
+
try:
|
| 377 |
+
import numpy as _np
|
| 378 |
+
except Exception:
|
| 379 |
+
_np = None
|
| 380 |
+
torch.set_num_threads(1)
|
| 381 |
+
os.environ.setdefault("OMP_NUM_THREADS", "1")
|
| 382 |
+
os.environ.setdefault("MKL_NUM_THREADS", "1")
|
| 383 |
+
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
|
| 384 |
+
# Optional: bind CPU affinity per worker to avoid contention (NUMA-aware)
|
| 385 |
+
try:
|
| 386 |
+
cpu_count = os.cpu_count() or 1
|
| 387 |
+
# chia đều CPU cho workers
|
| 388 |
+
num_workers = getattr(torch.utils.data, "_num_workers", None)
|
| 389 |
+
# fallback: if not available, compute from environment variable or pass externally
|
| 390 |
+
# We'll do a simple round-robin assignment using worker_id
|
| 391 |
+
# assign a small mask of cores to this worker (e.g., chunk size 4)
|
| 392 |
+
chunk = max(1, cpu_count // max(1, min(64, cpu_count)))
|
| 393 |
+
start = (worker_id * chunk) % cpu_count
|
| 394 |
+
end = start + chunk
|
| 395 |
+
mask = set(range(start, min(end, cpu_count)))
|
| 396 |
+
try:
|
| 397 |
+
os.sched_setaffinity(0, mask)
|
| 398 |
+
except Exception:
|
| 399 |
+
pass
|
| 400 |
+
except Exception:
|
| 401 |
+
pass
|
| 402 |
+
|
| 403 |
+
def set_seed(seed: int):
|
| 404 |
+
# random.seed(seed)
|
| 405 |
+
# np.random.seed(seed)
|
| 406 |
+
torch.manual_seed(seed)
|
| 407 |
+
torch.cuda.manual_seed_all(seed)
|
| 408 |
+
transformers.set_seed(seed)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
@pyrallis.wrap()
|
| 412 |
+
def main(mainCfg: MainConfig):
|
| 413 |
+
#mainCfg = get_config()
|
| 414 |
+
#print(mainCfg)
|
| 415 |
+
print('='*120)
|
| 416 |
+
# print(OmegaConf.to_yaml(mainCfg))
|
| 417 |
+
# print('-'*40)
|
| 418 |
+
#
|
| 419 |
+
# print((training_args))
|
| 420 |
+
set_seed(mainCfg.seed)
|
| 421 |
+
training_args = convert_to_trainer_args(mainCfg)
|
| 422 |
+
|
| 423 |
+
# wandb
|
| 424 |
+
ENTITY = "nvan-13-korea-university"
|
| 425 |
+
PROJECT = os.environ.get("WANDB_PROJECT")
|
| 426 |
+
api = wandb.Api()
|
| 427 |
+
try:
|
| 428 |
+
runs_list = api.runs(f"{ENTITY}/{PROJECT}")
|
| 429 |
+
next_run_num = len(runs_list) + 1
|
| 430 |
+
except Exception as e:
|
| 431 |
+
next_run_num = 1
|
| 432 |
+
|
| 433 |
+
training_args.run_name = f'[{next_run_num}]lr={mainCfg.trainer_args.learning_rate:.1e},b={mainCfg.trainer_args.per_device_train_batch_size},'\
|
| 434 |
+
f'n={mainCfg.rotation_adapter_config.num_rotations},r={mainCfg.rotation_adapter_config.r},'\
|
| 435 |
+
f'init={mainCfg.run_text}'
|
| 436 |
+
# training_args.project = f'Rotation-Llama2-{mainCfg.data.dataset_name}'
|
| 437 |
+
|
| 438 |
+
# print('-'*40)
|
| 439 |
+
# print(training_args.to_json_string())
|
| 440 |
+
# exit()
|
| 441 |
+
|
| 442 |
+
model = AutoModelForCausalLM.from_pretrained(mainCfg.model.model_name,
|
| 443 |
+
device_map="auto", low_cpu_mem_usage=True,
|
| 444 |
+
dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
|
| 445 |
+
# attn_implementation="sdpa",
|
| 446 |
+
)
|
| 447 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 448 |
+
print("DEVICE", model.device)
|
| 449 |
+
|
| 450 |
+
# for name, param in model.named_parameters():
|
| 451 |
+
# if 'q_proj' in name and 'layers.5' in name:
|
| 452 |
+
# print(f"Name: {name} | {param.shape} ")
|
| 453 |
+
# print(f"Name (pretrained): {name} | {param.shape} | {param.data[0:5,0:5]}")
|
| 454 |
+
# print('model', model)
|
| 455 |
+
# exit()
|
| 456 |
+
|
| 457 |
+
total_params_now = sum(p.numel() for p in model.parameters())
|
| 458 |
+
print(f'#params of the pretrained model, {total_params_now:,}')
|
| 459 |
+
# print(model)
|
| 460 |
+
if mainCfg.model.adapter_path is not None:
|
| 461 |
+
print('___ Loading from: ', mainCfg.model.adapter_path)
|
| 462 |
+
model = PeftModel.from_pretrained(model, mainCfg.model.adapter_path, is_trainable = True)
|
| 463 |
+
elif mainCfg.rotation_adapter_config.r is not None:
|
| 464 |
+
import peft
|
| 465 |
+
if mainCfg.run_text == 'loco':
|
| 466 |
+
rotation_adapter_config = asdict(mainCfg.rotation_adapter_config)
|
| 467 |
+
|
| 468 |
+
for adapter_name in mainCfg.data.adapter_names:
|
| 469 |
+
rotation_config = RotationConfig(**rotation_adapter_config)
|
| 470 |
+
model = get_peft_model(model, rotation_config, adapter_name=adapter_name)
|
| 471 |
+
print('loaded a LoCo model, batch = ', training_args.per_device_train_batch_size)
|
| 472 |
+
elif mainCfg.run_text == 'boft':
|
| 473 |
+
from peft import BOFTConfig
|
| 474 |
+
boft_config = BOFTConfig(
|
| 475 |
+
boft_block_size=mainCfg.rotation_adapter_config.r,
|
| 476 |
+
boft_n_butterfly_factor=2*mainCfg.rotation_adapter_config.num_rotations,
|
| 477 |
+
target_modules=["q_proj", "v_proj",],
|
| 478 |
+
boft_dropout=0.05, #mainCfg.rotation_adapter_config.drop_out,
|
| 479 |
+
bias="none",
|
| 480 |
+
# task_type="CAUSAL_LM",
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
for adapter_name in mainCfg.data.adapter_names:
|
| 484 |
+
model = peft.get_peft_model(model, boft_config, adapter_name=adapter_name)
|
| 485 |
+
print('loaded a BOFT model, batch = ', training_args.per_device_train_batch_size)
|
| 486 |
+
elif mainCfg.run_text == 'hra':
|
| 487 |
+
from peft import HRAConfig
|
| 488 |
+
hra_config = HRAConfig(
|
| 489 |
+
r=2*mainCfg.rotation_adapter_config.r,
|
| 490 |
+
target_modules=["q_proj", "v_proj",],
|
| 491 |
+
init_weights=True,
|
| 492 |
+
# task_type="CAUSAL_LM",
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
for adapter_name in mainCfg.data.adapter_names:
|
| 496 |
+
model = peft.get_peft_model(model, hra_config, adapter_name=adapter_name)
|
| 497 |
+
print('loaded a HRA model, batch = ', training_args.per_device_train_batch_size)
|
| 498 |
+
elif mainCfg.run_text == 'oft':
|
| 499 |
+
from peft import HRAConfig, OFTConfig
|
| 500 |
+
|
| 501 |
+
oft_config = OFTConfig(
|
| 502 |
+
# r=16,
|
| 503 |
+
oft_block_size=4*mainCfg.rotation_adapter_config.r,
|
| 504 |
+
use_cayley_neumann=True,
|
| 505 |
+
target_modules=["q_proj", "v_proj",],
|
| 506 |
+
module_dropout=0.05, # mainCfg.rotation_adapter_config.drop_out,
|
| 507 |
+
# task_type="CAUSAL_LM",
|
| 508 |
+
bias="none",
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
for adapter_name in mainCfg.data.adapter_names:
|
| 512 |
+
model = peft.get_peft_model(model, oft_config, adapter_name=adapter_name)
|
| 513 |
+
print('loaded a OFT model, batch = ', training_args.per_device_train_batch_size)
|
| 514 |
+
else:
|
| 515 |
+
raise KeyError('wrong model names')
|
| 516 |
+
|
| 517 |
+
else:
|
| 518 |
+
print("Full Parameter Fine-Tuning")
|
| 519 |
+
model = model.to(DEVICE)
|
| 520 |
+
|
| 521 |
+
# print('model', model)
|
| 522 |
+
model.print_trainable_parameters()
|
| 523 |
+
|
| 524 |
+
# print("Program starts")
|
| 525 |
+
# time.sleep(300)
|
| 526 |
+
# exit()
|
| 527 |
+
|
| 528 |
+
# for name, param in model.named_parameters():
|
| 529 |
+
# if 'q_proj' in name and 'rotation' in name and 'layers.5' in name:
|
| 530 |
+
# print(f"Name: {name} | {param.shape} ")
|
| 531 |
+
# print(f"Name (pretrained): {name} | {param.shape} ")
|
| 532 |
+
# X = param.data
|
| 533 |
+
# print('model', type(model), X.shape)
|
| 534 |
+
# visualize_value_distribution(X)
|
| 535 |
+
# exit()
|
| 536 |
+
|
| 537 |
+
rotation_layers = filter(
|
| 538 |
+
lambda p: p.requires_grad, model.parameters()
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 542 |
+
mainCfg.model.model_name,
|
| 543 |
+
model_max_length=mainCfg.model.model_max_seq_length,
|
| 544 |
+
padding_side="right",
|
| 545 |
+
use_fast=True,
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
if tokenizer.pad_token is None:
|
| 549 |
+
if tokenizer.unk_token_id is not None:
|
| 550 |
+
tokenizer.pad_token_id = tokenizer.unk_token_id
|
| 551 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 552 |
+
print("Set PAD token to UNK token.")
|
| 553 |
+
elif tokenizer.eos_token_id is not None:
|
| 554 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 555 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 556 |
+
print("Set PAD token to EOS token.")
|
| 557 |
+
|
| 558 |
+
if model is not None:
|
| 559 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 560 |
+
if model.config.pad_token_id != tokenizer.pad_token_id:
|
| 561 |
+
raise ValueError("Failed to sync pad_token_id between tokenizer and model config")
|
| 562 |
+
|
| 563 |
+
# local MetaMathQA-40K
|
| 564 |
+
raw_datasets = load_dataset("json", data_files=mainCfg.data.path, split=mainCfg.data.dataset_split)
|
| 565 |
+
|
| 566 |
+
train_dataset = raw_datasets.map(
|
| 567 |
+
train_tokenize_function,
|
| 568 |
+
batched=True,
|
| 569 |
+
batch_size=30000,
|
| 570 |
+
num_proc=32,
|
| 571 |
+
remove_columns=raw_datasets.column_names,
|
| 572 |
+
load_from_cache_file=True,
|
| 573 |
+
desc="Running tokenizer on train dataset",
|
| 574 |
+
fn_kwargs={"tokenizer": tokenizer, "query": mainCfg.data.dataset_field[0],
|
| 575 |
+
"response": mainCfg.data.dataset_field[1]}
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
# valid_dataset = raw_valid_datasets.map(
|
| 579 |
+
# train_tokenize_function,
|
| 580 |
+
# batched=True,
|
| 581 |
+
# batch_size=30000,
|
| 582 |
+
# num_proc=32,
|
| 583 |
+
# remove_columns=raw_train_datasets.column_names,
|
| 584 |
+
# load_from_cache_file=True,
|
| 585 |
+
# desc="Running tokenizer on train dataset",
|
| 586 |
+
# fn_kwargs={"tokenizer": tokenizer, "query": mainCfg.data.dataset_field[0],
|
| 587 |
+
# "response": mainCfg.data.dataset_field[1]}
|
| 588 |
+
# )
|
| 589 |
+
print('- dataset size: ', len(train_dataset))
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
# print('dataset', type(train_dataset))
|
| 593 |
+
# print('process', len(train_dataset))
|
| 594 |
+
# print(f"Sample features: {train_dataset.column_names}, {train_dataset.num_rows}")
|
| 595 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=mainCfg.model.model_max_seq_length,
|
| 596 |
+
#mode=mainCfg.model.data_collator_mode,
|
| 597 |
+
)
|
| 598 |
+
data_module = dict(train_dataset=train_dataset, data_collator=data_collator)
|
| 599 |
+
|
| 600 |
+
optimizer = optim.AdamW(
|
| 601 |
+
rotation_layers,
|
| 602 |
+
lr=mainCfg.trainer_args.learning_rate, #
|
| 603 |
+
eps=1e-8
|
| 604 |
+
)
|
| 605 |
+
# print('model x', model)
|
| 606 |
+
start_time = datetime.now()
|
| 607 |
+
print('start time: ', start_time.strftime("%Y-%m-%d %H:%M:%S"))
|
| 608 |
+
|
| 609 |
+
monitor = ExperimentMonitorCallback(
|
| 610 |
+
log_file_path="./training_metrics_bs8.json",
|
| 611 |
+
run_name="Experiment_BatchSize_8",
|
| 612 |
+
log_interval=10 # Will calculate average over every 100 steps
|
| 613 |
+
)
|
| 614 |
+
training_args.remove_unused_columns = False
|
| 615 |
+
training_args.torch_compile=False
|
| 616 |
+
trainer = MyTrainer(model=model, processing_class=tokenizer,
|
| 617 |
+
lamda=mainCfg.model.lambda_reg,
|
| 618 |
+
optimizers=(optimizer, None),
|
| 619 |
+
args=training_args, **data_module,
|
| 620 |
+
callbacks=[monitor],
|
| 621 |
+
)
|
| 622 |
+
model.config.use_cache = False
|
| 623 |
+
|
| 624 |
+
trainer.train()
|
| 625 |
+
|
| 626 |
+
end_time = datetime.now()
|
| 627 |
+
print('end time: ', end_time.strftime("%Y-%m-%d %H:%M:%S"), '| duration: ', end_time - start_time)
|
| 628 |
+
|
| 629 |
+
# Save Model (Includes Adapter weights & Config)
|
| 630 |
+
# trainer.save_model(os.path.join(training_args.output_dir, 'ft'))
|
| 631 |
+
# Save Tokenizer
|
| 632 |
+
tokenizer.save_pretrained(os.path.join(training_args.output_dir, 'ft'))
|
| 633 |
+
# Save Training State (Metrics & Logs)
|
| 634 |
+
trainer.save_state()
|
| 635 |
+
|
| 636 |
+
# save peft_config. Or model.base_model.peft_config['default']
|
| 637 |
+
# model.peft_config.save_pretrained(os.path.join(training_args.output_dir, 'ft'))
|
| 638 |
+
|
| 639 |
+
# the easiest way
|
| 640 |
+
model.save_pretrained(os.path.join(training_args.output_dir, 'ft2'))
|
| 641 |
+
return
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
class MyTrainer(Trainer):
|
| 646 |
+
|
| 647 |
+
def __init__(
|
| 648 |
+
self,
|
| 649 |
+
model: Union[PreTrainedModel, nn.Module] = None,
|
| 650 |
+
args: TrainingArguments = None,
|
| 651 |
+
data_collator: Optional[DataCollator] = None,
|
| 652 |
+
train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
|
| 653 |
+
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None,
|
| 654 |
+
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
| 655 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 656 |
+
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
| 657 |
+
callbacks: Optional[List[TrainerCallback]] = None,
|
| 658 |
+
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 659 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 660 |
+
#run_name: Optional[str] = None,
|
| 661 |
+
#report_to: Optional[Union[str, list[str]]] = None,
|
| 662 |
+
# project
|
| 663 |
+
lamda: float = 1e-4
|
| 664 |
+
):
|
| 665 |
+
super().__init__(model=model, args=args, data_collator=data_collator,
|
| 666 |
+
train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class,
|
| 667 |
+
model_init=model_init, compute_metrics=compute_metrics, callbacks=callbacks,
|
| 668 |
+
optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 669 |
+
#run_name=run_name, report_to=report_to
|
| 670 |
+
)
|
| 671 |
+
self.lamda = lamda
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
def get_train_dataloader(self):
|
| 675 |
+
# get dataset & sampler from super
|
| 676 |
+
train_dataset = self.train_dataset
|
| 677 |
+
sampler = self._get_train_sampler()
|
| 678 |
+
|
| 679 |
+
# compute effective batch size per step (HF has some routines; we use per_device_train_batch_size)
|
| 680 |
+
batch_size = self.args.train_batch_size if hasattr(self.args, "train_batch_size") else self.args.per_device_train_batch_size
|
| 681 |
+
|
| 682 |
+
# recommended num_workers: start moderate (16), you can tune upward
|
| 683 |
+
num_workers = getattr(self.args, "dataloader_num_workers", 16)
|
| 684 |
+
pin_memory = getattr(self.args, "dataloader_pin_memory", True)
|
| 685 |
+
prefetch_factor = getattr(self.args, "dataloader_prefetch_factor", 2)
|
| 686 |
+
persistent_workers = getattr(self.args, "dataloader_persistent_workers", True)
|
| 687 |
+
|
| 688 |
+
return DataLoader(
|
| 689 |
+
train_dataset,
|
| 690 |
+
batch_size=batch_size,
|
| 691 |
+
sampler=sampler,
|
| 692 |
+
collate_fn=self.data_collator,
|
| 693 |
+
drop_last=self.args.dataloader_drop_last if hasattr(self.args, "dataloader_drop_last") else False,
|
| 694 |
+
num_workers=num_workers,
|
| 695 |
+
pin_memory=pin_memory,
|
| 696 |
+
persistent_workers=persistent_workers,
|
| 697 |
+
prefetch_factor=prefetch_factor,
|
| 698 |
+
worker_init_fn=default_worker_init_fn,
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
if __name__ == "__main__":
|
| 702 |
+
main()
|