|
|
|
import re |
|
from copy import deepcopy |
|
from dataclasses import dataclass |
|
from types import MethodType |
|
from typing import Dict, Optional |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from swift.utils import get_logger |
|
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput |
|
|
|
logger = get_logger() |
|
|
|
|
|
@dataclass |
|
class PartConfig(SwiftConfig): |
|
""" |
|
Freeze the model and train a part of it. |
|
|
|
Args: |
|
target_modules(`Optional[str]`): The target modules to be trained in regex format |
|
""" |
|
|
|
target_modules: Optional[str] = None |
|
|
|
def __post_init__(self): |
|
from .mapping import SwiftTuners |
|
self.swift_type = SwiftTuners.PART |
|
|
|
|
|
class Part(SwiftAdapter): |
|
|
|
@staticmethod |
|
def target_module_matched(module_key: str, config: PartConfig): |
|
return re.fullmatch(config.target_modules, module_key) |
|
|
|
@staticmethod |
|
def prepare_model(model: nn.Module, config: PartConfig, adapter_name: str): |
|
name_list = [name for name, _ in model.named_modules(remove_duplicate=False)] |
|
for name in name_list: |
|
module: nn.Module = model.get_submodule(name) |
|
if Part.target_module_matched(name, config) and not getattr(module, 'plugin', False): |
|
if hasattr(module, 'base_layer'): |
|
module = module.base_layer |
|
|
|
def _forward(self, *args, **kwargs): |
|
child_list = [ |
|
sub_module for name, sub_module in self.named_modules(remove_duplicate=False) |
|
if '_part_' in name |
|
] |
|
sub_modules = [child for child in child_list if getattr(child, 'activated', False)] |
|
assert len(sub_modules) <= 1 |
|
if len(sub_modules) == 1: |
|
return sub_modules[0].forward(*args, **kwargs) |
|
else: |
|
return self.forward_origin(*args, **kwargs) |
|
|
|
if not hasattr(module, 'forward_origin'): |
|
module.forward_origin = module.forward |
|
module.forward = MethodType(_forward, module) |
|
|
|
new_module = deepcopy(module) |
|
for attr in dir(new_module): |
|
if '_part_' in attr: |
|
delattr(new_module, attr) |
|
new_module.part_name = adapter_name |
|
ActivationMixin.mark_all_sub_modules_as_plugin(new_module) |
|
setattr(module, f'_part_{adapter_name}', new_module) |
|
new_module.requires_grad_(True) |
|
|
|
def state_dict_callback(state_dict, adapter_name, **kwargs): |
|
new_state_dict = {} |
|
for key, value in state_dict.items(): |
|
if f'_part_{adapter_name}.' in key: |
|
if kwargs.get('replace_key', True): |
|
new_key = key.replace(f'_part_{adapter_name}.', '').replace('base_layer.', '') |
|
else: |
|
new_key = key |
|
new_state_dict[new_key] = value |
|
|
|
return new_state_dict |
|
|
|
def mark_trainable_callback(model: nn.Module): |
|
pass |
|
|
|
def load_state_dict_callback(model: nn.Module, adapter_name: str, state_dict: Dict[str, torch.Tensor]): |
|
new_state_dict = {} |
|
for name, module in model.named_modules(remove_duplicate=False): |
|
module: nn.Module |
|
if Part.target_module_matched(name, config): |
|
for param_name in state_dict: |
|
if param_name.startswith(name): |
|
end = param_name[len(name):] |
|
if '_part_' not in param_name: |
|
if hasattr(module, 'base_layer'): |
|
new_state_dict[name + f'.base_layer._part_{adapter_name}' |
|
+ end] = state_dict[param_name] |
|
else: |
|
new_state_dict[name + f'._part_{adapter_name}' + end] = state_dict[param_name] |
|
else: |
|
new_state_dict[param_name] = state_dict[param_name] |
|
return new_state_dict |
|
|
|
return SwiftOutput( |
|
config=config, |
|
state_dict_callback=state_dict_callback, |
|
mark_trainable_callback=mark_trainable_callback, |
|
load_state_dict_callback=load_state_dict_callback) |
|
|
|
@staticmethod |
|
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): |
|
name_list = [name for name, _ in module.named_modules(remove_duplicate=False)] |
|
for name in name_list: |
|
sub_module: nn.Module = module.get_submodule(name) |
|
if re.fullmatch(f'.*_part_{adapter_name}$', name): |
|
sub_module.activated = activate |
|
SwiftAdapter.save_memory(sub_module, adapter_name, name, activate, offload) |
|
|