Enxin's picture
Upload folder using huggingface_hub
96fe658 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
from copy import deepcopy
from dataclasses import dataclass, field, fields
from typing import Optional
import torch
from torch import nn
from swift.llm import MODEL_ARCH_MAPPING, HfConfigFactory, ModelKeys
from swift.utils.logger import get_logger
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
logger = get_logger()
@dataclass
class LLaMAProConfig(SwiftConfig):
"""
The configuration class for the LLaMAPro module.
See https://arxiv.org/abs/2401.02415
Args:
model_type(`str`): LLaMAPro only support parts of the LLM models because of the variables need to be manually
modified.
num_new_blocks(`int`): How many new blocks need to be added
num_groups(`int`): The groups of new blocks are split to. Default equals to `num_new_blocks` which means each
single layer will be inserted into every `num_hidden_layers/num_new_blocks` original layers.
"""
model_type: str = field(
default=None, metadata={
'choices': list(MODEL_ARCH_MAPPING.keys()),
})
num_new_blocks: int = None
num_groups: Optional[int] = None
def __post_init__(self):
from .mapping import SwiftTuners
self.swift_type = SwiftTuners.LLAMAPRO
class LLaMAPro(SwiftAdapter):
@staticmethod
def prepare_model(model: nn.Module, config: LLaMAProConfig, adapter_name: str) -> SwiftOutput:
"""Prepare a model with `LLaMAProConfig`"""
num_hidden_layers = HfConfigFactory.get_config_attr(model.config, 'num_hidden_layers')
if num_hidden_layers is None:
num_hidden_layers = HfConfigFactory.get_config_attr(model.config, 'num_layers')
assert num_hidden_layers is not None, 'Cannot find num of layers config'
assert num_hidden_layers % config.num_new_blocks == 0, f'Model layers {num_hidden_layers} ' \
f'should be divided by {config.num_new_blocks}'
if config.num_groups is None:
config.num_groups = config.num_new_blocks
# the except block will change the model_type, this will cause `model not found` error
# when using internvl
origin_model_type = config.model_type
model_type = origin_model_type
num_stride = num_hidden_layers // config.num_groups
try:
module_list = LLaMAPro._find_module_list(config, model)
except AssertionError as e:
model_type = LLaMAPro.search_correct_model_type(model)
if model_type is None:
language_model_name = SwiftAdapter.get_model_key_mapping(config.model_type, config).language_model
if language_model_name:
if isinstance(language_model_name, str):
language_model_name = [language_model_name]
language_model = model.get_submodule(language_model_name[0])
model_type = LLaMAPro.search_correct_model_type(language_model)
if model_type:
model = language_model
if model_type:
config.model_type = model_type
module_list = LLaMAPro._find_module_list(config, model)
else:
raise e
new_module_list = nn.ModuleList()
new_module_idx = []
for idx, module in enumerate(module_list):
new_module_list.append(module)
if (idx + 1) % num_stride == 0:
new_module = deepcopy(module)
ActivationMixin.mark_all_sub_modules_as_plugin(new_module)
new_module_list.append(new_module)
new_module_idx.append(idx + 1 + len(new_module_idx))
LLaMAPro._update_module_weight(config, new_module_list, new_module_idx)
LLaMAPro._update_module_attr(config, new_module_list)
model.config.num_hidden_layers = len(new_module_list)
LLaMAPro._set_module_list(config, model, new_module_list)
def activate_module(activate: bool):
if activate:
LLaMAPro._update_module_attr(config, new_module_list)
LLaMAPro._set_module_list(config, model, new_module_list)
else:
LLaMAPro._update_module_attr(config, module_list)
LLaMAPro._set_module_list(config, model, module_list)
def state_dict_callback(state_dict, adapter_name, **kwargs):
model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config)
new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx]
return {
key: value
for key, value in state_dict.items() if any([m_part in key for m_part in new_module_list])
}
def mark_trainable_callback(model):
model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config)
new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx]
for name, parameter in model.named_parameters():
parameter: nn.Parameter
if any([m_part in name for m_part in new_module_list]):
parameter.requires_grad = True
config.model_type = origin_model_type
model.activate_module = activate_module
return SwiftOutput(
config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
@staticmethod
def _update_module_attr(config: LLaMAProConfig, module_list):
model_type = config.model_type
model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config)
attention = model_key_mapping.attention
attention = attention.split('{}.')[1]
if model_type == 'phi3-small':
raise ValueError('phi3-small does not support llamapro currently')
if model_type in ('llama', 'mistral', 'qwen2', 'yi', 'gemma', 'deepseek', 'openbuddy', 'xverse', 'orion',
'bluelm', 'ziya', 'skywork', 'deepseek-v2', 'minicpm', 'phi3', 'internlm2'):
for idx, module in enumerate(module_list):
try:
getattr(module, attention).layer_idx = idx
except AttributeError:
getattr(module, 'cross_attn').layer_idx = idx
elif model_type in ('chatglm', 'glm4'):
for idx, module in enumerate(module_list):
getattr(module, attention).layer_number = idx
elif model_type in ('phi2', ):
for idx, module in enumerate(module_list):
getattr(module, attention).block_idx = idx
else:
for idx, module in enumerate(module_list):
attrs = [
attr for attr in dir(getattr(module_list[0], attention))
if attr in ('layer_idx', 'layer_number', 'block_idx')
]
assert len(attrs) <= 1
if attrs:
setattr(getattr(module, attention), attrs[0], idx)
else:
logger.warn(f'model_type: {model_type} seems has no layer_idx, if you encountered anything wrong,'
f'please give us a feedback.')
@classmethod
def get_model_key_mapping(cls, model_type, config) -> ModelKeys:
model_key_mapping = SwiftAdapter.get_model_key_mapping(model_type, config)
assert model_key_mapping.o_proj is not None and model_key_mapping.down_proj is not None, \
'LLaMAPro only support models with o_proj and down_proj components.'
return model_key_mapping
@classmethod
def search_correct_model_type(cls, module: nn.Module):
for arch_name, arch_type in MODEL_ARCH_MAPPING.items():
arch_type: ModelKeys
if getattr(arch_type, 'module_list') is None:
# Need to be a LLM arch
continue
matched = True
for f in fields(arch_type):
arch_str = getattr(arch_type, f.name)
if f.name == 'arch_name' or arch_str is None:
continue
arch_str = arch_str.replace('{}', '0')
try:
sub_module = module.get_submodule(arch_str)
if sub_module is None:
matched = False
except AttributeError:
matched = False
if not matched:
break
if matched:
return arch_name
@staticmethod
def _update_module_weight(config: LLaMAProConfig, module_list, new_module_idx):
model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)
o_proj = model_key_mapping.o_proj.split('{}.')[1]
down_proj = model_key_mapping.down_proj.split('{}.')[1]
for idx, module in enumerate(module_list):
if idx not in new_module_idx:
continue
_o_proj: nn.Linear = module.get_submodule(o_proj)
_down_proj: nn.Linear = module.get_submodule(down_proj)
_o_proj.weight.data = torch.zeros_like(_o_proj.weight.data)
_down_proj.weight.data = torch.zeros_like(_down_proj.weight.data)
if hasattr(_o_proj, 'bias') and _o_proj.bias is not None:
_o_proj.bias.data = torch.zeros_like(_o_proj.bias)
if hasattr(_down_proj, 'bias') and _down_proj.bias is not None:
_down_proj.bias.data = torch.zeros_like(_down_proj.bias)
@staticmethod
def _set_module_list(config, module: nn.Module, module_list: nn.ModuleList):
model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)
idx = model_key_mapping.module_list.rfind('.')
parent = module.get_submodule(model_key_mapping.module_list[:idx])
setattr(parent, model_key_mapping.module_list[idx + 1:], module_list)
@staticmethod
def _find_module_list(config, module: nn.Module) -> nn.ModuleList:
model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)
return module.get_submodule(model_key_mapping.module_list)
@staticmethod
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
module.activate_module(activate)
@staticmethod
def has_additional_modules():
return True