File size: 15,874 Bytes
96fe658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import re
import types
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union

import torch
import torch.nn as nn

from swift.utils import get_logger
from swift.utils.torch_utils import find_sub_module
from .restuning_components import ResTuner, detach_tensors, probe_input_pre_hook, probe_output_hook
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput

logger = get_logger()


@dataclass
class ResTuningConfig(SwiftConfig):
    """
    The configuration class for the ResTuning module.

    ResTuning is a flexible parameter-efficient and memory-efficient tuning paradigm framework.
    'Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone'
    by Jiang et al.(2023)
    See

    Args:
        dims(`Union[List[int], int]`): The dimensions of the hidden states
        root_modules(`str`): The root module to be replaced, can a regex string
        root_modules_hook(`str`): The hook type of root modules, can be "input" or "output"
        stem_modules(`Union[List[str], str]`): The stem modules to be replaced,
            can a regex string or name list of full match format
        stem_modules_hook(`Union[List[str], str]`): The hook type of stem modules, can be "input" or "output"
        target_modules(`str`): The target module to be replaced, can a regex string
        target_modules_hook(`str`): The hook type of target modules, can be "input" or "output"
        tuner_cfg(`Union[List[Dict], Dict, str]`): The configuration of the tuning module,
            can a string or customized config
        use_upsample(bool): Whether to use auxiliary upsample module
        upsample_out_channels(List[int]): The channels if `use_upsample`
        zero_init_last(bool): Use zero to initialize the last Linear in every sub tuner.

    """

    dims: Optional[Union[List[int], int]] = field(
        default=None, metadata={'help': 'The dimensions of the hidden states'})

    root_modules: str = field(
        default=None,
        metadata={
            'help':
            'The root module to be replaced, can a regex string (use the first matching module) or full match format'
        })

    root_modules_hook: str = field(
        default='input', metadata={'help': 'The hook type of root modules, can be "input" or "output"'})

    stem_modules: Optional[Union[List[str], str]] = field(
        default=None,
        metadata={'help': 'The stem modules to be replaced, can a regex string or name list of full match format'})

    stem_modules_hook: str = field(
        default='output', metadata={'help': 'The hook type of stem modules, can be "input" or "output"'})

    target_modules: str = field(
        default=None,
        metadata={
            'help':
            'The target module to be replaced, can a regex string (use the first matching module) or full match format'
        })

    target_modules_hook: str = field(
        default='input', metadata={'help': 'The hook type of target modules, can be "input" or "output"'})

    target_hidden_pos: Union[int, str] = field(
        default=None, metadata={'help': 'The position of the hidden state for target modules output'})

    tuner_cfg: Optional[Union[List[Dict], Dict, str]] = field(
        default=None, metadata={'help': 'The configuration of the tuning module, can a string or customized config'})

    use_upsample: bool = field(default=False, metadata={'help': 'Whether to use auxiliary upsample module'})

    upsample_out_channels: List[int] = field(
        default=None, metadata={'help': 'The number of output channels when "use_upsample" is set to "True"'})

    zero_init_last: bool = field(default=False, metadata={'help': 'Zero init last weight'})

    use_bypass: bool = field(default=True, metadata={'help': 'Whether to use bypass'})

    def __post_init__(self):
        from .mapping import SwiftTuners
        self.swift_type = SwiftTuners.RESTUNING
        self.target_hidden_pos = 0 if self.target_hidden_pos is None else self.target_hidden_pos


