# Copyright (c) Alibaba, Inc. and its affiliates. import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from swift.utils.logger import get_logger logger = get_logger() class ResTuner(nn.Module): def __init__(self, dim=None, layer_num=-1, depth=-1, zero_init_last=False, stage='', tuner_cfg={}, **kwargs): super().__init__() self.dim = dim self.layer_num = layer_num self.depth = depth self.stage = stage self.tuner_cfg = tuner_cfg if (isinstance(tuner_cfg, str) and tuner_cfg == 'res_adapter') or \ (isinstance(tuner_cfg, dict) and 'res_adapter' in tuner_cfg): tuner_cfg = tuner_cfg['res_adapter'] if isinstance(tuner_cfg, dict) else tuner_cfg self.tuner = ResAdapter( dim=dim, layer_num=layer_num, depth=depth, zero_init_last=zero_init_last, stage=stage, tuner_cfg=tuner_cfg, **kwargs) elif (isinstance(tuner_cfg, str) and tuner_cfg == 'res_group_adapter') or \ (isinstance(tuner_cfg, dict) and 'res_group_adapter' in tuner_cfg): tuner_cfg = tuner_cfg['res_group_adapter'] if isinstance(tuner_cfg, dict) else tuner_cfg self.tuner = ResGroupAdapter( dim=dim, layer_num=layer_num, depth=depth, zero_init_last=zero_init_last, stage=stage, tuner_cfg=tuner_cfg, **kwargs) elif (isinstance(tuner_cfg, str) and tuner_cfg == 'upsample') or \ (isinstance(tuner_cfg, dict) and 'upsample' in tuner_cfg): tuner_cfg = tuner_cfg['upsample'] if isinstance(tuner_cfg, dict) else tuner_cfg if 'upsample_out_channels' in kwargs: out_channels = kwargs['upsample_out_channels'] use_conv = True if out_channels else False else: out_channels = dim use_conv = False self.tuner = Upsample( channels=dim, use_conv=use_conv, out_channels=out_channels, tuner_cfg=tuner_cfg, **kwargs) else: self.tuner = Identity() def forward(self, x, *args, **kwargs): if self.tuner_cfg == 'zero' or 'zero' in self.tuner_cfg: x_out = 0.0 else: x_out = self.tuner(x, *args, **kwargs) return x_out class ResAdapter(nn.Module): def __init__(self, dim, layer_num=-1, depth=-1, zero_init_last=False, stage='', tuner_cfg=None, act_layer=nn.GELU, **kwargs): super(ResAdapter, self).__init__() self.dim = dim self.layer_num = layer_num self.depth = depth self.adapter_length = tuner_cfg['adapter_length'] if 'adapter_length' in tuner_cfg else 32 self.adapter_type = tuner_cfg['adapter_type'] if 'adapter_type' in tuner_cfg else None self.adapter_weight = tuner_cfg['adapter_weight'] if 'adapter_weight' in tuner_cfg else None self.adapter_length = self.adapter_length[self.layer_num] if isinstance(self.adapter_length, list) else self.adapter_length assert isinstance(self.adapter_length, int) or (isinstance(self.adapter_length, tuple) and len(self.adapter_length) == 3) if isinstance(self.adapter_length, int): self.ln1 = nn.Linear(dim, self.adapter_length) else: self.ln1 = nn.Linear(self.adapter_length[0], self.adapter_length[1]) self.activate = act_layer() if isinstance(self.adapter_length, int): self.ln2 = nn.Linear(self.adapter_length, dim) else: self.ln2 = nn.Linear(self.adapter_length[1], self.adapter_length[2]) dim = self.adapter_length[2] self._xavier_init_weights(self.ln1) if zero_init_last and layer_num == depth - 1: self._zero_init_weights(self.ln2) else: self._xavier_init_weights(self.ln2) self.scaling = init_weight_type(dim, self.adapter_weight) self._prepared = False def _zero_init_weights(self, m): if isinstance(m, nn.Linear): nn.init.zeros_(m.weight) nn.init.zeros_(m.bias) def _kaiming_init_weights(self, m): if isinstance(m, nn.Linear): nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) nn.init.normal_(m.bias) def _xavier_init_weights(self, m): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.normal_(m.bias, std=1e-6) def forward(self, x): if not self._prepared: self.ln1.to(x.device) self.activate.to(x.device) self.ln2.to(x.device) self._prepared = True x_dtype = x.dtype x = x.to(self.ln1.weight.dtype) x_shortcut = x if len(x_shortcut.size()) == 4: B, C, N1, N2 = x.size() x = x.view(x_shortcut.size()[0], x_shortcut.size()[1], -1).permute(0, 2, 1) x_adapter = self.ln2(self.activate(self.ln1(x))) if self.adapter_weight: x_adapter = apply_data_weight(x_adapter, self.scaling, self.adapter_weight) if len(x_shortcut.size()) == 4: x_adapter = x_adapter.permute(0, 2, 1).view(x_shortcut.size()[0], x_adapter.size()[-1], x_shortcut.size()[2], x_shortcut.size()[3]) x_out = x_shortcut + x_adapter return x_out.to(x_dtype) class ResGroupAdapter(nn.Module): def __init__(self, dim, layer_num=-1, depth=-1, zero_init_last=False, stage='', tuner_cfg=None, act_layer=nn.GELU, **kwargs): super(ResGroupAdapter, self).__init__() self.dim = dim self.layer_num = layer_num self.depth = depth self.adapter_type = tuner_cfg['adapter_type'] if 'adapter_type' in tuner_cfg else None self.adapter_weight = tuner_cfg['adapter_weight'] if 'adapter_weight' in tuner_cfg else None self.adapter_dim = tuner_cfg['dim'] if 'dim' in tuner_cfg else dim self.adapter_head = tuner_cfg['head'] if 'head' in tuner_cfg else 4 self.adapter_scale_factor = tuner_cfg['scale_factor'] if 'scale_factor' in tuner_cfg else 2 assert self.adapter_dim % self.adapter_head == 0, 'adapter dim should be divisible by adapter head' self.dim_mlp = self.adapter_dim // self.adapter_head self.ln1 = nn.Linear(self.dim_mlp, self.dim_mlp * self.adapter_scale_factor) self.ln2 = nn.Linear(self.dim_mlp * self.adapter_scale_factor, self.dim_mlp) self.activate = act_layer() self._kaiming_init_weights(self.ln1) if zero_init_last and layer_num == depth - 1: self._zero_init_weights(self.ln2) else: self._kaiming_init_weights(self.ln2) self.scaling = init_weight_type(dim, self.adapter_weight) self._prepared = False def _zero_init_weights(self, m): if isinstance(m, nn.Linear): nn.init.zeros_(m.weight) nn.init.zeros_(m.bias) def _kaiming_init_weights(self, m): if isinstance(m, nn.Linear): nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) nn.init.normal_(m.bias) def _xavier_init_weights(self, m): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.normal_(m.bias, std=1e-6) def forward(self, x): if not self._prepared: self.ln1.to(x.device) self.activate.to(x.device) self.ln2.to(x.device) self._prepared = True x_dtype = x.dtype x = x.to(self.ln1.weight.dtype) x_shortcut = x batch, inner_dim, height, width = x.shape x_adapter = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) x_adapter = rearrange(x_adapter, 'b n (c h) -> (b h) n c', h=self.adapter_head) x_adapter = self.ln2(self.activate(self.ln1(x_adapter))) x_adapter = rearrange(x_adapter, '(b h) n c -> b n (c h)', h=self.adapter_head) if self.adapter_weight: x_adapter = apply_data_weight(x_adapter, self.scaling, self.adapter_weight) x_adapter = x_adapter.reshape(batch, height, width, -1).permute(0, 3, 1, 2).contiguous() x_out = x_shortcut + x_adapter return x_out.to(x_dtype) class Identity(nn.Module): def __init__(self): super().__init__() def forward(self, inputs, *args, **kwargs): return inputs class Upsample(nn.Module): """ An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. """ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, **kwargs): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv if use_conv: self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding) self.init_weights() def init_weights(self): def _init_weights(m): if isinstance(m, nn.Conv2d): nn.init.zeros_(m.weight) nn.init.zeros_(m.bias) self.apply(_init_weights) def forward(self, x, target_size=None, *args, **kwargs): assert x.shape[1] == self.channels if target_size is None: x = F.interpolate(x.float(), scale_factor=2, mode='nearest').type_as(x) else: x = F.interpolate(x.float(), target_size, mode='nearest').type_as(x) if self.use_conv: x = self.conv(x) return x def init_weight_type(dim, weight_type): if weight_type is None: scaling = None elif weight_type == 'gate': scaling = nn.Linear(dim, 1) elif weight_type == 'scale': scaling = nn.Parameter(torch.Tensor(1)) scaling.data.fill_(1) elif weight_type == 'scale_kv': scaling_k = nn.Parameter(torch.Tensor(1)) scaling_k.data.fill_(1) scaling_v = nn.Parameter(torch.Tensor(1)) scaling_v.data.fill_(1) scaling = (scaling_k, scaling_v) elif weight_type == 'scale_channel': scaling = nn.Parameter(torch.Tensor(dim)) scaling.data.fill_(1) elif weight_type == 'scale_kv_channel': scaling_k = nn.Parameter(torch.Tensor(dim)) scaling_k.data.fill_(1) scaling_v = nn.Parameter(torch.Tensor(dim)) scaling_v.data.fill_(1) scaling = (scaling_k, scaling_v) elif weight_type and weight_type.startswith('scalar'): scaling = float(weight_type.split('_')[-1]) else: scaling = None return scaling def apply_data_weight(data, scaling, weight_type): if weight_type in ['gate']: scaling = torch.mean(torch.sigmoid(scaling(data)), dim=1).view(-1, 1, 1) elif weight_type in ['scale', 'scale_channel'] or weight_type.startswith('scalar'): scaling = scaling else: scaling = None if scaling is not None: data = data * scaling return data def detach_tensors(feats): if type(feats) in [list, tuple]: feats = [detach_tensors(feat) if feat is not None else None for feat in feats] elif isinstance(feats, dict): feats = {key: detach_tensors(val) for key, val in feats.items()} elif isinstance(feats, torch.Tensor): feats = feats.detach() else: feats = feats.detach() return feats def probe_tensors(module, feats, name): feats = detach_tensors(feats) setattr(module, name, feats) def probe_input_pre_hook(self, args): input = args[0] probe_tensors(self, input, 'probe_input_data') return args def probe_output_hook(self, args, result): output = result probe_tensors(self, output, 'probe_output_data') return output