VMSI's picture
Upload 281 files
0690950
import math
from typing import Union, Callable, List
import torch
import torch.nn.functional as F
import torchvision.transforms.functional
from torch import nn, Tensor
from einops import rearrange
from PIL import Image
from modules.processing import StableDiffusionProcessing, slerp as Slerp
from scripts.sdhook import (
SDHook,
each_unet_attn_layers,
each_unet_transformers,
each_unet_resblock
)
class Upscaler:
def __init__(self, mode: str, aa: bool):
mode = {
'nearest': 'nearest-exact',
'bilinear': 'bilinear',
'bicubic': 'bicubic',
}.get(mode.lower(), mode)
self.mode = mode
self.aa = bool(aa)
@property
def name(self):
s = self.mode
if self.aa: s += '-aa'
return s
def __call__(self, x: Tensor, scale: float = 2.0):
return F.interpolate(x, scale_factor=scale, mode=self.mode, antialias=self.aa)
class Downscaler:
def __init__(self, mode: str, aa: bool):
self._name = mode.lower()
intp, mode = {
'nearest': (F.interpolate, 'nearest-exact'),
'bilinear': (F.interpolate, 'bilinear'),
'bicubic': (F.interpolate, 'bicubic'),
'area': (F.interpolate, 'area'),
'pooling max': (F.max_pool2d, ''),
'pooling avg': (F.avg_pool2d, ''),
}[mode.lower()]
self.intp = intp
self.mode = mode
self.aa = bool(aa)
@property
def name(self):
s = self._name
if self.aa: s += '-aa'
return s
def __call__(self, x: Tensor, scale: float = 2.0):
if scale <= 1:
scale = float(scale)
scale_inv = 1 / scale
else:
scale_inv = float(scale)
scale = 1 / scale_inv
assert scale <= 1
assert 1 <= scale_inv
kwargs = {}
if len(self.mode) != 0:
kwargs['scale_factor'] = scale
kwargs['mode'] = self.mode
kwargs['antialias'] = self.aa
else:
kwargs['kernel_size'] = int(scale_inv)
return self.intp(x, **kwargs)
def lerp(v0, v1, t):
return torch.lerp(v0, v1, t)
def slerp(v0, v1, t):
v = Slerp(t, v0, v1)
if torch.any(torch.isnan(v)).item():
v = lerp(v0, v1, t)
return v
class Hooker(SDHook):
def __init__(
self,
enabled: bool,
multiply: int,
weight: float,
layers: Union[list,None],
apply_to: List[str],
start_steps: int,
max_steps: int,
up_fn: Callable[[Tensor,float], Tensor],
down_fn: Callable[[Tensor,float], Tensor],
intp: str,
x: float,
y: float,
force_float: bool,
mask_image: Union[Image.Image,None],
):
super().__init__(enabled)
self.multiply = int(multiply)
self.weight = float(weight)
self.layers = layers
self.apply_to = apply_to
self.start_steps = int(start_steps)
self.max_steps = int(max_steps)
self.up = up_fn
self.down = down_fn
self.x0 = x
self.y0 = y
self.force_float = force_float
self.mask_image = mask_image
if intp == 'lerp':
self.intp = lerp
elif intp == 'slerp':
self.intp = slerp
else:
raise ValueError(f'invalid interpolation method: {intp}')
if not (1 <= self.multiply and (self.multiply & (self.multiply - 1) == 0)):
raise ValueError(f'multiplier must be power of 2, but not: {self.multiply}')
if mask_image is not None:
if mask_image.mode != 'L':
raise ValueError(f'the mode of mask image is: {mask_image.mode}')
def hook_unet(self, p: StableDiffusionProcessing, unet: nn.Module):
step = 0
def hook_step_pre(*args, **kwargs):
nonlocal step
step += 1
self.hook_layer_pre(unet, hook_step_pre)
start_step = self.start_steps
max_steps = self.max_steps
M = self.multiply
def create_pre_hook(name: str, ctx: dict):
def pre_hook(module: nn.Module, inputs: list):
ctx['skipped'] = True
if step < start_step or max_steps < step:
return
x, *rest = inputs
dim = x.dim()
if dim == 3:
# attension
bi, ni, chi = x.shape
wi, hi, Ni = self.get_size(p, ni)
x = rearrange(x, 'b (h w) c -> b c h w', w=wi)
if len(rest) != 0:
# x. attn.
rest[0] = torch.concat((rest[0], rest[0]), dim=0)
elif dim == 4:
# resblock, transformer
bi, chi, hi, wi = x.shape
if 0 < len(rest):
t_emb = rest[0] # t_emb (for resblock) or context (for transformer)
rest[0] = torch.concat((t_emb, t_emb), dim=0)
else:
# `out` layer
pass
else:
return
# extract
w, h = wi // M, hi // M
if w == 0 or h == 0:
# input latent is too small to apply
return
s0, t0 = int(wi * self.x0), int(hi * self.y0)
s1, t1 = s0 + w, t0 + h
if wi < s1:
s1 = wi
s0 = s1 - w
if hi < t1:
t1 = hi
t0 = t1 - h
if s0 < 0 or t0 < 0:
raise ValueError(f'LLuL failed to process: s=({s0},{s1}), t=({t0},{t1})')
x1 = x[:, :, t0:t1, s0:s1]
# upscaling
x1 = self.up(x1, M)
if x1.shape[-1] < x.shape[-1] or x1.shape[-2] < x.shape[-2]:
dx = x.shape[-1] - x1.shape[-1]
dx1 = dx // 2
dx2 = dx - dx1
dy = x.shape[-2] - x1.shape[-2]
dy1 = dy // 2
dy2 = dy - dy1
x1 = F.pad(x1, (dx1, dx2, dy1, dy2), 'replicate')
x = torch.concat((x, x1), dim=0)
if dim == 3:
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
#print('I', tuple(inputs[0].shape), tuple(x.shape))
ctx['skipped'] = False
return x, *rest
return pre_hook
def create_post_hook(name: str, ctx: dict):
def post_hook(module: nn.Module, inputs: list, output: Tensor):
if step < start_step or max_steps < step:
return
if ctx['skipped']:
return
x = output
dim = x.dim()
if dim == 3:
bo, no, cho = x.shape
wo, ho, No = self.get_size(p, no)
x = rearrange(x, 'b (h w) c -> b c h w', w=wo)
elif dim == 4:
bo, cho, ho, wo = x.shape
else:
return
assert bo % 2 == 0
x, x1 = x[:bo//2], x[bo//2:]
# downscaling
x1 = self.down(x1, M)
# embed
w, h = x1.shape[-1], x1.shape[-2]
s0, t0 = int(wo * self.x0), int(ho * self.y0)
s1, t1 = s0 + w, t0 + h
if wo < s1:
s1 = wo
s0 = s1 - w
if ho < t1:
t1 = ho
t0 = t1 - h
if s0 < 0 or t0 < 0:
raise ValueError(f'LLuL failed to process: s=({s0},{s1}), t=({t0},{t1})')
x[:, :, t0:t1, s0:s1] = self.interpolate(x[:, :, t0:t1, s0:s1], x1, self.weight)
if dim == 3:
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
#print('O', tuple(inputs[0].shape), tuple(x.shape))
return x
return post_hook
def create_hook(name: str, **kwargs):
ctx = dict()
ctx.update(kwargs)
return (
create_pre_hook(name, ctx),
create_post_hook(name, ctx)
)
def wrap_for_xattn(pre, post):
def f(module: nn.Module, o: Callable, *args, **kwargs):
inputs = list(args) + list(kwargs.values())
inputs_ = pre(module, inputs)
if inputs_ is not None:
inputs = inputs_
output = o(*inputs)
output_ = post(module, inputs, output)
if output_ is not None:
output = output_
return output
return f
#
# process each attention layers
#
for name, attn in each_unet_attn_layers(unet):
if self.layers is not None:
if not any(layer in name for layer in self.layers):
continue
q_in = attn.to_q.in_features
k_in = attn.to_k.in_features
if q_in == k_in:
# self-attention
if 's. attn.' in self.apply_to:
pre, post = create_hook(name)
self.hook_layer_pre(attn, pre)
self.hook_layer(attn, post)
else:
# cross-attention
if 'x. attn.' in self.apply_to:
pre, post = create_hook(name)
self.hook_forward(attn, wrap_for_xattn(pre, post))
#
# process Resblocks
#
for name, res in each_unet_resblock(unet):
if 'resblock' not in self.apply_to:
continue
if self.layers is not None:
if not any(layer in name for layer in self.layers):
continue
pre, post = create_hook(name)
self.hook_layer_pre(res, pre)
self.hook_layer(res, post)
#
# process Transformers (including s/x-attn)
#
for name, res in each_unet_transformers(unet):
if 'transformer' not in self.apply_to:
continue
if self.layers is not None:
if not any(layer in name for layer in self.layers):
continue
pre, post = create_hook(name)
self.hook_layer_pre(res, pre)
self.hook_layer(res, post)
#
# process OUT
#
if 'out' in self.apply_to:
out = unet.out
pre, post = create_hook('out')
self.hook_layer_pre(out, pre)
self.hook_layer(out, post)
def get_size(self, p: StableDiffusionProcessing, n: int):
# n := p.width / N * p.height / N
wh = p.width * p.height
N2 = wh // n
N = int(math.sqrt(N2))
assert N*N == N2, f'N={N}, N2={N2}'
assert p.width % N == 0, f'width={p.width}, N={N}'
assert p.height % N == 0, f'height={p.height}, N={N}'
w, h = p.width // N, p.height // N
assert w * h == n, f'w={w}, h={h}, N={N}, n={n}'
return w, h, N
def interpolate(self, v1: Tensor, v2: Tensor, t: float):
dtype = v1.dtype
if self.force_float:
v1 = v1.float()
v2 = v2.float()
if self.mask_image is None:
v = self.intp(v1, v2, t)
else:
to_w, to_h = v1.shape[-1], v1.shape[-2]
resized_image = self.mask_image.resize((to_w, to_h), Image.BILINEAR)
mask = torchvision.transforms.functional.to_tensor(resized_image).to(device=v1.device, dtype=v1.dtype)
mask.unsqueeze_(0) # (C,H,W) -> (B,C,H,W)
mask.mul_(t)
v = self.intp(v1, v2, mask)
if self.force_float:
v = v.to(dtype)
return v