sparse / ms-swift /swift /tuners /restuning.py
Enxin's picture
Upload folder using huggingface_hub
96fe658 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import re
import types
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from swift.utils import get_logger
from swift.utils.torch_utils import find_sub_module
from .restuning_components import ResTuner, detach_tensors, probe_input_pre_hook, probe_output_hook
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
logger = get_logger()
@dataclass
class ResTuningConfig(SwiftConfig):
"""
The configuration class for the ResTuning module.
ResTuning is a flexible parameter-efficient and memory-efficient tuning paradigm framework.
'Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone'
by Jiang et al.(2023)
See
Args:
dims(`Union[List[int], int]`): The dimensions of the hidden states
root_modules(`str`): The root module to be replaced, can a regex string
root_modules_hook(`str`): The hook type of root modules, can be "input" or "output"
stem_modules(`Union[List[str], str]`): The stem modules to be replaced,
can a regex string or name list of full match format
stem_modules_hook(`Union[List[str], str]`): The hook type of stem modules, can be "input" or "output"
target_modules(`str`): The target module to be replaced, can a regex string
target_modules_hook(`str`): The hook type of target modules, can be "input" or "output"
tuner_cfg(`Union[List[Dict], Dict, str]`): The configuration of the tuning module,
can a string or customized config
use_upsample(bool): Whether to use auxiliary upsample module
upsample_out_channels(List[int]): The channels if `use_upsample`
zero_init_last(bool): Use zero to initialize the last Linear in every sub tuner.
"""
dims: Optional[Union[List[int], int]] = field(
default=None, metadata={'help': 'The dimensions of the hidden states'})
root_modules: str = field(
default=None,
metadata={
'help':
'The root module to be replaced, can a regex string (use the first matching module) or full match format'
})
root_modules_hook: str = field(
default='input', metadata={'help': 'The hook type of root modules, can be "input" or "output"'})
stem_modules: Optional[Union[List[str], str]] = field(
default=None,
metadata={'help': 'The stem modules to be replaced, can a regex string or name list of full match format'})
stem_modules_hook: str = field(
default='output', metadata={'help': 'The hook type of stem modules, can be "input" or "output"'})
target_modules: str = field(
default=None,
metadata={
'help':
'The target module to be replaced, can a regex string (use the first matching module) or full match format'
})
target_modules_hook: str = field(
default='input', metadata={'help': 'The hook type of target modules, can be "input" or "output"'})
target_hidden_pos: Union[int, str] = field(
default=None, metadata={'help': 'The position of the hidden state for target modules output'})
tuner_cfg: Optional[Union[List[Dict], Dict, str]] = field(
default=None, metadata={'help': 'The configuration of the tuning module, can a string or customized config'})
use_upsample: bool = field(default=False, metadata={'help': 'Whether to use auxiliary upsample module'})
upsample_out_channels: List[int] = field(
default=None, metadata={'help': 'The number of output channels when "use_upsample" is set to "True"'})
zero_init_last: bool = field(default=False, metadata={'help': 'Zero init last weight'})
use_bypass: bool = field(default=True, metadata={'help': 'Whether to use bypass'})
def __post_init__(self):
from .mapping import SwiftTuners
self.swift_type = SwiftTuners.RESTUNING
self.target_hidden_pos = 0 if self.target_hidden_pos is None else self.target_hidden_pos
class ResTuning(SwiftAdapter):
@staticmethod
def prepare_model(model: nn.Module, config: ResTuningConfig, adapter_name: str) -> SwiftOutput:
"""Prepare a model with `ResTuningConfig`"""
def _forward_seq(self, input, *args, **kwargs):
for idx, module in enumerate(self):
if idx >= len(self.origin_module_keys):
continue
input = module(input)
return input
def _forward_target(self, *args, **kwargs):
if self.target_modules_hook == 'input':
if isinstance(self.target_hidden_pos, int):
args = list(args)
_arg = args[self.target_hidden_pos]
else:
_arg = kwargs[self.target_hidden_pos]
args_main = _forward_restuning(self, _arg)
if isinstance(self.target_hidden_pos, int):
args[self.target_hidden_pos] = args_main
else:
kwargs[self.target_hidden_pos] = args_main
args_main = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
else:
_args_main = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
_arg = _args_main[self.target_hidden_pos] if isinstance(_args_main, (tuple, list, dict)) else _args_main
args_main = _forward_restuning(self, _arg)
if type(_args_main) != type(args_main):
_args_main[self.target_hidden_pos] = args_main
args_main = _args_main
return args_main
def _forward_restuning(self, origin_arg):
probe_results = []
root_module_ins = self.root_module_ins_list[0]
stem_module_ins_list = self.stem_module_ins_list
top_module = model.get_submodule('')
if root_module_ins:
if root_module_ins.root_modules_hook == 'input':
probe_results.append(root_module_ins.probe_input_data)
else:
probe_results.append(root_module_ins.probe_output_data)
for i, st_mod in enumerate(stem_module_ins_list):
if i == 0 and root_module_ins is None:
probe_results.append(st_mod.probe_input_data)
if st_mod.stem_modules_hook == 'input':
probe_results.append(st_mod.probe_input_data)
else:
probe_results.append(st_mod.probe_output_data)
args_main = getattr(top_module, f'restuning_{adapter_name}')(probe_results, origin_arg)
return args_main
# 1. Matching the root module
module_keys = [key for key, _ in model.named_modules()]
root_module_ins_list = []
if config.root_modules:
for module_key in module_keys:
if re.fullmatch(config.root_modules, module_key):
root_module = model.get_submodule(module_key)
logger.info(f'Matching root module [{module_key}] of type {type(root_module)}')
if isinstance(root_module, (nn.ModuleList, nn.ModuleDict)):
logger.warning(
f'Type of {type(root_module)} may not be supported because of its customized forward')
if config.root_modules_hook == 'input':
root_module.register_forward_pre_hook(probe_input_pre_hook)
else:
root_module.register_forward_hook(probe_output_hook)
root_module.root_modules_hook = config.root_modules_hook
root_module_ins_list.append(root_module)
break
if len(root_module_ins_list) == 0:
logger.error('Cannot match root modules')
# 2. Matching the stem module
stem_module_ins_list = []
stem_module_ins_index = []
for module_key in module_keys:
if (isinstance(config.stem_modules, str) and re.fullmatch(config.stem_modules, module_key)) or \
(isinstance(config.stem_modules, list) and module_key in config.stem_modules):
stem_module = model.get_submodule(module_key)
if isinstance(config.stem_modules, list):
stem_module_ins_index.append(config.stem_modules.index(module_key))
logger.info(f'Matching stem module [{module_key}] of type {type(stem_module)}')
if isinstance(stem_module, (nn.ModuleList, nn.ModuleDict)):
logger.warning(
f'Type of {type(stem_module)} may not be supported because of its customized forward')
if len(root_module_ins_list) == 0 and len(stem_module_ins_list) == 0:
stem_module.register_forward_pre_hook(probe_input_pre_hook)
if config.stem_modules_hook == 'input':
stem_module.register_forward_pre_hook(probe_input_pre_hook)
else:
stem_module.register_forward_hook(probe_output_hook)
stem_module.stem_modules_hook = config.stem_modules_hook
stem_module_ins_list.append(stem_module)
if isinstance(config.stem_modules, list):
stem_module_ins_list = [
stem_module_ins_list[stem_module_ins_index.index(i)] for i in range(len(stem_module_ins_index))
]
depth = len(stem_module_ins_list)
if len(stem_module_ins_list) == 0:
raise Exception('Cannot match source modules')
# 3. Init restuning module
if len(stem_module_ins_list) != 0:
top_module = model.get_submodule('')
restuning_module = ResTuningBypassModule(config.dims, depth, adapter_name, config.use_upsample,
config.upsample_out_channels, config.zero_init_last,
config.tuner_cfg)
setattr(top_module, f'restuning_{adapter_name}', restuning_module)
# 4. Matching the target module
target_module_ins = None
for module_key in module_keys:
if re.fullmatch(config.target_modules, module_key):
tgt_module = model.get_submodule(module_key)
logger.info(f'Matching target module [{module_key}] of type {type(tgt_module)}')
if isinstance(tgt_module, (nn.ModuleList, nn.ModuleDict)):
raise Exception(
f'Type of {type(tgt_module)} may not be supported because of its customized forward')
tgt_module.target_modules_hook = config.target_modules_hook
tgt_module.target_hidden_pos = config.target_hidden_pos
tgt_module.root_module_ins_list = root_module_ins_list
tgt_module.stem_module_ins_list = stem_module_ins_list
target_module_ins = tgt_module
if isinstance(tgt_module, nn.Sequential) and not hasattr(tgt_module, 'origin_module_keys'):
tgt_module.origin_module_keys = copy.deepcopy(list(tgt_module._modules.keys()))
setattr(tgt_module, f'forward_origin_{adapter_name}', types.MethodType(_forward_seq, tgt_module))
else:
setattr(tgt_module, f'forward_origin_{adapter_name}', tgt_module.forward)
tgt_module.forward = types.MethodType(_forward_target, tgt_module)
if target_module_ins is None:
raise Exception('Cannot match target modules')
def state_dict_callback(state_dict, adapter_name, **kwargs):
return {key: value for key, value in state_dict.items() if f'restuning_{adapter_name}' in key}
def mark_trainable_callback(model):
return
return SwiftOutput(
config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
@staticmethod
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
modules = find_sub_module(module, f'restuning_{adapter_name}')
for _module in modules:
_module: ActivationMixin
_module: nn.Module
_module.set_activation(adapter_name, activate)
SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload)
class ResTuningBypassModule(nn.Module, ActivationMixin):
"""The implementation of ResTuningBypass method.
"""
def __init__(
self,
dims,
depth,
adapter_name,
use_upsample=False,
upsample_out_channels=None,
zero_init_last=False,
tuner_cfg=None,
):
super(ResTuningBypassModule, self).__init__()
super(nn.Module, self).__init__('')
self.adapter_name = adapter_name
self.bypass_blocks = nn.Sequential(*[
ResTunerBypassBlock(
dim=dims[i] if isinstance(dims, list) else dims,
layer_num=i,
depth=depth,
use_upsample=use_upsample,
upsample_out_channels=upsample_out_channels[i] if isinstance(upsample_out_channels, list
) else upsample_out_channels,
zero_init_last=zero_init_last,
tuner_cfg=tuner_cfg[i] if isinstance(tuner_cfg, list) else tuner_cfg) for i in range(depth)
])
self.mark_all_sub_modules_as_plugin()
def forward(self, x_list, origin_arg, **kwargs):
if not self.is_activated(self.adapter_name):
return origin_arg
x_bypass = detach_tensors(x_list.pop(0))
x_bypass = x_bypass[0] if isinstance(x_bypass, (list, tuple)) else x_bypass
x_list = detach_tensors(x_list)
x_list = [_x[0] if isinstance(_x, (list, tuple)) else _x for _x in x_list]
for i, (bp_blk, x_stem) in enumerate(zip(self.bypass_blocks, x_list)):
target_size = x_list[i + 1].shape[2:] if i < len(x_list) - 1 else None
x_bypass = bp_blk(x_stem, x_bypass, target_size, **kwargs)
return x_bypass
class ResTunerBypassBlock(nn.Module):
def __init__(self, dim, layer_num=-1, depth=-1, use_upsample=False, zero_init_last=False, tuner_cfg=None, **kwargs):
super().__init__()
self.layer_num = layer_num
self.depth = depth
if isinstance(tuner_cfg, str):
lateral_cfg = tuner_cfg
vertical_cfg = tuner_cfg
aux_cfg = 'upsample' if use_upsample and layer_num != depth - 1 else None
elif isinstance(tuner_cfg, dict):
lateral_cfg = tuner_cfg['lateral_cfg'] if 'lateral_cfg' in tuner_cfg else None
vertical_cfg = tuner_cfg['vertical_cfg'] if 'vertical_cfg' in tuner_cfg else None
aux_cfg = tuner_cfg['aux_cfg'] if 'aux_cfg' in tuner_cfg else None
self.lateral_tuner = ResTuner(dim, layer_num, depth, zero_init_last, 'lateral', lateral_cfg, **kwargs)
self.vertical_tuner = ResTuner(dim, layer_num, depth, zero_init_last, 'vertical', vertical_cfg, **kwargs)
if aux_cfg and len(aux_cfg) != 0:
self.aux_tuner = ResTuner(dim, layer_num, depth, zero_init_last, 'aux', aux_cfg, **kwargs)
def forward(self, x_stem, x_bypass, target_size=None, **kwargs):
x_lateral = self.lateral_tuner(x_stem)
x_vertical = self.vertical_tuner(x_bypass)
x_bypass_out = x_lateral + x_vertical
if hasattr(self, 'aux_tuner'):
x_bypass_out = self.aux_tuner(x_bypass_out, target_size)
return x_bypass_out