|
|
|
from copy import deepcopy |
|
from dataclasses import dataclass, field, fields |
|
from typing import Optional |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from swift.llm import MODEL_ARCH_MAPPING, HfConfigFactory, ModelKeys |
|
from swift.utils.logger import get_logger |
|
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput |
|
|
|
logger = get_logger() |
|
|
|
|
|
@dataclass |
|
class LLaMAProConfig(SwiftConfig): |
|
""" |
|
The configuration class for the LLaMAPro module. |
|
|
|
See https://arxiv.org/abs/2401.02415 |
|
|
|
Args: |
|
model_type(`str`): LLaMAPro only support parts of the LLM models because of the variables need to be manually |
|
modified. |
|
num_new_blocks(`int`): How many new blocks need to be added |
|
num_groups(`int`): The groups of new blocks are split to. Default equals to `num_new_blocks` which means each |
|
single layer will be inserted into every `num_hidden_layers/num_new_blocks` original layers. |
|
""" |
|
model_type: str = field( |
|
default=None, metadata={ |
|
'choices': list(MODEL_ARCH_MAPPING.keys()), |
|
}) |
|
|
|
num_new_blocks: int = None |
|
|
|
num_groups: Optional[int] = None |
|
|
|
def __post_init__(self): |
|
from .mapping import SwiftTuners |
|
self.swift_type = SwiftTuners.LLAMAPRO |
|
|
|
|
|
class LLaMAPro(SwiftAdapter): |
|
|
|
@staticmethod |
|
def prepare_model(model: nn.Module, config: LLaMAProConfig, adapter_name: str) -> SwiftOutput: |
|
"""Prepare a model with `LLaMAProConfig`""" |
|
num_hidden_layers = HfConfigFactory.get_config_attr(model.config, 'num_hidden_layers') |
|
if num_hidden_layers is None: |
|
num_hidden_layers = HfConfigFactory.get_config_attr(model.config, 'num_layers') |
|
assert num_hidden_layers is not None, 'Cannot find num of layers config' |
|
assert num_hidden_layers % config.num_new_blocks == 0, f'Model layers {num_hidden_layers} ' \ |
|
f'should be divided by {config.num_new_blocks}' |
|
if config.num_groups is None: |
|
config.num_groups = config.num_new_blocks |
|
|
|
|
|
|
|
origin_model_type = config.model_type |
|
model_type = origin_model_type |
|
num_stride = num_hidden_layers // config.num_groups |
|
try: |
|
module_list = LLaMAPro._find_module_list(config, model) |
|
except AssertionError as e: |
|
model_type = LLaMAPro.search_correct_model_type(model) |
|
if model_type is None: |
|
language_model_name = SwiftAdapter.get_model_key_mapping(config.model_type, config).language_model |
|
if language_model_name: |
|
if isinstance(language_model_name, str): |
|
language_model_name = [language_model_name] |
|
language_model = model.get_submodule(language_model_name[0]) |
|
model_type = LLaMAPro.search_correct_model_type(language_model) |
|
if model_type: |
|
model = language_model |
|
|
|
if model_type: |
|
config.model_type = model_type |
|
module_list = LLaMAPro._find_module_list(config, model) |
|
else: |
|
raise e |
|
|
|
new_module_list = nn.ModuleList() |
|
new_module_idx = [] |
|
for idx, module in enumerate(module_list): |
|
new_module_list.append(module) |
|
if (idx + 1) % num_stride == 0: |
|
new_module = deepcopy(module) |
|
ActivationMixin.mark_all_sub_modules_as_plugin(new_module) |
|
new_module_list.append(new_module) |
|
new_module_idx.append(idx + 1 + len(new_module_idx)) |
|
|
|
LLaMAPro._update_module_weight(config, new_module_list, new_module_idx) |
|
LLaMAPro._update_module_attr(config, new_module_list) |
|
model.config.num_hidden_layers = len(new_module_list) |
|
LLaMAPro._set_module_list(config, model, new_module_list) |
|
|
|
def activate_module(activate: bool): |
|
if activate: |
|
LLaMAPro._update_module_attr(config, new_module_list) |
|
LLaMAPro._set_module_list(config, model, new_module_list) |
|
else: |
|
LLaMAPro._update_module_attr(config, module_list) |
|
LLaMAPro._set_module_list(config, model, module_list) |
|
|
|
def state_dict_callback(state_dict, adapter_name, **kwargs): |
|
model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config) |
|
new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx] |
|
return { |
|
key: value |
|
for key, value in state_dict.items() if any([m_part in key for m_part in new_module_list]) |
|
} |
|
|
|
def mark_trainable_callback(model): |
|
model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config) |
|
new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx] |
|
for name, parameter in model.named_parameters(): |
|
parameter: nn.Parameter |
|
if any([m_part in name for m_part in new_module_list]): |
|
parameter.requires_grad = True |
|
|
|
config.model_type = origin_model_type |
|
model.activate_module = activate_module |
|
return SwiftOutput( |
|
config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback) |
|
|
|
@staticmethod |
|
def _update_module_attr(config: LLaMAProConfig, module_list): |
|
model_type = config.model_type |
|
model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config) |
|
attention = model_key_mapping.attention |
|
attention = attention.split('{}.')[1] |
|
if model_type == 'phi3-small': |
|
raise ValueError('phi3-small does not support llamapro currently') |
|
if model_type in ('llama', 'mistral', 'qwen2', 'yi', 'gemma', 'deepseek', 'openbuddy', 'xverse', 'orion', |
|
'bluelm', 'ziya', 'skywork', 'deepseek-v2', 'minicpm', 'phi3', 'internlm2'): |
|
for idx, module in enumerate(module_list): |
|
try: |
|
getattr(module, attention).layer_idx = idx |
|
except AttributeError: |
|
getattr(module, 'cross_attn').layer_idx = idx |
|
elif model_type in ('chatglm', 'glm4'): |
|
for idx, module in enumerate(module_list): |
|
getattr(module, attention).layer_number = idx |
|
elif model_type in ('phi2', ): |
|
for idx, module in enumerate(module_list): |
|
getattr(module, attention).block_idx = idx |
|
else: |
|
for idx, module in enumerate(module_list): |
|
attrs = [ |
|
attr for attr in dir(getattr(module_list[0], attention)) |
|
if attr in ('layer_idx', 'layer_number', 'block_idx') |
|
] |
|
assert len(attrs) <= 1 |
|
if attrs: |
|
setattr(getattr(module, attention), attrs[0], idx) |
|
else: |
|
logger.warn(f'model_type: {model_type} seems has no layer_idx, if you encountered anything wrong,' |
|
f'please give us a feedback.') |
|
|
|
@classmethod |
|
def get_model_key_mapping(cls, model_type, config) -> ModelKeys: |
|
|
|
model_key_mapping = SwiftAdapter.get_model_key_mapping(model_type, config) |
|
assert model_key_mapping.o_proj is not None and model_key_mapping.down_proj is not None, \ |
|
'LLaMAPro only support models with o_proj and down_proj components.' |
|
return model_key_mapping |
|
|
|
@classmethod |
|
def search_correct_model_type(cls, module: nn.Module): |
|
for arch_name, arch_type in MODEL_ARCH_MAPPING.items(): |
|
arch_type: ModelKeys |
|
if getattr(arch_type, 'module_list') is None: |
|
|
|
continue |
|
|
|
matched = True |
|
for f in fields(arch_type): |
|
arch_str = getattr(arch_type, f.name) |
|
if f.name == 'arch_name' or arch_str is None: |
|
continue |
|
|
|
arch_str = arch_str.replace('{}', '0') |
|
try: |
|
sub_module = module.get_submodule(arch_str) |
|
if sub_module is None: |
|
matched = False |
|
except AttributeError: |
|
matched = False |
|
|
|
if not matched: |
|
break |
|
|
|
if matched: |
|
return arch_name |
|
|
|
@staticmethod |
|
def _update_module_weight(config: LLaMAProConfig, module_list, new_module_idx): |
|
model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config) |
|
o_proj = model_key_mapping.o_proj.split('{}.')[1] |
|
down_proj = model_key_mapping.down_proj.split('{}.')[1] |
|
|
|
for idx, module in enumerate(module_list): |
|
if idx not in new_module_idx: |
|
continue |
|
_o_proj: nn.Linear = module.get_submodule(o_proj) |
|
_down_proj: nn.Linear = module.get_submodule(down_proj) |
|
_o_proj.weight.data = torch.zeros_like(_o_proj.weight.data) |
|
_down_proj.weight.data = torch.zeros_like(_down_proj.weight.data) |
|
if hasattr(_o_proj, 'bias') and _o_proj.bias is not None: |
|
_o_proj.bias.data = torch.zeros_like(_o_proj.bias) |
|
if hasattr(_down_proj, 'bias') and _down_proj.bias is not None: |
|
_down_proj.bias.data = torch.zeros_like(_down_proj.bias) |
|
|
|
@staticmethod |
|
def _set_module_list(config, module: nn.Module, module_list: nn.ModuleList): |
|
model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config) |
|
idx = model_key_mapping.module_list.rfind('.') |
|
parent = module.get_submodule(model_key_mapping.module_list[:idx]) |
|
setattr(parent, model_key_mapping.module_list[idx + 1:], module_list) |
|
|
|
@staticmethod |
|
def _find_module_list(config, module: nn.Module) -> nn.ModuleList: |
|
model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config) |
|
return module.get_submodule(model_key_mapping.module_list) |
|
|
|
@staticmethod |
|
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): |
|
module.activate_module(activate) |
|
|
|
@staticmethod |
|
def has_additional_modules(): |
|
return True |
|
|