|
|
import math |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class LoRALinear(nn.Module): |
|
|
""" |
|
|
LoRA 线性层:直接持有 weight/bias,保持与 nn.Linear 相同的 state_dict key 结构。 |
|
|
|
|
|
state_dict 结构: |
|
|
- weight: 原始权重(与 nn.Linear 一致) |
|
|
- bias: 原始偏置(与 nn.Linear 一致) |
|
|
- lora_A: LoRA 低秩矩阵 A |
|
|
- lora_B: LoRA 低秩矩阵 B |
|
|
|
|
|
这样设计的好处:加载预训练权重时无需做 key 转换。 |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
base: nn.Linear, |
|
|
r: int, |
|
|
alpha: float = 1.0, |
|
|
dropout: float = 0.0, |
|
|
): |
|
|
super().__init__() |
|
|
assert isinstance(base, nn.Linear), "LoRALinear only supports wrapping nn.Linear." |
|
|
|
|
|
self.in_features = base.in_features |
|
|
self.out_features = base.out_features |
|
|
self.r = r |
|
|
self.alpha = alpha |
|
|
self._base_scaling = alpha / r if r > 0 else 0.0 |
|
|
|
|
|
|
|
|
|
|
|
self.register_buffer("scaling", torch.tensor(self._base_scaling), persistent=False) |
|
|
|
|
|
|
|
|
self.weight = base.weight |
|
|
self.bias = base.bias |
|
|
|
|
|
|
|
|
if r > 0: |
|
|
self.lora_A = nn.Parameter(torch.zeros(r, self.in_features)) |
|
|
self.lora_B = nn.Parameter(torch.zeros(self.out_features, r)) |
|
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
|
|
nn.init.zeros_(self.lora_B) |
|
|
else: |
|
|
self.register_parameter("lora_A", None) |
|
|
self.register_parameter("lora_B", None) |
|
|
|
|
|
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
result = F.linear(x, self.weight, self.bias) |
|
|
if self.r <= 0 or self.lora_A is None: |
|
|
return result |
|
|
|
|
|
lora_out = F.linear(F.linear(x, self.lora_A), self.lora_B) |
|
|
return result + self.dropout(lora_out) * self.scaling |
|
|
|
|
|
def reset_lora_parameters(self): |
|
|
"""重置 LoRA 参数到初始状态""" |
|
|
if self.r > 0 and self.lora_A is not None: |
|
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
|
|
nn.init.zeros_(self.lora_B) |
|
|
|
|
|
def set_enabled(self, enabled: bool): |
|
|
"""启用/禁用 LoRA(通过 scaling 控制,兼容 torch.compile)""" |
|
|
|
|
|
self.scaling.fill_(self._base_scaling if enabled else 0.0) |
|
|
|
|
|
@property |
|
|
def enabled(self) -> bool: |
|
|
return self.scaling.item() != 0.0 |
|
|
|
|
|
|
|
|
def _get_parent_module(root: nn.Module, name: str) -> Optional[nn.Module]: |
|
|
""" |
|
|
根据类似 'layers.0.self_attn.q_proj' 的全名,返回 parent module(即 q_proj 的上一级)。 |
|
|
""" |
|
|
parts = name.split(".") |
|
|
if len(parts) == 1: |
|
|
return root |
|
|
parent = root |
|
|
for p in parts[:-1]: |
|
|
if not hasattr(parent, p): |
|
|
return None |
|
|
parent = getattr(parent, p) |
|
|
return parent |
|
|
|
|
|
|
|
|
def apply_lora_to_named_linear_modules( |
|
|
root: nn.Module, |
|
|
*, |
|
|
target_submodule_names: list[str], |
|
|
r: int, |
|
|
alpha: float, |
|
|
dropout: float, |
|
|
) -> None: |
|
|
""" |
|
|
在给定模块及其子模块中,对名字以 target_submodule_names 结尾的 Linear 层注入 LoRA。 |
|
|
|
|
|
例如 target_submodule_names=["q_proj", "v_proj"] 时, |
|
|
会在所有名为 *.q_proj / *.v_proj 的 nn.Linear 上替换为 LoRALinear。 |
|
|
""" |
|
|
for full_name, module in list(root.named_modules()): |
|
|
if not isinstance(module, nn.Linear): |
|
|
continue |
|
|
short_name = full_name.split(".")[-1] |
|
|
if short_name not in target_submodule_names: |
|
|
continue |
|
|
|
|
|
parent = _get_parent_module(root, full_name) |
|
|
if parent is None: |
|
|
continue |
|
|
|
|
|
|
|
|
lora_layer = LoRALinear( |
|
|
base=module, |
|
|
r=r, |
|
|
alpha=alpha, |
|
|
dropout=dropout, |
|
|
) |
|
|
setattr(parent, short_name, lora_layer) |
|
|
|
|
|
|
|
|
|
|
|
|