tomxxie
适配zeroGPU
568e264
raw
history blame
12.9 kB
# Copyright (c) 2021 microsoft
# 2023 Alan (alanfangemail@gmail.com)
# -----------------------------------------------------------------------------
# Licensed under the MIT License (MIT). See LICENSE in the repo root for
# license information.
# -----------------------------------------------------------------------------
import logging
import torch
import torch.nn as nn
from typing import Dict, List
import wenet.finetune.lora.layers as lora
def get_nested_attr(module, attr_path):
attrs = attr_path.split('.')
for attr in attrs:
if hasattr(module, attr):
module = getattr(module, attr)
else:
return None
return module
def inject_lora(module, lora_config):
lora_rank = lora_config["lora_rank"]
lora_alpha = lora_config["lora_alpha"]
lora_dropout = lora_config["lora_dropout"]
for lora_attr in lora_config["lora_list"]:
if hasattr(module, lora_attr):
submodule = getattr(module, lora_attr)
n_feat = submodule.in_features
lora_linear = lora.Linear(n_feat, n_feat, r=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout)
setattr(module, lora_attr, lora_linear)
def inject_lora_to_model(model, lora_config):
lora_modules = []
for module in lora_config["lora_modules"]:
submodule = get_nested_attr(model, module)
for layer in submodule:
lora_modules.append(layer)
updated_lora_modules = []
for i in range(len(lora_modules)):
for attn_attr in lora_config["lora_attn_attr"]:
if hasattr(lora_modules[i], attn_attr):
updated_lora_modules.append(getattr(lora_modules[i], attn_attr))
for lora_module in updated_lora_modules:
inject_lora(lora_module, lora_config)
def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
logging.info('freezing all params except lora module.')
for n, p in model.named_parameters():
if 'lora_' not in n:
p.requires_grad = False
if bias == 'none':
return
elif bias == 'all':
for n, p in model.named_parameters():
if 'bias' in n:
p.requires_grad = True
elif bias == 'lora_only':
for m in model.modules():
if isinstance(m, lora.LoRALayer) and \
hasattr(m, 'bias') and \
m.bias is not None:
m.bias.requires_grad = True
else:
raise NotImplementedError
def lora_state_dict(model: nn.Module,
bias: str = 'none') -> Dict[str, torch.Tensor]:
my_state_dict = model.state_dict()
if bias == 'none':
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
elif bias == 'all':
return {
k: my_state_dict[k]
for k in my_state_dict if 'lora_' in k or 'bias' in k
}
elif bias == 'lora_only':
to_return = {}
for k in my_state_dict:
if 'lora_' in k:
to_return[k] = my_state_dict[k]
bias_name = k.split('lora_')[0] + 'bias'
if bias_name in my_state_dict:
to_return[bias_name] = my_state_dict[bias_name]
return to_return
else:
raise NotImplementedError
def get_record_gradient_hook(model, record_dict):
def record_gradient_hook(grad):
for n, p in model.named_parameters():
if p.requires_grad and p.grad is not None:
if n not in record_dict:
record_dict[n] = p.grad.cpu()
else:
record_dict[n] += p.grad.cpu()
p.grad = None
return grad
return record_gradient_hook
def estimate_gradient(
model, dataloader, max_iters: int = 8,
device: torch.device = torch.device("cpu")
) -> Dict[str, List[torch.Tensor]]:
r"""
Estimate the gradient of the model on the given dataset
"""
logging.info("Estimating gradient layer by layer, time needed")
model.train()
named_grads = {}
hooks = []
requires_grad_states = {}
for name, param in model.named_parameters():
requires_grad_states[name] = param.requires_grad
param.requires_grad = True
hook = param.register_hook(get_record_gradient_hook(model, named_grads))
hooks.append(hook)
num = 0
for _, batch_dict in enumerate(dataloader):
num += 1
if max_iters is not None and num >= max_iters:
break
outputs = model(batch_dict, device)
outputs['loss'].backward()
get_record_gradient_hook(model, named_grads)(None) # get gradient of last layer
# make sure the gradient is cleared
for n, p in model.named_parameters():
if p.grad is not None:
p.grad = None
for n, _ in named_grads.items():
named_grads[n] /= num
for hook in hooks:
hook.remove()
# recover original requires_grad states
for name, param in model.named_parameters():
param.requires_grad = requires_grad_states[name]
torch.cuda.empty_cache()
return named_grads
@torch.no_grad()
def reinit_lora_modules(name, module, init_config, **kwargs):
r"""Refer to https://github.com/Outsider565/LoRA-GA/blob/
c185846309ea9012d0bcd46ebd30347dda1c592c/run_exp.py#L67
Reinitialize the lora model with the given configuration.
"""
import math
lora_r = min(module.lora_A.shape)
a_dim = max(module.lora_A.shape)
b_dim = max(module.lora_B.shape)
if init_config.mode == "simple":
match init_config.lora_A:
case "gaussian":
torch.nn.init.normal_(
module.lora_A, mean=0.0,
std=init_config.lora_A_std
)
case "kaiming":
# https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124
torch.nn.init.kaiming_uniform_(module.lora_A,
a=math.sqrt(5))
case "fan_out_kaiming":
torch.nn.init.kaiming_normal_(
module.lora_A, mode="fan_out"
)
case "xavier":
torch.nn.init.xavier_normal_(module.lora_A)
case "zeros":
torch.nn.init.zeros_(module.lora_A)
case "unit":
torch.nn.init.normal_(
module.lora_A, mean=0.0,
std=1.0 / (a_dim**0.5)
)
case "orthogonal":
torch.nn.init.orthogonal_(module.lora_A)
case _:
raise ValueError(
f"Unknown lora_A initialization: {init_config.lora_A}"
)
match init_config.lora_B:
case "gaussian":
torch.nn.init.normal_(
module.lora_B, mean=0.0,
std=init_config.lora_B_std
)
case "kaiming":
torch.nn.init.kaiming_normal_(module.lora_B)
case "fan_out_kaiming":
torch.nn.init.kaiming_normal_(
module.lora_B, mode="fan_out"
)
case "xavier":
torch.nn.init.xavier_normal_(module.lora_B)
case "zeros":
torch.nn.init.zeros_(module.lora_B)
case "unit":
torch.nn.init.normal_(
module.lora_B, mean=0.0,
std=1.0 / (b_dim**0.5)
)
case "orthogonal":
torch.nn.init.orthogonal_(module.lora_B)
case _:
raise ValueError(
f"Unknown lora_B initialization: {init_config.lora_B}"
)
if getattr(init_config, 'scale', '') == "stable":
gamma = init_config.stable_gamma
m, n = module.weight.shape
module.lora_B.data *= (m**0.25) / gamma**0.5
module.lora_A.data *= (n**0.25) / gamma**0.5
elif init_config.mode == "svd":
U, S, V = torch.svd_lowrank(module.weight.float(), q=4 * lora_r,
niter=4)
V = V.T
m, n = module.weight.shape
if init_config.scale == "default":
S = S / module.scaling
module.lora_B = torch.nn.Parameter(
(U[:, :lora_r] * torch.sqrt(S[:lora_r])).contiguous()
)
module.lora_A = torch.nn.Parameter(
(V[:lora_r, :].T * torch.sqrt(S[:lora_r])).T.contiguous()
)
elif init_config.scale == "stable":
gamma = init_config.stable_gamma
module.lora_B = torch.nn.Parameter(
(U[:, :lora_r] * (m**0.25) / gamma**0.5).contiguous()
)
module.lora_A = torch.nn.Parameter(
(V[:lora_r, :] * (n**0.25) / gamma**0.5).contiguous()
)
elif init_config.scale == "unit":
module.lora_B = torch.nn.Parameter((U[:, :lora_r]).contiguous())
module.lora_A = torch.nn.Parameter((V[:lora_r, :]).contiguous())
elif init_config.scale == "normalized":
S_sum = S[:lora_r].sum()
module.lora_B = torch.nn.Parameter(
(U[:, :lora_r] * torch.sqrt(S[:lora_r])
/ torch.sqrt(S_sum) * lora_r**0.5).contiguous()
)
module.lora_A = torch.nn.Parameter(
(V[:lora_r, :].T * torch.sqrt(S[:lora_r])
/ torch.sqrt(S_sum) * lora_r**0.5).T.contiguous()
)
elif init_config.mode == "gradient":
named_grad = kwargs["named_grads"]
grad_name = name + ".weight"
grads = named_grad[grad_name]
U, S, V = torch.svd_lowrank(grads.cuda().float(), q=4 * lora_r, niter=4)
V = V.T
# set direction
if init_config.direction == "ArBr":
B = U[:, 0 : 2 * lora_r : 2]
A = V[1 : 2 * lora_r : 2, :]
elif init_config.direction == "A2rBr":
B = U[:, :lora_r]
A = V[lora_r : 2 * lora_r, :]
elif init_config.direction == "ArB2r":
B = U[:, lora_r : 2 * lora_r]
A = V[:lora_r, :]
scaling_factor = module.scaling
if init_config.scale == "gd":
A = A / scaling_factor
B = B / scaling_factor
elif init_config.scale == "unit":
# Because A,B is orthogonal, do not need to scale
pass
elif init_config.scale == "stable":
m, n = grads.shape
# m: feature_out, n: feature_in
# the scale of output is only related to the feature_out
gamma = init_config.stable_gamma
B = B * m**0.25 / gamma**0.5
A = A * m**0.25 / gamma**0.5
elif init_config.scale == "weightS":
_, S, _ = torch.svd_lowrank(module.weight.float(), q=4 * lora_r,
niter=4)
S = S / module.scaling
avg_s = torch.sqrt(S[:lora_r]).mean().to(A.device)
B = B * avg_s
A = A * avg_s
module.lora_B = torch.nn.Parameter(B.contiguous().cuda())
module.lora_A = torch.nn.Parameter(A.contiguous().cuda())
with torch.no_grad():
# consider dtype not in init_config
if not hasattr(init_config, "dtype"):
pass
elif init_config.dtype == "bf16":
module.lora_A.data = module.lora_A.data.to(torch.bfloat16)
module.lora_B.data = module.lora_B.data.to(torch.bfloat16)
elif init_config.dtype == "fp32":
module.lora_A.data = module.lora_A.data.to(torch.float32)
module.lora_B.data = module.lora_B.data.to(torch.float32)
# If lora_A@lora_B is not zero,
# then we need to subtract lora_A@lora_B from the original weight matrix
offset = (
module.lora_B @ module.lora_A
).to(module.weight.data.device)
scaling_factor = module.scaling
offset *= scaling_factor
if hasattr(init_config, "norm_clip") and init_config.norm_clip:
# for numerical stability,
# offset's largest value must be less then weight's largest value
ratio = torch.max(torch.abs(module.weight.data)) / torch.max(
torch.abs(offset)
)
if ratio < 1:
offset *= ratio
module.lora_A.data *= ratio**0.5
module.lora_B.data *= ratio**0.5
logging.warning(f"Clipping offset by {ratio}")
try:
module.weight.data -= offset
except Exception as e:
logging.warning(f"{e}")
breakpoint()