sparse / ms-swift /swift /tuners /scetuning /scetuning_components.py
Enxin's picture
Upload folder using huggingface_hub
96fe658 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import torch
import torch.nn as nn
from swift.utils.logger import get_logger
logger = get_logger()
def detach_tensors(feats):
if type(feats) in [list, tuple]:
feats = [detach_tensors(feat) if feat is not None else None for feat in feats]
elif isinstance(feats, dict):
feats = {key: detach_tensors(val) for key, val in feats.items()}
elif isinstance(feats, torch.Tensor):
feats = feats.detach()
else:
feats = feats.detach()
return feats
def probe_tensors(module, feats, name):
feats = detach_tensors(feats)
setattr(module, name, feats)
def probe_input_pre_hook(self, args):
input = args[0]
probe_tensors(self, input, 'probe_input_data')
return args
def probe_output_hook(self, args, result):
output = result
probe_tensors(self, output, 'probe_output_data')
return output
def choose_weight_type(weight_type, dim):
if weight_type == 'gate':
scaling = nn.Linear(dim, 1)
elif weight_type == 'scale':
scaling = nn.Parameter(torch.Tensor(1))
scaling.data.fill_(1)
elif weight_type == 'scale_channel':
scaling = nn.Parameter(torch.Tensor(dim))
scaling.data.fill_(1)
elif weight_type and weight_type.startswith('scalar'):
scaling = float(weight_type.split('_')[-1])
else:
scaling = None
return scaling
def get_weight_value(weight_type, scaling, x):
if weight_type in ['gate']:
scaling = torch.mean(torch.sigmoid(scaling(x)), dim=1).view(-1, 1, 1)
elif weight_type in ['scale', 'scale_channel'] or weight_type.startswith('scalar'):
scaling = scaling
else:
scaling = None
return scaling
class SCEAdapter(nn.Module):
def __init__(self,
dim,
adapter_length,
adapter_type=None,
adapter_weight=None,
act_layer=nn.GELU,
zero_init_last=True,
use_bias=True):
super(SCEAdapter, self).__init__()
self.dim = dim
self.adapter_length = adapter_length
self.adapter_type = adapter_type
self.adapter_weight = adapter_weight
self.zero_init_last = zero_init_last
self.ln1 = nn.Linear(dim, adapter_length, bias=use_bias)
self.activate = act_layer()
self.ln2 = nn.Linear(adapter_length, dim, bias=use_bias)
self.init_weights()
self.init_scaling()
def _zero_init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.zeros_(m.weight)
nn.init.zeros_(m.bias)
def _kaiming_init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
def init_weights(self):
self._kaiming_init_weights(self.ln1)
if self.zero_init_last:
self._zero_init_weights(self.ln2)
else:
self._kaiming_init_weights(self.ln2)
def init_scaling(self):
if self.adapter_weight:
self.scaling = choose_weight_type(self.adapter_weight, self.dim)
else:
self.scaling = None
def forward(self, x, x_shortcut=None, use_shortcut=True, **kwargs):
if x_shortcut is None:
x_shortcut = x
x_shape = x.shape
if len(x_shape) == 4:
b, d, h, w = x_shape
x = x.permute(0, 2, 3, 1).reshape(b, h * w, d)
out = self.ln2(self.activate(self.ln1(x)))
if self.adapter_weight:
scaling = get_weight_value(self.adapter_weight, self.scaling, out)
out = out * scaling if scaling is not None else out
if len(x_shape) == 4:
b, d, h, w = x_shape
out = out.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
if use_shortcut:
out = x_shortcut + out
return out