tonyshark's picture
Upload 132 files
cc69848 verified
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
from ..utils.quant import QuantLinears, log_bypass, log_suspect
class ModuleCustomSD(nn.Module):
def __init__(self):
super().__init__()
self._register_load_state_dict_pre_hook(self.load_weight_prehook)
self.register_load_state_dict_post_hook(self.load_weight_hook)
def load_weight_prehook(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
pass
def load_weight_hook(self, module, incompatible_keys):
pass
def custom_state_dict(self):
return None
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
# TODO: Remove `args` and the parsing logic when BC allows.
if len(args) > 0:
if destination is None:
destination = args[0]
if len(args) > 1 and prefix == "":
prefix = args[1]
if len(args) > 2 and keep_vars is False:
keep_vars = args[2]
# DeprecationWarning is ignored by default
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
local_metadata = dict(version=self._version)
if hasattr(destination, "_metadata"):
destination._metadata[prefix[:-1]] = local_metadata
if (custom_sd := self.custom_state_dict()) is not None:
for k, v in custom_sd.items():
destination[f"{prefix}{k}"] = v
return destination
else:
return super().state_dict(
*args, destination=destination, prefix=prefix, keep_vars=keep_vars
)
class LycorisBaseModule(ModuleCustomSD):
name: str
dtype_tensor: torch.Tensor
support_module = {}
weight_list = []
weight_list_det = []
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
dropout=0.0,
rank_dropout=0.0,
module_dropout=0.0,
rank_dropout_scale=False,
bypass_mode=None,
**kwargs,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
self.lora_name = lora_name
self.not_supported = False
self.module = type(org_module)
if isinstance(org_module, nn.Linear):
self.module_type = "linear"
self.shape = (org_module.out_features, org_module.in_features)
self.op = F.linear
self.dim = org_module.out_features
self.kw_dict = {}
elif isinstance(org_module, nn.Conv1d):
self.module_type = "conv1d"
self.shape = (
org_module.out_channels,
org_module.in_channels,
*org_module.kernel_size,
)
self.op = F.conv1d
self.dim = org_module.out_channels
self.kw_dict = {
"stride": org_module.stride,
"padding": org_module.padding,
"dilation": org_module.dilation,
"groups": org_module.groups,
}
elif isinstance(org_module, nn.Conv2d):
self.module_type = "conv2d"
self.shape = (
org_module.out_channels,
org_module.in_channels,
*org_module.kernel_size,
)
self.op = F.conv2d
self.dim = org_module.out_channels
self.kw_dict = {
"stride": org_module.stride,
"padding": org_module.padding,
"dilation": org_module.dilation,
"groups": org_module.groups,
}
elif isinstance(org_module, nn.Conv3d):
self.module_type = "conv3d"
self.shape = (
org_module.out_channels,
org_module.in_channels,
*org_module.kernel_size,
)
self.op = F.conv3d
self.dim = org_module.out_channels
self.kw_dict = {
"stride": org_module.stride,
"padding": org_module.padding,
"dilation": org_module.dilation,
"groups": org_module.groups,
}
elif isinstance(org_module, nn.LayerNorm):
self.module_type = "layernorm"
self.shape = tuple(org_module.normalized_shape)
self.op = F.layer_norm
self.dim = org_module.normalized_shape[0]
self.kw_dict = {
"normalized_shape": org_module.normalized_shape,
"eps": org_module.eps,
}
elif isinstance(org_module, nn.GroupNorm):
self.module_type = "groupnorm"
self.shape = (org_module.num_channels,)
self.op = F.group_norm
self.group_num = org_module.num_groups
self.dim = org_module.num_channels
self.kw_dict = {"num_groups": org_module.num_groups, "eps": org_module.eps}
else:
self.not_supported = True
self.module_type = "unknown"
self.register_buffer("dtype_tensor", torch.tensor(0.0), persistent=False)
self.is_quant = False
if isinstance(org_module, QuantLinears):
if not bypass_mode:
log_bypass()
self.is_quant = True
bypass_mode = True
if (
isinstance(org_module, nn.Linear)
and org_module.__class__.__name__ != "Linear"
):
if bypass_mode is None:
log_suspect()
bypass_mode = True
if bypass_mode == True:
self.is_quant = True
self.bypass_mode = bypass_mode
self.dropout = dropout
self.rank_dropout = rank_dropout
self.rank_dropout_scale = rank_dropout_scale
self.module_dropout = module_dropout
## Dropout things
# Since LoKr/LoHa/OFT/BOFT are hard to follow the rank_dropout definition from kohya
# We redefine the dropout procedure here.
# g(x) = WX + drop(Brank_drop(AX)) for LoCon(lora), bypass
# g(x) = WX + drop(ΔWX) for any algo except LoCon(lora), bypass
# g(x) = (W + Brank_drop(A))X for LoCon(lora), rebuid
# g(x) = (W + rank_drop(ΔW))X for any algo except LoCon(lora), rebuild
self.drop = nn.Identity() if dropout == 0 else nn.Dropout(dropout)
self.rank_drop = (
nn.Identity() if rank_dropout == 0 else nn.Dropout(rank_dropout)
)
self.multiplier = multiplier
self.org_forward = org_module.forward
self.org_module = [org_module]
@classmethod
def parametrize(cls, org_module, attr, *args, **kwargs):
from .full import FullModule
if cls is FullModule:
raise RuntimeError("FullModule cannot be used for parametrize.")
target_param = getattr(org_module, attr)
kwargs["bypass_mode"] = False
if target_param.dim() == 2:
proxy_module = nn.Linear(
target_param.shape[0], target_param.shape[1], bias=False
)
proxy_module.weight = target_param
elif target_param.dim() > 2:
module_type = [
None,
None,
None,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
None,
None,
][target_param.dim()]
proxy_module = module_type(
target_param.shape[0],
target_param.shape[1],
*target_param.shape[2:],
bias=False,
)
proxy_module.weight = target_param
module_obj = cls("", proxy_module, *args, **kwargs)
module_obj.forward = module_obj.parametrize_forward
module_obj.to(target_param)
parametrize.register_parametrization(org_module, attr, module_obj)
return module_obj
@classmethod
def algo_check(cls, state_dict, lora_name):
return any(f"{lora_name}.{k}" in state_dict for k in cls.weight_list_det)
@classmethod
def extract_state_dict(cls, state_dict, lora_name):
return [state_dict.get(f"{lora_name}.{k}", None) for k in cls.weight_list]
@classmethod
def make_module_from_state_dict(cls, lora_name, orig_module, *weights):
raise NotImplementedError
@property
def dtype(self):
return self.dtype_tensor.dtype
@property
def device(self):
return self.dtype_tensor.device
@property
def org_weight(self):
return self.org_module[0].weight
@org_weight.setter
def org_weight(self, value):
self.org_module[0].weight.data.copy_(value)
def apply_to(self, **kwargs):
if self.not_supported:
return
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
def restore(self):
if self.not_supported:
return
self.org_module[0].forward = self.org_forward
def merge_to(self, multiplier=1.0):
if self.not_supported:
return
self_device = next(self.parameters()).device
self_dtype = next(self.parameters()).dtype
self.to(self.org_weight)
weight, bias = self.get_merged_weight(
multiplier, self.org_weight.shape, self.org_weight.device
)
self.org_weight = weight.to(self.org_weight)
if bias is not None:
bias = bias.to(self.org_weight)
if self.org_module[0].bias is not None:
self.org_module[0].bias.data.copy_(bias)
else:
self.org_module[0].bias = nn.Parameter(bias)
self.to(self_device, self_dtype)
def get_diff_weight(self, multiplier=1.0, shape=None, device=None):
raise NotImplementedError
def get_merged_weight(self, multiplier=1.0, shape=None, device=None):
raise NotImplementedError
@torch.no_grad()
def apply_max_norm(self, max_norm, device=None):
return None, None
def bypass_forward_diff(self, x, scale=1):
raise NotImplementedError
def bypass_forward(self, x, scale=1):
raise NotImplementedError
def parametrize_forward(self, x: torch.Tensor, *args, **kwargs):
return self.get_merged_weight(
multiplier=self.multiplier, shape=x.shape, device=x.device
)[0].to(x.dtype)
def forward(self, *args, **kwargs):
raise NotImplementedError