Enxin's picture
Upload folder using huggingface_hub
96fe658 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import re
import types
from collections import OrderedDict
from dataclasses import dataclass, field
from functools import partial
from itertools import repeat
from typing import Union
import torch
from torch import nn
from swift.utils.logger import get_logger
from swift.utils.torch_utils import find_sub_module
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
logger = get_logger()
@dataclass
class SideConfig(SwiftConfig):
"""
The configuration class for the side module.
Side-Tuning only needs to train one side network and
weights the output of pre-trained model and side network.
'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks'
by Zhang et al.(2019)
See https://arxiv.org/abs/1912.13503
Args:
target_modules: The feedforward module to be replaced, in regex format
"""
dim: int = field(default=None, metadata={'help': 'The dimension of the hidden states'})
target_modules: str = field(
default=None, metadata={'help': 'The target module to be replaced, in full match format'})
side_module_name: str = field(default='fcn4', metadata={'help': 'The name of the additive side networks'})
source_hidden_pos: Union[str, int] = field(
default=0,
metadata={
'help': 'The position of the hidden state input to the target module, can be int (args) or str (kwargs)'
})
target_hidden_pos: Union[str, int] = field(
default=0,
metadata={
'help': 'The position of the hidden state output from the target module, can be int (args) or str (kwargs)'
})
def __post_init__(self):
from .mapping import SwiftTuners
self.swift_type = SwiftTuners.SIDE
class Side(SwiftAdapter):
@staticmethod
def prepare_model(model: nn.Module, config: SideConfig, adapter_name: str) -> SwiftOutput:
"""Prepare a model with `SideConfig`"""
module_keys = [key for key, _ in model.named_modules()]
for module_key in module_keys:
if re.fullmatch(config.target_modules, module_key): # noqa
tgt_module = model.get_submodule(module_key)
logger.info(f'Matching target module [{module_key}] of type {type(tgt_module)}')
if isinstance(tgt_module, (nn.ModuleList, nn.ModuleDict)):
raise Exception(
f'Type of {type(tgt_module)} may not be supported because of its customized forward')
def _forward(self, *args, **kwargs):
args_main = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
if isinstance(config.source_hidden_pos, int):
x = args[config.source_hidden_pos]
else:
x = kwargs[config.source_hidden_pos]
x_main = args_main[config.target_hidden_pos] \
if isinstance(args_main, (tuple, list, dict)) else args_main
out = getattr(self, f'side_{adapter_name}')(x, x_main)
if isinstance(args_main, (tuple, list, dict)):
args_main[config.target_hidden_pos] = out
else:
args_main = out
return args_main
if isinstance(tgt_module, nn.Sequential) and not hasattr(tgt_module, 'tgt_module_keys'):
tgt_module.tgt_module_keys = copy.deepcopy(list(tgt_module._modules.keys()))
def forward_seq(self, input, *args, **kwargs):
for idx, module in enumerate(self):
if idx >= len(tgt_module.tgt_module_keys):
continue
input = module(input)
return input
setattr(tgt_module, f'forward_origin_{adapter_name}', types.MethodType(forward_seq, tgt_module))
else:
setattr(tgt_module, f'forward_origin_{adapter_name}', tgt_module.forward)
tgt_module.forward = types.MethodType(_forward, tgt_module)
side_module = SideModule(config.dim, adapter_name, module_key, config.side_module_name)
setattr(tgt_module, f'side_{adapter_name}', side_module)
logger.info(f'Side modules(module_key): {module_key}.side_{adapter_name}')
def state_dict_callback(state_dict, adapter_name, **kwargs):
return {key: value for key, value in state_dict.items() if f'side_{adapter_name}' in key}
def mark_trainable_callback(model):
return
return SwiftOutput(
config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
@staticmethod
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
modules = find_sub_module(module, f'side_{adapter_name}')
for _module in modules:
_module: ActivationMixin
_module: nn.Module
_module.set_activation(adapter_name, activate)
SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload)
class SideModule(nn.Module, ActivationMixin):
"""The implementation of vision side-tuning method.
Side-Tuning only needs to train one side network and
weights the output of pre-trained model and side network.
'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks'
by Zhang et al.(2019)
See https://arxiv.org/abs/1912.13503
Args:
side_module_name: The name of the additive side networks.
"""
def __init__(self, dim, adapter_name, module_key, side_module_name='fcn4'):
super(SideModule, self).__init__()
super(nn.Module, self).__init__(module_key)
self.adapter_name = adapter_name
side_module_name = side_module_name.lower()
if side_module_name == 'fcn4':
self.side_net = FCN4(out_dims=dim)
elif side_module_name == 'mlp':
self.side_net = Mlp(dim)
elif side_module_name == 'alexnet':
import torchvision
mm = torchvision.models.alexnet(pretrained=True)
self.side_net = nn.Sequential(
OrderedDict([('features', mm.features), ('avgpool', mm.avgpool), ('flatten', nn.Flatten()),
('fc', nn.Linear(9216, dim, bias=False))]))
else:
raise ValueError(f'Unsupported side_module_name: {side_module_name}')
self.alpha = nn.Parameter(torch.tensor(0.0))
self.mark_all_sub_modules_as_plugin()
def forward(self, x, x_main):
if not self.is_activated(self.adapter_name):
return x_main
alpha_squashed = torch.sigmoid(self.alpha)
x_side = self.side_net(x)
x_out = alpha_squashed * x_main + (1 - alpha_squashed) * x_side
return x_out
class FCN4(nn.Module):
"""The implementation of simple FCN4 network for side network.
"""
def __init__(self, out_dims=-1, **kwargs):
super(FCN4, self).__init__(**kwargs)
self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, dilation=1), nn.GroupNorm(2, 16),
nn.ReLU())
self.conv2 = nn.Sequential(
nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 16),
nn.ReLU())
self.conv3 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 32),
nn.ReLU())
self.conv4 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 64),
nn.ReLU())
self.pool = nn.AdaptiveAvgPool2d((1, 1))
if out_dims > 0:
self.fc = nn.Linear(64, out_dims)
else:
self.fc = None
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
if self.fc is not None:
x = self.fc(x)
return x
class Mlp(nn.Module):
""" MLP as used in Vision Transformer.
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = tuple(repeat(bias, 2))
drop_probs = tuple(repeat(drop, 2))
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x