File size: 10,516 Bytes
96fe658 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
# Copyright (c) Alibaba, Inc. and its affiliates.
from copy import deepcopy
from dataclasses import dataclass, field, fields
from typing import Optional
import torch
from torch import nn
from swift.llm import MODEL_ARCH_MAPPING, HfConfigFactory, ModelKeys
from swift.utils.logger import get_logger
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
logger = get_logger()
@dataclass
class LLaMAProConfig(SwiftConfig):
"""
The configuration class for the LLaMAPro module.
See https://arxiv.org/abs/2401.02415
Args:
model_type(`str`): LLaMAPro only support parts of the LLM models because of the variables need to be manually
modified.
num_new_blocks(`int`): How many new blocks need to be added
num_groups(`int`): The groups of new blocks are split to. Default equals to `num_new_blocks` which means each
single layer will be inserted into every `num_hidden_layers/num_new_blocks` original layers.
"""
model_type: str = field(
default=None, metadata={
'choices': list(MODEL_ARCH_MAPPING.keys()),
})
num_new_blocks: int = None
num_groups: Optional[int] = None
def __post_init__(self):
from .mapping import SwiftTuners
self.swift_type = SwiftTuners.LLAMAPRO
class LLaMAPro(SwiftAdapter):
@staticmethod
def prepare_model(model: nn.Module, config: LLaMAProConfig, adapter_name: str) -> SwiftOutput:
"""Prepare a model with `LLaMAProConfig`"""
num_hidden_layers = HfConfigFactory.get_config_attr(model.config, 'num_hidden_layers')
if num_hidden_layers is None:
num_hidden_layers = HfConfigFactory.get_config_attr(model.config, 'num_layers')
assert num_hidden_layers is not None, 'Cannot find num of layers config'
assert num_hidden_layers % config.num_new_blocks == 0, f'Model layers {num_hidden_layers} ' \
f'should be divided by {config.num_new_blocks}'
if config.num_groups is None:
config.num_groups = config.num_new_blocks
# the except block will change the model_type, this will cause `model not found` error
# when using internvl
origin_model_type = config.model_type
model_type = origin_model_type
num_stride = num_hidden_layers // config.num_groups
try:
module_list = LLaMAPro._find_module_list(config, model)
except AssertionError as e:
model_type = LLaMAPro.search_correct_model_type(model)
if model_type is None:
language_model_name = SwiftAdapter.get_model_key_mapping(config.model_type, config).language_model
if language_model_name:
if isinstance(language_model_name, str):
language_model_name = [language_model_name]
language_model = model.get_submodule(language_model_name[0])
model_type = LLaMAPro.search_correct_model_type(language_model)
if model_type:
model = language_model
if model_type:
config.model_type = model_type
module_list = LLaMAPro._find_module_list(config, model)
else:
raise e
new_module_list = nn.ModuleList()
new_module_idx = []
for idx, module in enumerate(module_list):
new_module_list.append(module)
if (idx + 1) % num_stride == 0:
new_module = deepcopy(module)
ActivationMixin.mark_all_sub_modules_as_plugin(new_module)
new_module_list.append(new_module)
new_module_idx.append(idx + 1 + len(new_module_idx))
LLaMAPro._update_module_weight(config, new_module_list, new_module_idx)
LLaMAPro._update_module_attr(config, new_module_list)
model.config.num_hidden_layers = len(new_module_list)
LLaMAPro._set_module_list(config, model, new_module_list)
def activate_module(activate: bool):
if activate:
LLaMAPro._update_module_attr(config, new_module_list)
LLaMAPro._set_module_list(config, model, new_module_list)
else:
LLaMAPro._update_module_attr(config, module_list)
LLaMAPro._set_module_list(config, model, module_list)
def state_dict_callback(state_dict, adapter_name, **kwargs):
model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config)
new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx]
return {
key: value
for key, value in state_dict.items() if any([m_part in key for m_part in new_module_list])
}
def mark_trainable_callback(model):
model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config)
new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx]
for name, parameter in model.named_parameters():
parameter: nn.Parameter
if any([m_part in name for m_part in new_module_list]):
parameter.requires_grad = True
config.model_type = origin_model_type
model.activate_module = activate_module
return SwiftOutput(
config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
@staticmethod
def _update_module_attr(config: LLaMAProConfig, module_list):
model_type = config.model_type
model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config)
attention = model_key_mapping.attention
attention = attention.split('{}.')[1]
if model_type == 'phi3-small':
raise ValueError('phi3-small does not support llamapro currently')
if model_type in ('llama', 'mistral', 'qwen2', 'yi', 'gemma', 'deepseek', 'openbuddy', 'xverse', 'orion',
'bluelm', 'ziya', 'skywork', 'deepseek-v2', 'minicpm', 'phi3', 'internlm2'):
for idx, module in enumerate(module_list):
try:
getattr(module, attention).layer_idx = idx
except AttributeError:
getattr(module, 'cross_attn').layer_idx = idx
elif model_type in ('chatglm', 'glm4'):
for idx, module in enumerate(module_list):
getattr(module, attention).layer_number = idx
elif model_type in ('phi2', ):
for idx, module in enumerate(module_list):
getattr(module, attention).block_idx = idx
else:
for idx, module in enumerate(module_list):
attrs = [
attr for attr in dir(getattr(module_list[0], attention))
if attr in ('layer_idx', 'layer_number', 'block_idx')
]
assert len(attrs) <= 1
if attrs:
setattr(getattr(module, attention), attrs[0], idx)
else:
logger.warn(f'model_type: {model_type} seems has no layer_idx, if you encountered anything wrong,'
f'please give us a feedback.')
@classmethod
def get_model_key_mapping(cls, model_type, config) -> ModelKeys:
model_key_mapping = SwiftAdapter.get_model_key_mapping(model_type, config)
assert model_key_mapping.o_proj is not None and model_key_mapping.down_proj is not None, \
'LLaMAPro only support models with o_proj and down_proj components.'
return model_key_mapping
@classmethod
def search_correct_model_type(cls, module: nn.Module):
for arch_name, arch_type in MODEL_ARCH_MAPPING.items():
arch_type: ModelKeys
if getattr(arch_type, 'module_list') is None:
# Need to be a LLM arch
continue
matched = True
for f in fields(arch_type):
arch_str = getattr(arch_type, f.name)
if f.name == 'arch_name' or arch_str is None:
continue
arch_str = arch_str.replace('{}', '0')
try:
sub_module = module.get_submodule(arch_str)
if sub_module is None:
matched = False
except AttributeError:
matched = False
if not matched:
break
if matched:
return arch_name
@staticmethod
def _update_module_weight(config: LLaMAProConfig, module_list, new_module_idx):
model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)
o_proj = model_key_mapping.o_proj.split('{}.')[1]
down_proj = model_key_mapping.down_proj.split('{}.')[1]
for idx, module in enumerate(module_list):
if idx not in new_module_idx:
continue
_o_proj: nn.Linear = module.get_submodule(o_proj)
_down_proj: nn.Linear = module.get_submodule(down_proj)
_o_proj.weight.data = torch.zeros_like(_o_proj.weight.data)
_down_proj.weight.data = torch.zeros_like(_down_proj.weight.data)
if hasattr(_o_proj, 'bias') and _o_proj.bias is not None:
_o_proj.bias.data = torch.zeros_like(_o_proj.bias)
if hasattr(_down_proj, 'bias') and _down_proj.bias is not None:
_down_proj.bias.data = torch.zeros_like(_down_proj.bias)
@staticmethod
def _set_module_list(config, module: nn.Module, module_list: nn.ModuleList):
model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)
idx = model_key_mapping.module_list.rfind('.')
parent = module.get_submodule(model_key_mapping.module_list[:idx])
setattr(parent, model_key_mapping.module_list[idx + 1:], module_list)
@staticmethod
def _find_module_list(config, module: nn.Module) -> nn.ModuleList:
model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)
return module.get_submodule(model_key_mapping.module_list)
@staticmethod
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
module.activate_module(activate)
@staticmethod
def has_additional_modules():
return True
|