sparse / ms-swift /swift /tuners /restuning_components.py
Enxin's picture
Upload folder using huggingface_hub
96fe658 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from swift.utils.logger import get_logger
logger = get_logger()
class ResTuner(nn.Module):
def __init__(self, dim=None, layer_num=-1, depth=-1, zero_init_last=False, stage='', tuner_cfg={}, **kwargs):
super().__init__()
self.dim = dim
self.layer_num = layer_num
self.depth = depth
self.stage = stage
self.tuner_cfg = tuner_cfg
if (isinstance(tuner_cfg, str) and tuner_cfg == 'res_adapter') or \
(isinstance(tuner_cfg, dict) and 'res_adapter' in tuner_cfg):
tuner_cfg = tuner_cfg['res_adapter'] if isinstance(tuner_cfg, dict) else tuner_cfg
self.tuner = ResAdapter(
dim=dim,
layer_num=layer_num,
depth=depth,
zero_init_last=zero_init_last,
stage=stage,
tuner_cfg=tuner_cfg,
**kwargs)
elif (isinstance(tuner_cfg, str) and tuner_cfg == 'res_group_adapter') or \
(isinstance(tuner_cfg, dict) and 'res_group_adapter' in tuner_cfg):
tuner_cfg = tuner_cfg['res_group_adapter'] if isinstance(tuner_cfg, dict) else tuner_cfg
self.tuner = ResGroupAdapter(
dim=dim,
layer_num=layer_num,
depth=depth,
zero_init_last=zero_init_last,
stage=stage,
tuner_cfg=tuner_cfg,
**kwargs)
elif (isinstance(tuner_cfg, str) and tuner_cfg == 'upsample') or \
(isinstance(tuner_cfg, dict) and 'upsample' in tuner_cfg):
tuner_cfg = tuner_cfg['upsample'] if isinstance(tuner_cfg, dict) else tuner_cfg
if 'upsample_out_channels' in kwargs:
out_channels = kwargs['upsample_out_channels']
use_conv = True if out_channels else False
else:
out_channels = dim
use_conv = False
self.tuner = Upsample(
channels=dim, use_conv=use_conv, out_channels=out_channels, tuner_cfg=tuner_cfg, **kwargs)
else:
self.tuner = Identity()
def forward(self, x, *args, **kwargs):
if self.tuner_cfg == 'zero' or 'zero' in self.tuner_cfg:
x_out = 0.0
else:
x_out = self.tuner(x, *args, **kwargs)
return x_out
class ResAdapter(nn.Module):
def __init__(self,
dim,
layer_num=-1,
depth=-1,
zero_init_last=False,
stage='',
tuner_cfg=None,
act_layer=nn.GELU,
**kwargs):
super(ResAdapter, self).__init__()
self.dim = dim
self.layer_num = layer_num
self.depth = depth
self.adapter_length = tuner_cfg['adapter_length'] if 'adapter_length' in tuner_cfg else 32
self.adapter_type = tuner_cfg['adapter_type'] if 'adapter_type' in tuner_cfg else None
self.adapter_weight = tuner_cfg['adapter_weight'] if 'adapter_weight' in tuner_cfg else None
self.adapter_length = self.adapter_length[self.layer_num] if isinstance(self.adapter_length,
list) else self.adapter_length
assert isinstance(self.adapter_length, int) or (isinstance(self.adapter_length, tuple)
and len(self.adapter_length) == 3)
if isinstance(self.adapter_length, int):
self.ln1 = nn.Linear(dim, self.adapter_length)
else:
self.ln1 = nn.Linear(self.adapter_length[0], self.adapter_length[1])
self.activate = act_layer()
if isinstance(self.adapter_length, int):
self.ln2 = nn.Linear(self.adapter_length, dim)
else:
self.ln2 = nn.Linear(self.adapter_length[1], self.adapter_length[2])
dim = self.adapter_length[2]
self._xavier_init_weights(self.ln1)
if zero_init_last and layer_num == depth - 1:
self._zero_init_weights(self.ln2)
else:
self._xavier_init_weights(self.ln2)
self.scaling = init_weight_type(dim, self.adapter_weight)
self._prepared = False
def _zero_init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.zeros_(m.weight)
nn.init.zeros_(m.bias)
def _kaiming_init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
nn.init.normal_(m.bias)
def _xavier_init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.normal_(m.bias, std=1e-6)
def forward(self, x):
if not self._prepared:
self.ln1.to(x.device)
self.activate.to(x.device)
self.ln2.to(x.device)
self._prepared = True
x_dtype = x.dtype
x = x.to(self.ln1.weight.dtype)
x_shortcut = x
if len(x_shortcut.size()) == 4:
B, C, N1, N2 = x.size()
x = x.view(x_shortcut.size()[0], x_shortcut.size()[1], -1).permute(0, 2, 1)
x_adapter = self.ln2(self.activate(self.ln1(x)))
if self.adapter_weight:
x_adapter = apply_data_weight(x_adapter, self.scaling, self.adapter_weight)
if len(x_shortcut.size()) == 4:
x_adapter = x_adapter.permute(0, 2, 1).view(x_shortcut.size()[0],
x_adapter.size()[-1],
x_shortcut.size()[2],
x_shortcut.size()[3])
x_out = x_shortcut + x_adapter
return x_out.to(x_dtype)
class ResGroupAdapter(nn.Module):
def __init__(self,
dim,
layer_num=-1,
depth=-1,
zero_init_last=False,
stage='',
tuner_cfg=None,
act_layer=nn.GELU,
**kwargs):
super(ResGroupAdapter, self).__init__()
self.dim = dim
self.layer_num = layer_num
self.depth = depth
self.adapter_type = tuner_cfg['adapter_type'] if 'adapter_type' in tuner_cfg else None
self.adapter_weight = tuner_cfg['adapter_weight'] if 'adapter_weight' in tuner_cfg else None
self.adapter_dim = tuner_cfg['dim'] if 'dim' in tuner_cfg else dim
self.adapter_head = tuner_cfg['head'] if 'head' in tuner_cfg else 4
self.adapter_scale_factor = tuner_cfg['scale_factor'] if 'scale_factor' in tuner_cfg else 2
assert self.adapter_dim % self.adapter_head == 0, 'adapter dim should be divisible by adapter head'
self.dim_mlp = self.adapter_dim // self.adapter_head
self.ln1 = nn.Linear(self.dim_mlp, self.dim_mlp * self.adapter_scale_factor)
self.ln2 = nn.Linear(self.dim_mlp * self.adapter_scale_factor, self.dim_mlp)
self.activate = act_layer()
self._kaiming_init_weights(self.ln1)
if zero_init_last and layer_num == depth - 1:
self._zero_init_weights(self.ln2)
else:
self._kaiming_init_weights(self.ln2)
self.scaling = init_weight_type(dim, self.adapter_weight)
self._prepared = False
def _zero_init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.zeros_(m.weight)
nn.init.zeros_(m.bias)
def _kaiming_init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
nn.init.normal_(m.bias)
def _xavier_init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.normal_(m.bias, std=1e-6)
def forward(self, x):
if not self._prepared:
self.ln1.to(x.device)
self.activate.to(x.device)
self.ln2.to(x.device)
self._prepared = True
x_dtype = x.dtype
x = x.to(self.ln1.weight.dtype)
x_shortcut = x
batch, inner_dim, height, width = x.shape
x_adapter = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
x_adapter = rearrange(x_adapter, 'b n (c h) -> (b h) n c', h=self.adapter_head)
x_adapter = self.ln2(self.activate(self.ln1(x_adapter)))
x_adapter = rearrange(x_adapter, '(b h) n c -> b n (c h)', h=self.adapter_head)
if self.adapter_weight:
x_adapter = apply_data_weight(x_adapter, self.scaling, self.adapter_weight)
x_adapter = x_adapter.reshape(batch, height, width, -1).permute(0, 3, 1, 2).contiguous()
x_out = x_shortcut + x_adapter
return x_out.to(x_dtype)
class Identity(nn.Module):
def __init__(self):
super().__init__()
def forward(self, inputs, *args, **kwargs):
return inputs
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, **kwargs):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
if use_conv:
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding)
self.init_weights()
def init_weights(self):
def _init_weights(m):
if isinstance(m, nn.Conv2d):
nn.init.zeros_(m.weight)
nn.init.zeros_(m.bias)
self.apply(_init_weights)
def forward(self, x, target_size=None, *args, **kwargs):
assert x.shape[1] == self.channels
if target_size is None:
x = F.interpolate(x.float(), scale_factor=2, mode='nearest').type_as(x)
else:
x = F.interpolate(x.float(), target_size, mode='nearest').type_as(x)
if self.use_conv:
x = self.conv(x)
return x
def init_weight_type(dim, weight_type):
if weight_type is None:
scaling = None
elif weight_type == 'gate':
scaling = nn.Linear(dim, 1)
elif weight_type == 'scale':
scaling = nn.Parameter(torch.Tensor(1))
scaling.data.fill_(1)
elif weight_type == 'scale_kv':
scaling_k = nn.Parameter(torch.Tensor(1))
scaling_k.data.fill_(1)
scaling_v = nn.Parameter(torch.Tensor(1))
scaling_v.data.fill_(1)
scaling = (scaling_k, scaling_v)
elif weight_type == 'scale_channel':
scaling = nn.Parameter(torch.Tensor(dim))
scaling.data.fill_(1)
elif weight_type == 'scale_kv_channel':
scaling_k = nn.Parameter(torch.Tensor(dim))
scaling_k.data.fill_(1)
scaling_v = nn.Parameter(torch.Tensor(dim))
scaling_v.data.fill_(1)
scaling = (scaling_k, scaling_v)
elif weight_type and weight_type.startswith('scalar'):
scaling = float(weight_type.split('_')[-1])
else:
scaling = None
return scaling
def apply_data_weight(data, scaling, weight_type):
if weight_type in ['gate']:
scaling = torch.mean(torch.sigmoid(scaling(data)), dim=1).view(-1, 1, 1)
elif weight_type in ['scale', 'scale_channel'] or weight_type.startswith('scalar'):
scaling = scaling
else:
scaling = None
if scaling is not None:
data = data * scaling
return data
def detach_tensors(feats):
if type(feats) in [list, tuple]:
feats = [detach_tensors(feat) if feat is not None else None for feat in feats]
elif isinstance(feats, dict):
feats = {key: detach_tensors(val) for key, val in feats.items()}
elif isinstance(feats, torch.Tensor):
feats = feats.detach()
else:
feats = feats.detach()
return feats
def probe_tensors(module, feats, name):
feats = detach_tensors(feats)
setattr(module, name, feats)
def probe_input_pre_hook(self, args):
input = args[0]
probe_tensors(self, input, 'probe_input_data')
return args
def probe_output_hook(self, args, result):
output = result
probe_tensors(self, output, 'probe_output_data')
return output