Fabrice-TIERCELIN commited on
Commit
a6c349f
·
verified ·
1 Parent(s): 2df5266

Upload 5 files

Browse files
SUPIR/utils/colorfix.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ # --------------------------------------------------------------------------------
3
+ # Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
4
+ # --------------------------------------------------------------------------------
5
+ '''
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from torch import Tensor
10
+ from torch.nn import functional as F
11
+
12
+ from torchvision.transforms import ToTensor, ToPILImage
13
+
14
+ def adain_color_fix(target: Image, source: Image):
15
+ # Convert images to tensors
16
+ to_tensor = ToTensor()
17
+ target_tensor = to_tensor(target).unsqueeze(0)
18
+ source_tensor = to_tensor(source).unsqueeze(0)
19
+
20
+ # Apply adaptive instance normalization
21
+ result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
22
+
23
+ # Convert tensor back to image
24
+ to_image = ToPILImage()
25
+ result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
26
+
27
+ return result_image
28
+
29
+ def wavelet_color_fix(target: Image, source: Image):
30
+ # Convert images to tensors
31
+ to_tensor = ToTensor()
32
+ target_tensor = to_tensor(target).unsqueeze(0)
33
+ source_tensor = to_tensor(source).unsqueeze(0)
34
+
35
+ # Apply wavelet reconstruction
36
+ result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
37
+
38
+ # Convert tensor back to image
39
+ to_image = ToPILImage()
40
+ result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
41
+
42
+ return result_image
43
+
44
+ def calc_mean_std(feat: Tensor, eps=1e-5):
45
+ """Calculate mean and std for adaptive_instance_normalization.
46
+ Args:
47
+ feat (Tensor): 4D tensor.
48
+ eps (float): A small value added to the variance to avoid
49
+ divide-by-zero. Default: 1e-5.
50
+ """
51
+ size = feat.size()
52
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
53
+ b, c = size[:2]
54
+ feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
55
+ feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
56
+ feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
57
+ return feat_mean, feat_std
58
+
59
+ def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
60
+ """Adaptive instance normalization.
61
+ Adjust the reference features to have the similar color and illuminations
62
+ as those in the degradate features.
63
+ Args:
64
+ content_feat (Tensor): The reference feature.
65
+ style_feat (Tensor): The degradate features.
66
+ """
67
+ size = content_feat.size()
68
+ style_mean, style_std = calc_mean_std(style_feat)
69
+ content_mean, content_std = calc_mean_std(content_feat)
70
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
71
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
72
+
73
+ def wavelet_blur(image: Tensor, radius: int):
74
+ """
75
+ Apply wavelet blur to the input tensor.
76
+ """
77
+ # input shape: (1, 3, H, W)
78
+ # convolution kernel
79
+ kernel_vals = [
80
+ [0.0625, 0.125, 0.0625],
81
+ [0.125, 0.25, 0.125],
82
+ [0.0625, 0.125, 0.0625],
83
+ ]
84
+ kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
85
+ # add channel dimensions to the kernel to make it a 4D tensor
86
+ kernel = kernel[None, None]
87
+ # repeat the kernel across all input channels
88
+ kernel = kernel.repeat(3, 1, 1, 1)
89
+ image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
90
+ # apply convolution
91
+ output = F.conv2d(image, kernel, groups=3, dilation=radius)
92
+ return output
93
+
94
+ def wavelet_decomposition(image: Tensor, levels=5):
95
+ """
96
+ Apply wavelet decomposition to the input tensor.
97
+ This function only returns the low frequency & the high frequency.
98
+ """
99
+ high_freq = torch.zeros_like(image)
100
+ for i in range(levels):
101
+ radius = 2 ** i
102
+ low_freq = wavelet_blur(image, radius)
103
+ high_freq += (image - low_freq)
104
+ image = low_freq
105
+
106
+ return high_freq, low_freq
107
+
108
+ def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
109
+ """
110
+ Apply wavelet decomposition, so that the content will have the same color as the style.
111
+ """
112
+ # calculate the wavelet decomposition of the content feature
113
+ content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
114
+ del content_low_freq
115
+ # calculate the wavelet decomposition of the style feature
116
+ style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
117
+ del style_high_freq
118
+ # reconstruct the content feature with the style's high frequency
119
+ return content_high_freq + style_low_freq
120
+
SUPIR/utils/devices.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import contextlib
3
+ from functools import lru_cache
4
+
5
+ import torch
6
+ #from modules import errors
7
+
8
+ if sys.platform == "darwin":
9
+ from modules import mac_specific
10
+
11
+
12
+ def has_mps() -> bool:
13
+ if sys.platform != "darwin":
14
+ return False
15
+ else:
16
+ return mac_specific.has_mps
17
+
18
+
19
+ def get_cuda_device_string():
20
+ return "cuda"
21
+
22
+
23
+ def get_optimal_device_name():
24
+ if torch.cuda.is_available():
25
+ return get_cuda_device_string()
26
+
27
+ if has_mps():
28
+ return "mps"
29
+
30
+ return "cpu"
31
+
32
+
33
+ def get_optimal_device():
34
+ return torch.device(get_optimal_device_name())
35
+
36
+
37
+ def get_device_for(task):
38
+ return get_optimal_device()
39
+
40
+
41
+ def torch_gc():
42
+
43
+ if torch.cuda.is_available():
44
+ with torch.cuda.device(get_cuda_device_string()):
45
+ torch.cuda.empty_cache()
46
+ torch.cuda.ipc_collect()
47
+
48
+ if has_mps():
49
+ mac_specific.torch_mps_gc()
50
+
51
+
52
+ def enable_tf32():
53
+ if torch.cuda.is_available():
54
+
55
+ # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
56
+ # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
57
+ if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
58
+ torch.backends.cudnn.benchmark = True
59
+
60
+ torch.backends.cuda.matmul.allow_tf32 = True
61
+ torch.backends.cudnn.allow_tf32 = True
62
+
63
+
64
+ enable_tf32()
65
+ #errors.run(enable_tf32, "Enabling TF32")
66
+
67
+ cpu = torch.device("cpu")
68
+ device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda")
69
+ dtype = torch.float16
70
+ dtype_vae = torch.float16
71
+ dtype_unet = torch.float16
72
+ unet_needs_upcast = False
73
+
74
+
75
+ def cond_cast_unet(input):
76
+ return input.to(dtype_unet) if unet_needs_upcast else input
77
+
78
+
79
+ def cond_cast_float(input):
80
+ return input.float() if unet_needs_upcast else input
81
+
82
+
83
+ def randn(seed, shape):
84
+ torch.manual_seed(seed)
85
+ return torch.randn(shape, device=device)
86
+
87
+
88
+ def randn_without_seed(shape):
89
+ return torch.randn(shape, device=device)
90
+
91
+
92
+ def autocast(disable=False):
93
+ if disable:
94
+ return contextlib.nullcontext()
95
+
96
+ return torch.autocast("cuda")
97
+
98
+
99
+ def without_autocast(disable=False):
100
+ return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
101
+
102
+
103
+ class NansException(Exception):
104
+ pass
105
+
106
+
107
+ def test_for_nans(x, where):
108
+ if not torch.all(torch.isnan(x)).item():
109
+ return
110
+
111
+ if where == "unet":
112
+ message = "A tensor with all NaNs was produced in Unet."
113
+
114
+ elif where == "vae":
115
+ message = "A tensor with all NaNs was produced in VAE."
116
+
117
+ else:
118
+ message = "A tensor with all NaNs was produced."
119
+
120
+ message += " Use --disable-nan-check commandline argument to disable this check."
121
+
122
+ raise NansException(message)
123
+
124
+
125
+ @lru_cache
126
+ def first_time_calculation():
127
+ """
128
+ just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
129
+ spends about 2.7 seconds doing that, at least wih NVidia.
130
+ """
131
+
132
+ x = torch.zeros((1, 1)).to(device, dtype)
133
+ linear = torch.nn.Linear(1, 1).to(device, dtype)
134
+ linear(x)
135
+
136
+ x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
137
+ conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
138
+ conv2d(x)
SUPIR/utils/face_restoration_helper.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ import torch
5
+ from torchvision.transforms.functional import normalize
6
+
7
+ from facexlib.detection import init_detection_model
8
+ from facexlib.parsing import init_parsing_model
9
+ from facexlib.utils.misc import img2tensor, imwrite
10
+
11
+ from .file import load_file_from_url
12
+
13
+
14
+ def get_largest_face(det_faces, h, w):
15
+ def get_location(val, length):
16
+ if val < 0:
17
+ return 0
18
+ elif val > length:
19
+ return length
20
+ else:
21
+ return val
22
+
23
+ face_areas = []
24
+ for det_face in det_faces:
25
+ left = get_location(det_face[0], w)
26
+ right = get_location(det_face[2], w)
27
+ top = get_location(det_face[1], h)
28
+ bottom = get_location(det_face[3], h)
29
+ face_area = (right - left) * (bottom - top)
30
+ face_areas.append(face_area)
31
+ largest_idx = face_areas.index(max(face_areas))
32
+ return det_faces[largest_idx], largest_idx
33
+
34
+
35
+ def get_center_face(det_faces, h=0, w=0, center=None):
36
+ if center is not None:
37
+ center = np.array(center)
38
+ else:
39
+ center = np.array([w / 2, h / 2])
40
+ center_dist = []
41
+ for det_face in det_faces:
42
+ face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
43
+ dist = np.linalg.norm(face_center - center)
44
+ center_dist.append(dist)
45
+ center_idx = center_dist.index(min(center_dist))
46
+ return det_faces[center_idx], center_idx
47
+
48
+
49
+ class FaceRestoreHelper(object):
50
+ """Helper for the face restoration pipeline (base class)."""
51
+
52
+ def __init__(self,
53
+ upscale_factor,
54
+ face_size=512,
55
+ crop_ratio=(1, 1),
56
+ det_model='retinaface_resnet50',
57
+ save_ext='png',
58
+ template_3points=False,
59
+ pad_blur=False,
60
+ use_parse=False,
61
+ device=None):
62
+ self.template_3points = template_3points # improve robustness
63
+ self.upscale_factor = int(upscale_factor)
64
+ # the cropped face ratio based on the square face
65
+ self.crop_ratio = crop_ratio # (h, w)
66
+ assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
67
+ self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
68
+ self.det_model = det_model
69
+
70
+ if self.det_model == 'dlib':
71
+ # standard 5 landmarks for FFHQ faces with 1024 x 1024
72
+ self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
73
+ [337.91089109, 488.38613861], [437.95049505, 493.51485149],
74
+ [513.58415842, 678.5049505]])
75
+ self.face_template = self.face_template / (1024 // face_size)
76
+ elif self.template_3points:
77
+ self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
78
+ else:
79
+ # standard 5 landmarks for FFHQ faces with 512 x 512
80
+ # facexlib
81
+ self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
82
+ [201.26117, 371.41043], [313.08905, 371.15118]])
83
+
84
+ # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
85
+ # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
86
+ # [198.22603, 372.82502], [313.91018, 372.75659]])
87
+
88
+ self.face_template = self.face_template * (face_size / 512.0)
89
+ if self.crop_ratio[0] > 1:
90
+ self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
91
+ if self.crop_ratio[1] > 1:
92
+ self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
93
+ self.save_ext = save_ext
94
+ self.pad_blur = pad_blur
95
+ if self.pad_blur is True:
96
+ self.template_3points = False
97
+
98
+ self.all_landmarks_5 = []
99
+ self.det_faces = []
100
+ self.affine_matrices = []
101
+ self.inverse_affine_matrices = []
102
+ self.cropped_faces = []
103
+ self.restored_faces = []
104
+ self.pad_input_imgs = []
105
+
106
+ if device is None:
107
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
108
+ # self.device = get_device()
109
+ else:
110
+ self.device = device
111
+
112
+ # init face detection model
113
+ self.face_detector = init_detection_model(det_model, half=False, device=self.device)
114
+
115
+ # init face parsing model
116
+ self.use_parse = use_parse
117
+ self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)
118
+
119
+ def set_upscale_factor(self, upscale_factor):
120
+ self.upscale_factor = upscale_factor
121
+
122
+ def read_image(self, img):
123
+ """img can be image path or cv2 loaded image."""
124
+ # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
125
+ if isinstance(img, str):
126
+ img = cv2.imread(img)
127
+
128
+ if np.max(img) > 256: # 16-bit image
129
+ img = img / 65535 * 255
130
+ if len(img.shape) == 2: # gray image
131
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
132
+ elif img.shape[2] == 4: # BGRA image with alpha channel
133
+ img = img[:, :, 0:3]
134
+
135
+ self.input_img = img
136
+ # self.is_gray = is_gray(img, threshold=10)
137
+ # if self.is_gray:
138
+ # print('Grayscale input: True')
139
+
140
+ if min(self.input_img.shape[:2]) < 512:
141
+ f = 512.0 / min(self.input_img.shape[:2])
142
+ self.input_img = cv2.resize(self.input_img, (0, 0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
143
+
144
+ def init_dlib(self, detection_path, landmark5_path):
145
+ """Initialize the dlib detectors and predictors."""
146
+ try:
147
+ import dlib
148
+ except ImportError:
149
+ print('Please install dlib by running:' 'conda install -c conda-forge dlib')
150
+ detection_path = load_file_from_url(url=detection_path, model_dir='weights/dlib', progress=True, file_name=None)
151
+ landmark5_path = load_file_from_url(url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None)
152
+ face_detector = dlib.cnn_face_detection_model_v1(detection_path)
153
+ shape_predictor_5 = dlib.shape_predictor(landmark5_path)
154
+ return face_detector, shape_predictor_5
155
+
156
+ def get_face_landmarks_5_dlib(self,
157
+ only_keep_largest=False,
158
+ scale=1):
159
+ det_faces = self.face_detector(self.input_img, scale)
160
+
161
+ if len(det_faces) == 0:
162
+ print('No face detected. Try to increase upsample_num_times.')
163
+ return 0
164
+ else:
165
+ if only_keep_largest:
166
+ print('Detect several faces and only keep the largest.')
167
+ face_areas = []
168
+ for i in range(len(det_faces)):
169
+ face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
170
+ det_faces[i].rect.bottom() - det_faces[i].rect.top())
171
+ face_areas.append(face_area)
172
+ largest_idx = face_areas.index(max(face_areas))
173
+ self.det_faces = [det_faces[largest_idx]]
174
+ else:
175
+ self.det_faces = det_faces
176
+
177
+ if len(self.det_faces) == 0:
178
+ return 0
179
+
180
+ for face in self.det_faces:
181
+ shape = self.shape_predictor_5(self.input_img, face.rect)
182
+ landmark = np.array([[part.x, part.y] for part in shape.parts()])
183
+ self.all_landmarks_5.append(landmark)
184
+
185
+ return len(self.all_landmarks_5)
186
+
187
+ def get_face_landmarks_5(self,
188
+ only_keep_largest=False,
189
+ only_center_face=False,
190
+ resize=None,
191
+ blur_ratio=0.01,
192
+ eye_dist_threshold=None):
193
+ if self.det_model == 'dlib':
194
+ return self.get_face_landmarks_5_dlib(only_keep_largest)
195
+
196
+ if resize is None:
197
+ scale = 1
198
+ input_img = self.input_img
199
+ else:
200
+ h, w = self.input_img.shape[0:2]
201
+ scale = resize / min(h, w)
202
+ scale = max(1, scale) # always scale up
203
+ h, w = int(h * scale), int(w * scale)
204
+ interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
205
+ input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
206
+
207
+ with torch.no_grad():
208
+ bboxes = self.face_detector.detect_faces(input_img)
209
+
210
+ if bboxes is None or bboxes.shape[0] == 0:
211
+ return 0
212
+ else:
213
+ bboxes = bboxes / scale
214
+
215
+ for bbox in bboxes:
216
+ # remove faces with too small eye distance: side faces or too small faces
217
+ eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
218
+ if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
219
+ continue
220
+
221
+ if self.template_3points:
222
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
223
+ else:
224
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
225
+ self.all_landmarks_5.append(landmark)
226
+ self.det_faces.append(bbox[0:5])
227
+
228
+ if len(self.det_faces) == 0:
229
+ return 0
230
+ if only_keep_largest:
231
+ h, w, _ = self.input_img.shape
232
+ self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
233
+ self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
234
+ elif only_center_face:
235
+ h, w, _ = self.input_img.shape
236
+ self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
237
+ self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
238
+
239
+ # pad blurry images
240
+ if self.pad_blur:
241
+ self.pad_input_imgs = []
242
+ for landmarks in self.all_landmarks_5:
243
+ # get landmarks
244
+ eye_left = landmarks[0, :]
245
+ eye_right = landmarks[1, :]
246
+ eye_avg = (eye_left + eye_right) * 0.5
247
+ mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
248
+ eye_to_eye = eye_right - eye_left
249
+ eye_to_mouth = mouth_avg - eye_avg
250
+
251
+ # Get the oriented crop rectangle
252
+ # x: half width of the oriented crop rectangle
253
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
254
+ # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
255
+ # norm with the hypotenuse: get the direction
256
+ x /= np.hypot(*x) # get the hypotenuse of a right triangle
257
+ rect_scale = 1.5
258
+ x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
259
+ # y: half height of the oriented crop rectangle
260
+ y = np.flipud(x) * [-1, 1]
261
+
262
+ # c: center
263
+ c = eye_avg + eye_to_mouth * 0.1
264
+ # quad: (left_top, left_bottom, right_bottom, right_top)
265
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
266
+ # qsize: side length of the square
267
+ qsize = np.hypot(*x) * 2
268
+ border = max(int(np.rint(qsize * 0.1)), 3)
269
+
270
+ # get pad
271
+ # pad: (width_left, height_top, width_right, height_bottom)
272
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
273
+ int(np.ceil(max(quad[:, 1]))))
274
+ pad = [
275
+ max(-pad[0] + border, 1),
276
+ max(-pad[1] + border, 1),
277
+ max(pad[2] - self.input_img.shape[0] + border, 1),
278
+ max(pad[3] - self.input_img.shape[1] + border, 1)
279
+ ]
280
+
281
+ if max(pad) > 1:
282
+ # pad image
283
+ pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
284
+ # modify landmark coords
285
+ landmarks[:, 0] += pad[0]
286
+ landmarks[:, 1] += pad[1]
287
+ # blur pad images
288
+ h, w, _ = pad_img.shape
289
+ y, x, _ = np.ogrid[:h, :w, :1]
290
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
291
+ np.float32(w - 1 - x) / pad[2]),
292
+ 1.0 - np.minimum(np.float32(y) / pad[1],
293
+ np.float32(h - 1 - y) / pad[3]))
294
+ blur = int(qsize * blur_ratio)
295
+ if blur % 2 == 0:
296
+ blur += 1
297
+ blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
298
+ # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
299
+
300
+ pad_img = pad_img.astype('float32')
301
+ pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
302
+ pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
303
+ pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
304
+ self.pad_input_imgs.append(pad_img)
305
+ else:
306
+ self.pad_input_imgs.append(np.copy(self.input_img))
307
+
308
+ return len(self.all_landmarks_5)
309
+
310
+ def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
311
+ """Align and warp faces with face template.
312
+ """
313
+ if self.pad_blur:
314
+ assert len(self.pad_input_imgs) == len(
315
+ self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
316
+ for idx, landmark in enumerate(self.all_landmarks_5):
317
+ # use 5 landmarks to get affine matrix
318
+ # use cv2.LMEDS method for the equivalence to skimage transform
319
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
320
+ affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
321
+ self.affine_matrices.append(affine_matrix)
322
+ # warp and crop faces
323
+ if border_mode == 'constant':
324
+ border_mode = cv2.BORDER_CONSTANT
325
+ elif border_mode == 'reflect101':
326
+ border_mode = cv2.BORDER_REFLECT101
327
+ elif border_mode == 'reflect':
328
+ border_mode = cv2.BORDER_REFLECT
329
+ if self.pad_blur:
330
+ input_img = self.pad_input_imgs[idx]
331
+ else:
332
+ input_img = self.input_img
333
+ cropped_face = cv2.warpAffine(
334
+ input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
335
+ self.cropped_faces.append(cropped_face)
336
+ # save the cropped face
337
+ if save_cropped_path is not None:
338
+ path = os.path.splitext(save_cropped_path)[0]
339
+ save_path = f'{path}_{idx:02d}.{self.save_ext}'
340
+ imwrite(cropped_face, save_path)
341
+
342
+ def get_inverse_affine(self, save_inverse_affine_path=None):
343
+ """Get inverse affine matrix."""
344
+ for idx, affine_matrix in enumerate(self.affine_matrices):
345
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
346
+ inverse_affine *= self.upscale_factor
347
+ self.inverse_affine_matrices.append(inverse_affine)
348
+ # save inverse affine matrices
349
+ if save_inverse_affine_path is not None:
350
+ path, _ = os.path.splitext(save_inverse_affine_path)
351
+ save_path = f'{path}_{idx:02d}.pth'
352
+ torch.save(inverse_affine, save_path)
353
+
354
+ def add_restored_face(self, restored_face, input_face=None):
355
+ # if self.is_gray:
356
+ # restored_face = bgr2gray(restored_face) # convert img into grayscale
357
+ # if input_face is not None:
358
+ # restored_face = adain_npy(restored_face, input_face) # transfer the color
359
+ self.restored_faces.append(restored_face)
360
+
361
+ def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
362
+ h, w, _ = self.input_img.shape
363
+ h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
364
+
365
+ if upsample_img is None:
366
+ # simply resize the background
367
+ # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
368
+ upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
369
+ else:
370
+ upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
371
+
372
+ assert len(self.restored_faces) == len(
373
+ self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
374
+
375
+ inv_mask_borders = []
376
+ for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
377
+ if face_upsampler is not None:
378
+ restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
379
+ inverse_affine /= self.upscale_factor
380
+ inverse_affine[:, 2] *= self.upscale_factor
381
+ face_size = (self.face_size[0] * self.upscale_factor, self.face_size[1] * self.upscale_factor)
382
+ else:
383
+ # Add an offset to inverse affine matrix, for more precise back alignment
384
+ if self.upscale_factor > 1:
385
+ extra_offset = 0.5 * self.upscale_factor
386
+ else:
387
+ extra_offset = 0
388
+ inverse_affine[:, 2] += extra_offset
389
+ face_size = self.face_size
390
+ inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
391
+
392
+ # if draw_box or not self.use_parse: # use square parse maps
393
+ # mask = np.ones(face_size, dtype=np.float32)
394
+ # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
395
+ # # remove the black borders
396
+ # inv_mask_erosion = cv2.erode(
397
+ # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
398
+ # pasted_face = inv_mask_erosion[:, :, None] * inv_restored
399
+ # total_face_area = np.sum(inv_mask_erosion) # // 3
400
+ # # add border
401
+ # if draw_box:
402
+ # h, w = face_size
403
+ # mask_border = np.ones((h, w, 3), dtype=np.float32)
404
+ # border = int(1400/np.sqrt(total_face_area))
405
+ # mask_border[border:h-border, border:w-border,:] = 0
406
+ # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
407
+ # inv_mask_borders.append(inv_mask_border)
408
+ # if not self.use_parse:
409
+ # # compute the fusion edge based on the area of face
410
+ # w_edge = int(total_face_area**0.5) // 20
411
+ # erosion_radius = w_edge * 2
412
+ # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
413
+ # blur_size = w_edge * 2
414
+ # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
415
+ # if len(upsample_img.shape) == 2: # upsample_img is gray image
416
+ # upsample_img = upsample_img[:, :, None]
417
+ # inv_soft_mask = inv_soft_mask[:, :, None]
418
+
419
+ # always use square mask
420
+ mask = np.ones(face_size, dtype=np.float32)
421
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
422
+ # remove the black borders
423
+ inv_mask_erosion = cv2.erode(
424
+ inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
425
+ pasted_face = inv_mask_erosion[:, :, None] * inv_restored
426
+ total_face_area = np.sum(inv_mask_erosion) # // 3
427
+ # add border
428
+ if draw_box:
429
+ h, w = face_size
430
+ mask_border = np.ones((h, w, 3), dtype=np.float32)
431
+ border = int(1400 / np.sqrt(total_face_area))
432
+ mask_border[border:h - border, border:w - border, :] = 0
433
+ inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
434
+ inv_mask_borders.append(inv_mask_border)
435
+ # compute the fusion edge based on the area of face
436
+ w_edge = int(total_face_area ** 0.5) // 20
437
+ erosion_radius = w_edge * 2
438
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
439
+ blur_size = w_edge * 2
440
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
441
+ if len(upsample_img.shape) == 2: # upsample_img is gray image
442
+ upsample_img = upsample_img[:, :, None]
443
+ inv_soft_mask = inv_soft_mask[:, :, None]
444
+
445
+ # parse mask
446
+ if self.use_parse:
447
+ # inference
448
+ face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
449
+ face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
450
+ normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
451
+ face_input = torch.unsqueeze(face_input, 0).to(self.device)
452
+ with torch.no_grad():
453
+ out = self.face_parse(face_input)[0]
454
+ out = out.argmax(dim=1).squeeze().cpu().numpy()
455
+
456
+ parse_mask = np.zeros(out.shape)
457
+ MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
458
+ for idx, color in enumerate(MASK_COLORMAP):
459
+ parse_mask[out == idx] = color
460
+ # blur the mask
461
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
462
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
463
+ # remove the black borders
464
+ thres = 10
465
+ parse_mask[:thres, :] = 0
466
+ parse_mask[-thres:, :] = 0
467
+ parse_mask[:, :thres] = 0
468
+ parse_mask[:, -thres:] = 0
469
+ parse_mask = parse_mask / 255.
470
+
471
+ parse_mask = cv2.resize(parse_mask, face_size)
472
+ parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
473
+ inv_soft_parse_mask = parse_mask[:, :, None]
474
+ # pasted_face = inv_restored
475
+ fuse_mask = (inv_soft_parse_mask < inv_soft_mask).astype('int')
476
+ inv_soft_mask = inv_soft_parse_mask * fuse_mask + inv_soft_mask * (1 - fuse_mask)
477
+
478
+ if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
479
+ alpha = upsample_img[:, :, 3:]
480
+ upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
481
+ upsample_img = np.concatenate((upsample_img, alpha), axis=2)
482
+ else:
483
+ upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
484
+
485
+ if np.max(upsample_img) > 256: # 16-bit image
486
+ upsample_img = upsample_img.astype(np.uint16)
487
+ else:
488
+ upsample_img = upsample_img.astype(np.uint8)
489
+
490
+ # draw bounding box
491
+ if draw_box:
492
+ # upsample_input_img = cv2.resize(input_img, (w_up, h_up))
493
+ img_color = np.ones([*upsample_img.shape], dtype=np.float32)
494
+ img_color[:, :, 0] = 0
495
+ img_color[:, :, 1] = 255
496
+ img_color[:, :, 2] = 0
497
+ for inv_mask_border in inv_mask_borders:
498
+ upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
499
+ # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
500
+
501
+ if save_path is not None:
502
+ path = os.path.splitext(save_path)[0]
503
+ save_path = f'{path}.{self.save_ext}'
504
+ imwrite(upsample_img, save_path)
505
+ return upsample_img
506
+
507
+ def clean_all(self):
508
+ self.all_landmarks_5 = []
509
+ self.restored_faces = []
510
+ self.affine_matrices = []
511
+ self.cropped_faces = []
512
+ self.inverse_affine_matrices = []
513
+ self.det_faces = []
514
+ self.pad_input_imgs = []
SUPIR/utils/file.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Tuple
3
+
4
+ from urllib.parse import urlparse
5
+ from torch.hub import download_url_to_file, get_dir
6
+
7
+
8
+ def load_file_list(file_list_path: str) -> List[str]:
9
+ files = []
10
+ # each line in file list contains a path of an image
11
+ with open(file_list_path, "r") as fin:
12
+ for line in fin:
13
+ path = line.strip()
14
+ if path:
15
+ files.append(path)
16
+ return files
17
+
18
+
19
+ def list_image_files(
20
+ img_dir: str,
21
+ exts: Tuple[str]=(".jpg", ".png", ".jpeg"),
22
+ follow_links: bool=False,
23
+ log_progress: bool=False,
24
+ log_every_n_files: int=10000,
25
+ max_size: int=-1
26
+ ) -> List[str]:
27
+ files = []
28
+ for dir_path, _, file_names in os.walk(img_dir, followlinks=follow_links):
29
+ early_stop = False
30
+ for file_name in file_names:
31
+ if os.path.splitext(file_name)[1].lower() in exts:
32
+ if max_size >= 0 and len(files) >= max_size:
33
+ early_stop = True
34
+ break
35
+ files.append(os.path.join(dir_path, file_name))
36
+ if log_progress and len(files) % log_every_n_files == 0:
37
+ print(f"find {len(files)} images in {img_dir}")
38
+ if early_stop:
39
+ break
40
+ return files
41
+
42
+
43
+ def get_file_name_parts(file_path: str) -> Tuple[str, str, str]:
44
+ parent_path, file_name = os.path.split(file_path)
45
+ stem, ext = os.path.splitext(file_name)
46
+ return parent_path, stem, ext
47
+
48
+
49
+ # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/
50
+ def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
51
+ """Load file form http url, will download models if necessary.
52
+
53
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
54
+
55
+ Args:
56
+ url (str): URL to be downloaded.
57
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
58
+ Default: None.
59
+ progress (bool): Whether to show the download progress. Default: True.
60
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
61
+
62
+ Returns:
63
+ str: The path to the downloaded file.
64
+ """
65
+ if model_dir is None: # use the pytorch hub_dir
66
+ hub_dir = get_dir()
67
+ model_dir = os.path.join(hub_dir, 'checkpoints')
68
+
69
+ os.makedirs(model_dir, exist_ok=True)
70
+
71
+ parts = urlparse(url)
72
+ filename = os.path.basename(parts.path)
73
+ if file_name is not None:
74
+ filename = file_name
75
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
76
+ if not os.path.exists(cached_file):
77
+ print(f'Downloading: "{url}" to {cached_file}\n')
78
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
79
+ return cached_file
SUPIR/utils/tilevae.py ADDED
@@ -0,0 +1,971 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ #
3
+ # Ultimate VAE Tile Optimization
4
+ #
5
+ # Introducing a revolutionary new optimization designed to make
6
+ # the VAE work with giant images on limited VRAM!
7
+ # Say goodbye to the frustration of OOM and hello to seamless output!
8
+ #
9
+ # ------------------------------------------------------------------------
10
+ #
11
+ # This script is a wild hack that splits the image into tiles,
12
+ # encodes each tile separately, and merges the result back together.
13
+ #
14
+ # Advantages:
15
+ # - The VAE can now work with giant images on limited VRAM
16
+ # (~10 GB for 8K images!)
17
+ # - The merged output is completely seamless without any post-processing.
18
+ #
19
+ # Drawbacks:
20
+ # - Giant RAM needed. To store the intermediate results for a 4096x4096
21
+ # images, you need 32 GB RAM it consumes ~20GB); for 8192x8192
22
+ # you need 128 GB RAM machine (it consumes ~100 GB)
23
+ # - NaNs always appear in for 8k images when you use fp16 (half) VAE
24
+ # You must use --no-half-vae to disable half VAE for that giant image.
25
+ # - Slow speed. With default tile size, it takes around 50/200 seconds
26
+ # to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode
27
+ # a 8192x8192 image. (The speed is limited by both the GPU and the CPU.)
28
+ # - The gradient calculation is not compatible with this hack. It
29
+ # will break any backward() or torch.autograd.grad() that passes VAE.
30
+ # (But you can still use the VAE to generate training data.)
31
+ #
32
+ # How it works:
33
+ # 1) The image is split into tiles.
34
+ # - To ensure perfect results, each tile is padded with 32 pixels
35
+ # on each side.
36
+ # - Then the conv2d/silu/upsample/downsample can produce identical
37
+ # results to the original image without splitting.
38
+ # 2) The original forward is decomposed into a task queue and a task worker.
39
+ # - The task queue is a list of functions that will be executed in order.
40
+ # - The task worker is a loop that executes the tasks in the queue.
41
+ # 3) The task queue is executed for each tile.
42
+ # - Current tile is sent to GPU.
43
+ # - local operations are directly executed.
44
+ # - Group norm calculation is temporarily suspended until the mean
45
+ # and var of all tiles are calculated.
46
+ # - The residual is pre-calculated and stored and addded back later.
47
+ # - When need to go to the next tile, the current tile is send to cpu.
48
+ # 4) After all tiles are processed, tiles are merged on cpu and return.
49
+ #
50
+ # Enjoy!
51
+ #
52
+ # @author: LI YI @ Nanyang Technological University - Singapore
53
+ # @date: 2023-03-02
54
+ # @license: MIT License
55
+ #
56
+ # Please give me a star if you like this project!
57
+ #
58
+ # -------------------------------------------------------------------------
59
+
60
+ import gc
61
+ from time import time
62
+ import math
63
+ from tqdm import tqdm
64
+
65
+ import torch
66
+ import torch.version
67
+ import torch.nn.functional as F
68
+ from einops import rearrange
69
+ from diffusers.utils.import_utils import is_xformers_available
70
+
71
+ import SUPIR.utils.devices as devices
72
+
73
+ try:
74
+ import xformers
75
+ import xformers.ops
76
+ except ImportError:
77
+ pass
78
+
79
+ sd_flag = True
80
+
81
+ def get_recommend_encoder_tile_size():
82
+ if torch.cuda.is_available():
83
+ total_memory = torch.cuda.get_device_properties(
84
+ devices.device).total_memory // 2**20
85
+ if total_memory > 16*1000:
86
+ ENCODER_TILE_SIZE = 3072
87
+ elif total_memory > 12*1000:
88
+ ENCODER_TILE_SIZE = 2048
89
+ elif total_memory > 8*1000:
90
+ ENCODER_TILE_SIZE = 1536
91
+ else:
92
+ ENCODER_TILE_SIZE = 960
93
+ else:
94
+ ENCODER_TILE_SIZE = 512
95
+ return ENCODER_TILE_SIZE
96
+
97
+
98
+ def get_recommend_decoder_tile_size():
99
+ if torch.cuda.is_available():
100
+ total_memory = torch.cuda.get_device_properties(
101
+ devices.device).total_memory // 2**20
102
+ if total_memory > 30*1000:
103
+ DECODER_TILE_SIZE = 256
104
+ elif total_memory > 16*1000:
105
+ DECODER_TILE_SIZE = 192
106
+ elif total_memory > 12*1000:
107
+ DECODER_TILE_SIZE = 128
108
+ elif total_memory > 8*1000:
109
+ DECODER_TILE_SIZE = 96
110
+ else:
111
+ DECODER_TILE_SIZE = 64
112
+ else:
113
+ DECODER_TILE_SIZE = 64
114
+ return DECODER_TILE_SIZE
115
+
116
+
117
+ if 'global const':
118
+ DEFAULT_ENABLED = False
119
+ DEFAULT_MOVE_TO_GPU = False
120
+ DEFAULT_FAST_ENCODER = True
121
+ DEFAULT_FAST_DECODER = True
122
+ DEFAULT_COLOR_FIX = 0
123
+ DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size()
124
+ DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size()
125
+
126
+
127
+ # inplace version of silu
128
+ def inplace_nonlinearity(x):
129
+ # Test: fix for Nans
130
+ return F.silu(x, inplace=True)
131
+
132
+ # extracted from ldm.modules.diffusionmodules.model
133
+
134
+ # from diffusers lib
135
+ def attn_forward_new(self, h_):
136
+ batch_size, channel, height, width = h_.shape
137
+ hidden_states = h_.view(batch_size, channel, height * width).transpose(1, 2)
138
+
139
+ attention_mask = None
140
+ encoder_hidden_states = None
141
+ batch_size, sequence_length, _ = hidden_states.shape
142
+ attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
143
+
144
+ query = self.to_q(hidden_states)
145
+
146
+ if encoder_hidden_states is None:
147
+ encoder_hidden_states = hidden_states
148
+ elif self.norm_cross:
149
+ encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
150
+
151
+ key = self.to_k(encoder_hidden_states)
152
+ value = self.to_v(encoder_hidden_states)
153
+
154
+ query = self.head_to_batch_dim(query)
155
+ key = self.head_to_batch_dim(key)
156
+ value = self.head_to_batch_dim(value)
157
+
158
+ attention_probs = self.get_attention_scores(query, key, attention_mask)
159
+ hidden_states = torch.bmm(attention_probs, value)
160
+ hidden_states = self.batch_to_head_dim(hidden_states)
161
+
162
+ # linear proj
163
+ hidden_states = self.to_out[0](hidden_states)
164
+ # dropout
165
+ hidden_states = self.to_out[1](hidden_states)
166
+
167
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
168
+
169
+ return hidden_states
170
+
171
+ def attn_forward_new_pt2_0(self, hidden_states,):
172
+ scale = 1
173
+ attention_mask = None
174
+ encoder_hidden_states = None
175
+
176
+ input_ndim = hidden_states.ndim
177
+
178
+ if input_ndim == 4:
179
+ batch_size, channel, height, width = hidden_states.shape
180
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
181
+
182
+ batch_size, sequence_length, _ = (
183
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
184
+ )
185
+
186
+ if attention_mask is not None:
187
+ attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
188
+ # scaled_dot_product_attention expects attention_mask shape to be
189
+ # (batch, heads, source_length, target_length)
190
+ attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1])
191
+
192
+ if self.group_norm is not None:
193
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
194
+
195
+ query = self.to_q(hidden_states, scale=scale)
196
+
197
+ if encoder_hidden_states is None:
198
+ encoder_hidden_states = hidden_states
199
+ elif self.norm_cross:
200
+ encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
201
+
202
+ key = self.to_k(encoder_hidden_states, scale=scale)
203
+ value = self.to_v(encoder_hidden_states, scale=scale)
204
+
205
+ inner_dim = key.shape[-1]
206
+ head_dim = inner_dim // self.heads
207
+
208
+ query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
209
+
210
+ key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
211
+ value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
212
+
213
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
214
+ # TODO: add support for attn.scale when we move to Torch 2.1
215
+ hidden_states = F.scaled_dot_product_attention(
216
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
217
+ )
218
+
219
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
220
+ hidden_states = hidden_states.to(query.dtype)
221
+
222
+ # linear proj
223
+ hidden_states = self.to_out[0](hidden_states, scale=scale)
224
+ # dropout
225
+ hidden_states = self.to_out[1](hidden_states)
226
+
227
+ if input_ndim == 4:
228
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
229
+
230
+ return hidden_states
231
+
232
+ def attn_forward_new_xformers(self, hidden_states):
233
+ scale = 1
234
+ attention_op = None
235
+ attention_mask = None
236
+ encoder_hidden_states = None
237
+
238
+ input_ndim = hidden_states.ndim
239
+
240
+ if input_ndim == 4:
241
+ batch_size, channel, height, width = hidden_states.shape
242
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
243
+
244
+ batch_size, key_tokens, _ = (
245
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
246
+ )
247
+
248
+ attention_mask = self.prepare_attention_mask(attention_mask, key_tokens, batch_size)
249
+ if attention_mask is not None:
250
+ # expand our mask's singleton query_tokens dimension:
251
+ # [batch*heads, 1, key_tokens] ->
252
+ # [batch*heads, query_tokens, key_tokens]
253
+ # so that it can be added as a bias onto the attention scores that xformers computes:
254
+ # [batch*heads, query_tokens, key_tokens]
255
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
256
+ _, query_tokens, _ = hidden_states.shape
257
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
258
+
259
+ if self.group_norm is not None:
260
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
261
+
262
+ query = self.to_q(hidden_states, scale=scale)
263
+
264
+ if encoder_hidden_states is None:
265
+ encoder_hidden_states = hidden_states
266
+ elif self.norm_cross:
267
+ encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
268
+
269
+ key = self.to_k(encoder_hidden_states, scale=scale)
270
+ value = self.to_v(encoder_hidden_states, scale=scale)
271
+
272
+ query = self.head_to_batch_dim(query).contiguous()
273
+ key = self.head_to_batch_dim(key).contiguous()
274
+ value = self.head_to_batch_dim(value).contiguous()
275
+
276
+ hidden_states = xformers.ops.memory_efficient_attention(
277
+ query, key, value, attn_bias=attention_mask, op=attention_op#, scale=scale
278
+ )
279
+ hidden_states = hidden_states.to(query.dtype)
280
+ hidden_states = self.batch_to_head_dim(hidden_states)
281
+
282
+ # linear proj
283
+ hidden_states = self.to_out[0](hidden_states, scale=scale)
284
+ # dropout
285
+ hidden_states = self.to_out[1](hidden_states)
286
+
287
+ if input_ndim == 4:
288
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
289
+
290
+ return hidden_states
291
+
292
+ def attn_forward(self, h_):
293
+ q = self.q(h_)
294
+ k = self.k(h_)
295
+ v = self.v(h_)
296
+
297
+ # compute attention
298
+ b, c, h, w = q.shape
299
+ q = q.reshape(b, c, h*w)
300
+ q = q.permute(0, 2, 1) # b,hw,c
301
+ k = k.reshape(b, c, h*w) # b,c,hw
302
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
303
+ w_ = w_ * (int(c)**(-0.5))
304
+ w_ = torch.nn.functional.softmax(w_, dim=2)
305
+
306
+ # attend to values
307
+ v = v.reshape(b, c, h*w)
308
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
309
+ # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
310
+ h_ = torch.bmm(v, w_)
311
+ h_ = h_.reshape(b, c, h, w)
312
+
313
+ h_ = self.proj_out(h_)
314
+
315
+ return h_
316
+
317
+
318
+ def xformer_attn_forward(self, h_):
319
+ q = self.q(h_)
320
+ k = self.k(h_)
321
+ v = self.v(h_)
322
+
323
+ # compute attention
324
+ B, C, H, W = q.shape
325
+ q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
326
+
327
+ q, k, v = map(
328
+ lambda t: t.unsqueeze(3)
329
+ .reshape(B, t.shape[1], 1, C)
330
+ .permute(0, 2, 1, 3)
331
+ .reshape(B * 1, t.shape[1], C)
332
+ .contiguous(),
333
+ (q, k, v),
334
+ )
335
+ out = xformers.ops.memory_efficient_attention(
336
+ q, k, v, attn_bias=None, op=self.attention_op)
337
+
338
+ out = (
339
+ out.unsqueeze(0)
340
+ .reshape(B, 1, out.shape[1], C)
341
+ .permute(0, 2, 1, 3)
342
+ .reshape(B, out.shape[1], C)
343
+ )
344
+ out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
345
+ out = self.proj_out(out)
346
+ return out
347
+
348
+
349
+ def attn2task(task_queue, net):
350
+ if False: #isinstance(net, AttnBlock):
351
+ task_queue.append(('store_res', lambda x: x))
352
+ task_queue.append(('pre_norm', net.norm))
353
+ task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
354
+ task_queue.append(['add_res', None])
355
+ elif False: #isinstance(net, MemoryEfficientAttnBlock):
356
+ task_queue.append(('store_res', lambda x: x))
357
+ task_queue.append(('pre_norm', net.norm))
358
+ task_queue.append(
359
+ ('attn', lambda x, net=net: xformer_attn_forward(net, x)))
360
+ task_queue.append(['add_res', None])
361
+ else:
362
+ task_queue.append(('store_res', lambda x: x))
363
+ task_queue.append(('pre_norm', net.norm))
364
+ if is_xformers_available:
365
+ # task_queue.append(('attn', lambda x, net=net: attn_forward_new_xformers(net, x)))
366
+ task_queue.append(
367
+ ('attn', lambda x, net=net: xformer_attn_forward(net, x)))
368
+ elif hasattr(F, "scaled_dot_product_attention"):
369
+ task_queue.append(('attn', lambda x, net=net: attn_forward_new_pt2_0(net, x)))
370
+ else:
371
+ task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x)))
372
+ task_queue.append(['add_res', None])
373
+
374
+ def resblock2task(queue, block):
375
+ """
376
+ Turn a ResNetBlock into a sequence of tasks and append to the task queue
377
+
378
+ @param queue: the target task queue
379
+ @param block: ResNetBlock
380
+
381
+ """
382
+ if block.in_channels != block.out_channels:
383
+ if sd_flag:
384
+ if block.use_conv_shortcut:
385
+ queue.append(('store_res', block.conv_shortcut))
386
+ else:
387
+ queue.append(('store_res', block.nin_shortcut))
388
+ else:
389
+ if block.use_in_shortcut:
390
+ queue.append(('store_res', block.conv_shortcut))
391
+ else:
392
+ queue.append(('store_res', block.nin_shortcut))
393
+
394
+ else:
395
+ queue.append(('store_res', lambda x: x))
396
+ queue.append(('pre_norm', block.norm1))
397
+ queue.append(('silu', inplace_nonlinearity))
398
+ queue.append(('conv1', block.conv1))
399
+ queue.append(('pre_norm', block.norm2))
400
+ queue.append(('silu', inplace_nonlinearity))
401
+ queue.append(('conv2', block.conv2))
402
+ queue.append(['add_res', None])
403
+
404
+
405
+ def build_sampling(task_queue, net, is_decoder):
406
+ """
407
+ Build the sampling part of a task queue
408
+ @param task_queue: the target task queue
409
+ @param net: the network
410
+ @param is_decoder: currently building decoder or encoder
411
+ """
412
+ if is_decoder:
413
+ if sd_flag:
414
+ resblock2task(task_queue, net.mid.block_1)
415
+ attn2task(task_queue, net.mid.attn_1)
416
+ print(task_queue)
417
+ resblock2task(task_queue, net.mid.block_2)
418
+ resolution_iter = reversed(range(net.num_resolutions))
419
+ block_ids = net.num_res_blocks + 1
420
+ condition = 0
421
+ module = net.up
422
+ func_name = 'upsample'
423
+ else:
424
+ resblock2task(task_queue, net.mid_block.resnets[0])
425
+ attn2task(task_queue, net.mid_block.attentions[0])
426
+ resblock2task(task_queue, net.mid_block.resnets[1])
427
+ resolution_iter = (range(len(net.up_blocks))) # net.num_resolutions = 3
428
+ block_ids = 2 + 1
429
+ condition = len(net.up_blocks) - 1
430
+ module = net.up_blocks
431
+ func_name = 'upsamplers'
432
+ else:
433
+ if sd_flag:
434
+ resolution_iter = range(net.num_resolutions)
435
+ block_ids = net.num_res_blocks
436
+ condition = net.num_resolutions - 1
437
+ module = net.down
438
+ func_name = 'downsample'
439
+ else:
440
+ resolution_iter = range(len(net.down_blocks))
441
+ block_ids = 2
442
+ condition = len(net.down_blocks) - 1
443
+ module = net.down_blocks
444
+ func_name = 'downsamplers'
445
+
446
+ for i_level in resolution_iter:
447
+ for i_block in range(block_ids):
448
+ if sd_flag:
449
+ resblock2task(task_queue, module[i_level].block[i_block])
450
+ else:
451
+ resblock2task(task_queue, module[i_level].resnets[i_block])
452
+ if i_level != condition:
453
+ if sd_flag:
454
+ task_queue.append((func_name, getattr(module[i_level], func_name)))
455
+ else:
456
+ if is_decoder:
457
+ task_queue.append((func_name, module[i_level].upsamplers[0]))
458
+ else:
459
+ task_queue.append((func_name, module[i_level].downsamplers[0]))
460
+
461
+ if not is_decoder:
462
+ if sd_flag:
463
+ resblock2task(task_queue, net.mid.block_1)
464
+ attn2task(task_queue, net.mid.attn_1)
465
+ resblock2task(task_queue, net.mid.block_2)
466
+ else:
467
+ resblock2task(task_queue, net.mid_block.resnets[0])
468
+ attn2task(task_queue, net.mid_block.attentions[0])
469
+ resblock2task(task_queue, net.mid_block.resnets[1])
470
+
471
+
472
+ def build_task_queue(net, is_decoder):
473
+ """
474
+ Build a single task queue for the encoder or decoder
475
+ @param net: the VAE decoder or encoder network
476
+ @param is_decoder: currently building decoder or encoder
477
+ @return: the task queue
478
+ """
479
+ task_queue = []
480
+ task_queue.append(('conv_in', net.conv_in))
481
+
482
+ # construct the sampling part of the task queue
483
+ # because encoder and decoder share the same architecture, we extract the sampling part
484
+ build_sampling(task_queue, net, is_decoder)
485
+ if is_decoder and not sd_flag:
486
+ net.give_pre_end = False
487
+ net.tanh_out = False
488
+
489
+ if not is_decoder or not net.give_pre_end:
490
+ if sd_flag:
491
+ task_queue.append(('pre_norm', net.norm_out))
492
+ else:
493
+ task_queue.append(('pre_norm', net.conv_norm_out))
494
+ task_queue.append(('silu', inplace_nonlinearity))
495
+ task_queue.append(('conv_out', net.conv_out))
496
+ if is_decoder and net.tanh_out:
497
+ task_queue.append(('tanh', torch.tanh))
498
+
499
+ return task_queue
500
+
501
+
502
+ def clone_task_queue(task_queue):
503
+ """
504
+ Clone a task queue
505
+ @param task_queue: the task queue to be cloned
506
+ @return: the cloned task queue
507
+ """
508
+ return [[item for item in task] for task in task_queue]
509
+
510
+
511
+ def get_var_mean(input, num_groups, eps=1e-6):
512
+ """
513
+ Get mean and var for group norm
514
+ """
515
+ b, c = input.size(0), input.size(1)
516
+ channel_in_group = int(c/num_groups)
517
+ input_reshaped = input.contiguous().view(
518
+ 1, int(b * num_groups), channel_in_group, *input.size()[2:])
519
+ var, mean = torch.var_mean(
520
+ input_reshaped, dim=[0, 2, 3, 4], unbiased=False)
521
+ return var, mean
522
+
523
+
524
+ def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6):
525
+ """
526
+ Custom group norm with fixed mean and var
527
+
528
+ @param input: input tensor
529
+ @param num_groups: number of groups. by default, num_groups = 32
530
+ @param mean: mean, must be pre-calculated by get_var_mean
531
+ @param var: var, must be pre-calculated by get_var_mean
532
+ @param weight: weight, should be fetched from the original group norm
533
+ @param bias: bias, should be fetched from the original group norm
534
+ @param eps: epsilon, by default, eps = 1e-6 to match the original group norm
535
+
536
+ @return: normalized tensor
537
+ """
538
+ b, c = input.size(0), input.size(1)
539
+ channel_in_group = int(c/num_groups)
540
+ input_reshaped = input.contiguous().view(
541
+ 1, int(b * num_groups), channel_in_group, *input.size()[2:])
542
+
543
+ out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None,
544
+ training=False, momentum=0, eps=eps)
545
+
546
+ out = out.view(b, c, *input.size()[2:])
547
+
548
+ # post affine transform
549
+ if weight is not None:
550
+ out *= weight.view(1, -1, 1, 1)
551
+ if bias is not None:
552
+ out += bias.view(1, -1, 1, 1)
553
+ return out
554
+
555
+
556
+ def crop_valid_region(x, input_bbox, target_bbox, is_decoder):
557
+ """
558
+ Crop the valid region from the tile
559
+ @param x: input tile
560
+ @param input_bbox: original input bounding box
561
+ @param target_bbox: output bounding box
562
+ @param scale: scale factor
563
+ @return: cropped tile
564
+ """
565
+ padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox]
566
+ margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]
567
+ return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]]
568
+
569
+ # ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓
570
+
571
+
572
+ def perfcount(fn):
573
+ def wrapper(*args, **kwargs):
574
+ ts = time()
575
+
576
+ if torch.cuda.is_available():
577
+ torch.cuda.reset_peak_memory_stats(devices.device)
578
+ devices.torch_gc()
579
+ gc.collect()
580
+
581
+ ret = fn(*args, **kwargs)
582
+
583
+ devices.torch_gc()
584
+ gc.collect()
585
+ if torch.cuda.is_available():
586
+ vram = torch.cuda.max_memory_allocated(devices.device) / 2**20
587
+ torch.cuda.reset_peak_memory_stats(devices.device)
588
+ print(
589
+ f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB')
590
+ else:
591
+ print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')
592
+
593
+ return ret
594
+ return wrapper
595
+
596
+ # copy end :)
597
+
598
+
599
+ class GroupNormParam:
600
+ def __init__(self):
601
+ self.var_list = []
602
+ self.mean_list = []
603
+ self.pixel_list = []
604
+ self.weight = None
605
+ self.bias = None
606
+
607
+ def add_tile(self, tile, layer):
608
+ var, mean = get_var_mean(tile, 32)
609
+ # For giant images, the variance can be larger than max float16
610
+ # In this case we create a copy to float32
611
+ if var.dtype == torch.float16 and var.isinf().any():
612
+ fp32_tile = tile.float()
613
+ var, mean = get_var_mean(fp32_tile, 32)
614
+ # ============= DEBUG: test for infinite =============
615
+ # if torch.isinf(var).any():
616
+ # print('var: ', var)
617
+ # ====================================================
618
+ self.var_list.append(var)
619
+ self.mean_list.append(mean)
620
+ self.pixel_list.append(
621
+ tile.shape[2]*tile.shape[3])
622
+ if hasattr(layer, 'weight'):
623
+ self.weight = layer.weight
624
+ self.bias = layer.bias
625
+ else:
626
+ self.weight = None
627
+ self.bias = None
628
+
629
+ def summary(self):
630
+ """
631
+ summarize the mean and var and return a function
632
+ that apply group norm on each tile
633
+ """
634
+ if len(self.var_list) == 0:
635
+ return None
636
+ var = torch.vstack(self.var_list)
637
+ mean = torch.vstack(self.mean_list)
638
+ max_value = max(self.pixel_list)
639
+ pixels = torch.tensor(
640
+ self.pixel_list, dtype=torch.float32, device=devices.device) / max_value
641
+ sum_pixels = torch.sum(pixels)
642
+ pixels = pixels.unsqueeze(
643
+ 1) / sum_pixels
644
+ var = torch.sum(
645
+ var * pixels, dim=0)
646
+ mean = torch.sum(
647
+ mean * pixels, dim=0)
648
+ return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias)
649
+
650
+ @staticmethod
651
+ def from_tile(tile, norm):
652
+ """
653
+ create a function from a single tile without summary
654
+ """
655
+ var, mean = get_var_mean(tile, 32)
656
+ if var.dtype == torch.float16 and var.isinf().any():
657
+ fp32_tile = tile.float()
658
+ var, mean = get_var_mean(fp32_tile, 32)
659
+ # if it is a macbook, we need to convert back to float16
660
+ if var.device.type == 'mps':
661
+ # clamp to avoid overflow
662
+ var = torch.clamp(var, 0, 60000)
663
+ var = var.half()
664
+ mean = mean.half()
665
+ if hasattr(norm, 'weight'):
666
+ weight = norm.weight
667
+ bias = norm.bias
668
+ else:
669
+ weight = None
670
+ bias = None
671
+
672
+ def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):
673
+ return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)
674
+ return group_norm_func
675
+
676
+
677
+ class VAEHook:
678
+ def __init__(self, net, tile_size, is_decoder, fast_decoder, fast_encoder, color_fix, to_gpu=False):
679
+ self.net = net # encoder | decoder
680
+ self.tile_size = tile_size
681
+ self.is_decoder = is_decoder
682
+ self.fast_mode = (fast_encoder and not is_decoder) or (
683
+ fast_decoder and is_decoder)
684
+ self.color_fix = color_fix and not is_decoder
685
+ self.to_gpu = to_gpu
686
+ self.pad = 11 if is_decoder else 32
687
+
688
+ def __call__(self, x):
689
+ B, C, H, W = x.shape
690
+ original_device = next(self.net.parameters()).device
691
+ try:
692
+ if self.to_gpu:
693
+ self.net.to(devices.get_optimal_device())
694
+ if max(H, W) <= self.pad * 2 + self.tile_size:
695
+ print("[Tiled VAE]: the input size is tiny and unnecessary to tile.")
696
+ return self.net.original_forward(x)
697
+ else:
698
+ return self.vae_tile_forward(x)
699
+ finally:
700
+ self.net.to(original_device)
701
+
702
+ def get_best_tile_size(self, lowerbound, upperbound):
703
+ """
704
+ Get the best tile size for GPU memory
705
+ """
706
+ divider = 32
707
+ while divider >= 2:
708
+ remainer = lowerbound % divider
709
+ if remainer == 0:
710
+ return lowerbound
711
+ candidate = lowerbound - remainer + divider
712
+ if candidate <= upperbound:
713
+ return candidate
714
+ divider //= 2
715
+ return lowerbound
716
+
717
+ def split_tiles(self, h, w):
718
+ """
719
+ Tool function to split the image into tiles
720
+ @param h: height of the image
721
+ @param w: width of the image
722
+ @return: tile_input_bboxes, tile_output_bboxes
723
+ """
724
+ tile_input_bboxes, tile_output_bboxes = [], []
725
+ tile_size = self.tile_size
726
+ pad = self.pad
727
+ num_height_tiles = math.ceil((h - 2 * pad) / tile_size)
728
+ num_width_tiles = math.ceil((w - 2 * pad) / tile_size)
729
+ # If any of the numbers are 0, we let it be 1
730
+ # This is to deal with long and thin images
731
+ num_height_tiles = max(num_height_tiles, 1)
732
+ num_width_tiles = max(num_width_tiles, 1)
733
+
734
+ # Suggestions from https://github.com/Kahsolt: auto shrink the tile size
735
+ real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)
736
+ real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)
737
+ real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)
738
+ real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)
739
+
740
+ print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' +
741
+ f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}')
742
+
743
+ for i in range(num_height_tiles):
744
+ for j in range(num_width_tiles):
745
+ # bbox: [x1, x2, y1, y2]
746
+ # the padding is is unnessary for image borders. So we directly start from (32, 32)
747
+ input_bbox = [
748
+ pad + j * real_tile_width,
749
+ min(pad + (j + 1) * real_tile_width, w),
750
+ pad + i * real_tile_height,
751
+ min(pad + (i + 1) * real_tile_height, h),
752
+ ]
753
+
754
+ # if the output bbox is close to the image boundary, we extend it to the image boundary
755
+ output_bbox = [
756
+ input_bbox[0] if input_bbox[0] > pad else 0,
757
+ input_bbox[1] if input_bbox[1] < w - pad else w,
758
+ input_bbox[2] if input_bbox[2] > pad else 0,
759
+ input_bbox[3] if input_bbox[3] < h - pad else h,
760
+ ]
761
+
762
+ # scale to get the final output bbox
763
+ output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox]
764
+ tile_output_bboxes.append(output_bbox)
765
+
766
+ # indistinguishable expand the input bbox by pad pixels
767
+ tile_input_bboxes.append([
768
+ max(0, input_bbox[0] - pad),
769
+ min(w, input_bbox[1] + pad),
770
+ max(0, input_bbox[2] - pad),
771
+ min(h, input_bbox[3] + pad),
772
+ ])
773
+
774
+ return tile_input_bboxes, tile_output_bboxes
775
+
776
+ @torch.no_grad()
777
+ def estimate_group_norm(self, z, task_queue, color_fix):
778
+ device = z.device
779
+ tile = z
780
+ last_id = len(task_queue) - 1
781
+ while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':
782
+ last_id -= 1
783
+ if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':
784
+ raise ValueError('No group norm found in the task queue')
785
+ # estimate until the last group norm
786
+ for i in range(last_id + 1):
787
+ task = task_queue[i]
788
+ if task[0] == 'pre_norm':
789
+ group_norm_func = GroupNormParam.from_tile(tile, task[1])
790
+ task_queue[i] = ('apply_norm', group_norm_func)
791
+ if i == last_id:
792
+ return True
793
+ tile = group_norm_func(tile)
794
+ elif task[0] == 'store_res':
795
+ task_id = i + 1
796
+ while task_id < last_id and task_queue[task_id][0] != 'add_res':
797
+ task_id += 1
798
+ if task_id >= last_id:
799
+ continue
800
+ task_queue[task_id][1] = task[1](tile)
801
+ elif task[0] == 'add_res':
802
+ tile += task[1].to(device)
803
+ task[1] = None
804
+ elif color_fix and task[0] == 'downsample':
805
+ for j in range(i, last_id + 1):
806
+ if task_queue[j][0] == 'store_res':
807
+ task_queue[j] = ('store_res_cpu', task_queue[j][1])
808
+ return True
809
+ else:
810
+ tile = task[1](tile)
811
+ try:
812
+ devices.test_for_nans(tile, "vae")
813
+ except:
814
+ print(f'Nan detected in fast mode estimation. Fast mode disabled.')
815
+ return False
816
+
817
+ raise IndexError('Should not reach here')
818
+
819
+ @perfcount
820
+ @torch.no_grad()
821
+ def vae_tile_forward(self, z):
822
+ """
823
+ Decode a latent vector z into an image in a tiled manner.
824
+ @param z: latent vector
825
+ @return: image
826
+ """
827
+ device = next(self.net.parameters()).device
828
+ dtype = z.dtype
829
+ net = self.net
830
+ tile_size = self.tile_size
831
+ is_decoder = self.is_decoder
832
+
833
+ z = z.detach() # detach the input to avoid backprop
834
+
835
+ N, height, width = z.shape[0], z.shape[2], z.shape[3]
836
+ net.last_z_shape = z.shape
837
+
838
+ # Split the input into tiles and build a task queue for each tile
839
+ print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}')
840
+
841
+ in_bboxes, out_bboxes = self.split_tiles(height, width)
842
+
843
+ # Prepare tiles by split the input latents
844
+ tiles = []
845
+ for input_bbox in in_bboxes:
846
+ tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu()
847
+ tiles.append(tile)
848
+
849
+ num_tiles = len(tiles)
850
+ num_completed = 0
851
+
852
+ # Build task queues
853
+ single_task_queue = build_task_queue(net, is_decoder)
854
+ #print(single_task_queue)
855
+ if self.fast_mode:
856
+ # Fast mode: downsample the input image to the tile size,
857
+ # then estimate the group norm parameters on the downsampled image
858
+ scale_factor = tile_size / max(height, width)
859
+ z = z.to(device)
860
+ downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact')
861
+ # use nearest-exact to keep statictics as close as possible
862
+ print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')
863
+
864
+ # ======= Special thanks to @Kahsolt for distribution shift issue ======= #
865
+ # The downsampling will heavily distort its mean and std, so we need to recover it.
866
+ std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)
867
+ std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True)
868
+ downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old
869
+ del std_old, mean_old, std_new, mean_new
870
+ # occasionally the std_new is too small or too large, which exceeds the range of float16
871
+ # so we need to clamp it to max z's range.
872
+ downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max())
873
+ estimate_task_queue = clone_task_queue(single_task_queue)
874
+ if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix):
875
+ single_task_queue = estimate_task_queue
876
+ del downsampled_z
877
+
878
+ task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)]
879
+
880
+ # Dummy result
881
+ result = None
882
+ result_approx = None
883
+ #try:
884
+ # with devices.autocast():
885
+ # result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu()
886
+ #except: pass
887
+ # Free memory of input latent tensor
888
+ del z
889
+
890
+ # Task queue execution
891
+ pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ")
892
+
893
+ # execute the task back and forth when switch tiles so that we always
894
+ # keep one tile on the GPU to reduce unnecessary data transfer
895
+ forward = True
896
+ interrupted = False
897
+ #state.interrupted = interrupted
898
+ while True:
899
+ #if state.interrupted: interrupted = True ; break
900
+
901
+ group_norm_param = GroupNormParam()
902
+ for i in range(num_tiles) if forward else reversed(range(num_tiles)):
903
+ #if state.interrupted: interrupted = True ; break
904
+
905
+ tile = tiles[i].to(device)
906
+ input_bbox = in_bboxes[i]
907
+ task_queue = task_queues[i]
908
+
909
+ interrupted = False
910
+ while len(task_queue) > 0:
911
+ #if state.interrupted: interrupted = True ; break
912
+
913
+ # DEBUG: current task
914
+ # print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape)
915
+ task = task_queue.pop(0)
916
+ if task[0] == 'pre_norm':
917
+ group_norm_param.add_tile(tile, task[1])
918
+ break
919
+ elif task[0] == 'store_res' or task[0] == 'store_res_cpu':
920
+ task_id = 0
921
+ res = task[1](tile)
922
+ if not self.fast_mode or task[0] == 'store_res_cpu':
923
+ res = res.cpu()
924
+ while task_queue[task_id][0] != 'add_res':
925
+ task_id += 1
926
+ task_queue[task_id][1] = res
927
+ elif task[0] == 'add_res':
928
+ tile += task[1].to(device)
929
+ task[1] = None
930
+ else:
931
+ tile = task[1](tile)
932
+ #print(tiles[i].shape, tile.shape, task)
933
+ pbar.update(1)
934
+
935
+ if interrupted: break
936
+
937
+ # check for NaNs in the tile.
938
+ # If there are NaNs, we abort the process to save user's time
939
+ #devices.test_for_nans(tile, "vae")
940
+
941
+ #print(tiles[i].shape, tile.shape, i, num_tiles)
942
+ if len(task_queue) == 0:
943
+ tiles[i] = None
944
+ num_completed += 1
945
+ if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically
946
+ result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False)
947
+ result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder)
948
+ del tile
949
+ elif i == num_tiles - 1 and forward:
950
+ forward = False
951
+ tiles[i] = tile
952
+ elif i == 0 and not forward:
953
+ forward = True
954
+ tiles[i] = tile
955
+ else:
956
+ tiles[i] = tile.cpu()
957
+ del tile
958
+
959
+ if interrupted: break
960
+ if num_completed == num_tiles: break
961
+
962
+ # insert the group norm task to the head of each task queue
963
+ group_norm_func = group_norm_param.summary()
964
+ if group_norm_func is not None:
965
+ for i in range(num_tiles):
966
+ task_queue = task_queues[i]
967
+ task_queue.insert(0, ('apply_norm', group_norm_func))
968
+
969
+ # Done!
970
+ pbar.close()
971
+ return result.to(dtype) if result is not None else result_approx.to(device)