|
|
|
import copy |
|
import re |
|
import types |
|
from collections import OrderedDict |
|
from dataclasses import dataclass, field |
|
from functools import partial |
|
from itertools import repeat |
|
from typing import Union |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from swift.utils.logger import get_logger |
|
from swift.utils.torch_utils import find_sub_module |
|
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput |
|
|
|
logger = get_logger() |
|
|
|
|
|
@dataclass |
|
class SideConfig(SwiftConfig): |
|
""" |
|
The configuration class for the side module. |
|
|
|
Side-Tuning only needs to train one side network and |
|
weights the output of pre-trained model and side network. |
|
'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks' |
|
by Zhang et al.(2019) |
|
See https://arxiv.org/abs/1912.13503 |
|
|
|
Args: |
|
target_modules: The feedforward module to be replaced, in regex format |
|
""" |
|
|
|
dim: int = field(default=None, metadata={'help': 'The dimension of the hidden states'}) |
|
|
|
target_modules: str = field( |
|
default=None, metadata={'help': 'The target module to be replaced, in full match format'}) |
|
|
|
side_module_name: str = field(default='fcn4', metadata={'help': 'The name of the additive side networks'}) |
|
|
|
source_hidden_pos: Union[str, int] = field( |
|
default=0, |
|
metadata={ |
|
'help': 'The position of the hidden state input to the target module, can be int (args) or str (kwargs)' |
|
}) |
|
|
|
target_hidden_pos: Union[str, int] = field( |
|
default=0, |
|
metadata={ |
|
'help': 'The position of the hidden state output from the target module, can be int (args) or str (kwargs)' |
|
}) |
|
|
|
def __post_init__(self): |
|
from .mapping import SwiftTuners |
|
self.swift_type = SwiftTuners.SIDE |
|
|
|
|
|
class Side(SwiftAdapter): |
|
|
|
@staticmethod |
|
def prepare_model(model: nn.Module, config: SideConfig, adapter_name: str) -> SwiftOutput: |
|
"""Prepare a model with `SideConfig`""" |
|
module_keys = [key for key, _ in model.named_modules()] |
|
|
|
for module_key in module_keys: |
|
if re.fullmatch(config.target_modules, module_key): |
|
tgt_module = model.get_submodule(module_key) |
|
logger.info(f'Matching target module [{module_key}] of type {type(tgt_module)}') |
|
if isinstance(tgt_module, (nn.ModuleList, nn.ModuleDict)): |
|
raise Exception( |
|
f'Type of {type(tgt_module)} may not be supported because of its customized forward') |
|
|
|
def _forward(self, *args, **kwargs): |
|
args_main = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs) |
|
|
|
if isinstance(config.source_hidden_pos, int): |
|
x = args[config.source_hidden_pos] |
|
else: |
|
x = kwargs[config.source_hidden_pos] |
|
|
|
x_main = args_main[config.target_hidden_pos] \ |
|
if isinstance(args_main, (tuple, list, dict)) else args_main |
|
out = getattr(self, f'side_{adapter_name}')(x, x_main) |
|
if isinstance(args_main, (tuple, list, dict)): |
|
args_main[config.target_hidden_pos] = out |
|
else: |
|
args_main = out |
|
return args_main |
|
|
|
if isinstance(tgt_module, nn.Sequential) and not hasattr(tgt_module, 'tgt_module_keys'): |
|
tgt_module.tgt_module_keys = copy.deepcopy(list(tgt_module._modules.keys())) |
|
|
|
def forward_seq(self, input, *args, **kwargs): |
|
for idx, module in enumerate(self): |
|
if idx >= len(tgt_module.tgt_module_keys): |
|
continue |
|
input = module(input) |
|
return input |
|
|
|
setattr(tgt_module, f'forward_origin_{adapter_name}', types.MethodType(forward_seq, tgt_module)) |
|
else: |
|
setattr(tgt_module, f'forward_origin_{adapter_name}', tgt_module.forward) |
|
tgt_module.forward = types.MethodType(_forward, tgt_module) |
|
side_module = SideModule(config.dim, adapter_name, module_key, config.side_module_name) |
|
setattr(tgt_module, f'side_{adapter_name}', side_module) |
|
logger.info(f'Side modules(module_key): {module_key}.side_{adapter_name}') |
|
|
|
def state_dict_callback(state_dict, adapter_name, **kwargs): |
|
return {key: value for key, value in state_dict.items() if f'side_{adapter_name}' in key} |
|
|
|
def mark_trainable_callback(model): |
|
return |
|
|
|
return SwiftOutput( |
|
config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback) |
|
|
|
@staticmethod |
|
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): |
|
modules = find_sub_module(module, f'side_{adapter_name}') |
|
for _module in modules: |
|
_module: ActivationMixin |
|
_module: nn.Module |
|
_module.set_activation(adapter_name, activate) |
|
SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload) |
|
|
|
|
|
class SideModule(nn.Module, ActivationMixin): |
|
"""The implementation of vision side-tuning method. |
|
|
|
Side-Tuning only needs to train one side network and |
|
weights the output of pre-trained model and side network. |
|
'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks' |
|
by Zhang et al.(2019) |
|
See https://arxiv.org/abs/1912.13503 |
|
|
|
Args: |
|
side_module_name: The name of the additive side networks. |
|
""" |
|
|
|
def __init__(self, dim, adapter_name, module_key, side_module_name='fcn4'): |
|
super(SideModule, self).__init__() |
|
super(nn.Module, self).__init__(module_key) |
|
self.adapter_name = adapter_name |
|
|
|
side_module_name = side_module_name.lower() |
|
if side_module_name == 'fcn4': |
|
self.side_net = FCN4(out_dims=dim) |
|
elif side_module_name == 'mlp': |
|
self.side_net = Mlp(dim) |
|
elif side_module_name == 'alexnet': |
|
import torchvision |
|
mm = torchvision.models.alexnet(pretrained=True) |
|
self.side_net = nn.Sequential( |
|
OrderedDict([('features', mm.features), ('avgpool', mm.avgpool), ('flatten', nn.Flatten()), |
|
('fc', nn.Linear(9216, dim, bias=False))])) |
|
else: |
|
raise ValueError(f'Unsupported side_module_name: {side_module_name}') |
|
self.alpha = nn.Parameter(torch.tensor(0.0)) |
|
self.mark_all_sub_modules_as_plugin() |
|
|
|
def forward(self, x, x_main): |
|
if not self.is_activated(self.adapter_name): |
|
return x_main |
|
alpha_squashed = torch.sigmoid(self.alpha) |
|
x_side = self.side_net(x) |
|
x_out = alpha_squashed * x_main + (1 - alpha_squashed) * x_side |
|
return x_out |
|
|
|
|
|
class FCN4(nn.Module): |
|
"""The implementation of simple FCN4 network for side network. |
|
""" |
|
|
|
def __init__(self, out_dims=-1, **kwargs): |
|
super(FCN4, self).__init__(**kwargs) |
|
|
|
self.conv1 = nn.Sequential( |
|
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, dilation=1), nn.GroupNorm(2, 16), |
|
nn.ReLU()) |
|
self.conv2 = nn.Sequential( |
|
nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 16), |
|
nn.ReLU()) |
|
self.conv3 = nn.Sequential( |
|
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 32), |
|
nn.ReLU()) |
|
self.conv4 = nn.Sequential( |
|
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 64), |
|
nn.ReLU()) |
|
self.pool = nn.AdaptiveAvgPool2d((1, 1)) |
|
if out_dims > 0: |
|
self.fc = nn.Linear(64, out_dims) |
|
else: |
|
self.fc = None |
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
x = self.conv2(x) |
|
x = self.conv3(x) |
|
x = self.conv4(x) |
|
x = self.pool(x) |
|
x = x.view(x.size(0), -1) |
|
if self.fc is not None: |
|
x = self.fc(x) |
|
return x |
|
|
|
|
|
class Mlp(nn.Module): |
|
""" MLP as used in Vision Transformer. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_features=None, |
|
out_features=None, |
|
act_layer=nn.GELU, |
|
norm_layer=None, |
|
bias=True, |
|
drop=0., |
|
use_conv=False, |
|
): |
|
super().__init__() |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
bias = tuple(repeat(bias, 2)) |
|
drop_probs = tuple(repeat(drop, 2)) |
|
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear |
|
|
|
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) |
|
self.act = act_layer() |
|
self.drop1 = nn.Dropout(drop_probs[0]) |
|
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() |
|
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) |
|
self.drop2 = nn.Dropout(drop_probs[1]) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.act(x) |
|
x = self.drop1(x) |
|
x = self.norm(x) |
|
x = self.fc2(x) |
|
x = self.drop2(x) |
|
return x |
|
|