mordecaaiiaart commited on
Commit
deea016
1 Parent(s): 8d0b84d

Upload tilevae.py

Browse files
Files changed (1) hide show
  1. tilevae.py +753 -0
tilevae.py ADDED
@@ -0,0 +1,753 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ # ------------------------------------------------------------------------
3
+ #
4
+ # Tiled VAE
5
+ #
6
+ # Introducing a revolutionary new optimization designed to make
7
+ # the VAE work with giant images on limited VRAM!
8
+ # Say goodbye to the frustration of OOM and hello to seamless output!
9
+ #
10
+ # ------------------------------------------------------------------------
11
+ #
12
+ # This script is a wild hack that splits the image into tiles,
13
+ # encodes each tile separately, and merges the result back together.
14
+ #
15
+ # Advantages:
16
+ # - The VAE can now work with giant images on limited VRAM
17
+ # (~10 GB for 8K images!)
18
+ # - The merged output is completely seamless without any post-processing.
19
+ #
20
+ # Drawbacks:
21
+ # - NaNs always appear in for 8k images when you use fp16 (half) VAE
22
+ # You must use --no-half-vae to disable half VAE for that giant image.
23
+ # - The gradient calculation is not compatible with this hack. It
24
+ # will break any backward() or torch.autograd.grad() that passes VAE.
25
+ # (But you can still use the VAE to generate training data.)
26
+ #
27
+ # How it works:
28
+ # 1. The image is split into tiles, which are then padded with 11/32 pixels' in the decoder/encoder.
29
+ # 2. When Fast Mode is disabled:
30
+ # 1. The original VAE forward is decomposed into a task queue and a task worker, which starts to process each tile.
31
+ # 2. When GroupNorm is needed, it suspends, stores current GroupNorm mean and var, send everything to RAM, and turns to the next tile.
32
+ # 3. After all GroupNorm means and vars are summarized, it applies group norm to tiles and continues.
33
+ # 4. A zigzag execution order is used to reduce unnecessary data transfer.
34
+ # 3. When Fast Mode is enabled:
35
+ # 1. The original input is downsampled and passed to a separate task queue.
36
+ # 2. Its group norm parameters are recorded and used by all tiles' task queues.
37
+ # 3. Each tile is separately processed without any RAM-VRAM data transfer.
38
+ # 4. After all tiles are processed, tiles are written to a result buffer and returned.
39
+ # Encoder color fix = only estimate GroupNorm before downsampling, i.e., run in a semi-fast mode.
40
+ #
41
+ # Enjoy!
42
+ #
43
+ # @Author: LI YI @ Nanyang Technological University - Singapore
44
+ # @Date: 2023-03-02
45
+ # @License: CC BY-NC-SA 4.0
46
+ #
47
+ # Please give me a star if you like this project!
48
+ #
49
+ # -------------------------------------------------------------------------
50
+ '''
51
+
52
+ import gc
53
+ import math
54
+ from time import time
55
+ from tqdm import tqdm
56
+
57
+ import torch
58
+ import torch.version
59
+ import torch.nn.functional as F
60
+ import gradio as gr
61
+
62
+ import modules.scripts as scripts
63
+ import modules.devices as devices
64
+ from modules.shared import state
65
+ from modules.ui import gr_show
66
+ from modules.processing import opt_f
67
+ from modules.sd_vae_approx import cheap_approximation
68
+ from ldm.modules.diffusionmodules.model import AttnBlock, MemoryEfficientAttnBlock
69
+
70
+ from tile_utils.attn import get_attn_func
71
+ from tile_utils.typing import Processing
72
+
73
+
74
+ def get_rcmd_enc_tsize():
75
+ if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]:
76
+ total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20
77
+ if total_memory > 16*1000: ENCODER_TILE_SIZE = 3072
78
+ elif total_memory > 12*1000: ENCODER_TILE_SIZE = 2048
79
+ elif total_memory > 8*1000: ENCODER_TILE_SIZE = 1536
80
+ else: ENCODER_TILE_SIZE = 960
81
+ else: ENCODER_TILE_SIZE = 512
82
+ return ENCODER_TILE_SIZE
83
+
84
+
85
+ def get_rcmd_dec_tsize():
86
+ if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]:
87
+ total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20
88
+ if total_memory > 30*1000: DECODER_TILE_SIZE = 256
89
+ elif total_memory > 16*1000: DECODER_TILE_SIZE = 192
90
+ elif total_memory > 12*1000: DECODER_TILE_SIZE = 128
91
+ elif total_memory > 8*1000: DECODER_TILE_SIZE = 96
92
+ else: DECODER_TILE_SIZE = 64
93
+ else: DECODER_TILE_SIZE = 64
94
+ return DECODER_TILE_SIZE
95
+
96
+
97
+ def inplace_nonlinearity(x):
98
+ # Test: fix for Nans
99
+ return F.silu(x, inplace=True)
100
+
101
+
102
+ def attn2task(task_queue, net):
103
+ attn_forward = get_attn_func()
104
+ task_queue.append(('store_res', lambda x: x))
105
+ task_queue.append(('pre_norm', net.norm))
106
+ task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
107
+ task_queue.append(['add_res', None])
108
+
109
+
110
+ def resblock2task(queue, block):
111
+ """
112
+ Turn a ResNetBlock into a sequence of tasks and append to the task queue
113
+
114
+ @param queue: the target task queue
115
+ @param block: ResNetBlock
116
+
117
+ """
118
+ if block.in_channels != block.out_channels:
119
+ if block.use_conv_shortcut:
120
+ queue.append(('store_res', block.conv_shortcut))
121
+ else:
122
+ queue.append(('store_res', block.nin_shortcut))
123
+ else:
124
+ queue.append(('store_res', lambda x: x))
125
+ queue.append(('pre_norm', block.norm1))
126
+ queue.append(('silu', inplace_nonlinearity))
127
+ queue.append(('conv1', block.conv1))
128
+ queue.append(('pre_norm', block.norm2))
129
+ queue.append(('silu', inplace_nonlinearity))
130
+ queue.append(('conv2', block.conv2))
131
+ queue.append(['add_res', None])
132
+
133
+
134
+ def build_sampling(task_queue, net, is_decoder):
135
+ """
136
+ Build the sampling part of a task queue
137
+ @param task_queue: the target task queue
138
+ @param net: the network
139
+ @param is_decoder: currently building decoder or encoder
140
+ """
141
+ if is_decoder:
142
+ resblock2task(task_queue, net.mid.block_1)
143
+ attn2task(task_queue, net.mid.attn_1)
144
+ resblock2task(task_queue, net.mid.block_2)
145
+ resolution_iter = reversed(range(net.num_resolutions))
146
+ block_ids = net.num_res_blocks + 1
147
+ condition = 0
148
+ module = net.up
149
+ func_name = 'upsample'
150
+ else:
151
+ resolution_iter = range(net.num_resolutions)
152
+ block_ids = net.num_res_blocks
153
+ condition = net.num_resolutions - 1
154
+ module = net.down
155
+ func_name = 'downsample'
156
+
157
+ for i_level in resolution_iter:
158
+ for i_block in range(block_ids):
159
+ resblock2task(task_queue, module[i_level].block[i_block])
160
+ if i_level != condition:
161
+ task_queue.append((func_name, getattr(module[i_level], func_name)))
162
+
163
+ if not is_decoder:
164
+ resblock2task(task_queue, net.mid.block_1)
165
+ attn2task(task_queue, net.mid.attn_1)
166
+ resblock2task(task_queue, net.mid.block_2)
167
+
168
+
169
+ def build_task_queue(net, is_decoder):
170
+ """
171
+ Build a single task queue for the encoder or decoder
172
+ @param net: the VAE decoder or encoder network
173
+ @param is_decoder: currently building decoder or encoder
174
+ @return: the task queue
175
+ """
176
+ task_queue = []
177
+ task_queue.append(('conv_in', net.conv_in))
178
+
179
+ # construct the sampling part of the task queue
180
+ # because encoder and decoder share the same architecture, we extract the sampling part
181
+ build_sampling(task_queue, net, is_decoder)
182
+
183
+ if not is_decoder or not net.give_pre_end:
184
+ task_queue.append(('pre_norm', net.norm_out))
185
+ task_queue.append(('silu', inplace_nonlinearity))
186
+ task_queue.append(('conv_out', net.conv_out))
187
+ if is_decoder and net.tanh_out:
188
+ task_queue.append(('tanh', torch.tanh))
189
+
190
+ return task_queue
191
+
192
+
193
+ def clone_task_queue(task_queue):
194
+ """
195
+ Clone a task queue
196
+ @param task_queue: the task queue to be cloned
197
+ @return: the cloned task queue
198
+ """
199
+ return [[item for item in task] for task in task_queue]
200
+
201
+
202
+ def get_var_mean(input, num_groups, eps=1e-6):
203
+ """
204
+ Get mean and var for group norm
205
+ """
206
+ b, c = input.size(0), input.size(1)
207
+ channel_in_group = int(c/num_groups)
208
+ input_reshaped = input.contiguous().view(1, int(b * num_groups), channel_in_group, *input.size()[2:])
209
+ var, mean = torch.var_mean(input_reshaped, dim=[0, 2, 3, 4], unbiased=False)
210
+ return var, mean
211
+
212
+
213
+ def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6):
214
+ """
215
+ Custom group norm with fixed mean and var
216
+
217
+ @param input: input tensor
218
+ @param num_groups: number of groups. by default, num_groups = 32
219
+ @param mean: mean, must be pre-calculated by get_var_mean
220
+ @param var: var, must be pre-calculated by get_var_mean
221
+ @param weight: weight, should be fetched from the original group norm
222
+ @param bias: bias, should be fetched from the original group norm
223
+ @param eps: epsilon, by default, eps = 1e-6 to match the original group norm
224
+
225
+ @return: normalized tensor
226
+ """
227
+ b, c = input.size(0), input.size(1)
228
+ channel_in_group = int(c/num_groups)
229
+ input_reshaped = input.contiguous().view(
230
+ 1, int(b * num_groups), channel_in_group, *input.size()[2:])
231
+
232
+ out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None, training=False, momentum=0, eps=eps)
233
+ out = out.view(b, c, *input.size()[2:])
234
+
235
+ # post affine transform
236
+ if weight is not None:
237
+ out *= weight.view(1, -1, 1, 1)
238
+ if bias is not None:
239
+ out += bias.view(1, -1, 1, 1)
240
+ return out
241
+
242
+
243
+ def crop_valid_region(x, input_bbox, target_bbox, is_decoder):
244
+ """
245
+ Crop the valid region from the tile
246
+ @param x: input tile
247
+ @param input_bbox: original input bounding box
248
+ @param target_bbox: output bounding box
249
+ @param scale: scale factor
250
+ @return: cropped tile
251
+ """
252
+ padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox]
253
+ margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]
254
+ return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]]
255
+
256
+
257
+ # ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓
258
+
259
+ def perfcount(fn):
260
+ def wrapper(*args, **kwargs):
261
+ ts = time()
262
+
263
+ if torch.cuda.is_available():
264
+ torch.cuda.reset_peak_memory_stats(devices.device)
265
+ devices.torch_gc()
266
+ gc.collect()
267
+
268
+ ret = fn(*args, **kwargs)
269
+
270
+ devices.torch_gc()
271
+ gc.collect()
272
+ if torch.cuda.is_available():
273
+ vram = torch.cuda.max_memory_allocated(devices.device) / 2**20
274
+ print(f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB')
275
+ else:
276
+ print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')
277
+
278
+ return ret
279
+ return wrapper
280
+
281
+ # ↑↑↑ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↑↑↑
282
+
283
+
284
+ class GroupNormParam:
285
+
286
+ def __init__(self):
287
+ self.var_list = []
288
+ self.mean_list = []
289
+ self.pixel_list = []
290
+ self.weight = None
291
+ self.bias = None
292
+
293
+ def add_tile(self, tile, layer):
294
+ var, mean = get_var_mean(tile, 32)
295
+ # For giant images, the variance can be larger than max float16
296
+ # In this case we create a copy to float32
297
+ if var.dtype == torch.float16 and var.isinf().any():
298
+ fp32_tile = tile.float()
299
+ var, mean = get_var_mean(fp32_tile, 32)
300
+ # ============= DEBUG: test for infinite =============
301
+ # if torch.isinf(var).any():
302
+ # print('var: ', var)
303
+ # ====================================================
304
+ self.var_list.append(var)
305
+ self.mean_list.append(mean)
306
+ self.pixel_list.append(
307
+ tile.shape[2]*tile.shape[3])
308
+ if hasattr(layer, 'weight'):
309
+ self.weight = layer.weight
310
+ self.bias = layer.bias
311
+ else:
312
+ self.weight = None
313
+ self.bias = None
314
+
315
+ def summary(self):
316
+ """
317
+ summarize the mean and var and return a function
318
+ that apply group norm on each tile
319
+ """
320
+ if len(self.var_list) == 0: return None
321
+
322
+ var = torch.vstack(self.var_list)
323
+ mean = torch.vstack(self.mean_list)
324
+ max_value = max(self.pixel_list)
325
+ pixels = torch.tensor(self.pixel_list, dtype=torch.float32, device=devices.device) / max_value
326
+ sum_pixels = torch.sum(pixels)
327
+ pixels = pixels.unsqueeze(1) / sum_pixels
328
+ var = torch.sum(var * pixels, dim=0)
329
+ mean = torch.sum(mean * pixels, dim=0)
330
+ return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias)
331
+
332
+ @staticmethod
333
+ def from_tile(tile, norm):
334
+ """
335
+ create a function from a single tile without summary
336
+ """
337
+ var, mean = get_var_mean(tile, 32)
338
+ if var.dtype == torch.float16 and var.isinf().any():
339
+ fp32_tile = tile.float()
340
+ var, mean = get_var_mean(fp32_tile, 32)
341
+ # if it is a macbook, we need to convert back to float16
342
+ if var.device.type == 'mps':
343
+ # clamp to avoid overflow
344
+ var = torch.clamp(var, 0, 60000)
345
+ var = var.half()
346
+ mean = mean.half()
347
+ if hasattr(norm, 'weight'):
348
+ weight = norm.weight
349
+ bias = norm.bias
350
+ else:
351
+ weight = None
352
+ bias = None
353
+
354
+ def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):
355
+ return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)
356
+ return group_norm_func
357
+
358
+
359
+ class VAEHook:
360
+
361
+ def __init__(self, net, tile_size, is_decoder:bool, fast_decoder:bool, fast_encoder:bool, color_fix:bool, to_gpu:bool=False):
362
+ self.net = net # encoder | decoder
363
+ self.tile_size = tile_size
364
+ self.is_decoder = is_decoder
365
+ self.fast_mode = (fast_encoder and not is_decoder) or (fast_decoder and is_decoder)
366
+ self.color_fix = color_fix and not is_decoder
367
+ self.to_gpu = to_gpu
368
+ self.pad = 11 if is_decoder else 32 # FIXME: magic number
369
+
370
+ def __call__(self, x):
371
+ original_device = next(self.net.parameters()).device
372
+ try:
373
+ if self.to_gpu:
374
+ self.net = self.net.to(devices.get_optimal_device())
375
+
376
+ B, C, H, W = x.shape
377
+ if max(H, W) <= self.pad * 2 + self.tile_size:
378
+ print("[Tiled VAE]: the input size is tiny and unnecessary to tile.")
379
+ return self.net.original_forward(x)
380
+ else:
381
+ return self.vae_tile_forward(x)
382
+ finally:
383
+ self.net = self.net.to(original_device)
384
+
385
+ def get_best_tile_size(self, lowerbound, upperbound):
386
+ """
387
+ Get the best tile size for GPU memory
388
+ """
389
+ divider = 32
390
+ while divider >= 2:
391
+ remainer = lowerbound % divider
392
+ if remainer == 0:
393
+ return lowerbound
394
+ candidate = lowerbound - remainer + divider
395
+ if candidate <= upperbound:
396
+ return candidate
397
+ divider //= 2
398
+ return lowerbound
399
+
400
+ def split_tiles(self, h, w):
401
+ """
402
+ Tool function to split the image into tiles
403
+ @param h: height of the image
404
+ @param w: width of the image
405
+ @return: tile_input_bboxes, tile_output_bboxes
406
+ """
407
+ tile_input_bboxes, tile_output_bboxes = [], []
408
+ tile_size = self.tile_size
409
+ pad = self.pad
410
+ num_height_tiles = math.ceil((h - 2 * pad) / tile_size)
411
+ num_width_tiles = math.ceil((w - 2 * pad) / tile_size)
412
+ # If any of the numbers are 0, we let it be 1
413
+ # This is to deal with long and thin images
414
+ num_height_tiles = max(num_height_tiles, 1)
415
+ num_width_tiles = max(num_width_tiles, 1)
416
+
417
+ # Suggestions from https://github.com/Kahsolt: auto shrink the tile size
418
+ real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)
419
+ real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)
420
+ real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)
421
+ real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)
422
+
423
+ print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' +
424
+ f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}')
425
+
426
+ for i in range(num_height_tiles):
427
+ for j in range(num_width_tiles):
428
+ # bbox: [x1, x2, y1, y2]
429
+ # the padding is is unnessary for image borders. So we directly start from (32, 32)
430
+ input_bbox = [
431
+ pad + j * real_tile_width,
432
+ min(pad + (j + 1) * real_tile_width, w),
433
+ pad + i * real_tile_height,
434
+ min(pad + (i + 1) * real_tile_height, h),
435
+ ]
436
+
437
+ # if the output bbox is close to the image boundary, we extend it to the image boundary
438
+ output_bbox = [
439
+ input_bbox[0] if input_bbox[0] > pad else 0,
440
+ input_bbox[1] if input_bbox[1] < w - pad else w,
441
+ input_bbox[2] if input_bbox[2] > pad else 0,
442
+ input_bbox[3] if input_bbox[3] < h - pad else h,
443
+ ]
444
+
445
+ # scale to get the final output bbox
446
+ output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox]
447
+ tile_output_bboxes.append(output_bbox)
448
+
449
+ # indistinguishable expand the input bbox by pad pixels
450
+ tile_input_bboxes.append([
451
+ max(0, input_bbox[0] - pad),
452
+ min(w, input_bbox[1] + pad),
453
+ max(0, input_bbox[2] - pad),
454
+ min(h, input_bbox[3] + pad),
455
+ ])
456
+
457
+ return tile_input_bboxes, tile_output_bboxes
458
+
459
+ @torch.no_grad()
460
+ def estimate_group_norm(self, z, task_queue, color_fix):
461
+ device = z.device
462
+ tile = z
463
+ last_id = len(task_queue) - 1
464
+ while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':
465
+ last_id -= 1
466
+ if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':
467
+ raise ValueError('No group norm found in the task queue')
468
+ # estimate until the last group norm
469
+ for i in range(last_id + 1):
470
+ task = task_queue[i]
471
+ if task[0] == 'pre_norm':
472
+ group_norm_func = GroupNormParam.from_tile(tile, task[1])
473
+ task_queue[i] = ('apply_norm', group_norm_func)
474
+ if i == last_id:
475
+ return True
476
+ tile = group_norm_func(tile)
477
+ elif task[0] == 'store_res':
478
+ task_id = i + 1
479
+ while task_id < last_id and task_queue[task_id][0] != 'add_res':
480
+ task_id += 1
481
+ if task_id >= last_id:
482
+ continue
483
+ task_queue[task_id][1] = task[1](tile)
484
+ elif task[0] == 'add_res':
485
+ tile += task[1].to(device)
486
+ task[1] = None
487
+ elif color_fix and task[0] == 'downsample':
488
+ for j in range(i, last_id + 1):
489
+ if task_queue[j][0] == 'store_res':
490
+ task_queue[j] = ('store_res_cpu', task_queue[j][1])
491
+ return True
492
+ else:
493
+ tile = task[1](tile)
494
+ try:
495
+ devices.test_for_nans(tile, "vae")
496
+ except:
497
+ print(f'Nan detected in fast mode estimation. Fast mode disabled.')
498
+ return False
499
+
500
+ raise IndexError('Should not reach here')
501
+
502
+ @perfcount
503
+ @torch.no_grad()
504
+ def vae_tile_forward(self, z):
505
+ """
506
+ Decode a latent vector z into an image in a tiled manner.
507
+ @param z: latent vector
508
+ @return: image
509
+ """
510
+ device = next(self.net.parameters()).device
511
+ net = self.net
512
+ tile_size = self.tile_size
513
+ is_decoder = self.is_decoder
514
+
515
+ z = z.detach() # detach the input to avoid backprop
516
+
517
+ N, height, width = z.shape[0], z.shape[2], z.shape[3]
518
+ net.last_z_shape = z.shape
519
+
520
+ # Split the input into tiles and build a task queue for each tile
521
+ print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}')
522
+
523
+ in_bboxes, out_bboxes = self.split_tiles(height, width)
524
+
525
+ # Prepare tiles by split the input latents
526
+ tiles = []
527
+ for input_bbox in in_bboxes:
528
+ tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu()
529
+ tiles.append(tile)
530
+
531
+ num_tiles = len(tiles)
532
+ num_completed = 0
533
+
534
+ # Build task queues
535
+ single_task_queue = build_task_queue(net, is_decoder)
536
+ if self.fast_mode:
537
+ # Fast mode: downsample the input image to the tile size,
538
+ # then estimate the group norm parameters on the downsampled image
539
+ scale_factor = tile_size / max(height, width)
540
+ z = z.to(device)
541
+ downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact')
542
+ # use nearest-exact to keep statictics as close as possible
543
+ print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')
544
+
545
+ # ======= Special thanks to @Kahsolt for distribution shift issue ======= #
546
+ # The downsampling will heavily distort its mean and std, so we need to recover it.
547
+ std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)
548
+ std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True)
549
+ downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old
550
+ del std_old, mean_old, std_new, mean_new
551
+ # occasionally the std_new is too small or too large, which exceeds the range of float16
552
+ # so we need to clamp it to max z's range.
553
+ downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max())
554
+ estimate_task_queue = clone_task_queue(single_task_queue)
555
+ if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix):
556
+ single_task_queue = estimate_task_queue
557
+ del downsampled_z
558
+
559
+ task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)]
560
+
561
+ # Dummy result
562
+ result = None
563
+ result_approx = None
564
+ try:
565
+ with devices.autocast():
566
+ result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu()
567
+ except: pass
568
+ # Free memory of input latent tensor
569
+ del z
570
+
571
+ # Task queue execution
572
+ pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ")
573
+
574
+ # execute the task back and forth when switch tiles so that we always
575
+ # keep one tile on the GPU to reduce unnecessary data transfer
576
+ forward = True
577
+ interrupted = False
578
+ #state.interrupted = interrupted
579
+ while True:
580
+ if state.interrupted: interrupted = True ; break
581
+
582
+ group_norm_param = GroupNormParam()
583
+ for i in range(num_tiles) if forward else reversed(range(num_tiles)):
584
+ if state.interrupted: interrupted = True ; break
585
+
586
+ tile = tiles[i].to(device)
587
+ input_bbox = in_bboxes[i]
588
+ task_queue = task_queues[i]
589
+
590
+ interrupted = False
591
+ while len(task_queue) > 0:
592
+ if state.interrupted: interrupted = True ; break
593
+
594
+ # DEBUG: current task
595
+ # print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape)
596
+ task = task_queue.pop(0)
597
+ if task[0] == 'pre_norm':
598
+ group_norm_param.add_tile(tile, task[1])
599
+ break
600
+ elif task[0] == 'store_res' or task[0] == 'store_res_cpu':
601
+ task_id = 0
602
+ res = task[1](tile)
603
+ if not self.fast_mode or task[0] == 'store_res_cpu':
604
+ res = res.cpu()
605
+ while task_queue[task_id][0] != 'add_res':
606
+ task_id += 1
607
+ task_queue[task_id][1] = res
608
+ elif task[0] == 'add_res':
609
+ tile += task[1].to(device)
610
+ task[1] = None
611
+ else:
612
+ tile = task[1](tile)
613
+ pbar.update(1)
614
+
615
+ if interrupted: break
616
+
617
+ # check for NaNs in the tile.
618
+ # If there are NaNs, we abort the process to save user's time
619
+ devices.test_for_nans(tile, "vae")
620
+
621
+ if len(task_queue) == 0:
622
+ tiles[i] = None
623
+ num_completed += 1
624
+ if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically
625
+ result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False)
626
+ result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder)
627
+ del tile
628
+ elif i == num_tiles - 1 and forward:
629
+ forward = False
630
+ tiles[i] = tile
631
+ elif i == 0 and not forward:
632
+ forward = True
633
+ tiles[i] = tile
634
+ else:
635
+ tiles[i] = tile.cpu()
636
+ del tile
637
+
638
+ if interrupted: break
639
+ if num_completed == num_tiles: break
640
+
641
+ # insert the group norm task to the head of each task queue
642
+ group_norm_func = group_norm_param.summary()
643
+ if group_norm_func is not None:
644
+ for i in range(num_tiles):
645
+ task_queue = task_queues[i]
646
+ task_queue.insert(0, ('apply_norm', group_norm_func))
647
+
648
+ # Done!
649
+ pbar.close()
650
+ return result if result is not None else result_approx.to(device)
651
+
652
+
653
+ class Script(scripts.Script):
654
+
655
+ def __init__(self):
656
+ self.hooked = False
657
+
658
+ def title(self):
659
+ return "Tiled VAE"
660
+
661
+ def show(self, is_img2img):
662
+ return scripts.AlwaysVisible
663
+
664
+ def ui(self, is_img2img):
665
+ tab = 't2i' if not is_img2img else 'i2i'
666
+ uid = lambda name: f'MD-{tab}-{name}'
667
+
668
+ with gr.Accordion('Tiled VAE', open=False, elem_id=f'MDV-{tab}'):
669
+ with gr.Row() as tab_enable:
670
+ enabled = gr.Checkbox(label='Enable Tiled VAE', value=False, elem_id=uid('enable'))
671
+ vae_to_gpu = gr.Checkbox(label='Move VAE to GPU (if possible)', value=True, elem_id=uid('vae2gpu'))
672
+
673
+ gr.HTML('<p style="margin-bottom:0.8em"> Recommended to set tile sizes as large as possible before got CUDA error: out of memory. </p>')
674
+ with gr.Row() as tab_size:
675
+ encoder_tile_size = gr.Slider(label='Encoder Tile Size', minimum=256, maximum=4096, step=16, value=get_rcmd_enc_tsize(), elem_id=uid('enc-size'))
676
+ decoder_tile_size = gr.Slider(label='Decoder Tile Size', minimum=48, maximum=512, step=16, value=get_rcmd_dec_tsize(), elem_id=uid('dec-size'))
677
+ reset = gr.Button(value='↻ Reset', variant='tool')
678
+ reset.click(fn=lambda: [get_rcmd_enc_tsize(), get_rcmd_dec_tsize()], outputs=[encoder_tile_size, decoder_tile_size], show_progress=False)
679
+
680
+ with gr.Row() as tab_param:
681
+ fast_encoder = gr.Checkbox(label='Fast Encoder', value=True, elem_id=uid('fastenc'))
682
+ color_fix = gr.Checkbox(label='Fast Encoder Color Fix', value=False, visible=True, elem_id=uid('fastenc-colorfix'))
683
+ fast_decoder = gr.Checkbox(label='Fast Decoder', value=True, elem_id=uid('fastdec'))
684
+
685
+ fast_encoder.change(fn=gr_show, inputs=fast_encoder, outputs=color_fix, show_progress=False)
686
+
687
+ return [
688
+ enabled,
689
+ encoder_tile_size, decoder_tile_size,
690
+ vae_to_gpu, fast_decoder, fast_encoder, color_fix,
691
+ ]
692
+
693
+ def process(self, p:Processing,
694
+ enabled:bool,
695
+ encoder_tile_size:int, decoder_tile_size:int,
696
+ vae_to_gpu:bool, fast_decoder:bool, fast_encoder:bool, color_fix:bool
697
+ ):
698
+ enabled = True
699
+ encoder_tile_size = 1536
700
+ decoder_tile_size = 96
701
+ vae_to_gpu = True
702
+ fast_decoder = True
703
+ fast_encoder = False
704
+ color_fix = False
705
+ # for shorthand
706
+ vae = p.sd_model.first_stage_model
707
+ encoder = vae.encoder
708
+ decoder = vae.decoder
709
+
710
+ # undo hijack if disabled (in cases last time crashed)
711
+ if not enabled:
712
+ if self.hooked:
713
+ if isinstance(encoder.forward, VAEHook):
714
+ encoder.forward.net = None
715
+ encoder.forward = encoder.original_forward
716
+ if isinstance(decoder.forward, VAEHook):
717
+ decoder.forward.net = None
718
+ decoder.forward = decoder.original_forward
719
+ self.hooked = False
720
+ return
721
+
722
+ if devices.get_optimal_device_name().startswith('cuda') and vae.device == devices.cpu and not vae_to_gpu:
723
+ print("[Tiled VAE] warn: VAE is not on GPU, check 'Move VAE to GPU' if possible.")
724
+
725
+ # do hijack
726
+ kwargs = {
727
+ 'fast_decoder': fast_decoder,
728
+ 'fast_encoder': fast_encoder,
729
+ 'color_fix': color_fix,
730
+ 'to_gpu': vae_to_gpu,
731
+ }
732
+
733
+ # save original forward (only once)
734
+ if not hasattr(encoder, 'original_forward'): setattr(encoder, 'original_forward', encoder.forward)
735
+ if not hasattr(decoder, 'original_forward'): setattr(decoder, 'original_forward', decoder.forward)
736
+
737
+ self.hooked = True
738
+
739
+ encoder.forward = VAEHook(encoder, encoder_tile_size, is_decoder=False, **kwargs)
740
+ decoder.forward = VAEHook(decoder, decoder_tile_size, is_decoder=True, **kwargs)
741
+
742
+ def postprocess(self, p:Processing, processed, enabled:bool, *args):
743
+ if not enabled: return
744
+
745
+ vae = p.sd_model.first_stage_model
746
+ encoder = vae.encoder
747
+ decoder = vae.decoder
748
+ if isinstance(encoder.forward, VAEHook):
749
+ encoder.forward.net = None
750
+ encoder.forward = encoder.original_forward
751
+ if isinstance(decoder.forward, VAEHook):
752
+ decoder.forward.net = None
753
+ decoder.forward = decoder.original_forward