# Copyright (c) Alibaba, Inc. and its affiliates. from copy import deepcopy from dataclasses import dataclass, field, fields from typing import Optional import torch from torch import nn from swift.llm import MODEL_ARCH_MAPPING, HfConfigFactory, ModelKeys from swift.utils.logger import get_logger from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput logger = get_logger() @dataclass class LLaMAProConfig(SwiftConfig): """ The configuration class for the LLaMAPro module. See https://arxiv.org/abs/2401.02415 Args: model_type(`str`): LLaMAPro only support parts of the LLM models because of the variables need to be manually modified. num_new_blocks(`int`): How many new blocks need to be added num_groups(`int`): The groups of new blocks are split to. Default equals to `num_new_blocks` which means each single layer will be inserted into every `num_hidden_layers/num_new_blocks` original layers. """ model_type: str = field( default=None, metadata={ 'choices': list(MODEL_ARCH_MAPPING.keys()), }) num_new_blocks: int = None num_groups: Optional[int] = None def __post_init__(self): from .mapping import SwiftTuners self.swift_type = SwiftTuners.LLAMAPRO class LLaMAPro(SwiftAdapter): @staticmethod def prepare_model(model: nn.Module, config: LLaMAProConfig, adapter_name: str) -> SwiftOutput: """Prepare a model with `LLaMAProConfig`""" num_hidden_layers = HfConfigFactory.get_config_attr(model.config, 'num_hidden_layers') if num_hidden_layers is None: num_hidden_layers = HfConfigFactory.get_config_attr(model.config, 'num_layers') assert num_hidden_layers is not None, 'Cannot find num of layers config' assert num_hidden_layers % config.num_new_blocks == 0, f'Model layers {num_hidden_layers} ' \ f'should be divided by {config.num_new_blocks}' if config.num_groups is None: config.num_groups = config.num_new_blocks # the except block will change the model_type, this will cause `model not found` error # when using internvl origin_model_type = config.model_type model_type = origin_model_type num_stride = num_hidden_layers // config.num_groups try: module_list = LLaMAPro._find_module_list(config, model) except AssertionError as e: model_type = LLaMAPro.search_correct_model_type(model) if model_type is None: language_model_name = SwiftAdapter.get_model_key_mapping(config.model_type, config).language_model if language_model_name: if isinstance(language_model_name, str): language_model_name = [language_model_name] language_model = model.get_submodule(language_model_name[0]) model_type = LLaMAPro.search_correct_model_type(language_model) if model_type: model = language_model if model_type: config.model_type = model_type module_list = LLaMAPro._find_module_list(config, model) else: raise e new_module_list = nn.ModuleList() new_module_idx = [] for idx, module in enumerate(module_list): new_module_list.append(module) if (idx + 1) % num_stride == 0: new_module = deepcopy(module) ActivationMixin.mark_all_sub_modules_as_plugin(new_module) new_module_list.append(new_module) new_module_idx.append(idx + 1 + len(new_module_idx)) LLaMAPro._update_module_weight(config, new_module_list, new_module_idx) LLaMAPro._update_module_attr(config, new_module_list) model.config.num_hidden_layers = len(new_module_list) LLaMAPro._set_module_list(config, model, new_module_list) def activate_module(activate: bool): if activate: LLaMAPro._update_module_attr(config, new_module_list) LLaMAPro._set_module_list(config, model, new_module_list) else: LLaMAPro._update_module_attr(config, module_list) LLaMAPro._set_module_list(config, model, module_list) def state_dict_callback(state_dict, adapter_name, **kwargs): model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config) new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx] return { key: value for key, value in state_dict.items() if any([m_part in key for m_part in new_module_list]) } def mark_trainable_callback(model): model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config) new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx] for name, parameter in model.named_parameters(): parameter: nn.Parameter if any([m_part in name for m_part in new_module_list]): parameter.requires_grad = True config.model_type = origin_model_type model.activate_module = activate_module return SwiftOutput( config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback) @staticmethod def _update_module_attr(config: LLaMAProConfig, module_list): model_type = config.model_type model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config) attention = model_key_mapping.attention attention = attention.split('{}.')[1] if model_type == 'phi3-small': raise ValueError('phi3-small does not support llamapro currently') if model_type in ('llama', 'mistral', 'qwen2', 'yi', 'gemma', 'deepseek', 'openbuddy', 'xverse', 'orion', 'bluelm', 'ziya', 'skywork', 'deepseek-v2', 'minicpm', 'phi3', 'internlm2'): for idx, module in enumerate(module_list): try: getattr(module, attention).layer_idx = idx except AttributeError: getattr(module, 'cross_attn').layer_idx = idx elif model_type in ('chatglm', 'glm4'): for idx, module in enumerate(module_list): getattr(module, attention).layer_number = idx elif model_type in ('phi2', ): for idx, module in enumerate(module_list): getattr(module, attention).block_idx = idx else: for idx, module in enumerate(module_list): attrs = [ attr for attr in dir(getattr(module_list[0], attention)) if attr in ('layer_idx', 'layer_number', 'block_idx') ] assert len(attrs) <= 1 if attrs: setattr(getattr(module, attention), attrs[0], idx) else: logger.warn(f'model_type: {model_type} seems has no layer_idx, if you encountered anything wrong,' f'please give us a feedback.') @classmethod def get_model_key_mapping(cls, model_type, config) -> ModelKeys: model_key_mapping = SwiftAdapter.get_model_key_mapping(model_type, config) assert model_key_mapping.o_proj is not None and model_key_mapping.down_proj is not None, \ 'LLaMAPro only support models with o_proj and down_proj components.' return model_key_mapping @classmethod def search_correct_model_type(cls, module: nn.Module): for arch_name, arch_type in MODEL_ARCH_MAPPING.items(): arch_type: ModelKeys if getattr(arch_type, 'module_list') is None: # Need to be a LLM arch continue matched = True for f in fields(arch_type): arch_str = getattr(arch_type, f.name) if f.name == 'arch_name' or arch_str is None: continue arch_str = arch_str.replace('{}', '0') try: sub_module = module.get_submodule(arch_str) if sub_module is None: matched = False except AttributeError: matched = False if not matched: break if matched: return arch_name @staticmethod def _update_module_weight(config: LLaMAProConfig, module_list, new_module_idx): model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config) o_proj = model_key_mapping.o_proj.split('{}.')[1] down_proj = model_key_mapping.down_proj.split('{}.')[1] for idx, module in enumerate(module_list): if idx not in new_module_idx: continue _o_proj: nn.Linear = module.get_submodule(o_proj) _down_proj: nn.Linear = module.get_submodule(down_proj) _o_proj.weight.data = torch.zeros_like(_o_proj.weight.data) _down_proj.weight.data = torch.zeros_like(_down_proj.weight.data) if hasattr(_o_proj, 'bias') and _o_proj.bias is not None: _o_proj.bias.data = torch.zeros_like(_o_proj.bias) if hasattr(_down_proj, 'bias') and _down_proj.bias is not None: _down_proj.bias.data = torch.zeros_like(_down_proj.bias) @staticmethod def _set_module_list(config, module: nn.Module, module_list: nn.ModuleList): model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config) idx = model_key_mapping.module_list.rfind('.') parent = module.get_submodule(model_key_mapping.module_list[:idx]) setattr(parent, model_key_mapping.module_list[idx + 1:], module_list) @staticmethod def _find_module_list(config, module: nn.Module) -> nn.ModuleList: model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config) return module.get_submodule(model_key_mapping.module_list) @staticmethod def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): module.activate_module(activate) @staticmethod def has_additional_modules(): return True