|
|
|
import re |
|
import types |
|
from dataclasses import dataclass, field |
|
from typing import List, Optional, Union |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from swift.tuners.utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput |
|
from swift.utils import get_logger |
|
from swift.utils.torch_utils import find_sub_module |
|
from .scetuning_components import probe_output_hook |
|
|
|
logger = get_logger() |
|
|
|
|
|
@dataclass |
|
class SCETuningConfig(SwiftConfig): |
|
""" |
|
The configuration class for the SCEdit module. |
|
|
|
'SCEdit: Efficient and Controllable Image Diffusion Generation via Skip Connection Editing' by Jiang et al.(2023) |
|
See https://arxiv.org/abs/2312.11392 |
|
|
|
Args: |
|
dims(`Union[List[int], int]`): The dimensions of the hidden states |
|
target_modules(`Union[List[str], str]`): The target module to be replaced, can a regex string |
|
hint_modules(`Union[List[str], str]`): The hint module to be replaced, can a regex string |
|
tuner_mode(`str`): Location of tuner operation. |
|
tuner_op(`str`): Tuner operation. |
|
down_ratio(`float`): The dim down ratio of tuner hidden state. |
|
""" |
|
|
|
dims: Optional[Union[List[int], int]] = field( |
|
default=None, metadata={'help': 'The dimensions of the hidden states'}) |
|
|
|
target_modules: Optional[Union[List[str], str]] = field( |
|
default=None, |
|
metadata={'help': 'The target module to be replaced, can be a regex string or name list of full match format'}) |
|
|
|
hint_modules: Optional[Union[List[str], str]] = field( |
|
default=None, |
|
metadata={'help': 'The hint modules to be replaced, can be a regex string or name list of full match format'}) |
|
|
|
tuner_mode: str = field( |
|
default='decoder', |
|
metadata={'help': 'Location of tuner operation. The tuner mode choices: encoder, decoder, and identity'}) |
|
|
|
tuner_op: str = field(default='SCEAdapter', metadata={'help': 'The tuner ops choices: SCEAdapter'}) |
|
|
|
down_ratio: float = field(default=1.0, metadata={'help': 'The dim down ratio of tuner hidden state'}) |
|
|
|
def __post_init__(self): |
|
from swift.tuners.mapping import SwiftTuners |
|
self.swift_type = SwiftTuners.SCETUNING |
|
|
|
|
|
class SCETuning(SwiftAdapter): |
|
|
|
@staticmethod |
|
def prepare_model(model: nn.Module, config: SCETuningConfig, adapter_name: str) -> SwiftOutput: |
|
"""Prepare a model with `SCETuningConfig`""" |
|
module_keys = [key for key, _ in model.named_modules()] |
|
|
|
hint_module_ins_list = [] |
|
if config.hint_modules: |
|
if isinstance(config.hint_modules, list): |
|
for module_key in config.hint_modules: |
|
assert module_key in module_keys |
|
h_module = model.get_submodule(module_key) |
|
logger.info(f'Matching hint module [{module_key}] of type {type(h_module)}') |
|
if isinstance(h_module, (nn.ModuleList, nn.ModuleDict)): |
|
logger.warning( |
|
f'Type of {type(h_module)} may not be supported because of its customized forward') |
|
h_module.register_forward_hook(probe_output_hook, with_kwargs=True) |
|
hint_module_ins_list.append(h_module) |
|
else: |
|
for module_key in module_keys: |
|
if re.fullmatch(config.hint_modules, module_key): |
|
h_module = model.get_submodule(module_key) |
|
logger.info(f'Matching hint module [{module_key}] of type {type(h_module)}') |
|
if isinstance(h_module, (nn.ModuleList, nn.ModuleDict)): |
|
logger.warning( |
|
f'Type of {type(h_module)} may not be supported because of its customized forward') |
|
h_module.register_forward_hook(probe_output_hook, with_kwargs=True) |
|
hint_module_ins_list.append(h_module) |
|
if len(hint_module_ins_list) == 0: |
|
logger.error('Cannot match hint modules') |
|
|
|
def _get_module(module): |
|
if isinstance(module, nn.ModuleList): |
|
module = module[-1] |
|
return _get_module(module) |
|
return module |
|
|
|
|
|
target_module_ins_list = [] |
|
assert config.target_modules is not None |
|
if isinstance(config.target_modules, list): |
|
for module_key in config.target_modules: |
|
assert module_key in module_keys |
|
t_module = model.get_submodule(module_key) |
|
logger.info(f'Matching target module [{module_key}] of type {type(t_module)}') |
|
target_module_ins_list.append(_get_module(t_module)) |
|
else: |
|
for module_key in module_keys: |
|
if re.fullmatch(config.target_modules, module_key): |
|
t_module = model.get_submodule(module_key) |
|
logger.info(f'Matching target module [{module_key}] of type {type(t_module)}') |
|
target_module_ins_list.append(_get_module(t_module)) |
|
if len(target_module_ins_list) == 0: |
|
logger.error('Cannot match target modules') |
|
if len(hint_module_ins_list) > 0 and not len(hint_module_ins_list) == len(target_module_ins_list): |
|
logger.info("Target modules' length should be equal with hint modules.") |
|
assert len(hint_module_ins_list) == len(target_module_ins_list) |
|
if isinstance(config.dims, int): |
|
dims = [config.dims for _ in target_module_ins_list] |
|
else: |
|
assert len(config.dims) == len(target_module_ins_list) |
|
dims = config.dims |
|
|
|
|
|
def _forward_encoder_mode(self, *args, **kwargs): |
|
args = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs) |
|
args_type = type(args) |
|
if args_type is tuple: |
|
args = args[0] |
|
if hasattr(self, 'hint'): |
|
hint_out = self.hint.probe_output_data |
|
args_main = getattr(self, f'scetuner_{adapter_name}')(args, hint_out) |
|
else: |
|
args_main = getattr(self, f'scetuner_{adapter_name}')(args) |
|
if args_type is tuple: |
|
args_main = (args_main, ) |
|
return args_main |
|
|
|
def _forward_decoder_mode(self, *args, **kwargs): |
|
args_type = type(args) |
|
if args_type is tuple: |
|
args_sub_tuner = args[0] |
|
args_sub_extra = args[1:] |
|
tuner_module = getattr(self, f'scetuner_{adapter_name}') |
|
args_hidden, args_res = torch.split(args_sub_tuner, args_sub_tuner.shape[1] - tuner_module.dim, 1) |
|
if hasattr(self, 'hint'): |
|
hint_out = self.hint.probe_output_data |
|
args_res_new = tuner_module(args_res, hint_out) |
|
else: |
|
args_res_new = tuner_module(args_res) |
|
args_sub_tuner_new = torch.cat([args_hidden, args_res_new], dim=1) |
|
if args_type is tuple: |
|
args_main = (args_sub_tuner_new, *args_sub_extra) |
|
|
|
args_main = getattr(self, f'forward_origin_{adapter_name}')(*args_main, **kwargs) |
|
return args_main |
|
|
|
|
|
for tuner_id, t_module in enumerate(target_module_ins_list): |
|
setattr(t_module, f'forward_origin_{adapter_name}', getattr(t_module, 'forward')) |
|
if config.tuner_mode in ('encoder', 'identity'): |
|
_forward = _forward_encoder_mode |
|
elif config.tuner_mode == 'decoder': |
|
_forward = _forward_decoder_mode |
|
else: |
|
raise Exception(f'Error tuner_mode: {config.tuner_mode}') |
|
setattr(t_module, 'forward', types.MethodType(_forward, t_module)) |
|
tuner_op = SCETunerModule( |
|
name=config.tuner_op, |
|
adapter_name=adapter_name, |
|
module_key=str(tuner_id), |
|
dim=dims[tuner_id], |
|
tuner_length=int(dims[tuner_id] * config.down_ratio)) |
|
setattr(t_module, f'scetuner_{adapter_name}', tuner_op) |
|
if len(hint_module_ins_list) > 0: |
|
setattr(t_module, 'hint', hint_module_ins_list[tuner_id]) |
|
|
|
def state_dict_callback(state_dict, adapter_name, **kwargs): |
|
state_dict_new = {key: value for key, value in state_dict.items() if f'scetuner_{adapter_name}' in key} |
|
return state_dict_new |
|
|
|
def mark_trainable_callback(model): |
|
return |
|
|
|
return SwiftOutput( |
|
config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback) |
|
|
|
@staticmethod |
|
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): |
|
modules = find_sub_module(module, f'scetuner_{adapter_name}') |
|
for _module in modules: |
|
_module: ActivationMixin |
|
_module: nn.Module |
|
_module.set_activation(adapter_name, activate) |
|
SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload) |
|
|
|
|
|
class SCETunerModule(nn.Module, ActivationMixin): |
|
|
|
def __init__(self, |
|
name, |
|
adapter_name, |
|
module_key, |
|
dim, |
|
tuner_length, |
|
tuner_type=None, |
|
tuner_weight=None, |
|
act_layer=nn.GELU, |
|
zero_init_last=True, |
|
use_bias=True): |
|
super(SCETunerModule, self).__init__() |
|
super(nn.Module, self).__init__(module_key) |
|
self.name = name |
|
self.adapter_name = adapter_name |
|
self.dim = dim |
|
if name == 'SCEAdapter': |
|
from .scetuning_components import SCEAdapter |
|
self.tuner_op = SCEAdapter( |
|
dim=dim, |
|
adapter_length=tuner_length, |
|
adapter_type=tuner_type, |
|
adapter_weight=tuner_weight, |
|
act_layer=act_layer) |
|
else: |
|
raise Exception(f'Error tuner op {name}') |
|
self.mark_all_sub_modules_as_plugin() |
|
|
|
def forward(self, x, x_shortcut=None, use_shortcut=True, **kwargs): |
|
if not self.is_activated(self.adapter_name): |
|
return x |
|
if self.name == 'SCEAdapter': |
|
self.tuner_op.to(x.device) |
|
out = self.tuner_op(x) |
|
else: |
|
raise Exception(f'Error tuner op {self.name}') |
|
return out |
|
|