|
from collections import OrderedDict |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.nn.utils.parametrize as parametrize |
|
|
|
from ..utils.quant import QuantLinears, log_bypass, log_suspect |
|
|
|
|
|
class ModuleCustomSD(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self._register_load_state_dict_pre_hook(self.load_weight_prehook) |
|
self.register_load_state_dict_post_hook(self.load_weight_hook) |
|
|
|
def load_weight_prehook( |
|
self, |
|
state_dict, |
|
prefix, |
|
local_metadata, |
|
strict, |
|
missing_keys, |
|
unexpected_keys, |
|
error_msgs, |
|
): |
|
pass |
|
|
|
def load_weight_hook(self, module, incompatible_keys): |
|
pass |
|
|
|
def custom_state_dict(self): |
|
return None |
|
|
|
def state_dict(self, *args, destination=None, prefix="", keep_vars=False): |
|
|
|
if len(args) > 0: |
|
if destination is None: |
|
destination = args[0] |
|
if len(args) > 1 and prefix == "": |
|
prefix = args[1] |
|
if len(args) > 2 and keep_vars is False: |
|
keep_vars = args[2] |
|
|
|
|
|
if destination is None: |
|
destination = OrderedDict() |
|
destination._metadata = OrderedDict() |
|
|
|
local_metadata = dict(version=self._version) |
|
if hasattr(destination, "_metadata"): |
|
destination._metadata[prefix[:-1]] = local_metadata |
|
|
|
if (custom_sd := self.custom_state_dict()) is not None: |
|
for k, v in custom_sd.items(): |
|
destination[f"{prefix}{k}"] = v |
|
return destination |
|
else: |
|
return super().state_dict( |
|
*args, destination=destination, prefix=prefix, keep_vars=keep_vars |
|
) |
|
|
|
|
|
class LycorisBaseModule(ModuleCustomSD): |
|
name: str |
|
dtype_tensor: torch.Tensor |
|
support_module = {} |
|
weight_list = [] |
|
weight_list_det = [] |
|
|
|
def __init__( |
|
self, |
|
lora_name, |
|
org_module: nn.Module, |
|
multiplier=1.0, |
|
dropout=0.0, |
|
rank_dropout=0.0, |
|
module_dropout=0.0, |
|
rank_dropout_scale=False, |
|
bypass_mode=None, |
|
**kwargs, |
|
): |
|
"""if alpha == 0 or None, alpha is rank (no scaling).""" |
|
super().__init__() |
|
self.lora_name = lora_name |
|
self.not_supported = False |
|
|
|
self.module = type(org_module) |
|
if isinstance(org_module, nn.Linear): |
|
self.module_type = "linear" |
|
self.shape = (org_module.out_features, org_module.in_features) |
|
self.op = F.linear |
|
self.dim = org_module.out_features |
|
self.kw_dict = {} |
|
elif isinstance(org_module, nn.Conv1d): |
|
self.module_type = "conv1d" |
|
self.shape = ( |
|
org_module.out_channels, |
|
org_module.in_channels, |
|
*org_module.kernel_size, |
|
) |
|
self.op = F.conv1d |
|
self.dim = org_module.out_channels |
|
self.kw_dict = { |
|
"stride": org_module.stride, |
|
"padding": org_module.padding, |
|
"dilation": org_module.dilation, |
|
"groups": org_module.groups, |
|
} |
|
elif isinstance(org_module, nn.Conv2d): |
|
self.module_type = "conv2d" |
|
self.shape = ( |
|
org_module.out_channels, |
|
org_module.in_channels, |
|
*org_module.kernel_size, |
|
) |
|
self.op = F.conv2d |
|
self.dim = org_module.out_channels |
|
self.kw_dict = { |
|
"stride": org_module.stride, |
|
"padding": org_module.padding, |
|
"dilation": org_module.dilation, |
|
"groups": org_module.groups, |
|
} |
|
elif isinstance(org_module, nn.Conv3d): |
|
self.module_type = "conv3d" |
|
self.shape = ( |
|
org_module.out_channels, |
|
org_module.in_channels, |
|
*org_module.kernel_size, |
|
) |
|
self.op = F.conv3d |
|
self.dim = org_module.out_channels |
|
self.kw_dict = { |
|
"stride": org_module.stride, |
|
"padding": org_module.padding, |
|
"dilation": org_module.dilation, |
|
"groups": org_module.groups, |
|
} |
|
elif isinstance(org_module, nn.LayerNorm): |
|
self.module_type = "layernorm" |
|
self.shape = tuple(org_module.normalized_shape) |
|
self.op = F.layer_norm |
|
self.dim = org_module.normalized_shape[0] |
|
self.kw_dict = { |
|
"normalized_shape": org_module.normalized_shape, |
|
"eps": org_module.eps, |
|
} |
|
elif isinstance(org_module, nn.GroupNorm): |
|
self.module_type = "groupnorm" |
|
self.shape = (org_module.num_channels,) |
|
self.op = F.group_norm |
|
self.group_num = org_module.num_groups |
|
self.dim = org_module.num_channels |
|
self.kw_dict = {"num_groups": org_module.num_groups, "eps": org_module.eps} |
|
else: |
|
self.not_supported = True |
|
self.module_type = "unknown" |
|
|
|
self.register_buffer("dtype_tensor", torch.tensor(0.0), persistent=False) |
|
|
|
self.is_quant = False |
|
if isinstance(org_module, QuantLinears): |
|
if not bypass_mode: |
|
log_bypass() |
|
self.is_quant = True |
|
bypass_mode = True |
|
if ( |
|
isinstance(org_module, nn.Linear) |
|
and org_module.__class__.__name__ != "Linear" |
|
): |
|
if bypass_mode is None: |
|
log_suspect() |
|
bypass_mode = True |
|
if bypass_mode == True: |
|
self.is_quant = True |
|
self.bypass_mode = bypass_mode |
|
self.dropout = dropout |
|
self.rank_dropout = rank_dropout |
|
self.rank_dropout_scale = rank_dropout_scale |
|
self.module_dropout = module_dropout |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.drop = nn.Identity() if dropout == 0 else nn.Dropout(dropout) |
|
self.rank_drop = ( |
|
nn.Identity() if rank_dropout == 0 else nn.Dropout(rank_dropout) |
|
) |
|
|
|
self.multiplier = multiplier |
|
self.org_forward = org_module.forward |
|
self.org_module = [org_module] |
|
|
|
@classmethod |
|
def parametrize(cls, org_module, attr, *args, **kwargs): |
|
from .full import FullModule |
|
|
|
if cls is FullModule: |
|
raise RuntimeError("FullModule cannot be used for parametrize.") |
|
target_param = getattr(org_module, attr) |
|
kwargs["bypass_mode"] = False |
|
if target_param.dim() == 2: |
|
proxy_module = nn.Linear( |
|
target_param.shape[0], target_param.shape[1], bias=False |
|
) |
|
proxy_module.weight = target_param |
|
elif target_param.dim() > 2: |
|
module_type = [ |
|
None, |
|
None, |
|
None, |
|
nn.Conv1d, |
|
nn.Conv2d, |
|
nn.Conv3d, |
|
None, |
|
None, |
|
][target_param.dim()] |
|
proxy_module = module_type( |
|
target_param.shape[0], |
|
target_param.shape[1], |
|
*target_param.shape[2:], |
|
bias=False, |
|
) |
|
proxy_module.weight = target_param |
|
module_obj = cls("", proxy_module, *args, **kwargs) |
|
module_obj.forward = module_obj.parametrize_forward |
|
module_obj.to(target_param) |
|
parametrize.register_parametrization(org_module, attr, module_obj) |
|
return module_obj |
|
|
|
@classmethod |
|
def algo_check(cls, state_dict, lora_name): |
|
return any(f"{lora_name}.{k}" in state_dict for k in cls.weight_list_det) |
|
|
|
@classmethod |
|
def extract_state_dict(cls, state_dict, lora_name): |
|
return [state_dict.get(f"{lora_name}.{k}", None) for k in cls.weight_list] |
|
|
|
@classmethod |
|
def make_module_from_state_dict(cls, lora_name, orig_module, *weights): |
|
raise NotImplementedError |
|
|
|
@property |
|
def dtype(self): |
|
return self.dtype_tensor.dtype |
|
|
|
@property |
|
def device(self): |
|
return self.dtype_tensor.device |
|
|
|
@property |
|
def org_weight(self): |
|
return self.org_module[0].weight |
|
|
|
@org_weight.setter |
|
def org_weight(self, value): |
|
self.org_module[0].weight.data.copy_(value) |
|
|
|
def apply_to(self, **kwargs): |
|
if self.not_supported: |
|
return |
|
self.org_forward = self.org_module[0].forward |
|
self.org_module[0].forward = self.forward |
|
|
|
def restore(self): |
|
if self.not_supported: |
|
return |
|
self.org_module[0].forward = self.org_forward |
|
|
|
def merge_to(self, multiplier=1.0): |
|
if self.not_supported: |
|
return |
|
self_device = next(self.parameters()).device |
|
self_dtype = next(self.parameters()).dtype |
|
self.to(self.org_weight) |
|
weight, bias = self.get_merged_weight( |
|
multiplier, self.org_weight.shape, self.org_weight.device |
|
) |
|
self.org_weight = weight.to(self.org_weight) |
|
if bias is not None: |
|
bias = bias.to(self.org_weight) |
|
if self.org_module[0].bias is not None: |
|
self.org_module[0].bias.data.copy_(bias) |
|
else: |
|
self.org_module[0].bias = nn.Parameter(bias) |
|
self.to(self_device, self_dtype) |
|
|
|
def get_diff_weight(self, multiplier=1.0, shape=None, device=None): |
|
raise NotImplementedError |
|
|
|
def get_merged_weight(self, multiplier=1.0, shape=None, device=None): |
|
raise NotImplementedError |
|
|
|
@torch.no_grad() |
|
def apply_max_norm(self, max_norm, device=None): |
|
return None, None |
|
|
|
def bypass_forward_diff(self, x, scale=1): |
|
raise NotImplementedError |
|
|
|
def bypass_forward(self, x, scale=1): |
|
raise NotImplementedError |
|
|
|
def parametrize_forward(self, x: torch.Tensor, *args, **kwargs): |
|
return self.get_merged_weight( |
|
multiplier=self.multiplier, shape=x.shape, device=x.device |
|
)[0].to(x.dtype) |
|
|
|
def forward(self, *args, **kwargs): |
|
raise NotImplementedError |
|
|