File size: 13,687 Bytes
1df7ad4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 |
{
"cells": [
{
"cell_type": "markdown",
"id": "963e9ae0-ac68-44be-8c7d-fb9842784362",
"metadata": {},
"source": [
"# 4.6 基于llama的基因大模型指令微调"
]
},
{
"cell_type": "markdown",
"id": "182b82c4-d484-4c15-a600-03c3b51367ec",
"metadata": {},
"source": [
"**PEFT**(Parameter-Efficient Fine-Tuning,参数高效微调)是一种优化技术,旨在以最小的参数更新实现对大规模预训练模型(如 GPT、BERT 等)的微调。PEFT 技术通过减少微调所需的参数量,显著降低了存储和计算开销,同时保留模型的性能,特别适合资源受限的场景和领域特定任务的定制化。\n",
"\n",
"---\n",
"\n",
"### **1. 核心思想**\n",
"传统的微调方式需要更新整个预训练模型的所有参数,PEFT 技术通过只调整少量的参数(如特定层或额外添加的小型模块)实现微调目标,大幅减少了训练开销和存储需求。\n",
"\n",
"---\n",
"\n",
"### **2. 常见的 PEFT 方法**\n",
"\n",
"#### **(1)Adapter 模型**\n",
"- 在每一层 Transformer 的输出中插入小型适配器模块,仅训练适配器模块的参数。\n",
"- 原始模型参数保持冻结不变。\n",
"- 优点:适配器模块参数量小,能适应不同任务。\n",
"\n",
"示例方法:\n",
"- **AdapterFusion**\n",
"- **MAD-X**\n",
"\n",
"---\n",
"\n",
"#### **(2)Prefix Tuning**\n",
"- 在 Transformer 的输入前添加一组可学习的前缀向量,这些前缀与模型的注意力机制交互。\n",
"- 只调整前缀向量的参数,而不更新原始模型。\n",
"- 优点:对生成任务效果显著,参数量进一步减少。\n",
"\n",
"---\n",
"\n",
"#### **(3)LoRA(Low-Rank Adaptation)**\n",
"- 将预训练模型中的部分权重分解为两个低秩矩阵,仅调整这些低秩矩阵的参数。\n",
"- 原始权重保持冻结状态。\n",
"- 优点:参数量极小,计算高效。\n",
" \n",
"---\n",
"\n",
"#### **(4)Prompt Tuning**\n",
"- 在输入文本中添加可学习的提示(Prompt)。\n",
"- 适合 NLP 任务中的文本生成、分类等。\n",
"- 优点:实现简单,易于集成到现有框架。\n",
"\n",
"---\n",
"\n",
"### **3. PEFT 的优势**\n",
"\n",
"1. **显著减少参数更新量**:\n",
" - 微调传统的大模型(如 GPT-3)需要更新数百亿参数,而 PEFT 仅需更新百万级别甚至更少的参数。\n",
"\n",
"2. **高效存储**:\n",
" - 每个任务的微调结果只需存储少量额外参数,而不是整个模型。\n",
"\n",
"3. **适用多任务**:\n",
" - 同一预训练模型可以通过不同的 PEFT 模块适配多个任务,无需重新训练。\n",
"\n",
"4. **降低计算开销**:\n",
" - 训练所需的内存和计算显著减少,适合资源有限的环境。\n",
"\n",
"---\n",
"\n",
"### **4. 应用场景**\n",
"\n",
"1. **领域特定任务**:\n",
" - 医疗、法律、金融等领域微调预训练模型。\n",
"\n",
"2. **多任务学习**:\n",
" - 适配多个任务,复用同一模型的预训练权重。\n",
"\n",
"3. **资源受限场景**:\n",
" - 移动设备、边缘设备上的模型部署。\n",
"\n",
"---\n",
"\n",
"### **5. Hugging Face PEFT 库**\n",
"\n",
"Hugging Face 提供了专门的 PEFT 库,支持多种参数高效微调技术:\n",
"- **安装**:\n",
" ```bash\n",
" pip install peft\n",
" ```\n",
"- **使用 LoRA 微调示例**:\n",
" ```python\n",
" from transformers import AutoModelForCausalLM, AutoTokenizer\n",
" from peft import LoraConfig, get_peft_model, TaskType\n",
"\n",
" # 加载模型和分词器\n",
" model_name = \"gpt2\"\n",
" model = AutoModelForCausalLM.from_pretrained(model_name)\n",
" tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"\n",
" # 配置 LoRA\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" r=8,\n",
" lora_alpha=32,\n",
" target_modules=[\"q_proj\", \"v_proj\"],\n",
" lora_dropout=0.1,\n",
" bias=\"none\"\n",
" )\n",
"\n",
" # 使用 LoRA 微调模型\n",
" model = get_peft_model(model, lora_config)\n",
" model.print_trainable_parameters()\n",
"\n",
" # 微调代码...\n",
" ```\n",
"\n",
"---\n",
"\n",
"### **6. PEFT 的局限性**\n",
"1. **特定任务限制**:\n",
" - 在一些复杂任务中,PEFT 方法可能不如全量微调效果好。\n",
"\n",
"2. **需要设计合适的模块**:\n",
" - 不同任务需要选择和设计合适的 PEFT 技术。\n",
"\n",
"3. **与模型架构相关**:\n",
" - PEFT 技术可能需要对模型架构进行一定程度的修改。\n",
"\n",
"---\n",
"\n",
"### **7. 总结**\n",
"PEFT 是一个极具潜力的技术,特别适合在有限资源下对大模型进行微调。它在许多领域和任务中已显示出良好的效果,例如 LoRA 和 Adapter 模型已经成为高效微调的主流方法。\n",
"\n",
"如果您需要实现高效微调,可以结合 Hugging Face 的 PEFT 库快速上手。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5aa3d240-44e1-4811-8f61-d6ff2500a798",
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"import os\n",
"# 设置环境变量, autodl一般区域\n",
"result = subprocess.run('bash -c \"source /etc/network_turbo && env | grep proxy\"', shell=True, capture_output=True, text=True)\n",
"output = result.stdout\n",
"for line in output.splitlines():\n",
" if '=' in line:\n",
" var, value = line.split('=', 1)\n",
" os.environ[var] = value"
]
},
{
"cell_type": "markdown",
"id": "17bdb69d-3f0f-465e-bd60-2047a088e264",
"metadata": {},
"source": [
"如果您不确定模型中有哪些模块可以微调,可以打印模型结构:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "41a0c049-9134-4d89-aad0-1aa2241a9fca",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4becc479adbc472bb7672d49da16aafd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"generation_config.json: 0%| | 0.00/124 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"transformer\n",
"transformer.wte\n",
"transformer.wpe\n",
"transformer.drop\n",
"transformer.h\n",
"transformer.h.0\n",
"transformer.h.0.ln_1\n",
"transformer.h.0.attn\n",
"transformer.h.0.attn.c_attn\n",
"transformer.h.0.attn.c_proj\n",
"transformer.h.0.attn.attn_dropout\n",
"transformer.h.0.attn.resid_dropout\n",
"transformer.h.0.ln_2\n",
"transformer.h.0.mlp\n",
"transformer.h.0.mlp.c_fc\n",
"transformer.h.0.mlp.c_proj\n",
"transformer.h.0.mlp.act\n",
"transformer.h.0.mlp.dropout\n",
"transformer.h.1\n",
"transformer.h.1.ln_1\n",
"transformer.h.1.attn\n",
"transformer.h.1.attn.c_attn\n",
"transformer.h.1.attn.c_proj\n",
"transformer.h.1.attn.attn_dropout\n",
"transformer.h.1.attn.resid_dropout\n",
"transformer.h.1.ln_2\n",
"transformer.h.1.mlp\n",
"transformer.h.1.mlp.c_fc\n",
"transformer.h.1.mlp.c_proj\n",
"transformer.h.1.mlp.act\n",
"transformer.h.1.mlp.dropout\n",
"transformer.h.2\n",
"transformer.h.2.ln_1\n",
"transformer.h.2.attn\n",
"transformer.h.2.attn.c_attn\n",
"transformer.h.2.attn.c_proj\n",
"transformer.h.2.attn.attn_dropout\n",
"transformer.h.2.attn.resid_dropout\n",
"transformer.h.2.ln_2\n",
"transformer.h.2.mlp\n",
"transformer.h.2.mlp.c_fc\n",
"transformer.h.2.mlp.c_proj\n",
"transformer.h.2.mlp.act\n",
"transformer.h.2.mlp.dropout\n",
"transformer.h.3\n",
"transformer.h.3.ln_1\n",
"transformer.h.3.attn\n",
"transformer.h.3.attn.c_attn\n",
"transformer.h.3.attn.c_proj\n",
"transformer.h.3.attn.attn_dropout\n",
"transformer.h.3.attn.resid_dropout\n",
"transformer.h.3.ln_2\n",
"transformer.h.3.mlp\n",
"transformer.h.3.mlp.c_fc\n",
"transformer.h.3.mlp.c_proj\n",
"transformer.h.3.mlp.act\n",
"transformer.h.3.mlp.dropout\n",
"transformer.h.4\n",
"transformer.h.4.ln_1\n",
"transformer.h.4.attn\n",
"transformer.h.4.attn.c_attn\n",
"transformer.h.4.attn.c_proj\n",
"transformer.h.4.attn.attn_dropout\n",
"transformer.h.4.attn.resid_dropout\n",
"transformer.h.4.ln_2\n",
"transformer.h.4.mlp\n",
"transformer.h.4.mlp.c_fc\n",
"transformer.h.4.mlp.c_proj\n",
"transformer.h.4.mlp.act\n",
"transformer.h.4.mlp.dropout\n",
"transformer.h.5\n",
"transformer.h.5.ln_1\n",
"transformer.h.5.attn\n",
"transformer.h.5.attn.c_attn\n",
"transformer.h.5.attn.c_proj\n",
"transformer.h.5.attn.attn_dropout\n",
"transformer.h.5.attn.resid_dropout\n",
"transformer.h.5.ln_2\n",
"transformer.h.5.mlp\n",
"transformer.h.5.mlp.c_fc\n",
"transformer.h.5.mlp.c_proj\n",
"transformer.h.5.mlp.act\n",
"transformer.h.5.mlp.dropout\n",
"transformer.h.6\n",
"transformer.h.6.ln_1\n",
"transformer.h.6.attn\n",
"transformer.h.6.attn.c_attn\n",
"transformer.h.6.attn.c_proj\n",
"transformer.h.6.attn.attn_dropout\n",
"transformer.h.6.attn.resid_dropout\n",
"transformer.h.6.ln_2\n",
"transformer.h.6.mlp\n",
"transformer.h.6.mlp.c_fc\n",
"transformer.h.6.mlp.c_proj\n",
"transformer.h.6.mlp.act\n",
"transformer.h.6.mlp.dropout\n",
"transformer.h.7\n",
"transformer.h.7.ln_1\n",
"transformer.h.7.attn\n",
"transformer.h.7.attn.c_attn\n",
"transformer.h.7.attn.c_proj\n",
"transformer.h.7.attn.attn_dropout\n",
"transformer.h.7.attn.resid_dropout\n",
"transformer.h.7.ln_2\n",
"transformer.h.7.mlp\n",
"transformer.h.7.mlp.c_fc\n",
"transformer.h.7.mlp.c_proj\n",
"transformer.h.7.mlp.act\n",
"transformer.h.7.mlp.dropout\n",
"transformer.h.8\n",
"transformer.h.8.ln_1\n",
"transformer.h.8.attn\n",
"transformer.h.8.attn.c_attn\n",
"transformer.h.8.attn.c_proj\n",
"transformer.h.8.attn.attn_dropout\n",
"transformer.h.8.attn.resid_dropout\n",
"transformer.h.8.ln_2\n",
"transformer.h.8.mlp\n",
"transformer.h.8.mlp.c_fc\n",
"transformer.h.8.mlp.c_proj\n",
"transformer.h.8.mlp.act\n",
"transformer.h.8.mlp.dropout\n",
"transformer.h.9\n",
"transformer.h.9.ln_1\n",
"transformer.h.9.attn\n",
"transformer.h.9.attn.c_attn\n",
"transformer.h.9.attn.c_proj\n",
"transformer.h.9.attn.attn_dropout\n",
"transformer.h.9.attn.resid_dropout\n",
"transformer.h.9.ln_2\n",
"transformer.h.9.mlp\n",
"transformer.h.9.mlp.c_fc\n",
"transformer.h.9.mlp.c_proj\n",
"transformer.h.9.mlp.act\n",
"transformer.h.9.mlp.dropout\n",
"transformer.h.10\n",
"transformer.h.10.ln_1\n",
"transformer.h.10.attn\n",
"transformer.h.10.attn.c_attn\n",
"transformer.h.10.attn.c_proj\n",
"transformer.h.10.attn.attn_dropout\n",
"transformer.h.10.attn.resid_dropout\n",
"transformer.h.10.ln_2\n",
"transformer.h.10.mlp\n",
"transformer.h.10.mlp.c_fc\n",
"transformer.h.10.mlp.c_proj\n",
"transformer.h.10.mlp.act\n",
"transformer.h.10.mlp.dropout\n",
"transformer.h.11\n",
"transformer.h.11.ln_1\n",
"transformer.h.11.attn\n",
"transformer.h.11.attn.c_attn\n",
"transformer.h.11.attn.c_proj\n",
"transformer.h.11.attn.attn_dropout\n",
"transformer.h.11.attn.resid_dropout\n",
"transformer.h.11.ln_2\n",
"transformer.h.11.mlp\n",
"transformer.h.11.mlp.c_fc\n",
"transformer.h.11.mlp.c_proj\n",
"transformer.h.11.mlp.act\n",
"transformer.h.11.mlp.dropout\n",
"transformer.ln_f\n",
"lm_head\n"
]
}
],
"source": [
"from transformers import AutoModelForCausalLM\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n",
"\n",
"# 打印所有模块名称\n",
"for name, module in model.named_modules():\n",
" print(name)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "054a2956-9045-4ad5-a878-1bfc84ad4ed8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|