| """ |
| LoRA (Low-Rank Adaptation) implementation. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import WhisperForConditionalGeneration |
| from peft import LoraConfig, get_peft_model, TaskType, PeftModel |
| from typing import Optional |
| import itertools |
|
|
|
|
| def build_lora_student_model( |
| config, |
| device: torch.device |
| ) -> tuple[PeftModel, Optional[nn.Linear]]: |
| """构建 LoRA 学生模型""" |
| base = WhisperForConditionalGeneration.from_pretrained( |
| config.model.student_model, |
| torch_dtype=torch.float16, |
| use_cache=False |
| ) |
| base.gradient_checkpointing_enable() |
| |
| for p in itertools.chain(base.model.encoder.parameters(), |
| base.model.decoder.parameters()): |
| p.requires_grad = False |
|
|
| |
| lcfg = LoraConfig( |
| task_type=TaskType.SEQ_2_SEQ_LM, |
| r=config.lora.r, |
| lora_alpha=config.lora.alpha, |
| lora_dropout=config.lora.dropout, |
| target_modules=config.lora.target_modules, |
| bias="none" |
| ) |
| |
| student = get_peft_model(base, lcfg).to(device) |
| |
| proj = nn.Linear( |
| config.model.student_hidden_dim, |
| config.model.teacher_hidden_dim |
| ).to(device) if config.distillation.hidden_beta > 0 else None |
| |
| trainable = sum(p.numel() for p in student.parameters() if p.requires_grad) |
| if proj: |
| trainable += sum(p.numel() for p in proj.parameters()) |
| total = sum(p.numel() for p in student.parameters()) + \ |
| (sum(p.numel() for p in proj.parameters()) if proj else 0) |
| print(f"Trainable parameters: {trainable:,} ({trainable/total*100:.2f}%)") |
| |
| return student, proj |
|
|
|
|
| def build_teacher_model( |
| config, |
| device: torch.device |
| ) -> WhisperForConditionalGeneration: |
| """构建教师模型""" |
| teacher = WhisperForConditionalGeneration.from_pretrained( |
| config.model.teacher_model, |
| torch_dtype=torch.float16, |
| use_cache=False |
| ).eval() |
| |
| |
| for p in teacher.parameters(): |
| p.requires_grad = False |
| |
| return teacher |