Enxin's picture
Upload folder using huggingface_hub
96fe658 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import re
from copy import deepcopy
from dataclasses import dataclass
from types import MethodType
from typing import Dict, Optional
import torch
from torch import nn
from swift.utils import get_logger
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
logger = get_logger()
@dataclass
class PartConfig(SwiftConfig):
"""
Freeze the model and train a part of it.
Args:
target_modules(`Optional[str]`): The target modules to be trained in regex format
"""
target_modules: Optional[str] = None
def __post_init__(self):
from .mapping import SwiftTuners
self.swift_type = SwiftTuners.PART
class Part(SwiftAdapter):
@staticmethod
def target_module_matched(module_key: str, config: PartConfig):
return re.fullmatch(config.target_modules, module_key)
@staticmethod
def prepare_model(model: nn.Module, config: PartConfig, adapter_name: str):
name_list = [name for name, _ in model.named_modules(remove_duplicate=False)]
for name in name_list:
module: nn.Module = model.get_submodule(name)
if Part.target_module_matched(name, config) and not getattr(module, 'plugin', False):
if hasattr(module, 'base_layer'):
module = module.base_layer
def _forward(self, *args, **kwargs):
child_list = [
sub_module for name, sub_module in self.named_modules(remove_duplicate=False)
if '_part_' in name
]
sub_modules = [child for child in child_list if getattr(child, 'activated', False)]
assert len(sub_modules) <= 1
if len(sub_modules) == 1:
return sub_modules[0].forward(*args, **kwargs)
else:
return self.forward_origin(*args, **kwargs)
if not hasattr(module, 'forward_origin'):
module.forward_origin = module.forward
module.forward = MethodType(_forward, module)
new_module = deepcopy(module)
for attr in dir(new_module):
if '_part_' in attr:
delattr(new_module, attr)
new_module.part_name = adapter_name
ActivationMixin.mark_all_sub_modules_as_plugin(new_module)
setattr(module, f'_part_{adapter_name}', new_module)
new_module.requires_grad_(True)
def state_dict_callback(state_dict, adapter_name, **kwargs):
new_state_dict = {}
for key, value in state_dict.items():
if f'_part_{adapter_name}.' in key:
if kwargs.get('replace_key', True):
new_key = key.replace(f'_part_{adapter_name}.', '').replace('base_layer.', '')
else:
new_key = key
new_state_dict[new_key] = value
return new_state_dict
def mark_trainable_callback(model: nn.Module):
pass
def load_state_dict_callback(model: nn.Module, adapter_name: str, state_dict: Dict[str, torch.Tensor]):
new_state_dict = {}
for name, module in model.named_modules(remove_duplicate=False):
module: nn.Module
if Part.target_module_matched(name, config):
for param_name in state_dict:
if param_name.startswith(name):
end = param_name[len(name):]
if '_part_' not in param_name:
if hasattr(module, 'base_layer'):
new_state_dict[name + f'.base_layer._part_{adapter_name}'
+ end] = state_dict[param_name]
else:
new_state_dict[name + f'._part_{adapter_name}' + end] = state_dict[param_name]
else:
new_state_dict[param_name] = state_dict[param_name]
return new_state_dict
return SwiftOutput(
config=config,
state_dict_callback=state_dict_callback,
mark_trainable_callback=mark_trainable_callback,
load_state_dict_callback=load_state_dict_callback)
@staticmethod
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
name_list = [name for name, _ in module.named_modules(remove_duplicate=False)]
for name in name_list:
sub_module: nn.Module = module.get_submodule(name)
if re.fullmatch(f'.*_part_{adapter_name}$', name):
sub_module.activated = activate
SwiftAdapter.save_memory(sub_module, adapter_name, name, activate, offload)