Fabrice-TIERCELIN
commited on
Upload 5 files
Browse files- SUPIR/utils/colorfix.py +120 -0
- SUPIR/utils/devices.py +138 -0
- SUPIR/utils/face_restoration_helper.py +514 -0
- SUPIR/utils/file.py +79 -0
- SUPIR/utils/tilevae.py +971 -0
SUPIR/utils/colorfix.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
# --------------------------------------------------------------------------------
|
3 |
+
# Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
|
4 |
+
# --------------------------------------------------------------------------------
|
5 |
+
'''
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
from torch import Tensor
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
from torchvision.transforms import ToTensor, ToPILImage
|
13 |
+
|
14 |
+
def adain_color_fix(target: Image, source: Image):
|
15 |
+
# Convert images to tensors
|
16 |
+
to_tensor = ToTensor()
|
17 |
+
target_tensor = to_tensor(target).unsqueeze(0)
|
18 |
+
source_tensor = to_tensor(source).unsqueeze(0)
|
19 |
+
|
20 |
+
# Apply adaptive instance normalization
|
21 |
+
result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
|
22 |
+
|
23 |
+
# Convert tensor back to image
|
24 |
+
to_image = ToPILImage()
|
25 |
+
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
26 |
+
|
27 |
+
return result_image
|
28 |
+
|
29 |
+
def wavelet_color_fix(target: Image, source: Image):
|
30 |
+
# Convert images to tensors
|
31 |
+
to_tensor = ToTensor()
|
32 |
+
target_tensor = to_tensor(target).unsqueeze(0)
|
33 |
+
source_tensor = to_tensor(source).unsqueeze(0)
|
34 |
+
|
35 |
+
# Apply wavelet reconstruction
|
36 |
+
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
|
37 |
+
|
38 |
+
# Convert tensor back to image
|
39 |
+
to_image = ToPILImage()
|
40 |
+
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
41 |
+
|
42 |
+
return result_image
|
43 |
+
|
44 |
+
def calc_mean_std(feat: Tensor, eps=1e-5):
|
45 |
+
"""Calculate mean and std for adaptive_instance_normalization.
|
46 |
+
Args:
|
47 |
+
feat (Tensor): 4D tensor.
|
48 |
+
eps (float): A small value added to the variance to avoid
|
49 |
+
divide-by-zero. Default: 1e-5.
|
50 |
+
"""
|
51 |
+
size = feat.size()
|
52 |
+
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
53 |
+
b, c = size[:2]
|
54 |
+
feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
|
55 |
+
feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
|
56 |
+
feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
|
57 |
+
return feat_mean, feat_std
|
58 |
+
|
59 |
+
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
|
60 |
+
"""Adaptive instance normalization.
|
61 |
+
Adjust the reference features to have the similar color and illuminations
|
62 |
+
as those in the degradate features.
|
63 |
+
Args:
|
64 |
+
content_feat (Tensor): The reference feature.
|
65 |
+
style_feat (Tensor): The degradate features.
|
66 |
+
"""
|
67 |
+
size = content_feat.size()
|
68 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
69 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
70 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
71 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
72 |
+
|
73 |
+
def wavelet_blur(image: Tensor, radius: int):
|
74 |
+
"""
|
75 |
+
Apply wavelet blur to the input tensor.
|
76 |
+
"""
|
77 |
+
# input shape: (1, 3, H, W)
|
78 |
+
# convolution kernel
|
79 |
+
kernel_vals = [
|
80 |
+
[0.0625, 0.125, 0.0625],
|
81 |
+
[0.125, 0.25, 0.125],
|
82 |
+
[0.0625, 0.125, 0.0625],
|
83 |
+
]
|
84 |
+
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
85 |
+
# add channel dimensions to the kernel to make it a 4D tensor
|
86 |
+
kernel = kernel[None, None]
|
87 |
+
# repeat the kernel across all input channels
|
88 |
+
kernel = kernel.repeat(3, 1, 1, 1)
|
89 |
+
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
90 |
+
# apply convolution
|
91 |
+
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
92 |
+
return output
|
93 |
+
|
94 |
+
def wavelet_decomposition(image: Tensor, levels=5):
|
95 |
+
"""
|
96 |
+
Apply wavelet decomposition to the input tensor.
|
97 |
+
This function only returns the low frequency & the high frequency.
|
98 |
+
"""
|
99 |
+
high_freq = torch.zeros_like(image)
|
100 |
+
for i in range(levels):
|
101 |
+
radius = 2 ** i
|
102 |
+
low_freq = wavelet_blur(image, radius)
|
103 |
+
high_freq += (image - low_freq)
|
104 |
+
image = low_freq
|
105 |
+
|
106 |
+
return high_freq, low_freq
|
107 |
+
|
108 |
+
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
|
109 |
+
"""
|
110 |
+
Apply wavelet decomposition, so that the content will have the same color as the style.
|
111 |
+
"""
|
112 |
+
# calculate the wavelet decomposition of the content feature
|
113 |
+
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
114 |
+
del content_low_freq
|
115 |
+
# calculate the wavelet decomposition of the style feature
|
116 |
+
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
117 |
+
del style_high_freq
|
118 |
+
# reconstruct the content feature with the style's high frequency
|
119 |
+
return content_high_freq + style_low_freq
|
120 |
+
|
SUPIR/utils/devices.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import contextlib
|
3 |
+
from functools import lru_cache
|
4 |
+
|
5 |
+
import torch
|
6 |
+
#from modules import errors
|
7 |
+
|
8 |
+
if sys.platform == "darwin":
|
9 |
+
from modules import mac_specific
|
10 |
+
|
11 |
+
|
12 |
+
def has_mps() -> bool:
|
13 |
+
if sys.platform != "darwin":
|
14 |
+
return False
|
15 |
+
else:
|
16 |
+
return mac_specific.has_mps
|
17 |
+
|
18 |
+
|
19 |
+
def get_cuda_device_string():
|
20 |
+
return "cuda"
|
21 |
+
|
22 |
+
|
23 |
+
def get_optimal_device_name():
|
24 |
+
if torch.cuda.is_available():
|
25 |
+
return get_cuda_device_string()
|
26 |
+
|
27 |
+
if has_mps():
|
28 |
+
return "mps"
|
29 |
+
|
30 |
+
return "cpu"
|
31 |
+
|
32 |
+
|
33 |
+
def get_optimal_device():
|
34 |
+
return torch.device(get_optimal_device_name())
|
35 |
+
|
36 |
+
|
37 |
+
def get_device_for(task):
|
38 |
+
return get_optimal_device()
|
39 |
+
|
40 |
+
|
41 |
+
def torch_gc():
|
42 |
+
|
43 |
+
if torch.cuda.is_available():
|
44 |
+
with torch.cuda.device(get_cuda_device_string()):
|
45 |
+
torch.cuda.empty_cache()
|
46 |
+
torch.cuda.ipc_collect()
|
47 |
+
|
48 |
+
if has_mps():
|
49 |
+
mac_specific.torch_mps_gc()
|
50 |
+
|
51 |
+
|
52 |
+
def enable_tf32():
|
53 |
+
if torch.cuda.is_available():
|
54 |
+
|
55 |
+
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
56 |
+
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
57 |
+
if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
|
58 |
+
torch.backends.cudnn.benchmark = True
|
59 |
+
|
60 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
61 |
+
torch.backends.cudnn.allow_tf32 = True
|
62 |
+
|
63 |
+
|
64 |
+
enable_tf32()
|
65 |
+
#errors.run(enable_tf32, "Enabling TF32")
|
66 |
+
|
67 |
+
cpu = torch.device("cpu")
|
68 |
+
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda")
|
69 |
+
dtype = torch.float16
|
70 |
+
dtype_vae = torch.float16
|
71 |
+
dtype_unet = torch.float16
|
72 |
+
unet_needs_upcast = False
|
73 |
+
|
74 |
+
|
75 |
+
def cond_cast_unet(input):
|
76 |
+
return input.to(dtype_unet) if unet_needs_upcast else input
|
77 |
+
|
78 |
+
|
79 |
+
def cond_cast_float(input):
|
80 |
+
return input.float() if unet_needs_upcast else input
|
81 |
+
|
82 |
+
|
83 |
+
def randn(seed, shape):
|
84 |
+
torch.manual_seed(seed)
|
85 |
+
return torch.randn(shape, device=device)
|
86 |
+
|
87 |
+
|
88 |
+
def randn_without_seed(shape):
|
89 |
+
return torch.randn(shape, device=device)
|
90 |
+
|
91 |
+
|
92 |
+
def autocast(disable=False):
|
93 |
+
if disable:
|
94 |
+
return contextlib.nullcontext()
|
95 |
+
|
96 |
+
return torch.autocast("cuda")
|
97 |
+
|
98 |
+
|
99 |
+
def without_autocast(disable=False):
|
100 |
+
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
|
101 |
+
|
102 |
+
|
103 |
+
class NansException(Exception):
|
104 |
+
pass
|
105 |
+
|
106 |
+
|
107 |
+
def test_for_nans(x, where):
|
108 |
+
if not torch.all(torch.isnan(x)).item():
|
109 |
+
return
|
110 |
+
|
111 |
+
if where == "unet":
|
112 |
+
message = "A tensor with all NaNs was produced in Unet."
|
113 |
+
|
114 |
+
elif where == "vae":
|
115 |
+
message = "A tensor with all NaNs was produced in VAE."
|
116 |
+
|
117 |
+
else:
|
118 |
+
message = "A tensor with all NaNs was produced."
|
119 |
+
|
120 |
+
message += " Use --disable-nan-check commandline argument to disable this check."
|
121 |
+
|
122 |
+
raise NansException(message)
|
123 |
+
|
124 |
+
|
125 |
+
@lru_cache
|
126 |
+
def first_time_calculation():
|
127 |
+
"""
|
128 |
+
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
|
129 |
+
spends about 2.7 seconds doing that, at least wih NVidia.
|
130 |
+
"""
|
131 |
+
|
132 |
+
x = torch.zeros((1, 1)).to(device, dtype)
|
133 |
+
linear = torch.nn.Linear(1, 1).to(device, dtype)
|
134 |
+
linear(x)
|
135 |
+
|
136 |
+
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
137 |
+
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
138 |
+
conv2d(x)
|
SUPIR/utils/face_restoration_helper.py
ADDED
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
from torchvision.transforms.functional import normalize
|
6 |
+
|
7 |
+
from facexlib.detection import init_detection_model
|
8 |
+
from facexlib.parsing import init_parsing_model
|
9 |
+
from facexlib.utils.misc import img2tensor, imwrite
|
10 |
+
|
11 |
+
from .file import load_file_from_url
|
12 |
+
|
13 |
+
|
14 |
+
def get_largest_face(det_faces, h, w):
|
15 |
+
def get_location(val, length):
|
16 |
+
if val < 0:
|
17 |
+
return 0
|
18 |
+
elif val > length:
|
19 |
+
return length
|
20 |
+
else:
|
21 |
+
return val
|
22 |
+
|
23 |
+
face_areas = []
|
24 |
+
for det_face in det_faces:
|
25 |
+
left = get_location(det_face[0], w)
|
26 |
+
right = get_location(det_face[2], w)
|
27 |
+
top = get_location(det_face[1], h)
|
28 |
+
bottom = get_location(det_face[3], h)
|
29 |
+
face_area = (right - left) * (bottom - top)
|
30 |
+
face_areas.append(face_area)
|
31 |
+
largest_idx = face_areas.index(max(face_areas))
|
32 |
+
return det_faces[largest_idx], largest_idx
|
33 |
+
|
34 |
+
|
35 |
+
def get_center_face(det_faces, h=0, w=0, center=None):
|
36 |
+
if center is not None:
|
37 |
+
center = np.array(center)
|
38 |
+
else:
|
39 |
+
center = np.array([w / 2, h / 2])
|
40 |
+
center_dist = []
|
41 |
+
for det_face in det_faces:
|
42 |
+
face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
|
43 |
+
dist = np.linalg.norm(face_center - center)
|
44 |
+
center_dist.append(dist)
|
45 |
+
center_idx = center_dist.index(min(center_dist))
|
46 |
+
return det_faces[center_idx], center_idx
|
47 |
+
|
48 |
+
|
49 |
+
class FaceRestoreHelper(object):
|
50 |
+
"""Helper for the face restoration pipeline (base class)."""
|
51 |
+
|
52 |
+
def __init__(self,
|
53 |
+
upscale_factor,
|
54 |
+
face_size=512,
|
55 |
+
crop_ratio=(1, 1),
|
56 |
+
det_model='retinaface_resnet50',
|
57 |
+
save_ext='png',
|
58 |
+
template_3points=False,
|
59 |
+
pad_blur=False,
|
60 |
+
use_parse=False,
|
61 |
+
device=None):
|
62 |
+
self.template_3points = template_3points # improve robustness
|
63 |
+
self.upscale_factor = int(upscale_factor)
|
64 |
+
# the cropped face ratio based on the square face
|
65 |
+
self.crop_ratio = crop_ratio # (h, w)
|
66 |
+
assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
|
67 |
+
self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
|
68 |
+
self.det_model = det_model
|
69 |
+
|
70 |
+
if self.det_model == 'dlib':
|
71 |
+
# standard 5 landmarks for FFHQ faces with 1024 x 1024
|
72 |
+
self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
|
73 |
+
[337.91089109, 488.38613861], [437.95049505, 493.51485149],
|
74 |
+
[513.58415842, 678.5049505]])
|
75 |
+
self.face_template = self.face_template / (1024 // face_size)
|
76 |
+
elif self.template_3points:
|
77 |
+
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
|
78 |
+
else:
|
79 |
+
# standard 5 landmarks for FFHQ faces with 512 x 512
|
80 |
+
# facexlib
|
81 |
+
self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
|
82 |
+
[201.26117, 371.41043], [313.08905, 371.15118]])
|
83 |
+
|
84 |
+
# dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
|
85 |
+
# self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
|
86 |
+
# [198.22603, 372.82502], [313.91018, 372.75659]])
|
87 |
+
|
88 |
+
self.face_template = self.face_template * (face_size / 512.0)
|
89 |
+
if self.crop_ratio[0] > 1:
|
90 |
+
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
|
91 |
+
if self.crop_ratio[1] > 1:
|
92 |
+
self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
|
93 |
+
self.save_ext = save_ext
|
94 |
+
self.pad_blur = pad_blur
|
95 |
+
if self.pad_blur is True:
|
96 |
+
self.template_3points = False
|
97 |
+
|
98 |
+
self.all_landmarks_5 = []
|
99 |
+
self.det_faces = []
|
100 |
+
self.affine_matrices = []
|
101 |
+
self.inverse_affine_matrices = []
|
102 |
+
self.cropped_faces = []
|
103 |
+
self.restored_faces = []
|
104 |
+
self.pad_input_imgs = []
|
105 |
+
|
106 |
+
if device is None:
|
107 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
108 |
+
# self.device = get_device()
|
109 |
+
else:
|
110 |
+
self.device = device
|
111 |
+
|
112 |
+
# init face detection model
|
113 |
+
self.face_detector = init_detection_model(det_model, half=False, device=self.device)
|
114 |
+
|
115 |
+
# init face parsing model
|
116 |
+
self.use_parse = use_parse
|
117 |
+
self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)
|
118 |
+
|
119 |
+
def set_upscale_factor(self, upscale_factor):
|
120 |
+
self.upscale_factor = upscale_factor
|
121 |
+
|
122 |
+
def read_image(self, img):
|
123 |
+
"""img can be image path or cv2 loaded image."""
|
124 |
+
# self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
|
125 |
+
if isinstance(img, str):
|
126 |
+
img = cv2.imread(img)
|
127 |
+
|
128 |
+
if np.max(img) > 256: # 16-bit image
|
129 |
+
img = img / 65535 * 255
|
130 |
+
if len(img.shape) == 2: # gray image
|
131 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
132 |
+
elif img.shape[2] == 4: # BGRA image with alpha channel
|
133 |
+
img = img[:, :, 0:3]
|
134 |
+
|
135 |
+
self.input_img = img
|
136 |
+
# self.is_gray = is_gray(img, threshold=10)
|
137 |
+
# if self.is_gray:
|
138 |
+
# print('Grayscale input: True')
|
139 |
+
|
140 |
+
if min(self.input_img.shape[:2]) < 512:
|
141 |
+
f = 512.0 / min(self.input_img.shape[:2])
|
142 |
+
self.input_img = cv2.resize(self.input_img, (0, 0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
|
143 |
+
|
144 |
+
def init_dlib(self, detection_path, landmark5_path):
|
145 |
+
"""Initialize the dlib detectors and predictors."""
|
146 |
+
try:
|
147 |
+
import dlib
|
148 |
+
except ImportError:
|
149 |
+
print('Please install dlib by running:' 'conda install -c conda-forge dlib')
|
150 |
+
detection_path = load_file_from_url(url=detection_path, model_dir='weights/dlib', progress=True, file_name=None)
|
151 |
+
landmark5_path = load_file_from_url(url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None)
|
152 |
+
face_detector = dlib.cnn_face_detection_model_v1(detection_path)
|
153 |
+
shape_predictor_5 = dlib.shape_predictor(landmark5_path)
|
154 |
+
return face_detector, shape_predictor_5
|
155 |
+
|
156 |
+
def get_face_landmarks_5_dlib(self,
|
157 |
+
only_keep_largest=False,
|
158 |
+
scale=1):
|
159 |
+
det_faces = self.face_detector(self.input_img, scale)
|
160 |
+
|
161 |
+
if len(det_faces) == 0:
|
162 |
+
print('No face detected. Try to increase upsample_num_times.')
|
163 |
+
return 0
|
164 |
+
else:
|
165 |
+
if only_keep_largest:
|
166 |
+
print('Detect several faces and only keep the largest.')
|
167 |
+
face_areas = []
|
168 |
+
for i in range(len(det_faces)):
|
169 |
+
face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
|
170 |
+
det_faces[i].rect.bottom() - det_faces[i].rect.top())
|
171 |
+
face_areas.append(face_area)
|
172 |
+
largest_idx = face_areas.index(max(face_areas))
|
173 |
+
self.det_faces = [det_faces[largest_idx]]
|
174 |
+
else:
|
175 |
+
self.det_faces = det_faces
|
176 |
+
|
177 |
+
if len(self.det_faces) == 0:
|
178 |
+
return 0
|
179 |
+
|
180 |
+
for face in self.det_faces:
|
181 |
+
shape = self.shape_predictor_5(self.input_img, face.rect)
|
182 |
+
landmark = np.array([[part.x, part.y] for part in shape.parts()])
|
183 |
+
self.all_landmarks_5.append(landmark)
|
184 |
+
|
185 |
+
return len(self.all_landmarks_5)
|
186 |
+
|
187 |
+
def get_face_landmarks_5(self,
|
188 |
+
only_keep_largest=False,
|
189 |
+
only_center_face=False,
|
190 |
+
resize=None,
|
191 |
+
blur_ratio=0.01,
|
192 |
+
eye_dist_threshold=None):
|
193 |
+
if self.det_model == 'dlib':
|
194 |
+
return self.get_face_landmarks_5_dlib(only_keep_largest)
|
195 |
+
|
196 |
+
if resize is None:
|
197 |
+
scale = 1
|
198 |
+
input_img = self.input_img
|
199 |
+
else:
|
200 |
+
h, w = self.input_img.shape[0:2]
|
201 |
+
scale = resize / min(h, w)
|
202 |
+
scale = max(1, scale) # always scale up
|
203 |
+
h, w = int(h * scale), int(w * scale)
|
204 |
+
interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
|
205 |
+
input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
|
206 |
+
|
207 |
+
with torch.no_grad():
|
208 |
+
bboxes = self.face_detector.detect_faces(input_img)
|
209 |
+
|
210 |
+
if bboxes is None or bboxes.shape[0] == 0:
|
211 |
+
return 0
|
212 |
+
else:
|
213 |
+
bboxes = bboxes / scale
|
214 |
+
|
215 |
+
for bbox in bboxes:
|
216 |
+
# remove faces with too small eye distance: side faces or too small faces
|
217 |
+
eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
|
218 |
+
if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
|
219 |
+
continue
|
220 |
+
|
221 |
+
if self.template_3points:
|
222 |
+
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
|
223 |
+
else:
|
224 |
+
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
|
225 |
+
self.all_landmarks_5.append(landmark)
|
226 |
+
self.det_faces.append(bbox[0:5])
|
227 |
+
|
228 |
+
if len(self.det_faces) == 0:
|
229 |
+
return 0
|
230 |
+
if only_keep_largest:
|
231 |
+
h, w, _ = self.input_img.shape
|
232 |
+
self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
|
233 |
+
self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
|
234 |
+
elif only_center_face:
|
235 |
+
h, w, _ = self.input_img.shape
|
236 |
+
self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
|
237 |
+
self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
|
238 |
+
|
239 |
+
# pad blurry images
|
240 |
+
if self.pad_blur:
|
241 |
+
self.pad_input_imgs = []
|
242 |
+
for landmarks in self.all_landmarks_5:
|
243 |
+
# get landmarks
|
244 |
+
eye_left = landmarks[0, :]
|
245 |
+
eye_right = landmarks[1, :]
|
246 |
+
eye_avg = (eye_left + eye_right) * 0.5
|
247 |
+
mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
|
248 |
+
eye_to_eye = eye_right - eye_left
|
249 |
+
eye_to_mouth = mouth_avg - eye_avg
|
250 |
+
|
251 |
+
# Get the oriented crop rectangle
|
252 |
+
# x: half width of the oriented crop rectangle
|
253 |
+
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
254 |
+
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
|
255 |
+
# norm with the hypotenuse: get the direction
|
256 |
+
x /= np.hypot(*x) # get the hypotenuse of a right triangle
|
257 |
+
rect_scale = 1.5
|
258 |
+
x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
|
259 |
+
# y: half height of the oriented crop rectangle
|
260 |
+
y = np.flipud(x) * [-1, 1]
|
261 |
+
|
262 |
+
# c: center
|
263 |
+
c = eye_avg + eye_to_mouth * 0.1
|
264 |
+
# quad: (left_top, left_bottom, right_bottom, right_top)
|
265 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
266 |
+
# qsize: side length of the square
|
267 |
+
qsize = np.hypot(*x) * 2
|
268 |
+
border = max(int(np.rint(qsize * 0.1)), 3)
|
269 |
+
|
270 |
+
# get pad
|
271 |
+
# pad: (width_left, height_top, width_right, height_bottom)
|
272 |
+
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
273 |
+
int(np.ceil(max(quad[:, 1]))))
|
274 |
+
pad = [
|
275 |
+
max(-pad[0] + border, 1),
|
276 |
+
max(-pad[1] + border, 1),
|
277 |
+
max(pad[2] - self.input_img.shape[0] + border, 1),
|
278 |
+
max(pad[3] - self.input_img.shape[1] + border, 1)
|
279 |
+
]
|
280 |
+
|
281 |
+
if max(pad) > 1:
|
282 |
+
# pad image
|
283 |
+
pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
|
284 |
+
# modify landmark coords
|
285 |
+
landmarks[:, 0] += pad[0]
|
286 |
+
landmarks[:, 1] += pad[1]
|
287 |
+
# blur pad images
|
288 |
+
h, w, _ = pad_img.shape
|
289 |
+
y, x, _ = np.ogrid[:h, :w, :1]
|
290 |
+
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
|
291 |
+
np.float32(w - 1 - x) / pad[2]),
|
292 |
+
1.0 - np.minimum(np.float32(y) / pad[1],
|
293 |
+
np.float32(h - 1 - y) / pad[3]))
|
294 |
+
blur = int(qsize * blur_ratio)
|
295 |
+
if blur % 2 == 0:
|
296 |
+
blur += 1
|
297 |
+
blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
|
298 |
+
# blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
|
299 |
+
|
300 |
+
pad_img = pad_img.astype('float32')
|
301 |
+
pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
302 |
+
pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
|
303 |
+
pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
|
304 |
+
self.pad_input_imgs.append(pad_img)
|
305 |
+
else:
|
306 |
+
self.pad_input_imgs.append(np.copy(self.input_img))
|
307 |
+
|
308 |
+
return len(self.all_landmarks_5)
|
309 |
+
|
310 |
+
def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
|
311 |
+
"""Align and warp faces with face template.
|
312 |
+
"""
|
313 |
+
if self.pad_blur:
|
314 |
+
assert len(self.pad_input_imgs) == len(
|
315 |
+
self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
|
316 |
+
for idx, landmark in enumerate(self.all_landmarks_5):
|
317 |
+
# use 5 landmarks to get affine matrix
|
318 |
+
# use cv2.LMEDS method for the equivalence to skimage transform
|
319 |
+
# ref: https://blog.csdn.net/yichxi/article/details/115827338
|
320 |
+
affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
|
321 |
+
self.affine_matrices.append(affine_matrix)
|
322 |
+
# warp and crop faces
|
323 |
+
if border_mode == 'constant':
|
324 |
+
border_mode = cv2.BORDER_CONSTANT
|
325 |
+
elif border_mode == 'reflect101':
|
326 |
+
border_mode = cv2.BORDER_REFLECT101
|
327 |
+
elif border_mode == 'reflect':
|
328 |
+
border_mode = cv2.BORDER_REFLECT
|
329 |
+
if self.pad_blur:
|
330 |
+
input_img = self.pad_input_imgs[idx]
|
331 |
+
else:
|
332 |
+
input_img = self.input_img
|
333 |
+
cropped_face = cv2.warpAffine(
|
334 |
+
input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
|
335 |
+
self.cropped_faces.append(cropped_face)
|
336 |
+
# save the cropped face
|
337 |
+
if save_cropped_path is not None:
|
338 |
+
path = os.path.splitext(save_cropped_path)[0]
|
339 |
+
save_path = f'{path}_{idx:02d}.{self.save_ext}'
|
340 |
+
imwrite(cropped_face, save_path)
|
341 |
+
|
342 |
+
def get_inverse_affine(self, save_inverse_affine_path=None):
|
343 |
+
"""Get inverse affine matrix."""
|
344 |
+
for idx, affine_matrix in enumerate(self.affine_matrices):
|
345 |
+
inverse_affine = cv2.invertAffineTransform(affine_matrix)
|
346 |
+
inverse_affine *= self.upscale_factor
|
347 |
+
self.inverse_affine_matrices.append(inverse_affine)
|
348 |
+
# save inverse affine matrices
|
349 |
+
if save_inverse_affine_path is not None:
|
350 |
+
path, _ = os.path.splitext(save_inverse_affine_path)
|
351 |
+
save_path = f'{path}_{idx:02d}.pth'
|
352 |
+
torch.save(inverse_affine, save_path)
|
353 |
+
|
354 |
+
def add_restored_face(self, restored_face, input_face=None):
|
355 |
+
# if self.is_gray:
|
356 |
+
# restored_face = bgr2gray(restored_face) # convert img into grayscale
|
357 |
+
# if input_face is not None:
|
358 |
+
# restored_face = adain_npy(restored_face, input_face) # transfer the color
|
359 |
+
self.restored_faces.append(restored_face)
|
360 |
+
|
361 |
+
def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
|
362 |
+
h, w, _ = self.input_img.shape
|
363 |
+
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
|
364 |
+
|
365 |
+
if upsample_img is None:
|
366 |
+
# simply resize the background
|
367 |
+
# upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
368 |
+
upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
|
369 |
+
else:
|
370 |
+
upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
371 |
+
|
372 |
+
assert len(self.restored_faces) == len(
|
373 |
+
self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
|
374 |
+
|
375 |
+
inv_mask_borders = []
|
376 |
+
for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
|
377 |
+
if face_upsampler is not None:
|
378 |
+
restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
|
379 |
+
inverse_affine /= self.upscale_factor
|
380 |
+
inverse_affine[:, 2] *= self.upscale_factor
|
381 |
+
face_size = (self.face_size[0] * self.upscale_factor, self.face_size[1] * self.upscale_factor)
|
382 |
+
else:
|
383 |
+
# Add an offset to inverse affine matrix, for more precise back alignment
|
384 |
+
if self.upscale_factor > 1:
|
385 |
+
extra_offset = 0.5 * self.upscale_factor
|
386 |
+
else:
|
387 |
+
extra_offset = 0
|
388 |
+
inverse_affine[:, 2] += extra_offset
|
389 |
+
face_size = self.face_size
|
390 |
+
inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
|
391 |
+
|
392 |
+
# if draw_box or not self.use_parse: # use square parse maps
|
393 |
+
# mask = np.ones(face_size, dtype=np.float32)
|
394 |
+
# inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
395 |
+
# # remove the black borders
|
396 |
+
# inv_mask_erosion = cv2.erode(
|
397 |
+
# inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
|
398 |
+
# pasted_face = inv_mask_erosion[:, :, None] * inv_restored
|
399 |
+
# total_face_area = np.sum(inv_mask_erosion) # // 3
|
400 |
+
# # add border
|
401 |
+
# if draw_box:
|
402 |
+
# h, w = face_size
|
403 |
+
# mask_border = np.ones((h, w, 3), dtype=np.float32)
|
404 |
+
# border = int(1400/np.sqrt(total_face_area))
|
405 |
+
# mask_border[border:h-border, border:w-border,:] = 0
|
406 |
+
# inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
|
407 |
+
# inv_mask_borders.append(inv_mask_border)
|
408 |
+
# if not self.use_parse:
|
409 |
+
# # compute the fusion edge based on the area of face
|
410 |
+
# w_edge = int(total_face_area**0.5) // 20
|
411 |
+
# erosion_radius = w_edge * 2
|
412 |
+
# inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
413 |
+
# blur_size = w_edge * 2
|
414 |
+
# inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
415 |
+
# if len(upsample_img.shape) == 2: # upsample_img is gray image
|
416 |
+
# upsample_img = upsample_img[:, :, None]
|
417 |
+
# inv_soft_mask = inv_soft_mask[:, :, None]
|
418 |
+
|
419 |
+
# always use square mask
|
420 |
+
mask = np.ones(face_size, dtype=np.float32)
|
421 |
+
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
422 |
+
# remove the black borders
|
423 |
+
inv_mask_erosion = cv2.erode(
|
424 |
+
inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
|
425 |
+
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
|
426 |
+
total_face_area = np.sum(inv_mask_erosion) # // 3
|
427 |
+
# add border
|
428 |
+
if draw_box:
|
429 |
+
h, w = face_size
|
430 |
+
mask_border = np.ones((h, w, 3), dtype=np.float32)
|
431 |
+
border = int(1400 / np.sqrt(total_face_area))
|
432 |
+
mask_border[border:h - border, border:w - border, :] = 0
|
433 |
+
inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
|
434 |
+
inv_mask_borders.append(inv_mask_border)
|
435 |
+
# compute the fusion edge based on the area of face
|
436 |
+
w_edge = int(total_face_area ** 0.5) // 20
|
437 |
+
erosion_radius = w_edge * 2
|
438 |
+
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
439 |
+
blur_size = w_edge * 2
|
440 |
+
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
441 |
+
if len(upsample_img.shape) == 2: # upsample_img is gray image
|
442 |
+
upsample_img = upsample_img[:, :, None]
|
443 |
+
inv_soft_mask = inv_soft_mask[:, :, None]
|
444 |
+
|
445 |
+
# parse mask
|
446 |
+
if self.use_parse:
|
447 |
+
# inference
|
448 |
+
face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
|
449 |
+
face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
|
450 |
+
normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
451 |
+
face_input = torch.unsqueeze(face_input, 0).to(self.device)
|
452 |
+
with torch.no_grad():
|
453 |
+
out = self.face_parse(face_input)[0]
|
454 |
+
out = out.argmax(dim=1).squeeze().cpu().numpy()
|
455 |
+
|
456 |
+
parse_mask = np.zeros(out.shape)
|
457 |
+
MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
|
458 |
+
for idx, color in enumerate(MASK_COLORMAP):
|
459 |
+
parse_mask[out == idx] = color
|
460 |
+
# blur the mask
|
461 |
+
parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
|
462 |
+
parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
|
463 |
+
# remove the black borders
|
464 |
+
thres = 10
|
465 |
+
parse_mask[:thres, :] = 0
|
466 |
+
parse_mask[-thres:, :] = 0
|
467 |
+
parse_mask[:, :thres] = 0
|
468 |
+
parse_mask[:, -thres:] = 0
|
469 |
+
parse_mask = parse_mask / 255.
|
470 |
+
|
471 |
+
parse_mask = cv2.resize(parse_mask, face_size)
|
472 |
+
parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
|
473 |
+
inv_soft_parse_mask = parse_mask[:, :, None]
|
474 |
+
# pasted_face = inv_restored
|
475 |
+
fuse_mask = (inv_soft_parse_mask < inv_soft_mask).astype('int')
|
476 |
+
inv_soft_mask = inv_soft_parse_mask * fuse_mask + inv_soft_mask * (1 - fuse_mask)
|
477 |
+
|
478 |
+
if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
|
479 |
+
alpha = upsample_img[:, :, 3:]
|
480 |
+
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
|
481 |
+
upsample_img = np.concatenate((upsample_img, alpha), axis=2)
|
482 |
+
else:
|
483 |
+
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
|
484 |
+
|
485 |
+
if np.max(upsample_img) > 256: # 16-bit image
|
486 |
+
upsample_img = upsample_img.astype(np.uint16)
|
487 |
+
else:
|
488 |
+
upsample_img = upsample_img.astype(np.uint8)
|
489 |
+
|
490 |
+
# draw bounding box
|
491 |
+
if draw_box:
|
492 |
+
# upsample_input_img = cv2.resize(input_img, (w_up, h_up))
|
493 |
+
img_color = np.ones([*upsample_img.shape], dtype=np.float32)
|
494 |
+
img_color[:, :, 0] = 0
|
495 |
+
img_color[:, :, 1] = 255
|
496 |
+
img_color[:, :, 2] = 0
|
497 |
+
for inv_mask_border in inv_mask_borders:
|
498 |
+
upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
|
499 |
+
# upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
|
500 |
+
|
501 |
+
if save_path is not None:
|
502 |
+
path = os.path.splitext(save_path)[0]
|
503 |
+
save_path = f'{path}.{self.save_ext}'
|
504 |
+
imwrite(upsample_img, save_path)
|
505 |
+
return upsample_img
|
506 |
+
|
507 |
+
def clean_all(self):
|
508 |
+
self.all_landmarks_5 = []
|
509 |
+
self.restored_faces = []
|
510 |
+
self.affine_matrices = []
|
511 |
+
self.cropped_faces = []
|
512 |
+
self.inverse_affine_matrices = []
|
513 |
+
self.det_faces = []
|
514 |
+
self.pad_input_imgs = []
|
SUPIR/utils/file.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
from urllib.parse import urlparse
|
5 |
+
from torch.hub import download_url_to_file, get_dir
|
6 |
+
|
7 |
+
|
8 |
+
def load_file_list(file_list_path: str) -> List[str]:
|
9 |
+
files = []
|
10 |
+
# each line in file list contains a path of an image
|
11 |
+
with open(file_list_path, "r") as fin:
|
12 |
+
for line in fin:
|
13 |
+
path = line.strip()
|
14 |
+
if path:
|
15 |
+
files.append(path)
|
16 |
+
return files
|
17 |
+
|
18 |
+
|
19 |
+
def list_image_files(
|
20 |
+
img_dir: str,
|
21 |
+
exts: Tuple[str]=(".jpg", ".png", ".jpeg"),
|
22 |
+
follow_links: bool=False,
|
23 |
+
log_progress: bool=False,
|
24 |
+
log_every_n_files: int=10000,
|
25 |
+
max_size: int=-1
|
26 |
+
) -> List[str]:
|
27 |
+
files = []
|
28 |
+
for dir_path, _, file_names in os.walk(img_dir, followlinks=follow_links):
|
29 |
+
early_stop = False
|
30 |
+
for file_name in file_names:
|
31 |
+
if os.path.splitext(file_name)[1].lower() in exts:
|
32 |
+
if max_size >= 0 and len(files) >= max_size:
|
33 |
+
early_stop = True
|
34 |
+
break
|
35 |
+
files.append(os.path.join(dir_path, file_name))
|
36 |
+
if log_progress and len(files) % log_every_n_files == 0:
|
37 |
+
print(f"find {len(files)} images in {img_dir}")
|
38 |
+
if early_stop:
|
39 |
+
break
|
40 |
+
return files
|
41 |
+
|
42 |
+
|
43 |
+
def get_file_name_parts(file_path: str) -> Tuple[str, str, str]:
|
44 |
+
parent_path, file_name = os.path.split(file_path)
|
45 |
+
stem, ext = os.path.splitext(file_name)
|
46 |
+
return parent_path, stem, ext
|
47 |
+
|
48 |
+
|
49 |
+
# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/
|
50 |
+
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
51 |
+
"""Load file form http url, will download models if necessary.
|
52 |
+
|
53 |
+
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
54 |
+
|
55 |
+
Args:
|
56 |
+
url (str): URL to be downloaded.
|
57 |
+
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
|
58 |
+
Default: None.
|
59 |
+
progress (bool): Whether to show the download progress. Default: True.
|
60 |
+
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
str: The path to the downloaded file.
|
64 |
+
"""
|
65 |
+
if model_dir is None: # use the pytorch hub_dir
|
66 |
+
hub_dir = get_dir()
|
67 |
+
model_dir = os.path.join(hub_dir, 'checkpoints')
|
68 |
+
|
69 |
+
os.makedirs(model_dir, exist_ok=True)
|
70 |
+
|
71 |
+
parts = urlparse(url)
|
72 |
+
filename = os.path.basename(parts.path)
|
73 |
+
if file_name is not None:
|
74 |
+
filename = file_name
|
75 |
+
cached_file = os.path.abspath(os.path.join(model_dir, filename))
|
76 |
+
if not os.path.exists(cached_file):
|
77 |
+
print(f'Downloading: "{url}" to {cached_file}\n')
|
78 |
+
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
79 |
+
return cached_file
|
SUPIR/utils/tilevae.py
ADDED
@@ -0,0 +1,971 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
#
|
3 |
+
# Ultimate VAE Tile Optimization
|
4 |
+
#
|
5 |
+
# Introducing a revolutionary new optimization designed to make
|
6 |
+
# the VAE work with giant images on limited VRAM!
|
7 |
+
# Say goodbye to the frustration of OOM and hello to seamless output!
|
8 |
+
#
|
9 |
+
# ------------------------------------------------------------------------
|
10 |
+
#
|
11 |
+
# This script is a wild hack that splits the image into tiles,
|
12 |
+
# encodes each tile separately, and merges the result back together.
|
13 |
+
#
|
14 |
+
# Advantages:
|
15 |
+
# - The VAE can now work with giant images on limited VRAM
|
16 |
+
# (~10 GB for 8K images!)
|
17 |
+
# - The merged output is completely seamless without any post-processing.
|
18 |
+
#
|
19 |
+
# Drawbacks:
|
20 |
+
# - Giant RAM needed. To store the intermediate results for a 4096x4096
|
21 |
+
# images, you need 32 GB RAM it consumes ~20GB); for 8192x8192
|
22 |
+
# you need 128 GB RAM machine (it consumes ~100 GB)
|
23 |
+
# - NaNs always appear in for 8k images when you use fp16 (half) VAE
|
24 |
+
# You must use --no-half-vae to disable half VAE for that giant image.
|
25 |
+
# - Slow speed. With default tile size, it takes around 50/200 seconds
|
26 |
+
# to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode
|
27 |
+
# a 8192x8192 image. (The speed is limited by both the GPU and the CPU.)
|
28 |
+
# - The gradient calculation is not compatible with this hack. It
|
29 |
+
# will break any backward() or torch.autograd.grad() that passes VAE.
|
30 |
+
# (But you can still use the VAE to generate training data.)
|
31 |
+
#
|
32 |
+
# How it works:
|
33 |
+
# 1) The image is split into tiles.
|
34 |
+
# - To ensure perfect results, each tile is padded with 32 pixels
|
35 |
+
# on each side.
|
36 |
+
# - Then the conv2d/silu/upsample/downsample can produce identical
|
37 |
+
# results to the original image without splitting.
|
38 |
+
# 2) The original forward is decomposed into a task queue and a task worker.
|
39 |
+
# - The task queue is a list of functions that will be executed in order.
|
40 |
+
# - The task worker is a loop that executes the tasks in the queue.
|
41 |
+
# 3) The task queue is executed for each tile.
|
42 |
+
# - Current tile is sent to GPU.
|
43 |
+
# - local operations are directly executed.
|
44 |
+
# - Group norm calculation is temporarily suspended until the mean
|
45 |
+
# and var of all tiles are calculated.
|
46 |
+
# - The residual is pre-calculated and stored and addded back later.
|
47 |
+
# - When need to go to the next tile, the current tile is send to cpu.
|
48 |
+
# 4) After all tiles are processed, tiles are merged on cpu and return.
|
49 |
+
#
|
50 |
+
# Enjoy!
|
51 |
+
#
|
52 |
+
# @author: LI YI @ Nanyang Technological University - Singapore
|
53 |
+
# @date: 2023-03-02
|
54 |
+
# @license: MIT License
|
55 |
+
#
|
56 |
+
# Please give me a star if you like this project!
|
57 |
+
#
|
58 |
+
# -------------------------------------------------------------------------
|
59 |
+
|
60 |
+
import gc
|
61 |
+
from time import time
|
62 |
+
import math
|
63 |
+
from tqdm import tqdm
|
64 |
+
|
65 |
+
import torch
|
66 |
+
import torch.version
|
67 |
+
import torch.nn.functional as F
|
68 |
+
from einops import rearrange
|
69 |
+
from diffusers.utils.import_utils import is_xformers_available
|
70 |
+
|
71 |
+
import SUPIR.utils.devices as devices
|
72 |
+
|
73 |
+
try:
|
74 |
+
import xformers
|
75 |
+
import xformers.ops
|
76 |
+
except ImportError:
|
77 |
+
pass
|
78 |
+
|
79 |
+
sd_flag = True
|
80 |
+
|
81 |
+
def get_recommend_encoder_tile_size():
|
82 |
+
if torch.cuda.is_available():
|
83 |
+
total_memory = torch.cuda.get_device_properties(
|
84 |
+
devices.device).total_memory // 2**20
|
85 |
+
if total_memory > 16*1000:
|
86 |
+
ENCODER_TILE_SIZE = 3072
|
87 |
+
elif total_memory > 12*1000:
|
88 |
+
ENCODER_TILE_SIZE = 2048
|
89 |
+
elif total_memory > 8*1000:
|
90 |
+
ENCODER_TILE_SIZE = 1536
|
91 |
+
else:
|
92 |
+
ENCODER_TILE_SIZE = 960
|
93 |
+
else:
|
94 |
+
ENCODER_TILE_SIZE = 512
|
95 |
+
return ENCODER_TILE_SIZE
|
96 |
+
|
97 |
+
|
98 |
+
def get_recommend_decoder_tile_size():
|
99 |
+
if torch.cuda.is_available():
|
100 |
+
total_memory = torch.cuda.get_device_properties(
|
101 |
+
devices.device).total_memory // 2**20
|
102 |
+
if total_memory > 30*1000:
|
103 |
+
DECODER_TILE_SIZE = 256
|
104 |
+
elif total_memory > 16*1000:
|
105 |
+
DECODER_TILE_SIZE = 192
|
106 |
+
elif total_memory > 12*1000:
|
107 |
+
DECODER_TILE_SIZE = 128
|
108 |
+
elif total_memory > 8*1000:
|
109 |
+
DECODER_TILE_SIZE = 96
|
110 |
+
else:
|
111 |
+
DECODER_TILE_SIZE = 64
|
112 |
+
else:
|
113 |
+
DECODER_TILE_SIZE = 64
|
114 |
+
return DECODER_TILE_SIZE
|
115 |
+
|
116 |
+
|
117 |
+
if 'global const':
|
118 |
+
DEFAULT_ENABLED = False
|
119 |
+
DEFAULT_MOVE_TO_GPU = False
|
120 |
+
DEFAULT_FAST_ENCODER = True
|
121 |
+
DEFAULT_FAST_DECODER = True
|
122 |
+
DEFAULT_COLOR_FIX = 0
|
123 |
+
DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size()
|
124 |
+
DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size()
|
125 |
+
|
126 |
+
|
127 |
+
# inplace version of silu
|
128 |
+
def inplace_nonlinearity(x):
|
129 |
+
# Test: fix for Nans
|
130 |
+
return F.silu(x, inplace=True)
|
131 |
+
|
132 |
+
# extracted from ldm.modules.diffusionmodules.model
|
133 |
+
|
134 |
+
# from diffusers lib
|
135 |
+
def attn_forward_new(self, h_):
|
136 |
+
batch_size, channel, height, width = h_.shape
|
137 |
+
hidden_states = h_.view(batch_size, channel, height * width).transpose(1, 2)
|
138 |
+
|
139 |
+
attention_mask = None
|
140 |
+
encoder_hidden_states = None
|
141 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
142 |
+
attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
143 |
+
|
144 |
+
query = self.to_q(hidden_states)
|
145 |
+
|
146 |
+
if encoder_hidden_states is None:
|
147 |
+
encoder_hidden_states = hidden_states
|
148 |
+
elif self.norm_cross:
|
149 |
+
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
|
150 |
+
|
151 |
+
key = self.to_k(encoder_hidden_states)
|
152 |
+
value = self.to_v(encoder_hidden_states)
|
153 |
+
|
154 |
+
query = self.head_to_batch_dim(query)
|
155 |
+
key = self.head_to_batch_dim(key)
|
156 |
+
value = self.head_to_batch_dim(value)
|
157 |
+
|
158 |
+
attention_probs = self.get_attention_scores(query, key, attention_mask)
|
159 |
+
hidden_states = torch.bmm(attention_probs, value)
|
160 |
+
hidden_states = self.batch_to_head_dim(hidden_states)
|
161 |
+
|
162 |
+
# linear proj
|
163 |
+
hidden_states = self.to_out[0](hidden_states)
|
164 |
+
# dropout
|
165 |
+
hidden_states = self.to_out[1](hidden_states)
|
166 |
+
|
167 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
168 |
+
|
169 |
+
return hidden_states
|
170 |
+
|
171 |
+
def attn_forward_new_pt2_0(self, hidden_states,):
|
172 |
+
scale = 1
|
173 |
+
attention_mask = None
|
174 |
+
encoder_hidden_states = None
|
175 |
+
|
176 |
+
input_ndim = hidden_states.ndim
|
177 |
+
|
178 |
+
if input_ndim == 4:
|
179 |
+
batch_size, channel, height, width = hidden_states.shape
|
180 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
181 |
+
|
182 |
+
batch_size, sequence_length, _ = (
|
183 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
184 |
+
)
|
185 |
+
|
186 |
+
if attention_mask is not None:
|
187 |
+
attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
188 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
189 |
+
# (batch, heads, source_length, target_length)
|
190 |
+
attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1])
|
191 |
+
|
192 |
+
if self.group_norm is not None:
|
193 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
194 |
+
|
195 |
+
query = self.to_q(hidden_states, scale=scale)
|
196 |
+
|
197 |
+
if encoder_hidden_states is None:
|
198 |
+
encoder_hidden_states = hidden_states
|
199 |
+
elif self.norm_cross:
|
200 |
+
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
|
201 |
+
|
202 |
+
key = self.to_k(encoder_hidden_states, scale=scale)
|
203 |
+
value = self.to_v(encoder_hidden_states, scale=scale)
|
204 |
+
|
205 |
+
inner_dim = key.shape[-1]
|
206 |
+
head_dim = inner_dim // self.heads
|
207 |
+
|
208 |
+
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
209 |
+
|
210 |
+
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
211 |
+
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
212 |
+
|
213 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
214 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
215 |
+
hidden_states = F.scaled_dot_product_attention(
|
216 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
217 |
+
)
|
218 |
+
|
219 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
|
220 |
+
hidden_states = hidden_states.to(query.dtype)
|
221 |
+
|
222 |
+
# linear proj
|
223 |
+
hidden_states = self.to_out[0](hidden_states, scale=scale)
|
224 |
+
# dropout
|
225 |
+
hidden_states = self.to_out[1](hidden_states)
|
226 |
+
|
227 |
+
if input_ndim == 4:
|
228 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
229 |
+
|
230 |
+
return hidden_states
|
231 |
+
|
232 |
+
def attn_forward_new_xformers(self, hidden_states):
|
233 |
+
scale = 1
|
234 |
+
attention_op = None
|
235 |
+
attention_mask = None
|
236 |
+
encoder_hidden_states = None
|
237 |
+
|
238 |
+
input_ndim = hidden_states.ndim
|
239 |
+
|
240 |
+
if input_ndim == 4:
|
241 |
+
batch_size, channel, height, width = hidden_states.shape
|
242 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
243 |
+
|
244 |
+
batch_size, key_tokens, _ = (
|
245 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
246 |
+
)
|
247 |
+
|
248 |
+
attention_mask = self.prepare_attention_mask(attention_mask, key_tokens, batch_size)
|
249 |
+
if attention_mask is not None:
|
250 |
+
# expand our mask's singleton query_tokens dimension:
|
251 |
+
# [batch*heads, 1, key_tokens] ->
|
252 |
+
# [batch*heads, query_tokens, key_tokens]
|
253 |
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
254 |
+
# [batch*heads, query_tokens, key_tokens]
|
255 |
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
256 |
+
_, query_tokens, _ = hidden_states.shape
|
257 |
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
258 |
+
|
259 |
+
if self.group_norm is not None:
|
260 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
261 |
+
|
262 |
+
query = self.to_q(hidden_states, scale=scale)
|
263 |
+
|
264 |
+
if encoder_hidden_states is None:
|
265 |
+
encoder_hidden_states = hidden_states
|
266 |
+
elif self.norm_cross:
|
267 |
+
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
|
268 |
+
|
269 |
+
key = self.to_k(encoder_hidden_states, scale=scale)
|
270 |
+
value = self.to_v(encoder_hidden_states, scale=scale)
|
271 |
+
|
272 |
+
query = self.head_to_batch_dim(query).contiguous()
|
273 |
+
key = self.head_to_batch_dim(key).contiguous()
|
274 |
+
value = self.head_to_batch_dim(value).contiguous()
|
275 |
+
|
276 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
277 |
+
query, key, value, attn_bias=attention_mask, op=attention_op#, scale=scale
|
278 |
+
)
|
279 |
+
hidden_states = hidden_states.to(query.dtype)
|
280 |
+
hidden_states = self.batch_to_head_dim(hidden_states)
|
281 |
+
|
282 |
+
# linear proj
|
283 |
+
hidden_states = self.to_out[0](hidden_states, scale=scale)
|
284 |
+
# dropout
|
285 |
+
hidden_states = self.to_out[1](hidden_states)
|
286 |
+
|
287 |
+
if input_ndim == 4:
|
288 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
289 |
+
|
290 |
+
return hidden_states
|
291 |
+
|
292 |
+
def attn_forward(self, h_):
|
293 |
+
q = self.q(h_)
|
294 |
+
k = self.k(h_)
|
295 |
+
v = self.v(h_)
|
296 |
+
|
297 |
+
# compute attention
|
298 |
+
b, c, h, w = q.shape
|
299 |
+
q = q.reshape(b, c, h*w)
|
300 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
301 |
+
k = k.reshape(b, c, h*w) # b,c,hw
|
302 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
303 |
+
w_ = w_ * (int(c)**(-0.5))
|
304 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
305 |
+
|
306 |
+
# attend to values
|
307 |
+
v = v.reshape(b, c, h*w)
|
308 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
309 |
+
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
310 |
+
h_ = torch.bmm(v, w_)
|
311 |
+
h_ = h_.reshape(b, c, h, w)
|
312 |
+
|
313 |
+
h_ = self.proj_out(h_)
|
314 |
+
|
315 |
+
return h_
|
316 |
+
|
317 |
+
|
318 |
+
def xformer_attn_forward(self, h_):
|
319 |
+
q = self.q(h_)
|
320 |
+
k = self.k(h_)
|
321 |
+
v = self.v(h_)
|
322 |
+
|
323 |
+
# compute attention
|
324 |
+
B, C, H, W = q.shape
|
325 |
+
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
326 |
+
|
327 |
+
q, k, v = map(
|
328 |
+
lambda t: t.unsqueeze(3)
|
329 |
+
.reshape(B, t.shape[1], 1, C)
|
330 |
+
.permute(0, 2, 1, 3)
|
331 |
+
.reshape(B * 1, t.shape[1], C)
|
332 |
+
.contiguous(),
|
333 |
+
(q, k, v),
|
334 |
+
)
|
335 |
+
out = xformers.ops.memory_efficient_attention(
|
336 |
+
q, k, v, attn_bias=None, op=self.attention_op)
|
337 |
+
|
338 |
+
out = (
|
339 |
+
out.unsqueeze(0)
|
340 |
+
.reshape(B, 1, out.shape[1], C)
|
341 |
+
.permute(0, 2, 1, 3)
|
342 |
+
.reshape(B, out.shape[1], C)
|
343 |
+
)
|
344 |
+
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
345 |
+
out = self.proj_out(out)
|
346 |
+
return out
|
347 |
+
|
348 |
+
|
349 |
+
def attn2task(task_queue, net):
|
350 |
+
if False: #isinstance(net, AttnBlock):
|
351 |
+
task_queue.append(('store_res', lambda x: x))
|
352 |
+
task_queue.append(('pre_norm', net.norm))
|
353 |
+
task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
|
354 |
+
task_queue.append(['add_res', None])
|
355 |
+
elif False: #isinstance(net, MemoryEfficientAttnBlock):
|
356 |
+
task_queue.append(('store_res', lambda x: x))
|
357 |
+
task_queue.append(('pre_norm', net.norm))
|
358 |
+
task_queue.append(
|
359 |
+
('attn', lambda x, net=net: xformer_attn_forward(net, x)))
|
360 |
+
task_queue.append(['add_res', None])
|
361 |
+
else:
|
362 |
+
task_queue.append(('store_res', lambda x: x))
|
363 |
+
task_queue.append(('pre_norm', net.norm))
|
364 |
+
if is_xformers_available:
|
365 |
+
# task_queue.append(('attn', lambda x, net=net: attn_forward_new_xformers(net, x)))
|
366 |
+
task_queue.append(
|
367 |
+
('attn', lambda x, net=net: xformer_attn_forward(net, x)))
|
368 |
+
elif hasattr(F, "scaled_dot_product_attention"):
|
369 |
+
task_queue.append(('attn', lambda x, net=net: attn_forward_new_pt2_0(net, x)))
|
370 |
+
else:
|
371 |
+
task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x)))
|
372 |
+
task_queue.append(['add_res', None])
|
373 |
+
|
374 |
+
def resblock2task(queue, block):
|
375 |
+
"""
|
376 |
+
Turn a ResNetBlock into a sequence of tasks and append to the task queue
|
377 |
+
|
378 |
+
@param queue: the target task queue
|
379 |
+
@param block: ResNetBlock
|
380 |
+
|
381 |
+
"""
|
382 |
+
if block.in_channels != block.out_channels:
|
383 |
+
if sd_flag:
|
384 |
+
if block.use_conv_shortcut:
|
385 |
+
queue.append(('store_res', block.conv_shortcut))
|
386 |
+
else:
|
387 |
+
queue.append(('store_res', block.nin_shortcut))
|
388 |
+
else:
|
389 |
+
if block.use_in_shortcut:
|
390 |
+
queue.append(('store_res', block.conv_shortcut))
|
391 |
+
else:
|
392 |
+
queue.append(('store_res', block.nin_shortcut))
|
393 |
+
|
394 |
+
else:
|
395 |
+
queue.append(('store_res', lambda x: x))
|
396 |
+
queue.append(('pre_norm', block.norm1))
|
397 |
+
queue.append(('silu', inplace_nonlinearity))
|
398 |
+
queue.append(('conv1', block.conv1))
|
399 |
+
queue.append(('pre_norm', block.norm2))
|
400 |
+
queue.append(('silu', inplace_nonlinearity))
|
401 |
+
queue.append(('conv2', block.conv2))
|
402 |
+
queue.append(['add_res', None])
|
403 |
+
|
404 |
+
|
405 |
+
def build_sampling(task_queue, net, is_decoder):
|
406 |
+
"""
|
407 |
+
Build the sampling part of a task queue
|
408 |
+
@param task_queue: the target task queue
|
409 |
+
@param net: the network
|
410 |
+
@param is_decoder: currently building decoder or encoder
|
411 |
+
"""
|
412 |
+
if is_decoder:
|
413 |
+
if sd_flag:
|
414 |
+
resblock2task(task_queue, net.mid.block_1)
|
415 |
+
attn2task(task_queue, net.mid.attn_1)
|
416 |
+
print(task_queue)
|
417 |
+
resblock2task(task_queue, net.mid.block_2)
|
418 |
+
resolution_iter = reversed(range(net.num_resolutions))
|
419 |
+
block_ids = net.num_res_blocks + 1
|
420 |
+
condition = 0
|
421 |
+
module = net.up
|
422 |
+
func_name = 'upsample'
|
423 |
+
else:
|
424 |
+
resblock2task(task_queue, net.mid_block.resnets[0])
|
425 |
+
attn2task(task_queue, net.mid_block.attentions[0])
|
426 |
+
resblock2task(task_queue, net.mid_block.resnets[1])
|
427 |
+
resolution_iter = (range(len(net.up_blocks))) # net.num_resolutions = 3
|
428 |
+
block_ids = 2 + 1
|
429 |
+
condition = len(net.up_blocks) - 1
|
430 |
+
module = net.up_blocks
|
431 |
+
func_name = 'upsamplers'
|
432 |
+
else:
|
433 |
+
if sd_flag:
|
434 |
+
resolution_iter = range(net.num_resolutions)
|
435 |
+
block_ids = net.num_res_blocks
|
436 |
+
condition = net.num_resolutions - 1
|
437 |
+
module = net.down
|
438 |
+
func_name = 'downsample'
|
439 |
+
else:
|
440 |
+
resolution_iter = range(len(net.down_blocks))
|
441 |
+
block_ids = 2
|
442 |
+
condition = len(net.down_blocks) - 1
|
443 |
+
module = net.down_blocks
|
444 |
+
func_name = 'downsamplers'
|
445 |
+
|
446 |
+
for i_level in resolution_iter:
|
447 |
+
for i_block in range(block_ids):
|
448 |
+
if sd_flag:
|
449 |
+
resblock2task(task_queue, module[i_level].block[i_block])
|
450 |
+
else:
|
451 |
+
resblock2task(task_queue, module[i_level].resnets[i_block])
|
452 |
+
if i_level != condition:
|
453 |
+
if sd_flag:
|
454 |
+
task_queue.append((func_name, getattr(module[i_level], func_name)))
|
455 |
+
else:
|
456 |
+
if is_decoder:
|
457 |
+
task_queue.append((func_name, module[i_level].upsamplers[0]))
|
458 |
+
else:
|
459 |
+
task_queue.append((func_name, module[i_level].downsamplers[0]))
|
460 |
+
|
461 |
+
if not is_decoder:
|
462 |
+
if sd_flag:
|
463 |
+
resblock2task(task_queue, net.mid.block_1)
|
464 |
+
attn2task(task_queue, net.mid.attn_1)
|
465 |
+
resblock2task(task_queue, net.mid.block_2)
|
466 |
+
else:
|
467 |
+
resblock2task(task_queue, net.mid_block.resnets[0])
|
468 |
+
attn2task(task_queue, net.mid_block.attentions[0])
|
469 |
+
resblock2task(task_queue, net.mid_block.resnets[1])
|
470 |
+
|
471 |
+
|
472 |
+
def build_task_queue(net, is_decoder):
|
473 |
+
"""
|
474 |
+
Build a single task queue for the encoder or decoder
|
475 |
+
@param net: the VAE decoder or encoder network
|
476 |
+
@param is_decoder: currently building decoder or encoder
|
477 |
+
@return: the task queue
|
478 |
+
"""
|
479 |
+
task_queue = []
|
480 |
+
task_queue.append(('conv_in', net.conv_in))
|
481 |
+
|
482 |
+
# construct the sampling part of the task queue
|
483 |
+
# because encoder and decoder share the same architecture, we extract the sampling part
|
484 |
+
build_sampling(task_queue, net, is_decoder)
|
485 |
+
if is_decoder and not sd_flag:
|
486 |
+
net.give_pre_end = False
|
487 |
+
net.tanh_out = False
|
488 |
+
|
489 |
+
if not is_decoder or not net.give_pre_end:
|
490 |
+
if sd_flag:
|
491 |
+
task_queue.append(('pre_norm', net.norm_out))
|
492 |
+
else:
|
493 |
+
task_queue.append(('pre_norm', net.conv_norm_out))
|
494 |
+
task_queue.append(('silu', inplace_nonlinearity))
|
495 |
+
task_queue.append(('conv_out', net.conv_out))
|
496 |
+
if is_decoder and net.tanh_out:
|
497 |
+
task_queue.append(('tanh', torch.tanh))
|
498 |
+
|
499 |
+
return task_queue
|
500 |
+
|
501 |
+
|
502 |
+
def clone_task_queue(task_queue):
|
503 |
+
"""
|
504 |
+
Clone a task queue
|
505 |
+
@param task_queue: the task queue to be cloned
|
506 |
+
@return: the cloned task queue
|
507 |
+
"""
|
508 |
+
return [[item for item in task] for task in task_queue]
|
509 |
+
|
510 |
+
|
511 |
+
def get_var_mean(input, num_groups, eps=1e-6):
|
512 |
+
"""
|
513 |
+
Get mean and var for group norm
|
514 |
+
"""
|
515 |
+
b, c = input.size(0), input.size(1)
|
516 |
+
channel_in_group = int(c/num_groups)
|
517 |
+
input_reshaped = input.contiguous().view(
|
518 |
+
1, int(b * num_groups), channel_in_group, *input.size()[2:])
|
519 |
+
var, mean = torch.var_mean(
|
520 |
+
input_reshaped, dim=[0, 2, 3, 4], unbiased=False)
|
521 |
+
return var, mean
|
522 |
+
|
523 |
+
|
524 |
+
def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6):
|
525 |
+
"""
|
526 |
+
Custom group norm with fixed mean and var
|
527 |
+
|
528 |
+
@param input: input tensor
|
529 |
+
@param num_groups: number of groups. by default, num_groups = 32
|
530 |
+
@param mean: mean, must be pre-calculated by get_var_mean
|
531 |
+
@param var: var, must be pre-calculated by get_var_mean
|
532 |
+
@param weight: weight, should be fetched from the original group norm
|
533 |
+
@param bias: bias, should be fetched from the original group norm
|
534 |
+
@param eps: epsilon, by default, eps = 1e-6 to match the original group norm
|
535 |
+
|
536 |
+
@return: normalized tensor
|
537 |
+
"""
|
538 |
+
b, c = input.size(0), input.size(1)
|
539 |
+
channel_in_group = int(c/num_groups)
|
540 |
+
input_reshaped = input.contiguous().view(
|
541 |
+
1, int(b * num_groups), channel_in_group, *input.size()[2:])
|
542 |
+
|
543 |
+
out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None,
|
544 |
+
training=False, momentum=0, eps=eps)
|
545 |
+
|
546 |
+
out = out.view(b, c, *input.size()[2:])
|
547 |
+
|
548 |
+
# post affine transform
|
549 |
+
if weight is not None:
|
550 |
+
out *= weight.view(1, -1, 1, 1)
|
551 |
+
if bias is not None:
|
552 |
+
out += bias.view(1, -1, 1, 1)
|
553 |
+
return out
|
554 |
+
|
555 |
+
|
556 |
+
def crop_valid_region(x, input_bbox, target_bbox, is_decoder):
|
557 |
+
"""
|
558 |
+
Crop the valid region from the tile
|
559 |
+
@param x: input tile
|
560 |
+
@param input_bbox: original input bounding box
|
561 |
+
@param target_bbox: output bounding box
|
562 |
+
@param scale: scale factor
|
563 |
+
@return: cropped tile
|
564 |
+
"""
|
565 |
+
padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox]
|
566 |
+
margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]
|
567 |
+
return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]]
|
568 |
+
|
569 |
+
# ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓
|
570 |
+
|
571 |
+
|
572 |
+
def perfcount(fn):
|
573 |
+
def wrapper(*args, **kwargs):
|
574 |
+
ts = time()
|
575 |
+
|
576 |
+
if torch.cuda.is_available():
|
577 |
+
torch.cuda.reset_peak_memory_stats(devices.device)
|
578 |
+
devices.torch_gc()
|
579 |
+
gc.collect()
|
580 |
+
|
581 |
+
ret = fn(*args, **kwargs)
|
582 |
+
|
583 |
+
devices.torch_gc()
|
584 |
+
gc.collect()
|
585 |
+
if torch.cuda.is_available():
|
586 |
+
vram = torch.cuda.max_memory_allocated(devices.device) / 2**20
|
587 |
+
torch.cuda.reset_peak_memory_stats(devices.device)
|
588 |
+
print(
|
589 |
+
f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB')
|
590 |
+
else:
|
591 |
+
print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')
|
592 |
+
|
593 |
+
return ret
|
594 |
+
return wrapper
|
595 |
+
|
596 |
+
# copy end :)
|
597 |
+
|
598 |
+
|
599 |
+
class GroupNormParam:
|
600 |
+
def __init__(self):
|
601 |
+
self.var_list = []
|
602 |
+
self.mean_list = []
|
603 |
+
self.pixel_list = []
|
604 |
+
self.weight = None
|
605 |
+
self.bias = None
|
606 |
+
|
607 |
+
def add_tile(self, tile, layer):
|
608 |
+
var, mean = get_var_mean(tile, 32)
|
609 |
+
# For giant images, the variance can be larger than max float16
|
610 |
+
# In this case we create a copy to float32
|
611 |
+
if var.dtype == torch.float16 and var.isinf().any():
|
612 |
+
fp32_tile = tile.float()
|
613 |
+
var, mean = get_var_mean(fp32_tile, 32)
|
614 |
+
# ============= DEBUG: test for infinite =============
|
615 |
+
# if torch.isinf(var).any():
|
616 |
+
# print('var: ', var)
|
617 |
+
# ====================================================
|
618 |
+
self.var_list.append(var)
|
619 |
+
self.mean_list.append(mean)
|
620 |
+
self.pixel_list.append(
|
621 |
+
tile.shape[2]*tile.shape[3])
|
622 |
+
if hasattr(layer, 'weight'):
|
623 |
+
self.weight = layer.weight
|
624 |
+
self.bias = layer.bias
|
625 |
+
else:
|
626 |
+
self.weight = None
|
627 |
+
self.bias = None
|
628 |
+
|
629 |
+
def summary(self):
|
630 |
+
"""
|
631 |
+
summarize the mean and var and return a function
|
632 |
+
that apply group norm on each tile
|
633 |
+
"""
|
634 |
+
if len(self.var_list) == 0:
|
635 |
+
return None
|
636 |
+
var = torch.vstack(self.var_list)
|
637 |
+
mean = torch.vstack(self.mean_list)
|
638 |
+
max_value = max(self.pixel_list)
|
639 |
+
pixels = torch.tensor(
|
640 |
+
self.pixel_list, dtype=torch.float32, device=devices.device) / max_value
|
641 |
+
sum_pixels = torch.sum(pixels)
|
642 |
+
pixels = pixels.unsqueeze(
|
643 |
+
1) / sum_pixels
|
644 |
+
var = torch.sum(
|
645 |
+
var * pixels, dim=0)
|
646 |
+
mean = torch.sum(
|
647 |
+
mean * pixels, dim=0)
|
648 |
+
return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias)
|
649 |
+
|
650 |
+
@staticmethod
|
651 |
+
def from_tile(tile, norm):
|
652 |
+
"""
|
653 |
+
create a function from a single tile without summary
|
654 |
+
"""
|
655 |
+
var, mean = get_var_mean(tile, 32)
|
656 |
+
if var.dtype == torch.float16 and var.isinf().any():
|
657 |
+
fp32_tile = tile.float()
|
658 |
+
var, mean = get_var_mean(fp32_tile, 32)
|
659 |
+
# if it is a macbook, we need to convert back to float16
|
660 |
+
if var.device.type == 'mps':
|
661 |
+
# clamp to avoid overflow
|
662 |
+
var = torch.clamp(var, 0, 60000)
|
663 |
+
var = var.half()
|
664 |
+
mean = mean.half()
|
665 |
+
if hasattr(norm, 'weight'):
|
666 |
+
weight = norm.weight
|
667 |
+
bias = norm.bias
|
668 |
+
else:
|
669 |
+
weight = None
|
670 |
+
bias = None
|
671 |
+
|
672 |
+
def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):
|
673 |
+
return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)
|
674 |
+
return group_norm_func
|
675 |
+
|
676 |
+
|
677 |
+
class VAEHook:
|
678 |
+
def __init__(self, net, tile_size, is_decoder, fast_decoder, fast_encoder, color_fix, to_gpu=False):
|
679 |
+
self.net = net # encoder | decoder
|
680 |
+
self.tile_size = tile_size
|
681 |
+
self.is_decoder = is_decoder
|
682 |
+
self.fast_mode = (fast_encoder and not is_decoder) or (
|
683 |
+
fast_decoder and is_decoder)
|
684 |
+
self.color_fix = color_fix and not is_decoder
|
685 |
+
self.to_gpu = to_gpu
|
686 |
+
self.pad = 11 if is_decoder else 32
|
687 |
+
|
688 |
+
def __call__(self, x):
|
689 |
+
B, C, H, W = x.shape
|
690 |
+
original_device = next(self.net.parameters()).device
|
691 |
+
try:
|
692 |
+
if self.to_gpu:
|
693 |
+
self.net.to(devices.get_optimal_device())
|
694 |
+
if max(H, W) <= self.pad * 2 + self.tile_size:
|
695 |
+
print("[Tiled VAE]: the input size is tiny and unnecessary to tile.")
|
696 |
+
return self.net.original_forward(x)
|
697 |
+
else:
|
698 |
+
return self.vae_tile_forward(x)
|
699 |
+
finally:
|
700 |
+
self.net.to(original_device)
|
701 |
+
|
702 |
+
def get_best_tile_size(self, lowerbound, upperbound):
|
703 |
+
"""
|
704 |
+
Get the best tile size for GPU memory
|
705 |
+
"""
|
706 |
+
divider = 32
|
707 |
+
while divider >= 2:
|
708 |
+
remainer = lowerbound % divider
|
709 |
+
if remainer == 0:
|
710 |
+
return lowerbound
|
711 |
+
candidate = lowerbound - remainer + divider
|
712 |
+
if candidate <= upperbound:
|
713 |
+
return candidate
|
714 |
+
divider //= 2
|
715 |
+
return lowerbound
|
716 |
+
|
717 |
+
def split_tiles(self, h, w):
|
718 |
+
"""
|
719 |
+
Tool function to split the image into tiles
|
720 |
+
@param h: height of the image
|
721 |
+
@param w: width of the image
|
722 |
+
@return: tile_input_bboxes, tile_output_bboxes
|
723 |
+
"""
|
724 |
+
tile_input_bboxes, tile_output_bboxes = [], []
|
725 |
+
tile_size = self.tile_size
|
726 |
+
pad = self.pad
|
727 |
+
num_height_tiles = math.ceil((h - 2 * pad) / tile_size)
|
728 |
+
num_width_tiles = math.ceil((w - 2 * pad) / tile_size)
|
729 |
+
# If any of the numbers are 0, we let it be 1
|
730 |
+
# This is to deal with long and thin images
|
731 |
+
num_height_tiles = max(num_height_tiles, 1)
|
732 |
+
num_width_tiles = max(num_width_tiles, 1)
|
733 |
+
|
734 |
+
# Suggestions from https://github.com/Kahsolt: auto shrink the tile size
|
735 |
+
real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)
|
736 |
+
real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)
|
737 |
+
real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)
|
738 |
+
real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)
|
739 |
+
|
740 |
+
print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' +
|
741 |
+
f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}')
|
742 |
+
|
743 |
+
for i in range(num_height_tiles):
|
744 |
+
for j in range(num_width_tiles):
|
745 |
+
# bbox: [x1, x2, y1, y2]
|
746 |
+
# the padding is is unnessary for image borders. So we directly start from (32, 32)
|
747 |
+
input_bbox = [
|
748 |
+
pad + j * real_tile_width,
|
749 |
+
min(pad + (j + 1) * real_tile_width, w),
|
750 |
+
pad + i * real_tile_height,
|
751 |
+
min(pad + (i + 1) * real_tile_height, h),
|
752 |
+
]
|
753 |
+
|
754 |
+
# if the output bbox is close to the image boundary, we extend it to the image boundary
|
755 |
+
output_bbox = [
|
756 |
+
input_bbox[0] if input_bbox[0] > pad else 0,
|
757 |
+
input_bbox[1] if input_bbox[1] < w - pad else w,
|
758 |
+
input_bbox[2] if input_bbox[2] > pad else 0,
|
759 |
+
input_bbox[3] if input_bbox[3] < h - pad else h,
|
760 |
+
]
|
761 |
+
|
762 |
+
# scale to get the final output bbox
|
763 |
+
output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox]
|
764 |
+
tile_output_bboxes.append(output_bbox)
|
765 |
+
|
766 |
+
# indistinguishable expand the input bbox by pad pixels
|
767 |
+
tile_input_bboxes.append([
|
768 |
+
max(0, input_bbox[0] - pad),
|
769 |
+
min(w, input_bbox[1] + pad),
|
770 |
+
max(0, input_bbox[2] - pad),
|
771 |
+
min(h, input_bbox[3] + pad),
|
772 |
+
])
|
773 |
+
|
774 |
+
return tile_input_bboxes, tile_output_bboxes
|
775 |
+
|
776 |
+
@torch.no_grad()
|
777 |
+
def estimate_group_norm(self, z, task_queue, color_fix):
|
778 |
+
device = z.device
|
779 |
+
tile = z
|
780 |
+
last_id = len(task_queue) - 1
|
781 |
+
while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':
|
782 |
+
last_id -= 1
|
783 |
+
if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':
|
784 |
+
raise ValueError('No group norm found in the task queue')
|
785 |
+
# estimate until the last group norm
|
786 |
+
for i in range(last_id + 1):
|
787 |
+
task = task_queue[i]
|
788 |
+
if task[0] == 'pre_norm':
|
789 |
+
group_norm_func = GroupNormParam.from_tile(tile, task[1])
|
790 |
+
task_queue[i] = ('apply_norm', group_norm_func)
|
791 |
+
if i == last_id:
|
792 |
+
return True
|
793 |
+
tile = group_norm_func(tile)
|
794 |
+
elif task[0] == 'store_res':
|
795 |
+
task_id = i + 1
|
796 |
+
while task_id < last_id and task_queue[task_id][0] != 'add_res':
|
797 |
+
task_id += 1
|
798 |
+
if task_id >= last_id:
|
799 |
+
continue
|
800 |
+
task_queue[task_id][1] = task[1](tile)
|
801 |
+
elif task[0] == 'add_res':
|
802 |
+
tile += task[1].to(device)
|
803 |
+
task[1] = None
|
804 |
+
elif color_fix and task[0] == 'downsample':
|
805 |
+
for j in range(i, last_id + 1):
|
806 |
+
if task_queue[j][0] == 'store_res':
|
807 |
+
task_queue[j] = ('store_res_cpu', task_queue[j][1])
|
808 |
+
return True
|
809 |
+
else:
|
810 |
+
tile = task[1](tile)
|
811 |
+
try:
|
812 |
+
devices.test_for_nans(tile, "vae")
|
813 |
+
except:
|
814 |
+
print(f'Nan detected in fast mode estimation. Fast mode disabled.')
|
815 |
+
return False
|
816 |
+
|
817 |
+
raise IndexError('Should not reach here')
|
818 |
+
|
819 |
+
@perfcount
|
820 |
+
@torch.no_grad()
|
821 |
+
def vae_tile_forward(self, z):
|
822 |
+
"""
|
823 |
+
Decode a latent vector z into an image in a tiled manner.
|
824 |
+
@param z: latent vector
|
825 |
+
@return: image
|
826 |
+
"""
|
827 |
+
device = next(self.net.parameters()).device
|
828 |
+
dtype = z.dtype
|
829 |
+
net = self.net
|
830 |
+
tile_size = self.tile_size
|
831 |
+
is_decoder = self.is_decoder
|
832 |
+
|
833 |
+
z = z.detach() # detach the input to avoid backprop
|
834 |
+
|
835 |
+
N, height, width = z.shape[0], z.shape[2], z.shape[3]
|
836 |
+
net.last_z_shape = z.shape
|
837 |
+
|
838 |
+
# Split the input into tiles and build a task queue for each tile
|
839 |
+
print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}')
|
840 |
+
|
841 |
+
in_bboxes, out_bboxes = self.split_tiles(height, width)
|
842 |
+
|
843 |
+
# Prepare tiles by split the input latents
|
844 |
+
tiles = []
|
845 |
+
for input_bbox in in_bboxes:
|
846 |
+
tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu()
|
847 |
+
tiles.append(tile)
|
848 |
+
|
849 |
+
num_tiles = len(tiles)
|
850 |
+
num_completed = 0
|
851 |
+
|
852 |
+
# Build task queues
|
853 |
+
single_task_queue = build_task_queue(net, is_decoder)
|
854 |
+
#print(single_task_queue)
|
855 |
+
if self.fast_mode:
|
856 |
+
# Fast mode: downsample the input image to the tile size,
|
857 |
+
# then estimate the group norm parameters on the downsampled image
|
858 |
+
scale_factor = tile_size / max(height, width)
|
859 |
+
z = z.to(device)
|
860 |
+
downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact')
|
861 |
+
# use nearest-exact to keep statictics as close as possible
|
862 |
+
print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')
|
863 |
+
|
864 |
+
# ======= Special thanks to @Kahsolt for distribution shift issue ======= #
|
865 |
+
# The downsampling will heavily distort its mean and std, so we need to recover it.
|
866 |
+
std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)
|
867 |
+
std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True)
|
868 |
+
downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old
|
869 |
+
del std_old, mean_old, std_new, mean_new
|
870 |
+
# occasionally the std_new is too small or too large, which exceeds the range of float16
|
871 |
+
# so we need to clamp it to max z's range.
|
872 |
+
downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max())
|
873 |
+
estimate_task_queue = clone_task_queue(single_task_queue)
|
874 |
+
if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix):
|
875 |
+
single_task_queue = estimate_task_queue
|
876 |
+
del downsampled_z
|
877 |
+
|
878 |
+
task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)]
|
879 |
+
|
880 |
+
# Dummy result
|
881 |
+
result = None
|
882 |
+
result_approx = None
|
883 |
+
#try:
|
884 |
+
# with devices.autocast():
|
885 |
+
# result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu()
|
886 |
+
#except: pass
|
887 |
+
# Free memory of input latent tensor
|
888 |
+
del z
|
889 |
+
|
890 |
+
# Task queue execution
|
891 |
+
pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ")
|
892 |
+
|
893 |
+
# execute the task back and forth when switch tiles so that we always
|
894 |
+
# keep one tile on the GPU to reduce unnecessary data transfer
|
895 |
+
forward = True
|
896 |
+
interrupted = False
|
897 |
+
#state.interrupted = interrupted
|
898 |
+
while True:
|
899 |
+
#if state.interrupted: interrupted = True ; break
|
900 |
+
|
901 |
+
group_norm_param = GroupNormParam()
|
902 |
+
for i in range(num_tiles) if forward else reversed(range(num_tiles)):
|
903 |
+
#if state.interrupted: interrupted = True ; break
|
904 |
+
|
905 |
+
tile = tiles[i].to(device)
|
906 |
+
input_bbox = in_bboxes[i]
|
907 |
+
task_queue = task_queues[i]
|
908 |
+
|
909 |
+
interrupted = False
|
910 |
+
while len(task_queue) > 0:
|
911 |
+
#if state.interrupted: interrupted = True ; break
|
912 |
+
|
913 |
+
# DEBUG: current task
|
914 |
+
# print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape)
|
915 |
+
task = task_queue.pop(0)
|
916 |
+
if task[0] == 'pre_norm':
|
917 |
+
group_norm_param.add_tile(tile, task[1])
|
918 |
+
break
|
919 |
+
elif task[0] == 'store_res' or task[0] == 'store_res_cpu':
|
920 |
+
task_id = 0
|
921 |
+
res = task[1](tile)
|
922 |
+
if not self.fast_mode or task[0] == 'store_res_cpu':
|
923 |
+
res = res.cpu()
|
924 |
+
while task_queue[task_id][0] != 'add_res':
|
925 |
+
task_id += 1
|
926 |
+
task_queue[task_id][1] = res
|
927 |
+
elif task[0] == 'add_res':
|
928 |
+
tile += task[1].to(device)
|
929 |
+
task[1] = None
|
930 |
+
else:
|
931 |
+
tile = task[1](tile)
|
932 |
+
#print(tiles[i].shape, tile.shape, task)
|
933 |
+
pbar.update(1)
|
934 |
+
|
935 |
+
if interrupted: break
|
936 |
+
|
937 |
+
# check for NaNs in the tile.
|
938 |
+
# If there are NaNs, we abort the process to save user's time
|
939 |
+
#devices.test_for_nans(tile, "vae")
|
940 |
+
|
941 |
+
#print(tiles[i].shape, tile.shape, i, num_tiles)
|
942 |
+
if len(task_queue) == 0:
|
943 |
+
tiles[i] = None
|
944 |
+
num_completed += 1
|
945 |
+
if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically
|
946 |
+
result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False)
|
947 |
+
result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder)
|
948 |
+
del tile
|
949 |
+
elif i == num_tiles - 1 and forward:
|
950 |
+
forward = False
|
951 |
+
tiles[i] = tile
|
952 |
+
elif i == 0 and not forward:
|
953 |
+
forward = True
|
954 |
+
tiles[i] = tile
|
955 |
+
else:
|
956 |
+
tiles[i] = tile.cpu()
|
957 |
+
del tile
|
958 |
+
|
959 |
+
if interrupted: break
|
960 |
+
if num_completed == num_tiles: break
|
961 |
+
|
962 |
+
# insert the group norm task to the head of each task queue
|
963 |
+
group_norm_func = group_norm_param.summary()
|
964 |
+
if group_norm_func is not None:
|
965 |
+
for i in range(num_tiles):
|
966 |
+
task_queue = task_queues[i]
|
967 |
+
task_queue.insert(0, ('apply_norm', group_norm_func))
|
968 |
+
|
969 |
+
# Done!
|
970 |
+
pbar.close()
|
971 |
+
return result.to(dtype) if result is not None else result_approx.to(device)
|