File size: 10,516 Bytes
96fe658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
# Copyright (c) Alibaba, Inc. and its affiliates.
from copy import deepcopy
from dataclasses import dataclass, field, fields
from typing import Optional

import torch
from torch import nn

from swift.llm import MODEL_ARCH_MAPPING, HfConfigFactory, ModelKeys
from swift.utils.logger import get_logger
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput

logger = get_logger()


@dataclass
class LLaMAProConfig(SwiftConfig):
    """
    The configuration class for the LLaMAPro module.

    See https://arxiv.org/abs/2401.02415

    Args:
        model_type(`str`): LLaMAPro only support parts of the LLM models because of the variables need to be manually
            modified.
        num_new_blocks(`int`): How many new blocks need to be added
        num_groups(`int`): The groups of new blocks are split to. Default equals to `num_new_blocks` which means each
            single layer will be inserted into every `num_hidden_layers/num_new_blocks` original layers.
    """
    model_type: str = field(
        default=None, metadata={
            'choices': list(MODEL_ARCH_MAPPING.keys()),
        })

    num_new_blocks: int = None

    num_groups: Optional[int] = None

    def __post_init__(self):
        from .mapping import SwiftTuners
        self.swift_type = SwiftTuners.LLAMAPRO


class LLaMAPro(SwiftAdapter):

    @staticmethod
    def prepare_model(model: nn.Module, config: LLaMAProConfig, adapter_name: str) -> SwiftOutput:
        """Prepare a model with `LLaMAProConfig`"""
        num_hidden_layers = HfConfigFactory.get_config_attr(model.config, 'num_hidden_layers')
        if num_hidden_layers is None:
            num_hidden_layers = HfConfigFactory.get_config_attr(model.config, 'num_layers')
        assert num_hidden_layers is not None, 'Cannot find num of layers config'
        assert num_hidden_layers % config.num_new_blocks == 0, f'Model layers {num_hidden_layers} ' \
                                                               f'should be divided by {config.num_new_blocks}'
        if config.num_groups is None:
            config.num_groups = config.num_new_blocks

        # the except block will change the model_type, this will cause `model not found` error
        # when using internvl
        origin_model_type = config.model_type
        model_type = origin_model_type
        num_stride = num_hidden_layers // config.num_groups
        try:
            module_list = LLaMAPro._find_module_list(config, model)
        except AssertionError as e:
            model_type = LLaMAPro.search_correct_model_type(model)
            if model_type is None:
                language_model_name = SwiftAdapter.get_model_key_mapping(config.model_type, config).language_model
                if language_model_name:
                    if isinstance(language_model_name, str):
                        language_model_name = [language_model_name]
                    language_model = model.get_submodule(language_model_name[0])
                    model_type = LLaMAPro.search_correct_model_type(language_model)
                    if model_type:
                        model = language_model

            if model_type:
                config.model_type = model_type
                module_list = LLaMAPro._find_module_list(config, model)
            else:
                raise e

        new_module_list = nn.ModuleList()
        new_module_idx = []
        for idx, module in enumerate(module_list):
            new_module_list.append(module)
            if (idx + 1) % num_stride == 0:
                new_module = deepcopy(module)
                ActivationMixin.mark_all_sub_modules_as_plugin(new_module)
                new_module_list.append(new_module)
                new_module_idx.append(idx + 1 + len(new_module_idx))

        LLaMAPro._update_module_weight(config, new_module_list, new_module_idx)
        LLaMAPro._update_module_attr(config, new_module_list)
        model.config.num_hidden_layers = len(new_module_list)
        LLaMAPro._set_module_list(config, model, new_module_list)

        def activate_module(activate: bool):
            if activate:
                LLaMAPro._update_module_attr(config, new_module_list)
                LLaMAPro._set_module_list(config, model, new_module_list)
            else:
                LLaMAPro._update_module_attr(config, module_list)
                LLaMAPro._set_module_list(config, model, module_list)

        def state_dict_callback(state_dict, adapter_name, **kwargs):
            model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config)
            new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx]
            return {
                key: value
                for key, value in state_dict.items() if any([m_part in key for m_part in new_module_list])
            }

        def mark_trainable_callback(model):
            model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config)
            new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx]
            for name, parameter in model.named_parameters():
                parameter: nn.Parameter
                if any([m_part in name for m_part in new_module_list]):
                    parameter.requires_grad = True

        config.model_type = origin_model_type
        model.activate_module = activate_module
        return SwiftOutput(
            config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)

    @staticmethod
    def _update_module_attr(config: LLaMAProConfig, module_list):
        model_type = config.model_type
        model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config)
        attention = model_key_mapping.attention
        attention = attention.split('{}.')[1]
        if model_type == 'phi3-small':
            raise ValueError('phi3-small does not support llamapro currently')
        if model_type in ('llama', 'mistral', 'qwen2', 'yi', 'gemma', 'deepseek', 'openbuddy', 'xverse', 'orion',
                          'bluelm', 'ziya', 'skywork', 'deepseek-v2', 'minicpm', 'phi3', 'internlm2'):
            for idx, module in enumerate(module_list):
                try:
                    getattr(module, attention).layer_idx = idx
                except AttributeError:
                    getattr(module, 'cross_attn').layer_idx = idx
        elif model_type in ('chatglm', 'glm4'):
            for idx, module in enumerate(module_list):
                getattr(module, attention).layer_number = idx
        elif model_type in ('phi2', ):
            for idx, module in enumerate(module_list):
                getattr(module, attention).block_idx = idx
        else:
            for idx, module in enumerate(module_list):
                attrs = [
                    attr for attr in dir(getattr(module_list[0], attention))
                    if attr in ('layer_idx', 'layer_number', 'block_idx')
                ]
                assert len(attrs) <= 1
                if attrs:
                    setattr(getattr(module, attention), attrs[0], idx)
                else:
                    logger.warn(f'model_type: {model_type} seems has no layer_idx, if you encountered anything wrong,'
                                f'please give us a feedback.')

    @classmethod
    def get_model_key_mapping(cls, model_type, config) -> ModelKeys:

        model_key_mapping = SwiftAdapter.get_model_key_mapping(model_type, config)
        assert model_key_mapping.o_proj is not None and model_key_mapping.down_proj is not None, \
            'LLaMAPro only support models with o_proj and down_proj components.'
        return model_key_mapping

    @classmethod
    def search_correct_model_type(cls, module: nn.Module):
        for arch_name, arch_type in MODEL_ARCH_MAPPING.items():
            arch_type: ModelKeys
            if getattr(arch_type, 'module_list') is None:
                # Need to be a LLM arch
                continue

            matched = True
            for f in fields(arch_type):
                arch_str = getattr(arch_type, f.name)
                if f.name == 'arch_name' or arch_str is None:
                    continue

                arch_str = arch_str.replace('{}', '0')
                try:
                    sub_module = module.get_submodule(arch_str)
                    if sub_module is None:
                        matched = False
                except AttributeError:
                    matched = False

                if not matched:
                    break

            if matched:
                return arch_name

    @staticmethod
    def _update_module_weight(config: LLaMAProConfig, module_list, new_module_idx):
        model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)
        o_proj = model_key_mapping.o_proj.split('{}.')[1]
        down_proj = model_key_mapping.down_proj.split('{}.')[1]

        for idx, module in enumerate(module_list):
            if idx not in new_module_idx:
                continue
            _o_proj: nn.Linear = module.get_submodule(o_proj)
            _down_proj: nn.Linear = module.get_submodule(down_proj)
            _o_proj.weight.data = torch.zeros_like(_o_proj.weight.data)
            _down_proj.weight.data = torch.zeros_like(_down_proj.weight.data)
            if hasattr(_o_proj, 'bias') and _o_proj.bias is not None:
                _o_proj.bias.data = torch.zeros_like(_o_proj.bias)
            if hasattr(_down_proj, 'bias') and _down_proj.bias is not None:
                _down_proj.bias.data = torch.zeros_like(_down_proj.bias)

    @staticmethod
    def _set_module_list(config, module: nn.Module, module_list: nn.ModuleList):
        model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)
        idx = model_key_mapping.module_list.rfind('.')
        parent = module.get_submodule(model_key_mapping.module_list[:idx])
        setattr(parent, model_key_mapping.module_list[idx + 1:], module_list)

    @staticmethod
    def _find_module_list(config, module: nn.Module) -> nn.ModuleList:
        model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)
        return module.get_submodule(model_key_mapping.module_list)

    @staticmethod
    def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
        module.activate_module(activate)

    @staticmethod
    def has_additional_modules():
        return True