chaojiemao commited on
Commit
e8e3dcf
1 Parent(s): 161b0b1

Create ace_inference.py

Browse files
Files changed (1) hide show
  1. ace_inference.py +549 -0
ace_inference.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import copy
4
+ import math
5
+ import random
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torchvision.transforms.functional as TF
12
+ from PIL import Image
13
+ import torchvision.transforms as T
14
+ from scepter.modules.model.registry import DIFFUSIONS
15
+ from scepter.modules.model.utils.basic_utils import check_list_of_list
16
+ from scepter.modules.model.utils.basic_utils import \
17
+ pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor
18
+ from scepter.modules.model.utils.basic_utils import (
19
+ to_device, unpack_tensor_into_imagelist)
20
+ from scepter.modules.utils.distribute import we
21
+ from scepter.modules.utils.logger import get_logger
22
+
23
+ from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model
24
+
25
+
26
+ def process_edit_image(images,
27
+ masks,
28
+ tasks,
29
+ max_seq_len=1024,
30
+ max_aspect_ratio=4,
31
+ d=16,
32
+ **kwargs):
33
+
34
+ if not isinstance(images, list):
35
+ images = [images]
36
+ if not isinstance(masks, list):
37
+ masks = [masks]
38
+ if not isinstance(tasks, list):
39
+ tasks = [tasks]
40
+
41
+ img_tensors = []
42
+ mask_tensors = []
43
+ for img, mask, task in zip(images, masks, tasks):
44
+ if mask is None or mask == '':
45
+ mask = Image.new('L', img.size, 0)
46
+ W, H = img.size
47
+ if H / W > max_aspect_ratio:
48
+ img = TF.center_crop(img, [int(max_aspect_ratio * W), W])
49
+ mask = TF.center_crop(mask, [int(max_aspect_ratio * W), W])
50
+ elif W / H > max_aspect_ratio:
51
+ img = TF.center_crop(img, [H, int(max_aspect_ratio * H)])
52
+ mask = TF.center_crop(mask, [H, int(max_aspect_ratio * H)])
53
+
54
+ H, W = img.height, img.width
55
+ scale = min(1.0, math.sqrt(max_seq_len / ((H / d) * (W / d))))
56
+ rH = int(H * scale) // d * d # ensure divisible by self.d
57
+ rW = int(W * scale) // d * d
58
+
59
+ img = TF.resize(img, (rH, rW),
60
+ interpolation=TF.InterpolationMode.BICUBIC)
61
+ mask = TF.resize(mask, (rH, rW),
62
+ interpolation=TF.InterpolationMode.NEAREST_EXACT)
63
+
64
+ mask = np.asarray(mask)
65
+ mask = np.where(mask > 128, 1, 0)
66
+ mask = mask.astype(
67
+ np.float32) if np.any(mask) else np.ones_like(mask).astype(
68
+ np.float32)
69
+
70
+ img_tensor = TF.to_tensor(img).to(we.device_id)
71
+ img_tensor = TF.normalize(img_tensor,
72
+ mean=[0.5, 0.5, 0.5],
73
+ std=[0.5, 0.5, 0.5])
74
+ mask_tensor = TF.to_tensor(mask).to(we.device_id)
75
+ if task in ['inpainting', 'Try On', 'Inpainting']:
76
+ mask_indicator = mask_tensor.repeat(3, 1, 1)
77
+ img_tensor[mask_indicator == 1] = -1.0
78
+ img_tensors.append(img_tensor)
79
+ mask_tensors.append(mask_tensor)
80
+ return img_tensors, mask_tensors
81
+
82
+
83
+ class TextEmbedding(nn.Module):
84
+ def __init__(self, embedding_shape):
85
+ super().__init__()
86
+ self.pos = nn.Parameter(data=torch.zeros(embedding_shape))
87
+
88
+ class RefinerInference(DiffusionInference):
89
+ def init_from_cfg(self, cfg):
90
+ super().init_from_cfg(cfg)
91
+ self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION, logger=self.logger) \
92
+ if cfg.MODEL.have('DIFFUSION') else None
93
+ self.max_seq_length = cfg.MODEL.get("MAX_SEQ_LENGTH", 4096)
94
+ assert self.diffusion is not None
95
+ self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
96
+ self.dynamic_load(self.diffusion_model, 'diffusion_model')
97
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
98
+
99
+ @torch.no_grad()
100
+ def encode_first_stage(self, x, **kwargs):
101
+ _, dtype = self.get_function_info(self.first_stage_model, 'encode')
102
+ with torch.autocast('cuda',
103
+ enabled=dtype in ('float16', 'bfloat16'),
104
+ dtype=getattr(torch, dtype)):
105
+ def run_one_image(u):
106
+ zu = get_model(self.first_stage_model).encode(u)
107
+ if isinstance(zu, (tuple, list)):
108
+ zu = zu[0]
109
+ return zu
110
+ z = [run_one_image(u.unsqueeze(0) if u.dim == 3 else u) for u in x]
111
+ return z
112
+ def upscale_resize(self, image, interpolation=T.InterpolationMode.BILINEAR):
113
+ c, H, W = image.shape
114
+ scale = max(1.0, math.sqrt(self.max_seq_length / ((H / 16) * (W / 16))))
115
+ rH = int(H * scale) // 16 * 16 # ensure divisible by self.d
116
+ rW = int(W * scale) // 16 * 16
117
+ image = T.Resize((rH, rW), interpolation=interpolation, antialias=True)(image)
118
+ return image
119
+ @torch.no_grad()
120
+ def decode_first_stage(self, z):
121
+ _, dtype = self.get_function_info(self.first_stage_model, 'decode')
122
+ with torch.autocast('cuda',
123
+ enabled=dtype in ('float16', 'bfloat16'),
124
+ dtype=getattr(torch, dtype)):
125
+ return [get_model(self.first_stage_model).decode(zu) for zu in z]
126
+
127
+ def noise_sample(self, num_samples, h, w, seed, device = None, dtype = torch.bfloat16):
128
+ noise = torch.randn(
129
+ num_samples,
130
+ 16,
131
+ # allow for packing
132
+ 2 * math.ceil(h / 16),
133
+ 2 * math.ceil(w / 16),
134
+ device=device,
135
+ dtype=dtype,
136
+ generator=torch.Generator(device=device).manual_seed(seed),
137
+ )
138
+ return noise
139
+ def refine(self,
140
+ x_samples=None,
141
+ prompt=None,
142
+ reverse_scale=-1.,
143
+ seed = 2024,
144
+ use_dynamic_model = False,
145
+ **kwargs
146
+ ):
147
+ print(prompt)
148
+ value_input = copy.deepcopy(self.input)
149
+ x_samples = [self.upscale_resize(x) for x in x_samples]
150
+
151
+ noise = []
152
+ for i, x in enumerate(x_samples):
153
+ noise_ = self.noise_sample(1, x.shape[1],
154
+ x.shape[2], seed,
155
+ device = x.device)
156
+ noise.append(noise_)
157
+ noise, x_shapes = pack_imagelist_into_tensor(noise)
158
+ if reverse_scale > 0:
159
+ if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
160
+ x_samples = [x.unsqueeze(0) for x in x_samples]
161
+ x_start = self.encode_first_stage(x_samples, **kwargs)
162
+ if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
163
+ 'first_stage_model',
164
+ skip_loaded=True)
165
+ x_start, _ = pack_imagelist_into_tensor(x_start)
166
+ else:
167
+ x_start = None
168
+ # cond stage
169
+ if use_dynamic_model: self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
170
+ function_name, dtype = self.get_function_info(self.cond_stage_model)
171
+ with torch.autocast('cuda',
172
+ enabled=dtype == 'float16',
173
+ dtype=getattr(torch, dtype)):
174
+ ctx = getattr(get_model(self.cond_stage_model),
175
+ function_name)(prompt)
176
+ ctx["x_shapes"] = x_shapes
177
+ if use_dynamic_model: self.dynamic_unload(self.cond_stage_model,
178
+ 'cond_stage_model',
179
+ skip_loaded=True)
180
+
181
+
182
+ if use_dynamic_model: self.dynamic_load(self.diffusion_model, 'diffusion_model')
183
+ # UNet use input n_prompt
184
+ function_name, dtype = self.get_function_info(
185
+ self.diffusion_model)
186
+ with torch.autocast('cuda',
187
+ enabled=dtype in ('float16', 'bfloat16'),
188
+ dtype=getattr(torch, dtype)):
189
+ solver_sample = value_input.get('sample', 'flow_euler')
190
+ sample_steps = value_input.get('sample_steps', 20)
191
+ guide_scale = value_input.get('guide_scale', 3.5)
192
+ if guide_scale is not None:
193
+ guide_scale = torch.full((noise.shape[0],), guide_scale, device=noise.device,
194
+ dtype=noise.dtype)
195
+ else:
196
+ guide_scale = None
197
+ latent = self.diffusion.sample(
198
+ noise=noise,
199
+ sampler=solver_sample,
200
+ model=get_model(self.diffusion_model),
201
+ model_kwargs={"cond": ctx, "guidance": guide_scale},
202
+ steps=sample_steps,
203
+ show_progress=True,
204
+ guide_scale=guide_scale,
205
+ return_intermediate=None,
206
+ reverse_scale=reverse_scale,
207
+ x=x_start,
208
+ **kwargs).float()
209
+ latent = unpack_tensor_into_imagelist(latent, x_shapes)
210
+ if use_dynamic_model: self.dynamic_unload(self.diffusion_model,
211
+ 'diffusion_model',
212
+ skip_loaded=True)
213
+ if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
214
+ x_samples = self.decode_first_stage(latent)
215
+ if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
216
+ 'first_stage_model',
217
+ skip_loaded=True)
218
+ return x_samples
219
+
220
+
221
+ class ACEInference(DiffusionInference):
222
+ def __init__(self, logger=None):
223
+ if logger is None:
224
+ logger = get_logger(name='scepter')
225
+ self.logger = logger
226
+ self.loaded_model = {}
227
+ self.loaded_model_name = [
228
+ 'diffusion_model', 'first_stage_model', 'cond_stage_model'
229
+ ]
230
+
231
+ def init_from_cfg(self, cfg):
232
+ self.name = cfg.NAME
233
+ self.is_default = cfg.get('IS_DEFAULT', False)
234
+ self.use_dynamic_model = cfg.get('USE_DYNAMIC_MODEL', True)
235
+ module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None))
236
+ assert cfg.have('MODEL')
237
+
238
+ self.diffusion_model = self.infer_model(
239
+ cfg.MODEL.DIFFUSION_MODEL, module_paras.get(
240
+ 'DIFFUSION_MODEL',
241
+ None)) if cfg.MODEL.have('DIFFUSION_MODEL') else None
242
+ self.first_stage_model = self.infer_model(
243
+ cfg.MODEL.FIRST_STAGE_MODEL,
244
+ module_paras.get(
245
+ 'FIRST_STAGE_MODEL',
246
+ None)) if cfg.MODEL.have('FIRST_STAGE_MODEL') else None
247
+ self.cond_stage_model = self.infer_model(
248
+ cfg.MODEL.COND_STAGE_MODEL,
249
+ module_paras.get(
250
+ 'COND_STAGE_MODEL',
251
+ None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None
252
+
253
+ self.refiner_model_cfg = cfg.get('REFINER_MODEL', None)
254
+ # self.refiner_scale = cfg.get('REFINER_SCALE', 0.)
255
+ # self.refiner_prompt = cfg.get('REFINER_PROMPT', "")
256
+ self.ace_prompt = cfg.get("ACE_PROMPT", [])
257
+ if self.refiner_model_cfg:
258
+ self.refiner_module = RefinerInference(self.logger)
259
+ self.refiner_module.init_from_cfg(self.refiner_model_cfg)
260
+ else:
261
+ self.refiner_module = None
262
+
263
+ self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION,
264
+ logger=self.logger)
265
+
266
+
267
+ self.interpolate_func = lambda x: (F.interpolate(
268
+ x.unsqueeze(0),
269
+ scale_factor=1 / self.size_factor,
270
+ mode='nearest-exact') if x is not None else None)
271
+ self.text_indentifers = cfg.MODEL.get('TEXT_IDENTIFIER', [])
272
+ self.use_text_pos_embeddings = cfg.MODEL.get('USE_TEXT_POS_EMBEDDINGS',
273
+ False)
274
+ if self.use_text_pos_embeddings:
275
+ self.text_position_embeddings = TextEmbedding(
276
+ (10, 4096)).eval().requires_grad_(False).to(we.device_id)
277
+ else:
278
+ self.text_position_embeddings = None
279
+
280
+ self.max_seq_len = cfg.MODEL.DIFFUSION_MODEL.MAX_SEQ_LEN
281
+ self.scale_factor = cfg.get('SCALE_FACTOR', 0.18215)
282
+ self.size_factor = cfg.get('SIZE_FACTOR', 8)
283
+ self.decoder_bias = cfg.get('DECODER_BIAS', 0)
284
+ self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
285
+ self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
286
+ self.dynamic_load(self.diffusion_model, 'diffusion_model')
287
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
288
+
289
+ @torch.no_grad()
290
+ def encode_first_stage(self, x, **kwargs):
291
+ _, dtype = self.get_function_info(self.first_stage_model, 'encode')
292
+ with torch.autocast('cuda',
293
+ enabled=(dtype != 'float32'),
294
+ dtype=getattr(torch, dtype)):
295
+ z = [
296
+ self.scale_factor * get_model(self.first_stage_model)._encode(
297
+ i.unsqueeze(0).to(getattr(torch, dtype))) for i in x
298
+ ]
299
+ return z
300
+
301
+ @torch.no_grad()
302
+ def decode_first_stage(self, z):
303
+ _, dtype = self.get_function_info(self.first_stage_model, 'decode')
304
+ with torch.autocast('cuda',
305
+ enabled=(dtype != 'float32'),
306
+ dtype=getattr(torch, dtype)):
307
+ x = [
308
+ get_model(self.first_stage_model)._decode(
309
+ 1. / self.scale_factor * i.to(getattr(torch, dtype)))
310
+ for i in z
311
+ ]
312
+ return x
313
+
314
+
315
+
316
+ @torch.no_grad()
317
+ def __call__(self,
318
+ image=None,
319
+ mask=None,
320
+ prompt='',
321
+ task=None,
322
+ negative_prompt='',
323
+ output_height=512,
324
+ output_width=512,
325
+ sampler='ddim',
326
+ sample_steps=20,
327
+ guide_scale=4.5,
328
+ guide_rescale=0.5,
329
+ seed=-1,
330
+ history_io=None,
331
+ tar_index=0,
332
+ **kwargs):
333
+ input_image, input_mask = image, mask
334
+ g = torch.Generator(device=we.device_id)
335
+ seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
336
+ g.manual_seed(int(seed))
337
+ if input_image is not None:
338
+ # assert isinstance(input_image, list) and isinstance(input_mask, list)
339
+ if task is None:
340
+ task = [''] * len(input_image)
341
+ if not isinstance(prompt, list):
342
+ prompt = [prompt] * len(input_image)
343
+ if history_io is not None and len(history_io) > 0:
344
+ his_image, his_maks, his_prompt, his_task = history_io[
345
+ 'image'], history_io['mask'], history_io[
346
+ 'prompt'], history_io['task']
347
+ assert len(his_image) == len(his_maks) == len(
348
+ his_prompt) == len(his_task)
349
+ input_image = his_image + input_image
350
+ input_mask = his_maks + input_mask
351
+ task = his_task + task
352
+ prompt = his_prompt + [prompt[-1]]
353
+ prompt = [
354
+ pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp
355
+ for i, pp in enumerate(prompt)
356
+ ]
357
+
358
+ edit_image, edit_image_mask = process_edit_image(
359
+ input_image, input_mask, task, max_seq_len=self.max_seq_len)
360
+
361
+ image, image_mask = edit_image[tar_index], edit_image_mask[
362
+ tar_index]
363
+ edit_image, edit_image_mask = [edit_image], [edit_image_mask]
364
+
365
+ else:
366
+ edit_image = edit_image_mask = [[]]
367
+ image = torch.zeros(
368
+ size=[3, int(output_height),
369
+ int(output_width)])
370
+ image_mask = torch.ones(
371
+ size=[1, int(output_height),
372
+ int(output_width)])
373
+ if not isinstance(prompt, list):
374
+ prompt = [prompt]
375
+
376
+ image, image_mask, prompt = [image], [image_mask], [prompt]
377
+ assert check_list_of_list(prompt) and check_list_of_list(
378
+ edit_image) and check_list_of_list(edit_image_mask)
379
+ # Assign Negative Prompt
380
+ if isinstance(negative_prompt, list):
381
+ negative_prompt = negative_prompt[0]
382
+ assert isinstance(negative_prompt, str)
383
+
384
+ n_prompt = copy.deepcopy(prompt)
385
+ for nn_p_id, nn_p in enumerate(n_prompt):
386
+ assert isinstance(nn_p, list)
387
+ n_prompt[nn_p_id][-1] = negative_prompt
388
+
389
+ is_txt_image = sum([len(e_i) for e_i in edit_image]) < 1
390
+ image = to_device(image)
391
+
392
+ refiner_scale = kwargs.pop("refiner_scale", 0.0)
393
+ refiner_prompt = kwargs.pop("refiner_prompt", "")
394
+ use_ace = kwargs.pop("use_ace", True)
395
+ # <= 0 use ace as the txt2img generator.
396
+ if use_ace and (not is_txt_image or refiner_scale <= 0):
397
+ ctx, null_ctx = {}, {}
398
+ # Get Noise Shape
399
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
400
+ x = self.encode_first_stage(image)
401
+ self.dynamic_unload(self.first_stage_model,
402
+ 'first_stage_model',
403
+ skip_loaded=True)
404
+ noise = [
405
+ torch.empty(*i.shape, device=we.device_id).normal_(generator=g)
406
+ for i in x
407
+ ]
408
+ noise, x_shapes = pack_imagelist_into_tensor(noise)
409
+ ctx['x_shapes'] = null_ctx['x_shapes'] = x_shapes
410
+
411
+ image_mask = to_device(image_mask, strict=False)
412
+ cond_mask = [self.interpolate_func(i) for i in image_mask
413
+ ] if image_mask is not None else [None] * len(image)
414
+ ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
415
+
416
+ # Encode Prompt
417
+ self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
418
+ function_name, dtype = self.get_function_info(self.cond_stage_model)
419
+ cont, cont_mask = getattr(get_model(self.cond_stage_model),
420
+ function_name)(prompt)
421
+ cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,
422
+ cont_mask)
423
+ null_cont, null_cont_mask = getattr(get_model(self.cond_stage_model),
424
+ function_name)(n_prompt)
425
+ null_cont, null_cont_mask = self.cond_stage_embeddings(
426
+ prompt, edit_image, null_cont, null_cont_mask)
427
+ self.dynamic_unload(self.cond_stage_model,
428
+ 'cond_stage_model',
429
+ skip_loaded=False)
430
+ ctx['crossattn'] = cont
431
+ null_ctx['crossattn'] = null_cont
432
+
433
+ # Encode Edit Images
434
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
435
+ edit_image = [to_device(i, strict=False) for i in edit_image]
436
+ edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
437
+ e_img, e_mask = [], []
438
+ for u, m in zip(edit_image, edit_image_mask):
439
+ if u is None:
440
+ continue
441
+ if m is None:
442
+ m = [None] * len(u)
443
+ e_img.append(self.encode_first_stage(u, **kwargs))
444
+ e_mask.append([self.interpolate_func(i) for i in m])
445
+ self.dynamic_unload(self.first_stage_model,
446
+ 'first_stage_model',
447
+ skip_loaded=True)
448
+ null_ctx['edit'] = ctx['edit'] = e_img
449
+ null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
450
+
451
+ # Diffusion Process
452
+ self.dynamic_load(self.diffusion_model, 'diffusion_model')
453
+ function_name, dtype = self.get_function_info(self.diffusion_model)
454
+ with torch.autocast('cuda',
455
+ enabled=dtype in ('float16', 'bfloat16'),
456
+ dtype=getattr(torch, dtype)):
457
+ latent = self.diffusion.sample(
458
+ noise=noise,
459
+ sampler=sampler,
460
+ model=get_model(self.diffusion_model),
461
+ model_kwargs=[{
462
+ 'cond':
463
+ ctx,
464
+ 'mask':
465
+ cont_mask,
466
+ 'text_position_embeddings':
467
+ self.text_position_embeddings.pos if hasattr(
468
+ self.text_position_embeddings, 'pos') else None
469
+ }, {
470
+ 'cond':
471
+ null_ctx,
472
+ 'mask':
473
+ null_cont_mask,
474
+ 'text_position_embeddings':
475
+ self.text_position_embeddings.pos if hasattr(
476
+ self.text_position_embeddings, 'pos') else None
477
+ }] if guide_scale is not None and guide_scale > 1 else {
478
+ 'cond':
479
+ null_ctx,
480
+ 'mask':
481
+ cont_mask,
482
+ 'text_position_embeddings':
483
+ self.text_position_embeddings.pos if hasattr(
484
+ self.text_position_embeddings, 'pos') else None
485
+ },
486
+ steps=sample_steps,
487
+ show_progress=True,
488
+ seed=seed,
489
+ guide_scale=guide_scale,
490
+ guide_rescale=guide_rescale,
491
+ return_intermediate=None,
492
+ **kwargs)
493
+ self.dynamic_unload(self.diffusion_model,
494
+ 'diffusion_model',
495
+ skip_loaded=False)
496
+
497
+ # Decode to Pixel Space
498
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
499
+ samples = unpack_tensor_into_imagelist(latent, x_shapes)
500
+ x_samples = self.decode_first_stage(samples)
501
+ self.dynamic_unload(self.first_stage_model,
502
+ 'first_stage_model',
503
+ skip_loaded=False)
504
+ x_samples = [x.squeeze(0) for x in x_samples]
505
+ else:
506
+ x_samples = image
507
+ if self.refiner_module and refiner_scale > 0:
508
+ if is_txt_image:
509
+ random.shuffle(self.ace_prompt)
510
+ input_refine_prompt = [self.ace_prompt[0] + refiner_prompt if p[0] == "" else p[0] for p in prompt]
511
+ input_refine_scale = -1.
512
+ else:
513
+ input_refine_prompt = [p[0].replace("{image}", "") + " " + refiner_prompt for p in prompt]
514
+ input_refine_scale = refiner_scale
515
+ print(input_refine_prompt)
516
+
517
+ x_samples = self.refiner_module.refine(x_samples,
518
+ reverse_scale = input_refine_scale,
519
+ prompt= input_refine_prompt,
520
+ seed=seed,
521
+ use_dynamic_model=self.use_dynamic_model)
522
+
523
+ imgs = [
524
+ torch.clamp((x_i.float() + 1.0) / 2.0 + self.decoder_bias / 255,
525
+ min=0.0,
526
+ max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy()
527
+ for x_i in x_samples
528
+ ]
529
+ imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
530
+ return imgs
531
+
532
+ def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):
533
+ if self.use_text_pos_embeddings and not torch.sum(
534
+ self.text_position_embeddings.pos) > 0:
535
+ identifier_cont, _ = getattr(get_model(self.cond_stage_model),
536
+ 'encode')(self.text_indentifers,
537
+ return_mask=True)
538
+ self.text_position_embeddings.load_state_dict(
539
+ {'pos': identifier_cont[:, 0, :]})
540
+
541
+ cont_, cont_mask_ = [], []
542
+ for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask):
543
+ if isinstance(pp, list):
544
+ cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]])
545
+ cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]])
546
+ else:
547
+ raise NotImplementedError
548
+
549
+ return cont_, cont_mask_