File size: 15,874 Bytes
96fe658 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 |
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import re
import types
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from swift.utils import get_logger
from swift.utils.torch_utils import find_sub_module
from .restuning_components import ResTuner, detach_tensors, probe_input_pre_hook, probe_output_hook
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
logger = get_logger()
@dataclass
class ResTuningConfig(SwiftConfig):
"""
The configuration class for the ResTuning module.
ResTuning is a flexible parameter-efficient and memory-efficient tuning paradigm framework.
'Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone'
by Jiang et al.(2023)
See
Args:
dims(`Union[List[int], int]`): The dimensions of the hidden states
root_modules(`str`): The root module to be replaced, can a regex string
root_modules_hook(`str`): The hook type of root modules, can be "input" or "output"
stem_modules(`Union[List[str], str]`): The stem modules to be replaced,
can a regex string or name list of full match format
stem_modules_hook(`Union[List[str], str]`): The hook type of stem modules, can be "input" or "output"
target_modules(`str`): The target module to be replaced, can a regex string
target_modules_hook(`str`): The hook type of target modules, can be "input" or "output"
tuner_cfg(`Union[List[Dict], Dict, str]`): The configuration of the tuning module,
can a string or customized config
use_upsample(bool): Whether to use auxiliary upsample module
upsample_out_channels(List[int]): The channels if `use_upsample`
zero_init_last(bool): Use zero to initialize the last Linear in every sub tuner.
"""
dims: Optional[Union[List[int], int]] = field(
default=None, metadata={'help': 'The dimensions of the hidden states'})
root_modules: str = field(
default=None,
metadata={
'help':
'The root module to be replaced, can a regex string (use the first matching module) or full match format'
})
root_modules_hook: str = field(
default='input', metadata={'help': 'The hook type of root modules, can be "input" or "output"'})
stem_modules: Optional[Union[List[str], str]] = field(
default=None,
metadata={'help': 'The stem modules to be replaced, can a regex string or name list of full match format'})
stem_modules_hook: str = field(
default='output', metadata={'help': 'The hook type of stem modules, can be "input" or "output"'})
target_modules: str = field(
default=None,
metadata={
'help':
'The target module to be replaced, can a regex string (use the first matching module) or full match format'
})
target_modules_hook: str = field(
default='input', metadata={'help': 'The hook type of target modules, can be "input" or "output"'})
target_hidden_pos: Union[int, str] = field(
default=None, metadata={'help': 'The position of the hidden state for target modules output'})
tuner_cfg: Optional[Union[List[Dict], Dict, str]] = field(
default=None, metadata={'help': 'The configuration of the tuning module, can a string or customized config'})
use_upsample: bool = field(default=False, metadata={'help': 'Whether to use auxiliary upsample module'})
upsample_out_channels: List[int] = field(
default=None, metadata={'help': 'The number of output channels when "use_upsample" is set to "True"'})
zero_init_last: bool = field(default=False, metadata={'help': 'Zero init last weight'})
use_bypass: bool = field(default=True, metadata={'help': 'Whether to use bypass'})
def __post_init__(self):
from .mapping import SwiftTuners
self.swift_type = SwiftTuners.RESTUNING
self.target_hidden_pos = 0 if self.target_hidden_pos is None else self.target_hidden_pos
class ResTuning(SwiftAdapter):
@staticmethod
def prepare_model(model: nn.Module, config: ResTuningConfig, adapter_name: str) -> SwiftOutput:
"""Prepare a model with `ResTuningConfig`"""
def _forward_seq(self, input, *args, **kwargs):
for idx, module in enumerate(self):
if idx >= len(self.origin_module_keys):
continue
input = module(input)
return input
def _forward_target(self, *args, **kwargs):
if self.target_modules_hook == 'input':
if isinstance(self.target_hidden_pos, int):
args = list(args)
_arg = args[self.target_hidden_pos]
else:
_arg = kwargs[self.target_hidden_pos]
args_main = _forward_restuning(self, _arg)
if isinstance(self.target_hidden_pos, int):
args[self.target_hidden_pos] = args_main
else:
kwargs[self.target_hidden_pos] = args_main
args_main = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
else:
_args_main = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
_arg = _args_main[self.target_hidden_pos] if isinstance(_args_main, (tuple, list, dict)) else _args_main
args_main = _forward_restuning(self, _arg)
if type(_args_main) != type(args_main):
_args_main[self.target_hidden_pos] = args_main
args_main = _args_main
return args_main
def _forward_restuning(self, origin_arg):
probe_results = []
root_module_ins = self.root_module_ins_list[0]
stem_module_ins_list = self.stem_module_ins_list
top_module = model.get_submodule('')
if root_module_ins:
if root_module_ins.root_modules_hook == 'input':
probe_results.append(root_module_ins.probe_input_data)
else:
probe_results.append(root_module_ins.probe_output_data)
for i, st_mod in enumerate(stem_module_ins_list):
if i == 0 and root_module_ins is None:
probe_results.append(st_mod.probe_input_data)
if st_mod.stem_modules_hook == 'input':
probe_results.append(st_mod.probe_input_data)
else:
probe_results.append(st_mod.probe_output_data)
args_main = getattr(top_module, f'restuning_{adapter_name}')(probe_results, origin_arg)
return args_main
# 1. Matching the root module
module_keys = [key for key, _ in model.named_modules()]
root_module_ins_list = []
if config.root_modules:
for module_key in module_keys:
if re.fullmatch(config.root_modules, module_key):
root_module = model.get_submodule(module_key)
logger.info(f'Matching root module [{module_key}] of type {type(root_module)}')
if isinstance(root_module, (nn.ModuleList, nn.ModuleDict)):
logger.warning(
f'Type of {type(root_module)} may not be supported because of its customized forward')
if config.root_modules_hook == 'input':
root_module.register_forward_pre_hook(probe_input_pre_hook)
else:
root_module.register_forward_hook(probe_output_hook)
root_module.root_modules_hook = config.root_modules_hook
root_module_ins_list.append(root_module)
break
if len(root_module_ins_list) == 0:
logger.error('Cannot match root modules')
# 2. Matching the stem module
stem_module_ins_list = []
stem_module_ins_index = []
for module_key in module_keys:
if (isinstance(config.stem_modules, str) and re.fullmatch(config.stem_modules, module_key)) or \
(isinstance(config.stem_modules, list) and module_key in config.stem_modules):
stem_module = model.get_submodule(module_key)
if isinstance(config.stem_modules, list):
stem_module_ins_index.append(config.stem_modules.index(module_key))
logger.info(f'Matching stem module [{module_key}] of type {type(stem_module)}')
if isinstance(stem_module, (nn.ModuleList, nn.ModuleDict)):
logger.warning(
f'Type of {type(stem_module)} may not be supported because of its customized forward')
if len(root_module_ins_list) == 0 and len(stem_module_ins_list) == 0:
stem_module.register_forward_pre_hook(probe_input_pre_hook)
if config.stem_modules_hook == 'input':
stem_module.register_forward_pre_hook(probe_input_pre_hook)
else:
stem_module.register_forward_hook(probe_output_hook)
stem_module.stem_modules_hook = config.stem_modules_hook
stem_module_ins_list.append(stem_module)
if isinstance(config.stem_modules, list):
stem_module_ins_list = [
stem_module_ins_list[stem_module_ins_index.index(i)] for i in range(len(stem_module_ins_index))
]
depth = len(stem_module_ins_list)
if len(stem_module_ins_list) == 0:
raise Exception('Cannot match source modules')
# 3. Init restuning module
if len(stem_module_ins_list) != 0:
top_module = model.get_submodule('')
restuning_module = ResTuningBypassModule(config.dims, depth, adapter_name, config.use_upsample,
config.upsample_out_channels, config.zero_init_last,
config.tuner_cfg)
setattr(top_module, f'restuning_{adapter_name}', restuning_module)
# 4. Matching the target module
target_module_ins = None
for module_key in module_keys:
if re.fullmatch(config.target_modules, module_key):
tgt_module = model.get_submodule(module_key)
logger.info(f'Matching target module [{module_key}] of type {type(tgt_module)}')
if isinstance(tgt_module, (nn.ModuleList, nn.ModuleDict)):
raise Exception(
f'Type of {type(tgt_module)} may not be supported because of its customized forward')
tgt_module.target_modules_hook = config.target_modules_hook
tgt_module.target_hidden_pos = config.target_hidden_pos
tgt_module.root_module_ins_list = root_module_ins_list
tgt_module.stem_module_ins_list = stem_module_ins_list
target_module_ins = tgt_module
if isinstance(tgt_module, nn.Sequential) and not hasattr(tgt_module, 'origin_module_keys'):
tgt_module.origin_module_keys = copy.deepcopy(list(tgt_module._modules.keys()))
setattr(tgt_module, f'forward_origin_{adapter_name}', types.MethodType(_forward_seq, tgt_module))
else:
setattr(tgt_module, f'forward_origin_{adapter_name}', tgt_module.forward)
tgt_module.forward = types.MethodType(_forward_target, tgt_module)
if target_module_ins is None:
raise Exception('Cannot match target modules')
def state_dict_callback(state_dict, adapter_name, **kwargs):
return {key: value for key, value in state_dict.items() if f'restuning_{adapter_name}' in key}
def mark_trainable_callback(model):
return
return SwiftOutput(
config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
@staticmethod
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
modules = find_sub_module(module, f'restuning_{adapter_name}')
for _module in modules:
_module: ActivationMixin
_module: nn.Module
_module.set_activation(adapter_name, activate)
SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload)
class ResTuningBypassModule(nn.Module, ActivationMixin):
"""The implementation of ResTuningBypass method.
"""
def __init__(
self,
dims,
depth,
adapter_name,
use_upsample=False,
upsample_out_channels=None,
zero_init_last=False,
tuner_cfg=None,
):
super(ResTuningBypassModule, self).__init__()
super(nn.Module, self).__init__('')
self.adapter_name = adapter_name
self.bypass_blocks = nn.Sequential(*[
ResTunerBypassBlock(
dim=dims[i] if isinstance(dims, list) else dims,
layer_num=i,
depth=depth,
use_upsample=use_upsample,
upsample_out_channels=upsample_out_channels[i] if isinstance(upsample_out_channels, list
) else upsample_out_channels,
zero_init_last=zero_init_last,
tuner_cfg=tuner_cfg[i] if isinstance(tuner_cfg, list) else tuner_cfg) for i in range(depth)
])
self.mark_all_sub_modules_as_plugin()
def forward(self, x_list, origin_arg, **kwargs):
if not self.is_activated(self.adapter_name):
return origin_arg
x_bypass = detach_tensors(x_list.pop(0))
x_bypass = x_bypass[0] if isinstance(x_bypass, (list, tuple)) else x_bypass
x_list = detach_tensors(x_list)
x_list = [_x[0] if isinstance(_x, (list, tuple)) else _x for _x in x_list]
for i, (bp_blk, x_stem) in enumerate(zip(self.bypass_blocks, x_list)):
target_size = x_list[i + 1].shape[2:] if i < len(x_list) - 1 else None
x_bypass = bp_blk(x_stem, x_bypass, target_size, **kwargs)
return x_bypass
class ResTunerBypassBlock(nn.Module):
def __init__(self, dim, layer_num=-1, depth=-1, use_upsample=False, zero_init_last=False, tuner_cfg=None, **kwargs):
super().__init__()
self.layer_num = layer_num
self.depth = depth
if isinstance(tuner_cfg, str):
lateral_cfg = tuner_cfg
vertical_cfg = tuner_cfg
aux_cfg = 'upsample' if use_upsample and layer_num != depth - 1 else None
elif isinstance(tuner_cfg, dict):
lateral_cfg = tuner_cfg['lateral_cfg'] if 'lateral_cfg' in tuner_cfg else None
vertical_cfg = tuner_cfg['vertical_cfg'] if 'vertical_cfg' in tuner_cfg else None
aux_cfg = tuner_cfg['aux_cfg'] if 'aux_cfg' in tuner_cfg else None
self.lateral_tuner = ResTuner(dim, layer_num, depth, zero_init_last, 'lateral', lateral_cfg, **kwargs)
self.vertical_tuner = ResTuner(dim, layer_num, depth, zero_init_last, 'vertical', vertical_cfg, **kwargs)
if aux_cfg and len(aux_cfg) != 0:
self.aux_tuner = ResTuner(dim, layer_num, depth, zero_init_last, 'aux', aux_cfg, **kwargs)
def forward(self, x_stem, x_bypass, target_size=None, **kwargs):
x_lateral = self.lateral_tuner(x_stem)
x_vertical = self.vertical_tuner(x_bypass)
x_bypass_out = x_lateral + x_vertical
if hasattr(self, 'aux_tuner'):
x_bypass_out = self.aux_tuner(x_bypass_out, target_size)
return x_bypass_out
|