# Copyright (c) Alibaba, Inc. and its affiliates. import math import torch import torch.nn as nn from swift.utils.logger import get_logger logger = get_logger() def detach_tensors(feats): if type(feats) in [list, tuple]: feats = [detach_tensors(feat) if feat is not None else None for feat in feats] elif isinstance(feats, dict): feats = {key: detach_tensors(val) for key, val in feats.items()} elif isinstance(feats, torch.Tensor): feats = feats.detach() else: feats = feats.detach() return feats def probe_tensors(module, feats, name): feats = detach_tensors(feats) setattr(module, name, feats) def probe_input_pre_hook(self, args): input = args[0] probe_tensors(self, input, 'probe_input_data') return args def probe_output_hook(self, args, result): output = result probe_tensors(self, output, 'probe_output_data') return output def choose_weight_type(weight_type, dim): if weight_type == 'gate': scaling = nn.Linear(dim, 1) elif weight_type == 'scale': scaling = nn.Parameter(torch.Tensor(1)) scaling.data.fill_(1) elif weight_type == 'scale_channel': scaling = nn.Parameter(torch.Tensor(dim)) scaling.data.fill_(1) elif weight_type and weight_type.startswith('scalar'): scaling = float(weight_type.split('_')[-1]) else: scaling = None return scaling def get_weight_value(weight_type, scaling, x): if weight_type in ['gate']: scaling = torch.mean(torch.sigmoid(scaling(x)), dim=1).view(-1, 1, 1) elif weight_type in ['scale', 'scale_channel'] or weight_type.startswith('scalar'): scaling = scaling else: scaling = None return scaling class SCEAdapter(nn.Module): def __init__(self, dim, adapter_length, adapter_type=None, adapter_weight=None, act_layer=nn.GELU, zero_init_last=True, use_bias=True): super(SCEAdapter, self).__init__() self.dim = dim self.adapter_length = adapter_length self.adapter_type = adapter_type self.adapter_weight = adapter_weight self.zero_init_last = zero_init_last self.ln1 = nn.Linear(dim, adapter_length, bias=use_bias) self.activate = act_layer() self.ln2 = nn.Linear(adapter_length, dim, bias=use_bias) self.init_weights() self.init_scaling() def _zero_init_weights(self, m): if isinstance(m, nn.Linear): nn.init.zeros_(m.weight) nn.init.zeros_(m.bias) def _kaiming_init_weights(self, m): if isinstance(m, nn.Linear): nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) def init_weights(self): self._kaiming_init_weights(self.ln1) if self.zero_init_last: self._zero_init_weights(self.ln2) else: self._kaiming_init_weights(self.ln2) def init_scaling(self): if self.adapter_weight: self.scaling = choose_weight_type(self.adapter_weight, self.dim) else: self.scaling = None def forward(self, x, x_shortcut=None, use_shortcut=True, **kwargs): if x_shortcut is None: x_shortcut = x x_shape = x.shape if len(x_shape) == 4: b, d, h, w = x_shape x = x.permute(0, 2, 3, 1).reshape(b, h * w, d) out = self.ln2(self.activate(self.ln1(x))) if self.adapter_weight: scaling = get_weight_value(self.adapter_weight, self.scaling, out) out = out * scaling if scaling is not None else out if len(x_shape) == 4: b, d, h, w = x_shape out = out.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() if use_shortcut: out = x_shortcut + out return out