toto10's picture
Upload folder using huggingface_hub (#1)
34097e9
raw
history blame
No virus
12.8 kB
import torch
import torch.nn as nn
import importlib
from collections import OrderedDict
from omegaconf import OmegaConf
from copy import deepcopy
from modules import devices, lowvram, shared, scripts
cond_cast_unet = getattr(devices, 'cond_cast_unet', lambda x: x)
from ldm.modules.diffusionmodules.util import timestep_embedding
from ldm.modules.diffusionmodules.openaimodel import UNetModel
class TorchHijackForUnet:
"""
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
"""
def __getattr__(self, item):
if item == 'cat':
return self.cat
if hasattr(torch, item):
return getattr(torch, item)
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
def cat(self, tensors, *args, **kwargs):
if len(tensors) == 2:
a, b = tensors
if a.shape[-2:] != b.shape[-2:]:
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
tensors = (a, b)
return torch.cat(tensors, *args, **kwargs)
th = TorchHijackForUnet()
def align(hint, size):
b, c, h1, w1 = hint.shape
h, w = size
if h != h1 or w != w1:
hint = th.nn.functional.interpolate(hint, size=size, mode="nearest")
return hint
def get_node_name(name, parent_name):
if len(name) <= len(parent_name):
return False, ''
p = name[:len(parent_name)]
if p != parent_name:
return False, ''
return True, name[len(parent_name):]
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
class PlugableAdapter(nn.Module):
def __init__(self, state_dict, config_path, lowvram=False, base_model=None) -> None:
super().__init__()
self.config = OmegaConf.load(config_path)
model = Adapter
try:
self.target = self.config.model.target
model = get_obj_from_str(self.config.model.target)
except ImportError:
pass
self.control_model = model(**self.config.model.params)
self.control_model.load_state_dict(state_dict)
self.lowvram = lowvram
self.control = None
self.hint_cond = None
if not self.lowvram:
self.control_model.to(devices.get_device_for("controlnet"))
def reset(self):
self.control = None
self.hint_cond = None
def forward(self, hint=None, x=None, *args, **kwargs):
if self.control is not None:
return deepcopy(self.control)
self.hint_cond = cond_cast_unet(hint)
hint_in = cond_cast_unet(hint)
if hasattr(self.control_model, 'conv_in') and self.control_model.conv_in.in_channels == 64:
hint_in = hint_in[:, 0:1, :, :]
self.control = self.control_model(hint_in)
return deepcopy(self.control)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResnetBlock(nn.Module):
def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
super().__init__()
ps = ksize//2
if in_c != out_c or sk==False:
self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
else:
# print('n_in')
self.in_conv = None
self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
self.act = nn.ReLU()
self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
if sk==False:
self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
else:
# print('n_sk')
self.skep = None
self.down = down
if self.down == True:
self.down_opt = Downsample(in_c, use_conv=use_conv)
def forward(self, x):
if self.down == True:
x = self.down_opt(x)
if self.in_conv is not None: # edit
h = self.in_conv(x)
# x = self.in_conv(x)
# else:
# x = x
h = self.block1(h)
h = self.act(h)
h = self.block2(h)
if self.skep is not None:
return h + self.skep(x)
else:
return h + x
class ResnetBlock(nn.Module):
def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
super().__init__()
ps = ksize//2
if in_c != out_c or sk==False:
self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
else:
# print('n_in')
self.in_conv = None
self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
self.act = nn.ReLU()
self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
if sk==False:
self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
else:
self.skep = None
self.down = down
if self.down == True:
self.down_opt = Downsample(in_c, use_conv=use_conv)
def forward(self, x):
if self.down == True:
x = self.down_opt(x)
if self.in_conv is not None: # edit
x = self.in_conv(x)
h = self.block1(x)
h = self.act(h)
h = self.block2(h)
if self.skep is not None:
return h + self.skep(x)
else:
return h + x
class Adapter(nn.Module):
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True):
super(Adapter, self).__init__()
self.unshuffle = nn.PixelUnshuffle(8)
self.channels = channels
self.nums_rb = nums_rb
self.body = []
for i in range(len(channels)):
for j in range(nums_rb):
if (i!=0) and (j==0):
self.body.append(ResnetBlock(channels[i-1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
else:
self.body.append(ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
self.body = nn.ModuleList(self.body)
self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)
def forward(self, x):
# unshuffle
x = self.unshuffle(x)
# extract features
features = []
x = self.conv_in(x)
for i in range(len(self.channels)):
for j in range(self.nums_rb):
idx = i*self.nums_rb +j
x = self.body[idx](x)
features.append(x)
return features
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class StyleAdapter(nn.Module):
def __init__(self, width=1024, context_dim=768, num_head=8, n_layes=3, num_token=4):
super().__init__()
scale = width ** -0.5
self.transformer_layes = nn.Sequential(*[ResidualAttentionBlock(width, num_head) for _ in range(n_layes)])
self.num_token = num_token
self.style_embedding = nn.Parameter(torch.randn(1, num_token, width) * scale)
self.ln_post = LayerNorm(width)
self.ln_pre = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, context_dim))
def forward(self, x):
# x shape [N, HW+1, C]
style_embedding = self.style_embedding + torch.zeros(
(x.shape[0], self.num_token, self.style_embedding.shape[-1]), device=x.device)
x = torch.cat([x, style_embedding], dim=1)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer_layes(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, -self.num_token:, :])
x = x @ self.proj
return x
class ResnetBlock_light(nn.Module):
def __init__(self, in_c):
super().__init__()
self.block1 = nn.Conv2d(in_c, in_c, 3, 1, 1)
self.act = nn.ReLU()
self.block2 = nn.Conv2d(in_c, in_c, 3, 1, 1)
def forward(self, x):
h = self.block1(x)
h = self.act(h)
h = self.block2(h)
return h + x
class extractor(nn.Module):
def __init__(self, in_c, inter_c, out_c, nums_rb, down=False):
super().__init__()
self.in_conv = nn.Conv2d(in_c, inter_c, 1, 1, 0)
self.body = []
for _ in range(nums_rb):
self.body.append(ResnetBlock_light(inter_c))
self.body = nn.Sequential(*self.body)
self.out_conv = nn.Conv2d(inter_c, out_c, 1, 1, 0)
self.down = down
if self.down == True:
self.down_opt = Downsample(in_c, use_conv=False)
def forward(self, x):
if self.down == True:
x = self.down_opt(x)
x = self.in_conv(x)
x = self.body(x)
x = self.out_conv(x)
return x
class Adapter_light(nn.Module):
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64):
super(Adapter_light, self).__init__()
self.unshuffle = nn.PixelUnshuffle(8)
self.channels = channels
self.nums_rb = nums_rb
self.body = []
for i in range(len(channels)):
if i == 0:
self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False))
else:
self.body.append(extractor(in_c=channels[i-1], inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=True))
self.body = nn.ModuleList(self.body)
def forward(self, x):
# unshuffle
x = self.unshuffle(x)
# extract features
features = []
for i in range(len(self.channels)):
x = self.body[i](x)
features.append(x)
return features