dg845 commited on
Commit
d3ef367
1 Parent(s): b312363

Upload 2 files

Browse files

Add testing scripts for UniDiffuser-v0

unidiffuser/sample_v0.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_collections
2
+ import torch
3
+ import random
4
+ import utils
5
+ from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
6
+ from absl import logging
7
+ import einops
8
+ import libs.autoencoder
9
+ import libs.clip
10
+ from torchvision.utils import save_image, make_grid
11
+ import torchvision.transforms as standard_transforms
12
+ import numpy as np
13
+ import clip
14
+ from PIL import Image
15
+ import time
16
+
17
+
18
+ def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
19
+ _betas = (
20
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
21
+ )
22
+ return _betas.numpy()
23
+
24
+
25
+ def prepare_contexts(config, clip_text_model, clip_img_model, clip_img_model_preprocess, autoencoder):
26
+ resolution = config.z_shape[-1] * 8
27
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
28
+
29
+ contexts = torch.randn(config.n_samples, 77, config.clip_text_dim).to(device)
30
+ img_contexts = torch.randn(config.n_samples, 2 * config.z_shape[0], config.z_shape[1], config.z_shape[2])
31
+ clip_imgs = torch.randn(config.n_samples, 1, config.clip_img_dim)
32
+
33
+ if config.mode in ['t2i', 't2i2t']:
34
+ prompts = [ config.prompt ] * config.n_samples
35
+ contexts = clip_text_model.encode(prompts)
36
+
37
+ elif config.mode in ['i2t', 'i2t2i']:
38
+ from PIL import Image
39
+ img_contexts = []
40
+ clip_imgs = []
41
+
42
+ def get_img_feature(image):
43
+ image = np.array(image).astype(np.uint8)
44
+ image = utils.center_crop(resolution, resolution, image)
45
+ clip_img_feature = clip_img_model.encode_image(clip_img_model_preprocess(Image.fromarray(image)).unsqueeze(0).to(device))
46
+
47
+ image = (image / 127.5 - 1.0).astype(np.float32)
48
+ image = einops.rearrange(image, 'h w c -> 1 c h w')
49
+ image = torch.tensor(image, device=device)
50
+ moments = autoencoder.encode_moments(image)
51
+
52
+ return clip_img_feature, moments
53
+
54
+ image = Image.open(config.img).convert('RGB')
55
+ clip_img, img_context = get_img_feature(image)
56
+
57
+ img_contexts.append(img_context)
58
+ clip_imgs.append(clip_img)
59
+ img_contexts = img_contexts * config.n_samples
60
+ clip_imgs = clip_imgs * config.n_samples
61
+
62
+ img_contexts = torch.concat(img_contexts, dim=0)
63
+ clip_imgs = torch.stack(clip_imgs, dim=0)
64
+
65
+ return contexts, img_contexts, clip_imgs
66
+
67
+
68
+ def unpreprocess(v): # to B C H W and [0, 1]
69
+ v = 0.5 * (v + 1.)
70
+ v.clamp_(0., 1.)
71
+ return v
72
+
73
+
74
+ def set_seed(seed: int):
75
+ random.seed(seed)
76
+ np.random.seed(seed)
77
+ torch.manual_seed(seed)
78
+ torch.cuda.manual_seed_all(seed)
79
+
80
+
81
+ def evaluate(config):
82
+ if config.get('benchmark', False):
83
+ torch.backends.cudnn.benchmark = True
84
+ torch.backends.cudnn.deterministic = False
85
+
86
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
87
+ set_seed(config.seed)
88
+
89
+ config = ml_collections.FrozenConfigDict(config)
90
+ utils.set_logger(log_level='info')
91
+
92
+ _betas = stable_diffusion_beta_schedule()
93
+ N = len(_betas)
94
+
95
+ nnet = utils.get_nnet(**config.nnet)
96
+ logging.info(f'load nnet from {config.nnet_path}')
97
+ nnet.load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
98
+ nnet.to(device)
99
+ nnet.eval()
100
+
101
+ use_caption_decoder = config.text_dim < config.clip_text_dim or config.mode != 't2i'
102
+ if use_caption_decoder:
103
+ from libs.caption_decoder import CaptionDecoder
104
+ caption_decoder = CaptionDecoder(device=device, **config.caption_decoder)
105
+ else:
106
+ caption_decoder = None
107
+
108
+ clip_text_model = libs.clip.FrozenCLIPEmbedder(device=device)
109
+ clip_text_model.eval()
110
+ clip_text_model.to(device)
111
+
112
+ autoencoder = libs.autoencoder.get_model(**config.autoencoder)
113
+ autoencoder.to(device)
114
+
115
+ clip_img_model, clip_img_model_preprocess = clip.load("ViT-B/32", device=device, jit=False)
116
+
117
+ empty_context = clip_text_model.encode([''])[0]
118
+
119
+ def split(x):
120
+ C, H, W = config.z_shape
121
+ z_dim = C * H * W
122
+ z, clip_img = x.split([z_dim, config.clip_img_dim], dim=1)
123
+ z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
124
+ clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=config.clip_img_dim)
125
+ return z, clip_img
126
+
127
+
128
+ def combine(z, clip_img):
129
+ z = einops.rearrange(z, 'B C H W -> B (C H W)')
130
+ clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
131
+ return torch.concat([z, clip_img], dim=-1)
132
+
133
+
134
+ def t2i_nnet(x, timesteps, text): # text is the low dimension version of the text clip embedding
135
+ """
136
+ 1. calculate the conditional model output
137
+ 2. calculate unconditional model output
138
+ config.sample.t2i_cfg_mode == 'empty_token': using the original cfg with the empty string
139
+ config.sample.t2i_cfg_mode == 'true_uncond: using the unconditional model learned by our method
140
+ 3. return linear combination of conditional output and unconditional output
141
+ """
142
+ z, clip_img = split(x)
143
+
144
+ t_text = torch.zeros(timesteps.size(0), dtype=torch.int, device=device)
145
+
146
+ z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=t_text)
147
+ x_out = combine(z_out, clip_img_out)
148
+
149
+ if config.sample.scale == 0.:
150
+ return x_out
151
+
152
+ if config.sample.t2i_cfg_mode == 'empty_token':
153
+ _empty_context = einops.repeat(empty_context, 'L D -> B L D', B=x.size(0))
154
+ if use_caption_decoder:
155
+ _empty_context = caption_decoder.encode_prefix(_empty_context)
156
+ z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z, clip_img, text=_empty_context, t_img=timesteps, t_text=t_text)
157
+ x_out_uncond = combine(z_out_uncond, clip_img_out_uncond)
158
+ elif config.sample.t2i_cfg_mode == 'true_uncond':
159
+ text_N = torch.randn_like(text) # 3 other possible choices
160
+ z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z, clip_img, text=text_N, t_img=timesteps, t_text=torch.ones_like(timesteps) * N)
161
+ x_out_uncond = combine(z_out_uncond, clip_img_out_uncond)
162
+ else:
163
+ raise NotImplementedError
164
+
165
+ return x_out + config.sample.scale * (x_out - x_out_uncond)
166
+
167
+
168
+ def i_nnet(x, timesteps):
169
+ z, clip_img = split(x)
170
+ text = torch.randn(x.size(0), 77, config.text_dim, device=device)
171
+ t_text = torch.ones_like(timesteps) * N
172
+ z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=t_text)
173
+ x_out = combine(z_out, clip_img_out)
174
+ return x_out
175
+
176
+ def t_nnet(x, timesteps):
177
+ z = torch.randn(x.size(0), *config.z_shape, device=device)
178
+ clip_img = torch.randn(x.size(0), 1, config.clip_img_dim, device=device)
179
+ z_out, clip_img_out, text_out = nnet(z, clip_img, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps)
180
+ return text_out
181
+
182
+ def i2t_nnet(x, timesteps, z, clip_img):
183
+ """
184
+ 1. calculate the conditional model output
185
+ 2. calculate unconditional model output
186
+ 3. return linear combination of conditional output and unconditional output
187
+ """
188
+ t_img = torch.zeros(timesteps.size(0), dtype=torch.int, device=device)
189
+
190
+ z_out, clip_img_out, text_out = nnet(z, clip_img, text=x, t_img=t_img, t_text=timesteps)
191
+
192
+ if config.sample.scale == 0.:
193
+ return text_out
194
+
195
+ z_N = torch.randn_like(z) # 3 other possible choices
196
+ clip_img_N = torch.randn_like(clip_img)
197
+ z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z_N, clip_img_N, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps)
198
+
199
+ return text_out + config.sample.scale * (text_out - text_out_uncond)
200
+
201
+ def split_joint(x):
202
+ C, H, W = config.z_shape
203
+ z_dim = C * H * W
204
+ z, clip_img, text = x.split([z_dim, config.clip_img_dim, 77 * config.text_dim], dim=1)
205
+ z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
206
+ clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=config.clip_img_dim)
207
+ text = einops.rearrange(text, 'B (L D) -> B L D', L=77, D=config.text_dim)
208
+ return z, clip_img, text
209
+
210
+ def combine_joint(z, clip_img, text):
211
+ z = einops.rearrange(z, 'B C H W -> B (C H W)')
212
+ clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
213
+ text = einops.rearrange(text, 'B L D -> B (L D)')
214
+ return torch.concat([z, clip_img, text], dim=-1)
215
+
216
+ def joint_nnet(x, timesteps):
217
+ z, clip_img, text = split_joint(x)
218
+ z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=timesteps)
219
+ x_out = combine_joint(z_out, clip_img_out, text_out)
220
+
221
+ if config.sample.scale == 0.:
222
+ return x_out
223
+
224
+ z_noise = torch.randn(x.size(0), *config.z_shape, device=device)
225
+ clip_img_noise = torch.randn(x.size(0), 1, config.clip_img_dim, device=device)
226
+ text_noise = torch.randn(x.size(0), 77, config.text_dim, device=device)
227
+
228
+ _, _, text_out_uncond = nnet(z_noise, clip_img_noise, text=text, t_img=torch.ones_like(timesteps) * N, t_text=timesteps)
229
+ z_out_uncond, clip_img_out_uncond, _ = nnet(z, clip_img, text=text_noise, t_img=timesteps, t_text=torch.ones_like(timesteps) * N)
230
+
231
+ x_out_uncond = combine_joint(z_out_uncond, clip_img_out_uncond, text_out_uncond)
232
+
233
+ return x_out + config.sample.scale * (x_out - x_out_uncond)
234
+
235
+ @torch.cuda.amp.autocast()
236
+ def encode(_batch):
237
+ return autoencoder.encode(_batch)
238
+
239
+ @torch.cuda.amp.autocast()
240
+ def decode(_batch):
241
+ return autoencoder.decode(_batch)
242
+
243
+
244
+ logging.info(config.sample)
245
+ logging.info(f'N={N}')
246
+
247
+ contexts, img_contexts, clip_imgs = prepare_contexts(config, clip_text_model, clip_img_model, clip_img_model_preprocess, autoencoder)
248
+
249
+ contexts = contexts # the clip embedding of conditioned texts
250
+ contexts_low_dim = contexts if not use_caption_decoder else caption_decoder.encode_prefix(contexts) # the low dimensional version of the contexts, which is the input to the nnet
251
+
252
+ img_contexts = img_contexts # img_contexts is the autoencoder moment
253
+ z_img = autoencoder.sample(img_contexts)
254
+ clip_imgs = clip_imgs # the clip embedding of conditioned image
255
+
256
+ if config.mode in ['t2i', 't2i2t']:
257
+ _n_samples = contexts_low_dim.size(0)
258
+ elif config.mode in ['i2t', 'i2t2i']:
259
+ _n_samples = img_contexts.size(0)
260
+ else:
261
+ _n_samples = config.n_samples
262
+
263
+
264
+ def sample_fn(mode, **kwargs):
265
+
266
+ _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
267
+ _clip_img_init = torch.randn(_n_samples, 1, config.clip_img_dim, device=device)
268
+ _text_init = torch.randn(_n_samples, 77, config.text_dim, device=device)
269
+ if mode == 'joint':
270
+ _x_init = combine_joint(_z_init, _clip_img_init, _text_init)
271
+ elif mode in ['t2i', 'i']:
272
+ _x_init = combine(_z_init, _clip_img_init)
273
+ elif mode in ['i2t', 't']:
274
+ _x_init = _text_init
275
+ noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
276
+
277
+ def model_fn(x, t_continuous):
278
+ t = t_continuous * N
279
+ if mode == 'joint':
280
+ return joint_nnet(x, t)
281
+ elif mode == 't2i':
282
+ return t2i_nnet(x, t, **kwargs)
283
+ elif mode == 'i2t':
284
+ return i2t_nnet(x, t, **kwargs)
285
+ elif mode == 'i':
286
+ return i_nnet(x, t)
287
+ elif mode == 't':
288
+ return t_nnet(x, t)
289
+
290
+ dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
291
+ with torch.no_grad():
292
+ with torch.autocast(device_type=device):
293
+ start_time = time.time()
294
+ x = dpm_solver.sample(_x_init, steps=config.sample.sample_steps, eps=1. / N, T=1.)
295
+ end_time = time.time()
296
+ print(f'\ngenerate {_n_samples} samples with {config.sample.sample_steps} steps takes {end_time - start_time:.2f}s')
297
+
298
+ # os.makedirs(config.output_path, exist_ok=True)
299
+ if mode == 'joint':
300
+ _z, _clip_img, _text = split_joint(x)
301
+ return _z, _clip_img, _text
302
+ elif mode in ['t2i', 'i']:
303
+ _z, _clip_img = split(x)
304
+ return _z, _clip_img
305
+ elif mode in ['i2t', 't']:
306
+ return x
307
+
308
+ output_images = None
309
+ output_text = None
310
+
311
+ if config.mode in ['joint']:
312
+ _z, _clip_img, _text = sample_fn(config.mode)
313
+ samples = unpreprocess(decode(_z))
314
+ prompts = caption_decoder.generate_captions(_text)
315
+ output_images = samples
316
+ output_text = prompts
317
+
318
+ elif config.mode in ['t2i', 'i', 'i2t2i']:
319
+ if config.mode == 't2i':
320
+ _z, _clip_img = sample_fn(config.mode, text=contexts_low_dim) # conditioned on the text embedding
321
+ elif config.mode == 'i':
322
+ _z, _clip_img = sample_fn(config.mode)
323
+ elif config.mode == 'i2t2i':
324
+ _text = sample_fn('i2t', z=z_img, clip_img=clip_imgs) # conditioned on the image embedding
325
+ _z, _clip_img = sample_fn('t2i', text=_text)
326
+ samples = unpreprocess(decode(_z))
327
+ output_images = samples
328
+
329
+
330
+ elif config.mode in ['i2t', 't', 't2i2t']:
331
+ if config.mode == 'i2t':
332
+ _text = sample_fn(config.mode, z=z_img, clip_img=clip_imgs) # conditioned on the image embedding
333
+ elif config.mode == 't':
334
+ _text = sample_fn(config.mode)
335
+ elif config.mode == 't2i2t':
336
+ _z, _clip_img = sample_fn('t2i', text=contexts_low_dim)
337
+ _text = sample_fn('i2t', z=_z, clip_img=_clip_img)
338
+ samples = caption_decoder.generate_captions(_text)
339
+ logging.info(samples)
340
+ output_text = samples
341
+
342
+ print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB')
343
+ # print(f'\nresults are saved in {os.path.join(config.output_path, config.mode)} :)')
344
+
345
+ return output_images, output_text
346
+
347
+
348
+ def d(**kwargs):
349
+ """Helper of creating a config dict."""
350
+ return ml_collections.ConfigDict(initial_dictionary=kwargs)
351
+
352
+
353
+ def get_config():
354
+ config = ml_collections.ConfigDict()
355
+
356
+ config.seed = 1234
357
+ config.pred = 'noise_pred'
358
+ config.z_shape = (4, 64, 64)
359
+ config.clip_img_dim = 512
360
+ config.clip_text_dim = 768
361
+ config.text_dim = 64 # reduce dimension
362
+
363
+ config.autoencoder = d(
364
+ pretrained_path='models/autoencoder_kl.pth',
365
+ )
366
+
367
+ config.caption_decoder = d(
368
+ pretrained_path="models/caption_decoder.pth",
369
+ hidden_dim=config.get_ref('text_dim')
370
+ )
371
+
372
+ config.nnet = d(
373
+ name='uvit_multi_post_ln',
374
+ img_size=64,
375
+ in_chans=4,
376
+ patch_size=2,
377
+ embed_dim=1536,
378
+ depth=30,
379
+ num_heads=24,
380
+ mlp_ratio=4,
381
+ qkv_bias=False,
382
+ pos_drop_rate=0.,
383
+ drop_rate=0.,
384
+ attn_drop_rate=0.,
385
+ mlp_time_embed=False,
386
+ text_dim=config.get_ref('text_dim'),
387
+ num_text_tokens=77,
388
+ clip_img_dim=config.get_ref('clip_img_dim'),
389
+ use_checkpoint=True
390
+ )
391
+
392
+ config.sample = d(
393
+ sample_steps=50,
394
+ scale=7.,
395
+ t2i_cfg_mode='true_uncond'
396
+ )
397
+
398
+ return config
399
+
400
+
401
+ def sample(mode, prompt, image, sample_steps=50, scale=7.0, seed=None):
402
+ config = get_config()
403
+
404
+ config.nnet_path = "models/uvit_v0.pth"
405
+ config.n_samples = 1
406
+ config.nrow = 1
407
+
408
+ config.mode = mode
409
+ config.prompt = prompt
410
+ config.img = image
411
+
412
+ config.sample.sample_steps = sample_steps
413
+ config.sample.scale = scale
414
+ if seed is not None:
415
+ config.seed = seed
416
+
417
+ sample_images, sample_text = evaluate(config)
418
+ return sample_images, sample_text
unidiffuser/sample_v0_test.py ADDED
@@ -0,0 +1,786 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_collections
2
+ import torch
3
+ import random
4
+ import utils
5
+ from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
6
+ from absl import logging
7
+ import einops
8
+ import libs.autoencoder
9
+ import libs.clip
10
+ from torchvision.utils import save_image, make_grid
11
+ import torchvision.transforms as standard_transforms
12
+ import numpy as np
13
+ import clip
14
+ from PIL import Image
15
+ import time
16
+
17
+ from typing import Optional, Union, List, Tuple
18
+
19
+ from torch import nn
20
+ from transformers import (
21
+ CLIPFeatureExtractor,
22
+ CLIPProcessor,
23
+ CLIPTextModel,
24
+ CLIPTokenizer,
25
+ CLIPVisionModel,
26
+ GPT2LMHeadModel,
27
+ GPT2Tokenizer,
28
+ )
29
+
30
+ from libs.autoencoder import Encoder, Decoder
31
+ from libs.clip import AbstractEncoder
32
+ from libs.caption_decoder import generate2, generate_beam
33
+
34
+
35
+ # ----Define Testing Versions of Classes----
36
+
37
+
38
+ class TestAutoencoderKL(nn.Module):
39
+ def __init__(self, ddconfig, embed_dim, pretrained_path, scale_factor=0.18215):
40
+ super().__init__()
41
+ print(f'Create autoencoder with scale_factor={scale_factor}')
42
+ self.encoder = Encoder(**ddconfig)
43
+ self.decoder = Decoder(**ddconfig)
44
+ assert ddconfig["double_z"]
45
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
46
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
47
+ self.embed_dim = embed_dim
48
+ self.scale_factor = scale_factor
49
+ m, u = self.load_state_dict(torch.load(pretrained_path, map_location='cpu'))
50
+ assert len(m) == 0 and len(u) == 0
51
+ self.eval()
52
+ self.requires_grad_(False)
53
+
54
+ def encode_moments(self, x):
55
+ h = self.encoder(x)
56
+ moments = self.quant_conv(h)
57
+ return moments
58
+
59
+ def sample(self, moments, noise=None, generator=None, device="cuda"):
60
+ mean, logvar = torch.chunk(moments, 2, dim=1)
61
+ if noise is None:
62
+ # Generate on CPU.
63
+ noise = randn_tensor(mean.shape, generator=generator)
64
+ # Then move to desired device
65
+ noise = noise.to(device)
66
+ logvar = torch.clamp(logvar, -30.0, 20.0)
67
+ std = torch.exp(0.5 * logvar)
68
+ z = mean + std * noise
69
+ z = self.scale_factor * z
70
+ return z
71
+
72
+ def get_moment_params(self, moments):
73
+ mean, logvar = torch.chunk(moments, 2, dim=1)
74
+ return mean, logvar
75
+
76
+ def encode(self, x):
77
+ moments = self.encode_moments(x)
78
+ # z = self.sample(moments)
79
+ # Instead of sampling from the diagonal gaussian, return its mode (mean)
80
+ mean, logvar = self.get_moment_params(moments)
81
+ return mean
82
+
83
+ def decode(self, z):
84
+ z = (1. / self.scale_factor) * z
85
+ z = self.post_quant_conv(z)
86
+ dec = self.decoder(z)
87
+ return dec
88
+
89
+ def forward(self, inputs, fn):
90
+ if fn == 'encode_moments':
91
+ return self.encode_moments(inputs)
92
+ elif fn == 'encode':
93
+ return self.encode(inputs)
94
+ elif fn == 'decode':
95
+ return self.decode(inputs)
96
+ else:
97
+ raise NotImplementedError
98
+
99
+ def freeze(self):
100
+ self.eval()
101
+ self.requires_grad_(False)
102
+
103
+
104
+ # ----Define Testing Utility Functions----
105
+
106
+
107
+ def get_test_autoencoder(pretrained_path, scale_factor=0.18215):
108
+ ddconfig = dict(
109
+ double_z=True,
110
+ z_channels=4,
111
+ resolution=256,
112
+ in_channels=3,
113
+ out_ch=3,
114
+ ch=128,
115
+ ch_mult=[1, 2, 4, 4],
116
+ num_res_blocks=2,
117
+ attn_resolutions=[],
118
+ dropout=0.0
119
+ )
120
+ vae_scale_factor = 2 ** (len(ddconfig['ch_mult']) - 1)
121
+ return TestAutoencoderKL(ddconfig, 4, pretrained_path, scale_factor), vae_scale_factor
122
+
123
+
124
+ # Modified from diffusers.utils.randn_tensor
125
+ def randn_tensor(
126
+ shape: Union[Tuple, List],
127
+ generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
128
+ device: Optional["torch.device"] = None,
129
+ dtype: Optional["torch.dtype"] = None,
130
+ layout: Optional["torch.layout"] = None,
131
+ ):
132
+ """This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When
133
+ passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor
134
+ will always be created on CPU.
135
+ """
136
+ # device on which tensor is created defaults to device
137
+ rand_device = device
138
+ batch_size = shape[0]
139
+
140
+ layout = layout or torch.strided
141
+ device = device or torch.device("cpu")
142
+
143
+ if generator is not None:
144
+ gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
145
+ if gen_device_type != device.type and gen_device_type == "cpu":
146
+ rand_device = "cpu"
147
+ if device != "mps":
148
+ logging.info(
149
+ f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
150
+ f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
151
+ f" slighly speed up this function by passing a generator that was created on the {device} device."
152
+ )
153
+ elif gen_device_type != device.type and gen_device_type == "cuda":
154
+ raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
155
+
156
+ if isinstance(generator, list):
157
+ shape = (1,) + shape[1:]
158
+ latents = [
159
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
160
+ for i in range(batch_size)
161
+ ]
162
+ latents = torch.cat(latents, dim=0).to(device)
163
+ else:
164
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
165
+
166
+ return latents
167
+
168
+
169
+ # Sample from the autoencoder latent space directly instead of sampling the autoencoder moment.
170
+ def prepare_latents(
171
+ config,
172
+ clip_text_model,
173
+ clip_img_model,
174
+ clip_img_model_preprocess,
175
+ autoencoder,
176
+ vae_scale_factor,
177
+ device,
178
+ ):
179
+ resolution = config.z_shape[-1] * vae_scale_factor
180
+ # Fix device to CPU for reproducibility.
181
+ latent_device = "cpu"
182
+ latent_torch_device = torch.device(latent_device)
183
+ generator = torch.Generator(device=latent_torch_device).manual_seed(config.seed)
184
+
185
+ contexts = randn_tensor((config.n_samples, 77, config.clip_text_dim), generator=generator, device=latent_torch_device)
186
+ img_contexts = randn_tensor((config.n_samples, config.z_shape[0], config.z_shape[1], config.z_shape[2]), generator=generator, device=latent_torch_device)
187
+ clip_imgs = randn_tensor((config.n_samples, 1, config.clip_img_dim), generator=generator, device=latent_torch_device)
188
+
189
+ if config.mode in ['t2i', 't2i2t']:
190
+ prompts = [ config.prompt ] * config.n_samples
191
+ contexts = clip_text_model.encode(prompts)
192
+ elif config.mode in ['i2t', 'i2t2i']:
193
+ from PIL import Image
194
+ img_contexts = []
195
+ clip_imgs = []
196
+
197
+ def get_img_feature(image):
198
+ image = np.array(image).astype(np.uint8)
199
+ image = utils.center_crop(resolution, resolution, image)
200
+ clip_img_feature = clip_img_model.encode_image(clip_img_model_preprocess(Image.fromarray(image)).unsqueeze(0).to(device))
201
+
202
+ image = (image / 127.5 - 1.0).astype(np.float32)
203
+ image = einops.rearrange(image, 'h w c -> 1 c h w')
204
+ image = torch.tensor(image, device=device)
205
+ # Get moments then get the mode of the moment (diagonal Gaussian) distribution
206
+ moments = autoencoder.encode_moments(image)
207
+ # Sample from the moments
208
+ moments = autoencoder.sample(moments, generator=generator, device=device)
209
+
210
+ return clip_img_feature, moments
211
+
212
+ image = Image.open(config.img).convert('RGB')
213
+ clip_img, img_context = get_img_feature(image)
214
+
215
+ img_contexts.append(img_context)
216
+ clip_imgs.append(clip_img)
217
+ img_contexts = img_contexts * config.n_samples
218
+ clip_imgs = clip_imgs * config.n_samples
219
+
220
+ img_contexts = torch.concat(img_contexts, dim=0)
221
+ clip_imgs = torch.stack(clip_imgs, dim=0)
222
+
223
+ contexts = contexts.to(device)
224
+ img_contexts = img_contexts.to(device)
225
+ clip_imgs = clip_imgs.to(device)
226
+ return contexts, img_contexts, clip_imgs
227
+
228
+
229
+ # ----END----
230
+
231
+
232
+ def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
233
+ _betas = (
234
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
235
+ )
236
+ return _betas.numpy()
237
+
238
+
239
+ def prepare_contexts(config, clip_text_model, clip_img_model, clip_img_model_preprocess, autoencoder):
240
+ resolution = config.z_shape[-1] * 8
241
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
242
+
243
+ contexts = torch.randn(config.n_samples, 77, config.clip_text_dim).to(device)
244
+ img_contexts = torch.randn(config.n_samples, 2 * config.z_shape[0], config.z_shape[1], config.z_shape[2])
245
+ clip_imgs = torch.randn(config.n_samples, 1, config.clip_img_dim)
246
+
247
+ if config.mode in ['t2i', 't2i2t']:
248
+ prompts = [ config.prompt ] * config.n_samples
249
+ contexts = clip_text_model.encode(prompts)
250
+
251
+ elif config.mode in ['i2t', 'i2t2i']:
252
+ from PIL import Image
253
+ img_contexts = []
254
+ clip_imgs = []
255
+
256
+ def get_img_feature(image):
257
+ image = np.array(image).astype(np.uint8)
258
+ image = utils.center_crop(resolution, resolution, image)
259
+ # clip_img_feature = clip_img_model.encode_image(clip_img_model_preprocess(Image.fromarray(image)).unsqueeze(0).to(device))
260
+ clip_inputs = clip_img_model_preprocess(images=image, return_tensors="pt")
261
+ clip_img_feature = clip_img_model(**clip_inputs).image_embeds
262
+
263
+ image = (image / 127.5 - 1.0).astype(np.float32)
264
+ image = einops.rearrange(image, 'h w c -> 1 c h w')
265
+ image = torch.tensor(image, device=device)
266
+ moments = autoencoder.encode_moments(image)
267
+
268
+ return clip_img_feature, moments
269
+
270
+ image = Image.open(config.img).convert('RGB')
271
+ clip_img, img_context = get_img_feature(image)
272
+
273
+ img_contexts.append(img_context)
274
+ clip_imgs.append(clip_img)
275
+ img_contexts = img_contexts * config.n_samples
276
+ clip_imgs = clip_imgs * config.n_samples
277
+
278
+ img_contexts = torch.concat(img_contexts, dim=0)
279
+ clip_imgs = torch.stack(clip_imgs, dim=0)
280
+
281
+ return contexts, img_contexts, clip_imgs
282
+
283
+
284
+ def unpreprocess(v): # to B C H W and [0, 1]
285
+ v = 0.5 * (v + 1.)
286
+ v.clamp_(0., 1.)
287
+ return v
288
+
289
+
290
+ def set_seed(seed: int):
291
+ random.seed(seed)
292
+ np.random.seed(seed)
293
+ torch.manual_seed(seed)
294
+ torch.cuda.manual_seed_all(seed)
295
+
296
+
297
+ def evaluate(config):
298
+ if config.get('benchmark', False):
299
+ torch.backends.cudnn.benchmark = True
300
+ torch.backends.cudnn.deterministic = False
301
+
302
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
303
+ device = config.sample.device
304
+ torch_device = torch.device(device)
305
+ set_seed(config.seed)
306
+
307
+ # Instantiate generator
308
+ generator = torch.Generator(device=torch_device).manual_seed(config.seed)
309
+
310
+ config = ml_collections.FrozenConfigDict(config)
311
+ # utils.set_logger(log_level='info')
312
+ # utils.set_logger(log_level='debug', fname="./logs/test.txt")
313
+ utils.set_logger(log_level=config.sample.log_level)
314
+
315
+ _betas = stable_diffusion_beta_schedule()
316
+ N = len(_betas)
317
+
318
+ nnet = utils.get_nnet(**config.nnet)
319
+ logging.info(f'load nnet from {config.nnet_path}')
320
+ nnet.load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
321
+ nnet.to(device)
322
+ nnet.eval()
323
+
324
+ use_caption_decoder = config.text_dim < config.clip_text_dim or config.mode != 't2i'
325
+ if use_caption_decoder:
326
+ from libs.caption_decoder import CaptionDecoder
327
+ caption_decoder = CaptionDecoder(device=device, **config.caption_decoder)
328
+ else:
329
+ caption_decoder = None
330
+
331
+ clip_text_model = libs.clip.FrozenCLIPEmbedder(device=device)
332
+ clip_text_model.eval()
333
+ clip_text_model.to(device)
334
+
335
+ # autoencoder = libs.autoencoder.get_model(**config.autoencoder)
336
+ # Load test autoencoder
337
+ autoencoder, vae_scale_factor = get_test_autoencoder(**config.autoencoder)
338
+ autoencoder.to(device)
339
+ # print(f"VAE scale factor: {vae_scale_factor}")
340
+
341
+ clip_img_model, clip_img_model_preprocess = clip.load("ViT-B/32", device=device, jit=False)
342
+
343
+ empty_context = clip_text_model.encode([''])[0]
344
+
345
+ def split(x):
346
+ C, H, W = config.z_shape
347
+ z_dim = C * H * W
348
+ z, clip_img = x.split([z_dim, config.clip_img_dim], dim=1)
349
+ z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
350
+ clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=config.clip_img_dim)
351
+ return z, clip_img
352
+
353
+
354
+ def combine(z, clip_img):
355
+ z = einops.rearrange(z, 'B C H W -> B (C H W)')
356
+ clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
357
+ return torch.concat([z, clip_img], dim=-1)
358
+
359
+
360
+ def t2i_nnet(x, timesteps, text): # text is the low dimension version of the text clip embedding
361
+ """
362
+ 1. calculate the conditional model output
363
+ 2. calculate unconditional model output
364
+ config.sample.t2i_cfg_mode == 'empty_token': using the original cfg with the empty string
365
+ config.sample.t2i_cfg_mode == 'true_uncond: using the unconditional model learned by our method
366
+ 3. return linear combination of conditional output and unconditional output
367
+ """
368
+ z, clip_img = split(x)
369
+
370
+ t_text = torch.zeros(timesteps.size(0), dtype=torch.int, device=device)
371
+
372
+ z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=t_text)
373
+ logging.debug(f"Conditional VAE out: {z_out}")
374
+ logging.debug(f"Conditional VAE out shape: {z_out.shape}")
375
+ logging.debug(f"Conditional CLIP out: {clip_img_out}")
376
+ logging.debug(f"Conditional CLIP out shape: {clip_img_out.shape}")
377
+ x_out = combine(z_out, clip_img_out)
378
+
379
+ if config.sample.scale == 0.:
380
+ return x_out
381
+
382
+ if config.sample.t2i_cfg_mode == 'empty_token':
383
+ _empty_context = einops.repeat(empty_context, 'L D -> B L D', B=x.size(0))
384
+ if use_caption_decoder:
385
+ _empty_context = caption_decoder.encode_prefix(_empty_context)
386
+ z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z, clip_img, text=_empty_context, t_img=timesteps, t_text=t_text)
387
+ x_out_uncond = combine(z_out_uncond, clip_img_out_uncond)
388
+ elif config.sample.t2i_cfg_mode == 'true_uncond':
389
+ # text_N = torch.randn_like(text) # 3 other possible choices
390
+ text_N = randn_tensor(text.shape, generator=generator, device=torch_device)
391
+ logging.debug(f"Unconditional random text: {text_N}")
392
+ logging.debug(f"Unconditional random text shape: {text_N.shape}")
393
+ z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z, clip_img, text=text_N, t_img=timesteps, t_text=torch.ones_like(timesteps) * N)
394
+ logging.debug(f"Unconditional VAE out: {z_out_uncond}")
395
+ logging.debug(f"Unconditional VAE out shape: {z_out_uncond.shape}")
396
+ logging.debug(f"Unconditional CLIP out: {clip_img_out_uncond}")
397
+ logging.debug(f"Unconditional CLIP out shape: {clip_img_out_uncond.shape}")
398
+ x_out_uncond = combine(z_out_uncond, clip_img_out_uncond)
399
+ else:
400
+ raise NotImplementedError
401
+
402
+ return x_out + config.sample.scale * (x_out - x_out_uncond)
403
+
404
+
405
+ def i_nnet(x, timesteps):
406
+ z, clip_img = split(x)
407
+ # text = torch.randn(x.size(0), 77, config.text_dim, device=device)
408
+ text = randn_tensor((x.size(0), 77, config.text_dim), generator=generator, device=torch_device)
409
+ t_text = torch.ones_like(timesteps) * N
410
+ z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=t_text)
411
+ x_out = combine(z_out, clip_img_out)
412
+ return x_out
413
+
414
+ def t_nnet(x, timesteps):
415
+ # z = torch.randn(x.size(0), *config.z_shape, device=device)
416
+ # clip_img = torch.randn(x.size(0), 1, config.clip_img_dim, device=device)
417
+ z = randn_tensor((x.size(0), *config.z_shape), generator=generator, device=torch_device)
418
+ clip_img = randn_tensor((x.size(0), 1, config.clip_img_dim), generator=generator, device=torch_device)
419
+ z_out, clip_img_out, text_out = nnet(z, clip_img, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps)
420
+ return text_out
421
+
422
+ def i2t_nnet(x, timesteps, z, clip_img):
423
+ """
424
+ 1. calculate the conditional model output
425
+ 2. calculate unconditional model output
426
+ 3. return linear combination of conditional output and unconditional output
427
+ """
428
+ t_img = torch.zeros(timesteps.size(0), dtype=torch.int, device=device)
429
+
430
+ z_out, clip_img_out, text_out = nnet(z, clip_img, text=x, t_img=t_img, t_text=timesteps)
431
+
432
+ if config.sample.scale == 0.:
433
+ return text_out
434
+
435
+ # z_N = torch.randn_like(z) # 3 other possible choices
436
+ # clip_img_N = torch.randn_like(clip_img)
437
+ z_N = randn_tensor(z.shape, generator=generator, device=torch_device)
438
+ clip_img_N = randn_tensor(clip_img.shape, generator=generator, device=torch_device)
439
+ z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z_N, clip_img_N, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps)
440
+
441
+ return text_out + config.sample.scale * (text_out - text_out_uncond)
442
+
443
+ def split_joint(x):
444
+ C, H, W = config.z_shape
445
+ z_dim = C * H * W
446
+ z, clip_img, text = x.split([z_dim, config.clip_img_dim, 77 * config.text_dim], dim=1)
447
+ z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
448
+ clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=config.clip_img_dim)
449
+ text = einops.rearrange(text, 'B (L D) -> B L D', L=77, D=config.text_dim)
450
+ return z, clip_img, text
451
+
452
+ def combine_joint(z, clip_img, text):
453
+ z = einops.rearrange(z, 'B C H W -> B (C H W)')
454
+ clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
455
+ text = einops.rearrange(text, 'B L D -> B (L D)')
456
+ return torch.concat([z, clip_img, text], dim=-1)
457
+
458
+ def joint_nnet(x, timesteps):
459
+ logging.debug(f"Timestep: {timesteps}")
460
+ z, clip_img, text = split_joint(x)
461
+ z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=timesteps)
462
+ logging.debug(f"Conditional VAE out: {z_out}")
463
+ logging.debug(f"Conditional VAE out shape: {z_out.shape}")
464
+ logging.debug(f"Conditional CLIP out: {clip_img_out}")
465
+ logging.debug(f"Conditional CLIP out shape: {clip_img_out.shape}")
466
+ logging.debug(f"Conditional text out: {text_out}")
467
+ logging.debug(f"Conditional text out shape: {text_out.shape}")
468
+ x_out = combine_joint(z_out, clip_img_out, text_out)
469
+
470
+ if config.sample.scale == 0.:
471
+ return x_out
472
+
473
+ # z_noise = torch.randn(x.size(0), *config.z_shape, device=device)
474
+ # clip_img_noise = torch.randn(x.size(0), 1, config.clip_img_dim, device=device)
475
+ # text_noise = torch.randn(x.size(0), 77, config.text_dim, device=device)
476
+ z_noise = randn_tensor((x.size(0), *config.z_shape), generator=generator, device=torch_device, dtype=z_out.dtype)
477
+ clip_img_noise = randn_tensor((x.size(0), 1, config.clip_img_dim), generator=generator, device=torch_device, dtype=clip_img_out.dtype)
478
+ text_noise = randn_tensor((x.size(0), 77, config.text_dim), generator=generator, device=torch_device, dtype=text_out.dtype)
479
+
480
+ _, _, text_out_uncond = nnet(z_noise, clip_img_noise, text=text, t_img=torch.ones_like(timesteps) * N, t_text=timesteps)
481
+ logging.debug(f"Unconditional text out: {text_out_uncond}")
482
+ logging.debug(f"Unconditional text out shape: {text_out_uncond.shape}")
483
+ z_out_uncond, clip_img_out_uncond, _ = nnet(z, clip_img, text=text_noise, t_img=timesteps, t_text=torch.ones_like(timesteps) * N)
484
+ logging.debug(f"Unconditional VAE out: {z_out_uncond}")
485
+ logging.debug(f"Unconditional VAE out shape: {z_out_uncond.shape}")
486
+ logging.debug(f"Unconditional CLIP out: {clip_img_out_uncond}")
487
+ logging.debug(f"Unconditional CLIP out shape: {clip_img_out_uncond.shape}")
488
+
489
+ x_out_uncond = combine_joint(z_out_uncond, clip_img_out_uncond, text_out_uncond)
490
+
491
+ return x_out + config.sample.scale * (x_out - x_out_uncond)
492
+
493
+ @torch.cuda.amp.autocast()
494
+ def encode(_batch):
495
+ return autoencoder.encode(_batch)
496
+
497
+ @torch.cuda.amp.autocast()
498
+ def decode(_batch):
499
+ return autoencoder.decode(_batch)
500
+
501
+
502
+ logging.info(config.sample)
503
+ logging.info(f'N={N}')
504
+
505
+ # contexts, img_contexts, clip_imgs = prepare_contexts(config, clip_text_model, clip_img_model, clip_img_model_preprocess, autoencoder)
506
+ contexts, img_contexts, clip_imgs = prepare_latents(
507
+ config,
508
+ clip_text_model,
509
+ clip_img_model,
510
+ clip_img_model_preprocess,
511
+ autoencoder,
512
+ vae_scale_factor,
513
+ device,
514
+ )
515
+ logging.debug(f"Text latents: {contexts}")
516
+ logging.debug(f"Text latents shape: {contexts.shape}")
517
+
518
+ contexts = contexts # the clip embedding of conditioned texts
519
+ contexts_low_dim = contexts if not use_caption_decoder else caption_decoder.encode_prefix(contexts) # the low dimensional version of the contexts, which is the input to the nnet
520
+
521
+ logging.debug(f"Low dim text latents: {contexts_low_dim}")
522
+ logging.debug(f"Low dim text latents shape: {contexts_low_dim.shape}")
523
+
524
+ img_contexts = img_contexts # img_contexts is the autoencoder moment
525
+ # z_img = autoencoder.sample(img_contexts, generator=cpu_generator, device=device)
526
+ z_img = img_contexts # sample autoencoder latents directly, no need to call sample()
527
+ clip_imgs = clip_imgs # the clip embedding of conditioned image
528
+
529
+ logging.debug(f"VAE latents: {z_img}")
530
+ logging.debug(f"VAE latents shape: {z_img.shape}")
531
+ logging.debug(f"CLIP latents: {clip_imgs}")
532
+ logging.debug(f"CLIP latents shape: {clip_imgs.shape}")
533
+
534
+ if config.mode in ['t2i', 't2i2t']:
535
+ _n_samples = contexts_low_dim.size(0)
536
+ elif config.mode in ['i2t', 'i2t2i']:
537
+ _n_samples = img_contexts.size(0)
538
+ else:
539
+ _n_samples = config.n_samples
540
+
541
+
542
+ def sample_fn(mode, **kwargs):
543
+
544
+ # _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
545
+ # _clip_img_init = torch.randn(_n_samples, 1, config.clip_img_dim, device=device)
546
+ # _text_init = torch.randn(_n_samples, 77, config.text_dim, device=device)
547
+ _z_init = randn_tensor((_n_samples, *config.z_shape), generator=generator, device=torch_device)
548
+ _clip_img_init = randn_tensor((_n_samples, 1, config.clip_img_dim), generator=generator, device=torch_device)
549
+ _text_init = randn_tensor((_n_samples, 77, config.text_dim), generator=generator, device=torch_device)
550
+ if mode == 'joint':
551
+ _x_init = combine_joint(_z_init, _clip_img_init, _text_init)
552
+ elif mode in ['t2i', 'i']:
553
+ _x_init = combine(_z_init, _clip_img_init)
554
+ elif mode in ['i2t', 't']:
555
+ _x_init = _text_init
556
+
557
+ logging.debug(f"Latents: {_x_init}")
558
+ logging.debug(f"Latents shape: {_x_init.shape}")
559
+
560
+ noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
561
+
562
+ def model_fn(x, t_continuous):
563
+ t = t_continuous * N
564
+ if mode == 'joint':
565
+ return joint_nnet(x, t)
566
+ elif mode == 't2i':
567
+ return t2i_nnet(x, t, **kwargs)
568
+ elif mode == 'i2t':
569
+ return i2t_nnet(x, t, **kwargs)
570
+ elif mode == 'i':
571
+ return i_nnet(x, t)
572
+ elif mode == 't':
573
+ return t_nnet(x, t)
574
+
575
+ dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
576
+ with torch.no_grad():
577
+ with torch.autocast(device_type=device):
578
+ start_time = time.time()
579
+ x = dpm_solver.sample(_x_init, steps=config.sample.sample_steps, eps=1. / N, T=1.)
580
+ end_time = time.time()
581
+ print(f'\ngenerate {_n_samples} samples with {config.sample.sample_steps} steps takes {end_time - start_time:.2f}s')
582
+
583
+ # os.makedirs(config.output_path, exist_ok=True)
584
+ if mode == 'joint':
585
+ _z, _clip_img, _text = split_joint(x)
586
+ return _z, _clip_img, _text
587
+ elif mode in ['t2i', 'i']:
588
+ _z, _clip_img = split(x)
589
+ return _z, _clip_img
590
+ elif mode in ['i2t', 't']:
591
+ return x
592
+
593
+ def test_sample_fn(mode, **kwargs):
594
+ if mode == 'joint':
595
+ _x_init = combine_joint(z_img, clip_imgs, contexts_low_dim)
596
+ elif mode in ['t2i', 'i']:
597
+ _x_init = combine(z_img, clip_imgs)
598
+ elif mode in ['i2t', 't']:
599
+ _x_init = contexts_low_dim
600
+
601
+ logging.debug(f"Latents: {_x_init}")
602
+ logging.debug(f"Latents shape: {_x_init.shape}")
603
+
604
+ noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
605
+
606
+ def model_fn(x, t_continuous):
607
+ t = t_continuous * N
608
+ if mode == 'joint':
609
+ noise_pred = joint_nnet(x, t)
610
+ logging.debug(f"Noise pred for time {t}: {noise_pred}")
611
+ logging.debug(f"Noise pred for time {t} shape: {noise_pred.shape}")
612
+ return noise_pred
613
+ # return joint_nnet(x, t)
614
+ elif mode == 't2i':
615
+ noise_pred = t2i_nnet(x, t, **kwargs)
616
+ logging.debug(f"Noise pred for time {t}: {noise_pred}")
617
+ logging.debug(f"Noise pred for time {t} shape: {noise_pred.shape}")
618
+ return noise_pred
619
+ # return t2i_nnet(x, t, **kwargs)
620
+ elif mode == 'i2t':
621
+ return i2t_nnet(x, t, **kwargs)
622
+ elif mode == 'i':
623
+ return i_nnet(x, t)
624
+ elif mode == 't':
625
+ return t_nnet(x, t)
626
+
627
+ dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
628
+ with torch.no_grad():
629
+ # Remove autocast to run in full precision for testing on CPU
630
+ start_time = time.time()
631
+ x = dpm_solver.sample(_x_init, steps=config.sample.sample_steps, eps=1. / N, T=1.)
632
+ end_time = time.time()
633
+ print(f'\ngenerate {_n_samples} samples with {config.sample.sample_steps} steps takes {end_time - start_time:.2f}s')
634
+
635
+ logging.debug(f"Full UNet sample: {x}")
636
+ logging.debug(f"Full UNet sample shape: {x.shape}")
637
+
638
+ # os.makedirs(config.output_path, exist_ok=True)
639
+ if mode == 'joint':
640
+ _z, _clip_img, _text = split_joint(x)
641
+ return _z, _clip_img, _text
642
+ elif mode in ['t2i', 'i']:
643
+ _z, _clip_img = split(x)
644
+ return _z, _clip_img
645
+ elif mode in ['i2t', 't']:
646
+ return x
647
+
648
+ output_images = None
649
+ output_text = None
650
+
651
+ if config.mode in ['joint']:
652
+ # _z, _clip_img, _text = sample_fn(config.mode)
653
+ _z, _clip_img, _text = test_sample_fn(config.mode)
654
+
655
+ logging.debug(f"Text output: {_text}")
656
+ logging.debug(f"Text output shape: {_text.shape}")
657
+ logging.debug(f"VAE output: {_z}")
658
+ logging.debug(f"VAE output shape: {_z.shape}")
659
+ logging.debug(f"CLIP output: {_clip_img}")
660
+ logging.debug(f"CLIP output shape: {_clip_img.shape}")
661
+
662
+ samples = unpreprocess(decode(_z))
663
+
664
+ logging.debug(f"VAE decoded sample: {samples}")
665
+ logging.debug(f"VAE decoded sample shape: {samples.shape}")
666
+
667
+ prompts = caption_decoder.generate_captions(_text)
668
+
669
+ logging.debug(f"Generated text: {prompts}")
670
+
671
+ output_images = samples
672
+ output_text = prompts
673
+
674
+ elif config.mode in ['t2i', 'i', 'i2t2i']:
675
+ if config.mode == 't2i':
676
+ # _z, _clip_img = sample_fn(config.mode, text=contexts_low_dim) # conditioned on the text embedding
677
+ _z, _clip_img = test_sample_fn(config.mode, text=contexts_low_dim)
678
+
679
+ logging.debug(f"VAE output: {_z}")
680
+ logging.debug(f"VAE output shape: {_z.shape}")
681
+ logging.debug(f"CLIP output: {_clip_img}")
682
+ logging.debug(f"CLIP output shape: {_clip_img.shape}")
683
+ elif config.mode == 'i':
684
+ # _z, _clip_img = sample_fn(config.mode)
685
+ _z, _clip_img = test_sample_fn(config.mode)
686
+ elif config.mode == 'i2t2i':
687
+ _text = sample_fn('i2t', z=z_img, clip_img=clip_imgs) # conditioned on the image embedding
688
+ _z, _clip_img = sample_fn('t2i', text=_text)
689
+ samples = unpreprocess(decode(_z))
690
+ output_images = samples
691
+
692
+
693
+ elif config.mode in ['i2t', 't', 't2i2t']:
694
+ if config.mode == 'i2t':
695
+ # _text = sample_fn(config.mode, z=z_img, clip_img=clip_imgs) # conditioned on the image embedding
696
+ _text = test_sample_fn(config.mode, z=z_img, clip_img=clip_imgs) # conditioned on the image embedding
697
+ elif config.mode == 't':
698
+ # _text = sample_fn(config.mode)
699
+ _text = test_sample_fn(config.mode)
700
+ elif config.mode == 't2i2t':
701
+ _z, _clip_img = sample_fn('t2i', text=contexts_low_dim)
702
+ _text = sample_fn('i2t', z=_z, clip_img=_clip_img)
703
+ samples = caption_decoder.generate_captions(_text)
704
+ logging.info(samples)
705
+ output_text = samples
706
+
707
+ print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB')
708
+ # print(f'\nresults are saved in {os.path.join(config.output_path, config.mode)} :)')
709
+
710
+ return output_images, output_text
711
+
712
+
713
+ def d(**kwargs):
714
+ """Helper of creating a config dict."""
715
+ return ml_collections.ConfigDict(initial_dictionary=kwargs)
716
+
717
+
718
+ def get_config():
719
+ config = ml_collections.ConfigDict()
720
+
721
+ config.seed = 0
722
+ config.pred = 'noise_pred'
723
+ config.z_shape = (4, 64, 64)
724
+ config.clip_img_dim = 512
725
+ config.clip_text_dim = 768
726
+ config.text_dim = 64 # reduce dimension
727
+
728
+ config.autoencoder = d(
729
+ pretrained_path='models/autoencoder_kl.pth',
730
+ )
731
+
732
+ config.caption_decoder = d(
733
+ pretrained_path="models/caption_decoder.pth",
734
+ hidden_dim=config.get_ref('text_dim')
735
+ )
736
+
737
+ config.nnet = d(
738
+ name='uvit_multi_post_ln',
739
+ img_size=64,
740
+ in_chans=4,
741
+ patch_size=2,
742
+ embed_dim=1536,
743
+ depth=30,
744
+ num_heads=24,
745
+ mlp_ratio=4,
746
+ qkv_bias=False,
747
+ pos_drop_rate=0.,
748
+ drop_rate=0.,
749
+ attn_drop_rate=0.,
750
+ mlp_time_embed=False,
751
+ text_dim=config.get_ref('text_dim'),
752
+ num_text_tokens=77,
753
+ clip_img_dim=config.get_ref('clip_img_dim'),
754
+ use_checkpoint=True
755
+ )
756
+
757
+ config.sample = d(
758
+ sample_steps=3,
759
+ scale=7.,
760
+ t2i_cfg_mode='true_uncond',
761
+ device="cuda",
762
+ log_level="debug",
763
+ log_dir=None,
764
+ )
765
+
766
+ return config
767
+
768
+
769
+ def sample(mode, prompt, image, sample_steps=50, scale=7.0, seed=None):
770
+ config = get_config()
771
+
772
+ config.nnet_path = "models/uvit_v0.pth"
773
+ config.n_samples = 1
774
+ config.nrow = 1
775
+
776
+ config.mode = mode
777
+ config.prompt = prompt
778
+ config.img = image
779
+
780
+ config.sample.sample_steps = sample_steps
781
+ config.sample.scale = scale
782
+ if seed is not None:
783
+ config.seed = seed
784
+
785
+ sample_images, sample_text = evaluate(config)
786
+ return sample_images, sample_text