class ResTuning(SwiftAdapter):

    @staticmethod
    def prepare_model(model: nn.Module, config: ResTuningConfig, adapter_name: str) -> SwiftOutput:
        """Prepare a model with `ResTuningConfig`"""

        def _forward_seq(self, input, *args, **kwargs):
            for idx, module in enumerate(self):
                if idx >= len(self.origin_module_keys):
                    continue
                input = module(input)
            return input

        def _forward_target(self, *args, **kwargs):
            if self.target_modules_hook == 'input':
                if isinstance(self.target_hidden_pos, int):
                    args = list(args)
                    _arg = args[self.target_hidden_pos]
                else:
                    _arg = kwargs[self.target_hidden_pos]
                args_main = _forward_restuning(self, _arg)
                if isinstance(self.target_hidden_pos, int):
                    args[self.target_hidden_pos] = args_main
                else:
                    kwargs[self.target_hidden_pos] = args_main
                args_main = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
            else:
                _args_main = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
                _arg = _args_main[self.target_hidden_pos] if isinstance(_args_main, (tuple, list, dict)) else _args_main
                args_main = _forward_restuning(self, _arg)
                if type(_args_main) != type(args_main):
                    _args_main[self.target_hidden_pos] = args_main
                    args_main = _args_main
            return args_main

        def _forward_restuning(self, origin_arg):
            probe_results = []
            root_module_ins = self.root_module_ins_list[0]
            stem_module_ins_list = self.stem_module_ins_list
            top_module = model.get_submodule('')
            if root_module_ins:
                if root_module_ins.root_modules_hook == 'input':
                    probe_results.append(root_module_ins.probe_input_data)
                else:
                    probe_results.append(root_module_ins.probe_output_data)
            for i, st_mod in enumerate(stem_module_ins_list):
                if i == 0 and root_module_ins is None:
                    probe_results.append(st_mod.probe_input_data)
                if st_mod.stem_modules_hook == 'input':
                    probe_results.append(st_mod.probe_input_data)
                else:
                    probe_results.append(st_mod.probe_output_data)
            args_main = getattr(top_module, f'restuning_{adapter_name}')(probe_results, origin_arg)
            return args_main

        # 1. Matching the root module
        module_keys = [key for key, _ in model.named_modules()]
        root_module_ins_list = []
        if config.root_modules:
            for module_key in module_keys:
                if re.fullmatch(config.root_modules, module_key):
                    root_module = model.get_submodule(module_key)
                    logger.info(f'Matching root module [{module_key}] of type {type(root_module)}')
                    if isinstance(root_module, (nn.ModuleList, nn.ModuleDict)):
                        logger.warning(
                            f'Type of {type(root_module)} may not be supported because of its customized forward')
                    if config.root_modules_hook == 'input':
                        root_module.register_forward_pre_hook(probe_input_pre_hook)
                    else:
                        root_module.register_forward_hook(probe_output_hook)
                    root_module.root_modules_hook = config.root_modules_hook
                    root_module_ins_list.append(root_module)
                    break
            if len(root_module_ins_list) == 0:
                logger.error('Cannot match root modules')

        # 2. Matching the stem module
        stem_module_ins_list = []
        stem_module_ins_index = []
        for module_key in module_keys:
            if (isinstance(config.stem_modules, str) and re.fullmatch(config.stem_modules, module_key)) or \
                    (isinstance(config.stem_modules, list) and module_key in config.stem_modules):
                stem_module = model.get_submodule(module_key)
                if isinstance(config.stem_modules, list):
                    stem_module_ins_index.append(config.stem_modules.index(module_key))
                logger.info(f'Matching stem module [{module_key}] of type {type(stem_module)}')
                if isinstance(stem_module, (nn.ModuleList, nn.ModuleDict)):
                    logger.warning(
                        f'Type of {type(stem_module)} may not be supported because of its customized forward')
                if len(root_module_ins_list) == 0 and len(stem_module_ins_list) == 0:
                    stem_module.register_forward_pre_hook(probe_input_pre_hook)
                if config.stem_modules_hook == 'input':
                    stem_module.register_forward_pre_hook(probe_input_pre_hook)
                else:
                    stem_module.register_forward_hook(probe_output_hook)
                stem_module.stem_modules_hook = config.stem_modules_hook
                stem_module_ins_list.append(stem_module)
        if isinstance(config.stem_modules, list):
            stem_module_ins_list = [
                stem_module_ins_list[stem_module_ins_index.index(i)] for i in range(len(stem_module_ins_index))
            ]
        depth = len(stem_module_ins_list)
        if len(stem_module_ins_list) == 0:
            raise Exception('Cannot match source modules')

        # 3. Init restuning module
        if len(stem_module_ins_list) != 0:
            top_module = model.get_submodule('')
            restuning_module = ResTuningBypassModule(config.dims, depth, adapter_name, config.use_upsample,
                                                     config.upsample_out_channels, config.zero_init_last,
                                                     config.tuner_cfg)
            setattr(top_module, f'restuning_{adapter_name}', restuning_module)

        # 4. Matching the target module
        target_module_ins = None
        for module_key in module_keys:
            if re.fullmatch(config.target_modules, module_key):
                tgt_module = model.get_submodule(module_key)
                logger.info(f'Matching target module [{module_key}] of type {type(tgt_module)}')
                if isinstance(tgt_module, (nn.ModuleList, nn.ModuleDict)):
                    raise Exception(
                        f'Type of {type(tgt_module)} may not be supported because of its customized forward')

                tgt_module.target_modules_hook = config.target_modules_hook
                tgt_module.target_hidden_pos = config.target_hidden_pos
                tgt_module.root_module_ins_list = root_module_ins_list
                tgt_module.stem_module_ins_list = stem_module_ins_list
                target_module_ins = tgt_module

                if isinstance(tgt_module, nn.Sequential) and not hasattr(tgt_module, 'origin_module_keys'):
                    tgt_module.origin_module_keys = copy.deepcopy(list(tgt_module._modules.keys()))

                    setattr(tgt_module, f'forward_origin_{adapter_name}', types.MethodType(_forward_seq, tgt_module))
                else:
                    setattr(tgt_module, f'forward_origin_{adapter_name}', tgt_module.forward)
                tgt_module.forward = types.MethodType(_forward_target, tgt_module)
        if target_module_ins is None:
            raise Exception('Cannot match target modules')

        def state_dict_callback(state_dict, adapter_name, **kwargs):
            return {key: value for key, value in state_dict.items() if f'restuning_{adapter_name}' in key}

        def mark_trainable_callback(model):
            return

        return SwiftOutput(
            config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)

    @staticmethod
    def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
        modules = find_sub_module(module, f'restuning_{adapter_name}')
        for _module in modules:
            _module: ActivationMixin
            _module: nn.Module
            _module.set_activation(adapter_name, activate)
            SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload)


