import torch import torch.nn as nn import importlib from collections import OrderedDict from omegaconf import OmegaConf from copy import deepcopy from modules import devices, lowvram, shared, scripts from ldm.modules.diffusionmodules.util import timestep_embedding from ldm.modules.diffusionmodules.openaimodel import UNetModel class TorchHijackForUnet: """ This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match; this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64 """ def __getattr__(self, item): if item == 'cat': return self.cat if hasattr(torch, item): return getattr(torch, item) raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) def cat(self, tensors, *args, **kwargs): if len(tensors) == 2: a, b = tensors if a.shape[-2:] != b.shape[-2:]: a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest") tensors = (a, b) return torch.cat(tensors, *args, **kwargs) th = TorchHijackForUnet() def align(hint, size): b, c, h1, w1 = hint.shape h, w = size if h != h1 or w != w1: hint = th.nn.functional.interpolate(hint, size=size, mode="nearest") return hint def get_node_name(name, parent_name): if len(name) <= len(parent_name): return False, '' p = name[:len(parent_name)] if p != parent_name: return False, '' return True, name[len(parent_name):] def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) class PlugableAdapter(nn.Module): def __init__(self, state_dict, config_path, lowvram=False, base_model=None) -> None: super().__init__() config = OmegaConf.load(config_path) model = Adapter try: self.target = config.model.target model = get_obj_from_str(config.model.target) except ImportError: pass self.control_model = model(**config.model.params) self.control_model.load_state_dict(state_dict) self.lowvram = lowvram self.control = None self.hint_cond = None if not self.lowvram: self.control_model.to(devices.get_device_for("controlnet")) def reset(self): self.control = None self.hint_cond = None def forward(self, hint=None, *args, **kwargs): if self.control is not None: return deepcopy(self.control) self.hint_cond = hint hint_in = hint if hasattr(self.control_model, 'conv_in') and self.control_model.conv_in.in_channels == 64: hint_in = hint_in[0].unsqueeze(0).unsqueeze(0) else: hint_in = hint_in.unsqueeze(0) self.control = self.control_model(hint_in) return deepcopy(self.control) def conv_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. """ if dims == 1: return nn.Conv1d(*args, **kwargs) elif dims == 2: return nn.Conv2d(*args, **kwargs) elif dims == 3: return nn.Conv3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") def avg_pool_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D average pooling module. """ if dims == 1: return nn.AvgPool1d(*args, **kwargs) elif dims == 2: return nn.AvgPool2d(*args, **kwargs) elif dims == 3: return nn.AvgPool3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") class Downsample(nn.Module): """ A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. """ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims stride = 2 if dims != 3 else (1, 2, 2) if use_conv: self.op = conv_nd( dims, self.channels, self.out_channels, 3, stride=stride, padding=padding ) else: assert self.channels == self.out_channels self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) def forward(self, x): assert x.shape[1] == self.channels return self.op(x) class ResnetBlock(nn.Module): def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True): super().__init__() ps = ksize//2 if in_c != out_c or sk==False: self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps) else: # print('n_in') self.in_conv = None self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1) self.act = nn.ReLU() self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps) if sk==False: self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps) else: # print('n_sk') self.skep = None self.down = down if self.down == True: self.down_opt = Downsample(in_c, use_conv=use_conv) def forward(self, x): if self.down == True: x = self.down_opt(x) if self.in_conv is not None: # edit h = self.in_conv(x) # x = self.in_conv(x) # else: # x = x h = self.block1(h) h = self.act(h) h = self.block2(h) if self.skep is not None: return h + self.skep(x) else: return h + x class ResnetBlock(nn.Module): def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True): super().__init__() ps = ksize//2 if in_c != out_c or sk==False: self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps) else: # print('n_in') self.in_conv = None self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1) self.act = nn.ReLU() self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps) if sk==False: self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps) else: self.skep = None self.down = down if self.down == True: self.down_opt = Downsample(in_c, use_conv=use_conv) def forward(self, x): if self.down == True: x = self.down_opt(x) if self.in_conv is not None: # edit x = self.in_conv(x) h = self.block1(x) h = self.act(h) h = self.block2(h) if self.skep is not None: return h + self.skep(x) else: return h + x class Adapter(nn.Module): def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True): super(Adapter, self).__init__() self.unshuffle = nn.PixelUnshuffle(8) self.channels = channels self.nums_rb = nums_rb self.body = [] for i in range(len(channels)): for j in range(nums_rb): if (i!=0) and (j==0): self.body.append(ResnetBlock(channels[i-1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv)) else: self.body.append(ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv)) self.body = nn.ModuleList(self.body) self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1) def forward(self, x): # unshuffle x = self.unshuffle(x) # extract features features = [] x = self.conv_in(x) for i in range(len(self.channels)): for j in range(self.nums_rb): idx = i*self.nums_rb +j x = self.body[idx](x) features.append(x) return features class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16.""" def forward(self, x: torch.Tensor): orig_type = x.dtype ret = super().forward(x.type(torch.float32)) return ret.type(orig_type) class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class ResidualAttentionBlock(nn.Module): def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential( OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model))])) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask def attention(self, x: torch.Tensor): self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] def forward(self, x: torch.Tensor): x = x + self.attention(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x class StyleAdapter(nn.Module): def __init__(self, width=1024, context_dim=768, num_head=8, n_layes=3, num_token=4): super().__init__() scale = width ** -0.5 self.transformer_layes = nn.Sequential(*[ResidualAttentionBlock(width, num_head) for _ in range(n_layes)]) self.num_token = num_token self.style_embedding = nn.Parameter(torch.randn(1, num_token, width) * scale) self.ln_post = LayerNorm(width) self.ln_pre = LayerNorm(width) self.proj = nn.Parameter(scale * torch.randn(width, context_dim)) def forward(self, x): # x shape [N, HW+1, C] style_embedding = self.style_embedding + torch.zeros( (x.shape[0], self.num_token, self.style_embedding.shape[-1]), device=x.device) x = torch.cat([x, style_embedding], dim=1) x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer_layes(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_post(x[:, -self.num_token:, :]) x = x @ self.proj return x class ResnetBlock_light(nn.Module): def __init__(self, in_c): super().__init__() self.block1 = nn.Conv2d(in_c, in_c, 3, 1, 1) self.act = nn.ReLU() self.block2 = nn.Conv2d(in_c, in_c, 3, 1, 1) def forward(self, x): h = self.block1(x) h = self.act(h) h = self.block2(h) return h + x class extractor(nn.Module): def __init__(self, in_c, inter_c, out_c, nums_rb, down=False): super().__init__() self.in_conv = nn.Conv2d(in_c, inter_c, 1, 1, 0) self.body = [] for _ in range(nums_rb): self.body.append(ResnetBlock_light(inter_c)) self.body = nn.Sequential(*self.body) self.out_conv = nn.Conv2d(inter_c, out_c, 1, 1, 0) self.down = down if self.down == True: self.down_opt = Downsample(in_c, use_conv=False) def forward(self, x): if self.down == True: x = self.down_opt(x) x = self.in_conv(x) x = self.body(x) x = self.out_conv(x) return x class Adapter_light(nn.Module): def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64): super(Adapter_light, self).__init__() self.unshuffle = nn.PixelUnshuffle(8) self.channels = channels self.nums_rb = nums_rb self.body = [] for i in range(len(channels)): if i == 0: self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False)) else: self.body.append(extractor(in_c=channels[i-1], inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=True)) self.body = nn.ModuleList(self.body) def forward(self, x): # unshuffle x = self.unshuffle(x) # extract features features = [] for i in range(len(self.channels)): x = self.body[i](x) features.append(x) return features