|
|
|
import copy |
|
import re |
|
import types |
|
from dataclasses import dataclass, field |
|
from typing import Dict, List, Optional, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from swift.utils import get_logger |
|
from swift.utils.torch_utils import find_sub_module |
|
from .restuning_components import ResTuner, detach_tensors, probe_input_pre_hook, probe_output_hook |
|
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput |
|
|
|
logger = get_logger() |
|
|
|
|
|
@dataclass |
|
class ResTuningConfig(SwiftConfig): |
|
""" |
|
The configuration class for the ResTuning module. |
|
|
|
ResTuning is a flexible parameter-efficient and memory-efficient tuning paradigm framework. |
|
'Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone' |
|
by Jiang et al.(2023) |
|
See |
|
|
|
Args: |
|
dims(`Union[List[int], int]`): The dimensions of the hidden states |
|
root_modules(`str`): The root module to be replaced, can a regex string |
|
root_modules_hook(`str`): The hook type of root modules, can be "input" or "output" |
|
stem_modules(`Union[List[str], str]`): The stem modules to be replaced, |
|
can a regex string or name list of full match format |
|
stem_modules_hook(`Union[List[str], str]`): The hook type of stem modules, can be "input" or "output" |
|
target_modules(`str`): The target module to be replaced, can a regex string |
|
target_modules_hook(`str`): The hook type of target modules, can be "input" or "output" |
|
tuner_cfg(`Union[List[Dict], Dict, str]`): The configuration of the tuning module, |
|
can a string or customized config |
|
use_upsample(bool): Whether to use auxiliary upsample module |
|
upsample_out_channels(List[int]): The channels if `use_upsample` |
|
zero_init_last(bool): Use zero to initialize the last Linear in every sub tuner. |
|
|
|
""" |
|
|
|
dims: Optional[Union[List[int], int]] = field( |
|
default=None, metadata={'help': 'The dimensions of the hidden states'}) |
|
|
|
root_modules: str = field( |
|
default=None, |
|
metadata={ |
|
'help': |
|
'The root module to be replaced, can a regex string (use the first matching module) or full match format' |
|
}) |
|
|
|
root_modules_hook: str = field( |
|
default='input', metadata={'help': 'The hook type of root modules, can be "input" or "output"'}) |
|
|
|
stem_modules: Optional[Union[List[str], str]] = field( |
|
default=None, |
|
metadata={'help': 'The stem modules to be replaced, can a regex string or name list of full match format'}) |
|
|
|
stem_modules_hook: str = field( |
|
default='output', metadata={'help': 'The hook type of stem modules, can be "input" or "output"'}) |
|
|
|
target_modules: str = field( |
|
default=None, |
|
metadata={ |
|
'help': |
|
'The target module to be replaced, can a regex string (use the first matching module) or full match format' |
|
}) |
|
|
|
target_modules_hook: str = field( |
|
default='input', metadata={'help': 'The hook type of target modules, can be "input" or "output"'}) |
|
|
|
target_hidden_pos: Union[int, str] = field( |
|
default=None, metadata={'help': 'The position of the hidden state for target modules output'}) |
|
|
|
tuner_cfg: Optional[Union[List[Dict], Dict, str]] = field( |
|
default=None, metadata={'help': 'The configuration of the tuning module, can a string or customized config'}) |
|
|
|
use_upsample: bool = field(default=False, metadata={'help': 'Whether to use auxiliary upsample module'}) |
|
|
|
upsample_out_channels: List[int] = field( |
|
default=None, metadata={'help': 'The number of output channels when "use_upsample" is set to "True"'}) |
|
|
|
zero_init_last: bool = field(default=False, metadata={'help': 'Zero init last weight'}) |
|
|
|
use_bypass: bool = field(default=True, metadata={'help': 'Whether to use bypass'}) |
|
|
|
def __post_init__(self): |
|
from .mapping import SwiftTuners |
|
self.swift_type = SwiftTuners.RESTUNING |
|
self.target_hidden_pos = 0 if self.target_hidden_pos is None else self.target_hidden_pos |
|
|
|
|
|
class ResTuning(SwiftAdapter): |
|
|
|
@staticmethod |
|
def prepare_model(model: nn.Module, config: ResTuningConfig, adapter_name: str) -> SwiftOutput: |
|
"""Prepare a model with `ResTuningConfig`""" |
|
|
|
def _forward_seq(self, input, *args, **kwargs): |
|
for idx, module in enumerate(self): |
|
if idx >= len(self.origin_module_keys): |
|
continue |
|
input = module(input) |
|
return input |
|
|
|
def _forward_target(self, *args, **kwargs): |
|
if self.target_modules_hook == 'input': |
|
if isinstance(self.target_hidden_pos, int): |
|
args = list(args) |
|
_arg = args[self.target_hidden_pos] |
|
else: |
|
_arg = kwargs[self.target_hidden_pos] |
|
args_main = _forward_restuning(self, _arg) |
|
if isinstance(self.target_hidden_pos, int): |
|
args[self.target_hidden_pos] = args_main |
|
else: |
|
kwargs[self.target_hidden_pos] = args_main |
|
args_main = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs) |
|
else: |
|
_args_main = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs) |
|
_arg = _args_main[self.target_hidden_pos] if isinstance(_args_main, (tuple, list, dict)) else _args_main |
|
args_main = _forward_restuning(self, _arg) |
|
if type(_args_main) != type(args_main): |
|
_args_main[self.target_hidden_pos] = args_main |
|
args_main = _args_main |
|
return args_main |
|
|
|
def _forward_restuning(self, origin_arg): |
|
probe_results = [] |
|
root_module_ins = self.root_module_ins_list[0] |
|
stem_module_ins_list = self.stem_module_ins_list |
|
top_module = model.get_submodule('') |
|
if root_module_ins: |
|
if root_module_ins.root_modules_hook == 'input': |
|
probe_results.append(root_module_ins.probe_input_data) |
|
else: |
|
probe_results.append(root_module_ins.probe_output_data) |
|
for i, st_mod in enumerate(stem_module_ins_list): |
|
if i == 0 and root_module_ins is None: |
|
probe_results.append(st_mod.probe_input_data) |
|
if st_mod.stem_modules_hook == 'input': |
|
probe_results.append(st_mod.probe_input_data) |
|
else: |
|
probe_results.append(st_mod.probe_output_data) |
|
args_main = getattr(top_module, f'restuning_{adapter_name}')(probe_results, origin_arg) |
|
return args_main |
|
|
|
|
|
module_keys = [key for key, _ in model.named_modules()] |
|
root_module_ins_list = [] |
|
if config.root_modules: |
|
for module_key in module_keys: |
|
if re.fullmatch(config.root_modules, module_key): |
|
root_module = model.get_submodule(module_key) |
|
logger.info(f'Matching root module [{module_key}] of type {type(root_module)}') |
|
if isinstance(root_module, (nn.ModuleList, nn.ModuleDict)): |
|
logger.warning( |
|
f'Type of {type(root_module)} may not be supported because of its customized forward') |
|
if config.root_modules_hook == 'input': |
|
root_module.register_forward_pre_hook(probe_input_pre_hook) |
|
else: |
|
root_module.register_forward_hook(probe_output_hook) |
|
root_module.root_modules_hook = config.root_modules_hook |
|
root_module_ins_list.append(root_module) |
|
break |
|
if len(root_module_ins_list) == 0: |
|
logger.error('Cannot match root modules') |
|
|
|
|
|
stem_module_ins_list = [] |
|
stem_module_ins_index = [] |
|
for module_key in module_keys: |
|
if (isinstance(config.stem_modules, str) and re.fullmatch(config.stem_modules, module_key)) or \ |
|
(isinstance(config.stem_modules, list) and module_key in config.stem_modules): |
|
stem_module = model.get_submodule(module_key) |
|
if isinstance(config.stem_modules, list): |
|
stem_module_ins_index.append(config.stem_modules.index(module_key)) |
|
logger.info(f'Matching stem module [{module_key}] of type {type(stem_module)}') |
|
if isinstance(stem_module, (nn.ModuleList, nn.ModuleDict)): |
|
logger.warning( |
|
f'Type of {type(stem_module)} may not be supported because of its customized forward') |
|
if len(root_module_ins_list) == 0 and len(stem_module_ins_list) == 0: |
|
stem_module.register_forward_pre_hook(probe_input_pre_hook) |
|
if config.stem_modules_hook == 'input': |
|
stem_module.register_forward_pre_hook(probe_input_pre_hook) |
|
else: |
|
stem_module.register_forward_hook(probe_output_hook) |
|
stem_module.stem_modules_hook = config.stem_modules_hook |
|
stem_module_ins_list.append(stem_module) |
|
if isinstance(config.stem_modules, list): |
|
stem_module_ins_list = [ |
|
stem_module_ins_list[stem_module_ins_index.index(i)] for i in range(len(stem_module_ins_index)) |
|
] |
|
depth = len(stem_module_ins_list) |
|
if len(stem_module_ins_list) == 0: |
|
raise Exception('Cannot match source modules') |
|
|
|
|
|
if len(stem_module_ins_list) != 0: |
|
top_module = model.get_submodule('') |
|
restuning_module = ResTuningBypassModule(config.dims, depth, adapter_name, config.use_upsample, |
|
config.upsample_out_channels, config.zero_init_last, |
|
config.tuner_cfg) |
|
setattr(top_module, f'restuning_{adapter_name}', restuning_module) |
|
|
|
|
|
target_module_ins = None |
|
for module_key in module_keys: |
|
if re.fullmatch(config.target_modules, module_key): |
|
tgt_module = model.get_submodule(module_key) |
|
logger.info(f'Matching target module [{module_key}] of type {type(tgt_module)}') |
|
if isinstance(tgt_module, (nn.ModuleList, nn.ModuleDict)): |
|
raise Exception( |
|
f'Type of {type(tgt_module)} may not be supported because of its customized forward') |
|
|
|
tgt_module.target_modules_hook = config.target_modules_hook |
|
tgt_module.target_hidden_pos = config.target_hidden_pos |
|
tgt_module.root_module_ins_list = root_module_ins_list |
|
tgt_module.stem_module_ins_list = stem_module_ins_list |
|
target_module_ins = tgt_module |
|
|
|
if isinstance(tgt_module, nn.Sequential) and not hasattr(tgt_module, 'origin_module_keys'): |
|
tgt_module.origin_module_keys = copy.deepcopy(list(tgt_module._modules.keys())) |
|
|
|
setattr(tgt_module, f'forward_origin_{adapter_name}', types.MethodType(_forward_seq, tgt_module)) |
|
else: |
|
setattr(tgt_module, f'forward_origin_{adapter_name}', tgt_module.forward) |
|
tgt_module.forward = types.MethodType(_forward_target, tgt_module) |
|
if target_module_ins is None: |
|
raise Exception('Cannot match target modules') |
|
|
|
def state_dict_callback(state_dict, adapter_name, **kwargs): |
|
return {key: value for key, value in state_dict.items() if f'restuning_{adapter_name}' in key} |
|
|
|
def mark_trainable_callback(model): |
|
return |
|
|
|
return SwiftOutput( |
|
config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback) |
|
|
|
@staticmethod |
|
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): |
|
modules = find_sub_module(module, f'restuning_{adapter_name}') |
|
for _module in modules: |
|
_module: ActivationMixin |
|
_module: nn.Module |
|
_module.set_activation(adapter_name, activate) |
|
SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload) |
|
|
|
|
|
class ResTuningBypassModule(nn.Module, ActivationMixin): |
|
"""The implementation of ResTuningBypass method. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dims, |
|
depth, |
|
adapter_name, |
|
use_upsample=False, |
|
upsample_out_channels=None, |
|
zero_init_last=False, |
|
tuner_cfg=None, |
|
): |
|
super(ResTuningBypassModule, self).__init__() |
|
super(nn.Module, self).__init__('') |
|
self.adapter_name = adapter_name |
|
|
|
self.bypass_blocks = nn.Sequential(*[ |
|
ResTunerBypassBlock( |
|
dim=dims[i] if isinstance(dims, list) else dims, |
|
layer_num=i, |
|
depth=depth, |
|
use_upsample=use_upsample, |
|
upsample_out_channels=upsample_out_channels[i] if isinstance(upsample_out_channels, list |
|
) else upsample_out_channels, |
|
zero_init_last=zero_init_last, |
|
tuner_cfg=tuner_cfg[i] if isinstance(tuner_cfg, list) else tuner_cfg) for i in range(depth) |
|
]) |
|
self.mark_all_sub_modules_as_plugin() |
|
|
|
def forward(self, x_list, origin_arg, **kwargs): |
|
if not self.is_activated(self.adapter_name): |
|
return origin_arg |
|
x_bypass = detach_tensors(x_list.pop(0)) |
|
x_bypass = x_bypass[0] if isinstance(x_bypass, (list, tuple)) else x_bypass |
|
x_list = detach_tensors(x_list) |
|
x_list = [_x[0] if isinstance(_x, (list, tuple)) else _x for _x in x_list] |
|
for i, (bp_blk, x_stem) in enumerate(zip(self.bypass_blocks, x_list)): |
|
target_size = x_list[i + 1].shape[2:] if i < len(x_list) - 1 else None |
|
x_bypass = bp_blk(x_stem, x_bypass, target_size, **kwargs) |
|
return x_bypass |
|
|
|
|
|
class ResTunerBypassBlock(nn.Module): |
|
|
|
def __init__(self, dim, layer_num=-1, depth=-1, use_upsample=False, zero_init_last=False, tuner_cfg=None, **kwargs): |
|
super().__init__() |
|
self.layer_num = layer_num |
|
self.depth = depth |
|
|
|
if isinstance(tuner_cfg, str): |
|
lateral_cfg = tuner_cfg |
|
vertical_cfg = tuner_cfg |
|
aux_cfg = 'upsample' if use_upsample and layer_num != depth - 1 else None |
|
elif isinstance(tuner_cfg, dict): |
|
lateral_cfg = tuner_cfg['lateral_cfg'] if 'lateral_cfg' in tuner_cfg else None |
|
vertical_cfg = tuner_cfg['vertical_cfg'] if 'vertical_cfg' in tuner_cfg else None |
|
aux_cfg = tuner_cfg['aux_cfg'] if 'aux_cfg' in tuner_cfg else None |
|
|
|
self.lateral_tuner = ResTuner(dim, layer_num, depth, zero_init_last, 'lateral', lateral_cfg, **kwargs) |
|
self.vertical_tuner = ResTuner(dim, layer_num, depth, zero_init_last, 'vertical', vertical_cfg, **kwargs) |
|
if aux_cfg and len(aux_cfg) != 0: |
|
self.aux_tuner = ResTuner(dim, layer_num, depth, zero_init_last, 'aux', aux_cfg, **kwargs) |
|
|
|
def forward(self, x_stem, x_bypass, target_size=None, **kwargs): |
|
x_lateral = self.lateral_tuner(x_stem) |
|
x_vertical = self.vertical_tuner(x_bypass) |
|
|
|
x_bypass_out = x_lateral + x_vertical |
|
if hasattr(self, 'aux_tuner'): |
|
x_bypass_out = self.aux_tuner(x_bypass_out, target_size) |
|
return x_bypass_out |
|
|