class ResTuningBypassModule(nn.Module, ActivationMixin):
    """The implementation of ResTuningBypass method.
    """

    def __init__(
        self,
        dims,
        depth,
        adapter_name,
        use_upsample=False,
        upsample_out_channels=None,
        zero_init_last=False,
        tuner_cfg=None,
    ):
        super(ResTuningBypassModule, self).__init__()
        super(nn.Module, self).__init__('')
        self.adapter_name = adapter_name

        self.bypass_blocks = nn.Sequential(*[
            ResTunerBypassBlock(
                dim=dims[i] if isinstance(dims, list) else dims,
                layer_num=i,
                depth=depth,
                use_upsample=use_upsample,
                upsample_out_channels=upsample_out_channels[i] if isinstance(upsample_out_channels, list
                                                                             ) else upsample_out_channels,
                zero_init_last=zero_init_last,
                tuner_cfg=tuner_cfg[i] if isinstance(tuner_cfg, list) else tuner_cfg) for i in range(depth)
        ])
        self.mark_all_sub_modules_as_plugin()

    def forward(self, x_list, origin_arg, **kwargs):
        if not self.is_activated(self.adapter_name):
            return origin_arg
        x_bypass = detach_tensors(x_list.pop(0))
        x_bypass = x_bypass[0] if isinstance(x_bypass, (list, tuple)) else x_bypass
        x_list = detach_tensors(x_list)
        x_list = [_x[0] if isinstance(_x, (list, tuple)) else _x for _x in x_list]
        for i, (bp_blk, x_stem) in enumerate(zip(self.bypass_blocks, x_list)):
            target_size = x_list[i + 1].shape[2:] if i < len(x_list) - 1 else None
            x_bypass = bp_blk(x_stem, x_bypass, target_size, **kwargs)
        return x_bypass


class ResTunerBypassBlock(nn.Module):

    def __init__(self, dim, layer_num=-1, depth=-1, use_upsample=False, zero_init_last=False, tuner_cfg=None, **kwargs):
        super().__init__()
        self.layer_num = layer_num
        self.depth = depth

        if isinstance(tuner_cfg, str):
            lateral_cfg = tuner_cfg
            vertical_cfg = tuner_cfg
            aux_cfg = 'upsample' if use_upsample and layer_num != depth - 1 else None
        elif isinstance(tuner_cfg, dict):
            lateral_cfg = tuner_cfg['lateral_cfg'] if 'lateral_cfg' in tuner_cfg else None
            vertical_cfg = tuner_cfg['vertical_cfg'] if 'vertical_cfg' in tuner_cfg else None
            aux_cfg = tuner_cfg['aux_cfg'] if 'aux_cfg' in tuner_cfg else None

        self.lateral_tuner = ResTuner(dim, layer_num, depth, zero_init_last, 'lateral', lateral_cfg, **kwargs)
        self.vertical_tuner = ResTuner(dim, layer_num, depth, zero_init_last, 'vertical', vertical_cfg, **kwargs)
        if aux_cfg and len(aux_cfg) != 0:
            self.aux_tuner = ResTuner(dim, layer_num, depth, zero_init_last, 'aux', aux_cfg, **kwargs)

    def forward(self, x_stem, x_bypass, target_size=None, **kwargs):
        x_lateral = self.lateral_tuner(x_stem)
        x_vertical = self.vertical_tuner(x_bypass)

        x_bypass_out = x_lateral + x_vertical
        if hasattr(self, 'aux_tuner'):
            x_bypass_out = self.aux_tuner(x_bypass_out, target_size)
        return x_bypass_out