AlexKM commited on
Commit
6be96ac
1 Parent(s): d522f59

Upload predict.py

Browse files
Files changed (1) hide show
  1. predict.py +730 -0
predict.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ clone the following repo if haven't
3
+ - git clone 'https://github.com/openai/CLIP'
4
+ - git clone 'https://github.com/CompVis/taming-transformers'
5
+ """
6
+
7
+ import sys
8
+ import tempfile
9
+ import warnings
10
+ import numpy as np
11
+ from pathlib import Path
12
+ import argparse
13
+ import torch
14
+ from torch import nn, optim
15
+ from torch.nn import functional as F
16
+ from torchvision import transforms
17
+ from torchvision.transforms import functional as TF
18
+ from torch.cuda import get_device_properties
19
+ from omegaconf import OmegaConf
20
+ from torch_optimizer import DiffGrad, AdamP, RAdam
21
+ import kornia.augmentation as K
22
+ import imageio
23
+ from tqdm import tqdm
24
+ import cog
25
+ from CLIP import clip
26
+ from PIL import ImageFile, Image, PngImagePlugin, ImageChops
27
+
28
+ sys.path.append("taming-transformers")
29
+ from taming.models import cond_transformer, vqgan
30
+
31
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
32
+ torch.backends.cudnn.benchmark = False
33
+ warnings.filterwarnings("ignore")
34
+
35
+
36
+ class ReplaceGrad(torch.autograd.Function):
37
+ @staticmethod
38
+ def forward(ctx, x_forward, x_backward):
39
+ ctx.shape = x_backward.shape
40
+ return x_forward
41
+
42
+ @staticmethod
43
+ def backward(ctx, grad_in):
44
+ return None, grad_in.sum_to_size(ctx.shape)
45
+
46
+
47
+ class ClampWithGrad(torch.autograd.Function):
48
+ @staticmethod
49
+ def forward(ctx, input, min, max):
50
+ ctx.min = min
51
+ ctx.max = max
52
+ ctx.save_for_backward(input)
53
+ return input.clamp(min, max)
54
+
55
+ @staticmethod
56
+ def backward(ctx, grad_in):
57
+ (input,) = ctx.saved_tensors
58
+ return (
59
+ grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0),
60
+ None,
61
+ None,
62
+ )
63
+
64
+
65
+ replace_grad = ReplaceGrad.apply
66
+ clamp_with_grad = ClampWithGrad.apply
67
+
68
+
69
+ class Predictor(cog.Predictor):
70
+ def setup(self):
71
+ self.device = torch.device("cuda:0")
72
+ # Check for GPU and reduce the default image size if low VRAM
73
+ default_image_size = 512 # >8GB VRAM
74
+ if not torch.cuda.is_available():
75
+ default_image_size = 256 # no GPU found
76
+ elif (
77
+ get_device_properties(0).total_memory <= 2 ** 33
78
+ ): # 2 ** 33 = 8,589,934,592 bytes = 8 GB
79
+ default_image_size = 318 # <8GB VRAM
80
+
81
+ self.args = get_args()
82
+ self.args.size = [default_image_size, default_image_size]
83
+ self.model = load_vqgan_model(
84
+ self.args.vqgan_config, self.args.vqgan_checkpoint
85
+ ).to(self.device)
86
+ print("Model loaded!")
87
+ jit = True if float(torch.__version__[:3]) < 1.8 else False
88
+ self.perceptor = (
89
+ clip.load(self.args.clip_model, jit=jit)[0]
90
+ .eval()
91
+ .requires_grad_(False)
92
+ .to(self.device)
93
+ )
94
+ cut_size = self.perceptor.visual.input_resolution
95
+ # choose latest Cutout class as default
96
+ self.make_cutouts = MakeCutouts(
97
+ cut_size, self.args.cutn, self.args, cut_pow=self.args.cut_pow
98
+ )
99
+
100
+ self.z_min = self.model.quantize.embedding.weight.min(dim=0).values[
101
+ None, :, None, None
102
+ ]
103
+ self.z_max = self.model.quantize.embedding.weight.max(dim=0).values[
104
+ None, :, None, None
105
+ ]
106
+
107
+ print("Using device:", self.device)
108
+ print("Optimising using:", self.args.optimiser)
109
+
110
+ @cog.input(
111
+ "image",
112
+ type=Path,
113
+ default=None,
114
+ help="Initial Image, optional. When the image is provided, the prompts will be used to create some 'style transfer' effect",
115
+ )
116
+ @cog.input(
117
+ "prompts",
118
+ type=str,
119
+ default="A cute, smiling, Nerdy Rodent",
120
+ help="Prompts for generating images. Supports multiple prompts separated by pipe | ",
121
+ )
122
+ @cog.input(
123
+ "iterations",
124
+ type=int,
125
+ default=300,
126
+ help="total iterations for generating images. Set to lower iterations when initial image is uploaded",
127
+ )
128
+ @cog.input(
129
+ "display_frequency",
130
+ type=int,
131
+ default=20,
132
+ help="display frequency for intermediate generated images",
133
+ )
134
+ def predict(self, image, prompts, iterations, display_frequency):
135
+ # gumbel is False
136
+ e_dim = self.model.quantize.e_dim
137
+ n_toks = self.model.quantize.n_e
138
+ f = 2 ** (self.model.decoder.num_resolutions - 1)
139
+ toksX, toksY = self.args.size[0] // f, self.args.size[1] // f
140
+ sideX, sideY = toksX * f, toksY * f
141
+
142
+ if image is not None:
143
+ self.args.init_image = str(image)
144
+ self.args.step_size = 0.25
145
+ if "http" in self.args.init_image:
146
+ img = Image.open(urlopen(self.args.init_image))
147
+ else:
148
+ img = Image.open(self.args.init_image)
149
+ pil_image = img.convert("RGB")
150
+ pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
151
+ pil_tensor = TF.to_tensor(pil_image)
152
+ z, *_ = self.model.encode(pil_tensor.to(self.device).unsqueeze(0) * 2 - 1)
153
+ else:
154
+ one_hot = F.one_hot(
155
+ torch.randint(n_toks, [toksY * toksX], device=self.device), n_toks
156
+ ).float()
157
+ # gumbel is False
158
+ z = one_hot @ self.model.quantize.embedding.weight
159
+ z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
160
+
161
+ z_orig = z.clone()
162
+ z.requires_grad_(True)
163
+
164
+ self.opt = get_opt(self.args.optimiser, self.args.step_size, z)
165
+
166
+ self.args.display_freq = display_frequency
167
+ self.args.max_iterations = iterations
168
+
169
+ story_phrases = [phrase.strip() for phrase in prompts.split("^")]
170
+
171
+ # Make a list of all phrases
172
+ all_phrases = []
173
+ for phrase in story_phrases:
174
+ all_phrases.append(phrase.split("|"))
175
+
176
+ # First phrase
177
+ prompts = all_phrases[0]
178
+
179
+ pMs = []
180
+ for prompt in prompts:
181
+ txt, weight, stop = split_prompt(prompt)
182
+ embed = self.perceptor.encode_text(
183
+ clip.tokenize(txt).to(self.device)
184
+ ).float()
185
+ pMs.append(Prompt(embed, weight, stop).to(self.device))
186
+ # args.image_prompts is None for now
187
+ # args.noise_prompt_seeds, args.noise_prompt_weights None for now
188
+ print(f"Using text prompts: {prompts}")
189
+ if self.args.init_image:
190
+ print(f"Using initial image: {self.args.init_image}")
191
+
192
+ if self.args.seed is None:
193
+ seed = torch.seed()
194
+ else:
195
+ seed = self.args.seed
196
+ torch.manual_seed(seed)
197
+ print(f"Using seed: {seed}")
198
+ i = 0 # Iteration counter
199
+ # j = 0 # Zoom video frame counter
200
+ # p = 1 # Phrase counter
201
+ # smoother = 0 # Smoother counter
202
+ # this_video_frame = 0 # for video styling
203
+
204
+ out_path = Path(tempfile.mkdtemp()) / "out.png"
205
+ # Do it
206
+ for i in range(1, self.args.max_iterations + 1):
207
+ self.opt.zero_grad(set_to_none=True)
208
+ lossAll = ascend_txt(
209
+ i, z, self.perceptor, self.args, self.model, self.make_cutouts, pMs
210
+ )
211
+
212
+ if i % self.args.display_freq == 0 and not i == self.args.max_iterations:
213
+ yield checkin(i, lossAll, prompts, self.model, z, out_path)
214
+
215
+ loss = sum(lossAll)
216
+ loss.backward()
217
+ self.opt.step()
218
+
219
+ # with torch.no_grad():
220
+ with torch.inference_mode():
221
+ z.copy_(z.maximum(self.z_min).minimum(self.z_max))
222
+
223
+ # Ready to stop yet?
224
+ if i == self.args.max_iterations:
225
+ yield checkin(i, lossAll, prompts, self.model, z, out_path)
226
+
227
+
228
+ @torch.inference_mode()
229
+ def checkin(i, losses, prompts, model, z, outpath):
230
+ losses_str = ", ".join(f"{loss.item():g}" for loss in losses)
231
+ tqdm.write(f"i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}")
232
+ out = synth(z, model)
233
+ info = PngImagePlugin.PngInfo()
234
+ info.add_text("comment", f"{prompts}")
235
+ TF.to_pil_image(out[0].cpu()).save(str(outpath), pnginfo=info)
236
+ return outpath
237
+
238
+
239
+ def get_args():
240
+ vq_parser = argparse.ArgumentParser(description="Image generation using VQGAN+CLIP")
241
+
242
+ # Add the arguments
243
+ vq_parser.add_argument(
244
+ "-p", "--prompts", type=str, help="Text prompts", default=None, dest="prompts"
245
+ )
246
+ vq_parser.add_argument(
247
+ "-ip",
248
+ "--image_prompts",
249
+ type=str,
250
+ help="Image prompts / target image",
251
+ default=[],
252
+ dest="image_prompts",
253
+ )
254
+ vq_parser.add_argument(
255
+ "-i",
256
+ "--iterations",
257
+ type=int,
258
+ help="Number of iterations",
259
+ default=500,
260
+ dest="max_iterations",
261
+ )
262
+ vq_parser.add_argument(
263
+ "-se",
264
+ "--save_every",
265
+ type=int,
266
+ help="Save image iterations",
267
+ default=50,
268
+ dest="display_freq",
269
+ )
270
+ vq_parser.add_argument(
271
+ "-s",
272
+ "--size",
273
+ nargs=2,
274
+ type=int,
275
+ help="Image size (width height) (default: %(default)s)",
276
+ dest="size",
277
+ )
278
+ vq_parser.add_argument(
279
+ "-ii",
280
+ "--init_image",
281
+ type=str,
282
+ help="Initial image",
283
+ default=None,
284
+ dest="init_image",
285
+ )
286
+ vq_parser.add_argument(
287
+ "-in",
288
+ "--init_noise",
289
+ type=str,
290
+ help="Initial noise image (pixels or gradient)",
291
+ default=None,
292
+ dest="init_noise",
293
+ )
294
+ vq_parser.add_argument(
295
+ "-iw",
296
+ "--init_weight",
297
+ type=float,
298
+ help="Initial weight",
299
+ default=0.0,
300
+ dest="init_weight",
301
+ )
302
+ vq_parser.add_argument(
303
+ "-m",
304
+ "--clip_model",
305
+ type=str,
306
+ help="CLIP model (e.g. ViT-B/32, ViT-B/16)",
307
+ default="ViT-B/32",
308
+ dest="clip_model",
309
+ )
310
+ vq_parser.add_argument(
311
+ "-conf",
312
+ "--vqgan_config",
313
+ type=str,
314
+ help="VQGAN config",
315
+ default=f"checkpoints/vqgan_imagenet_f16_16384.yaml",
316
+ dest="vqgan_config",
317
+ )
318
+ vq_parser.add_argument(
319
+ "-ckpt",
320
+ "--vqgan_checkpoint",
321
+ type=str,
322
+ help="VQGAN checkpoint",
323
+ default=f"checkpoints/vqgan_imagenet_f16_16384.ckpt",
324
+ dest="vqgan_checkpoint",
325
+ )
326
+ vq_parser.add_argument(
327
+ "-nps",
328
+ "--noise_prompt_seeds",
329
+ nargs="*",
330
+ type=int,
331
+ help="Noise prompt seeds",
332
+ default=[],
333
+ dest="noise_prompt_seeds",
334
+ )
335
+ vq_parser.add_argument(
336
+ "-npw",
337
+ "--noise_prompt_weights",
338
+ nargs="*",
339
+ type=float,
340
+ help="Noise prompt weights",
341
+ default=[],
342
+ dest="noise_prompt_weights",
343
+ )
344
+ vq_parser.add_argument(
345
+ "-lr",
346
+ "--learning_rate",
347
+ type=float,
348
+ help="Learning rate",
349
+ default=0.1,
350
+ dest="step_size",
351
+ )
352
+ vq_parser.add_argument(
353
+ "-cutm",
354
+ "--cut_method",
355
+ type=str,
356
+ help="Cut method",
357
+ choices=["original", "updated", "nrupdated", "updatedpooling", "latest"],
358
+ default="latest",
359
+ dest="cut_method",
360
+ )
361
+ vq_parser.add_argument(
362
+ "-cuts", "--num_cuts", type=int, help="Number of cuts", default=32, dest="cutn"
363
+ )
364
+ vq_parser.add_argument(
365
+ "-cutp",
366
+ "--cut_power",
367
+ type=float,
368
+ help="Cut power",
369
+ default=1.0,
370
+ dest="cut_pow",
371
+ )
372
+ vq_parser.add_argument(
373
+ "-sd", "--seed", type=int, help="Seed", default=None, dest="seed"
374
+ )
375
+ vq_parser.add_argument(
376
+ "-opt",
377
+ "--optimiser",
378
+ type=str,
379
+ help="Optimiser",
380
+ choices=[
381
+ "Adam",
382
+ "AdamW",
383
+ "Adagrad",
384
+ "Adamax",
385
+ "DiffGrad",
386
+ "AdamP",
387
+ "RAdam",
388
+ "RMSprop",
389
+ ],
390
+ default="Adam",
391
+ dest="optimiser",
392
+ )
393
+ vq_parser.add_argument(
394
+ "-o",
395
+ "--output",
396
+ type=str,
397
+ help="Output filename",
398
+ default="output.png",
399
+ dest="output",
400
+ )
401
+ vq_parser.add_argument(
402
+ "-vid",
403
+ "--video",
404
+ action="store_true",
405
+ help="Create video frames?",
406
+ dest="make_video",
407
+ )
408
+ vq_parser.add_argument(
409
+ "-zvid",
410
+ "--zoom_video",
411
+ action="store_true",
412
+ help="Create zoom video?",
413
+ dest="make_zoom_video",
414
+ )
415
+ vq_parser.add_argument(
416
+ "-zs",
417
+ "--zoom_start",
418
+ type=int,
419
+ help="Zoom start iteration",
420
+ default=0,
421
+ dest="zoom_start",
422
+ )
423
+ vq_parser.add_argument(
424
+ "-zse",
425
+ "--zoom_save_every",
426
+ type=int,
427
+ help="Save zoom image iterations",
428
+ default=10,
429
+ dest="zoom_frequency",
430
+ )
431
+ vq_parser.add_argument(
432
+ "-zsc",
433
+ "--zoom_scale",
434
+ type=float,
435
+ help="Zoom scale %",
436
+ default=0.99,
437
+ dest="zoom_scale",
438
+ )
439
+ vq_parser.add_argument(
440
+ "-zsx",
441
+ "--zoom_shift_x",
442
+ type=int,
443
+ help="Zoom shift x (left/right) amount in pixels",
444
+ default=0,
445
+ dest="zoom_shift_x",
446
+ )
447
+ vq_parser.add_argument(
448
+ "-zsy",
449
+ "--zoom_shift_y",
450
+ type=int,
451
+ help="Zoom shift y (up/down) amount in pixels",
452
+ default=0,
453
+ dest="zoom_shift_y",
454
+ )
455
+ vq_parser.add_argument(
456
+ "-cpe",
457
+ "--change_prompt_every",
458
+ type=int,
459
+ help="Prompt change frequency",
460
+ default=0,
461
+ dest="prompt_frequency",
462
+ )
463
+ vq_parser.add_argument(
464
+ "-vl",
465
+ "--video_length",
466
+ type=float,
467
+ help="Video length in seconds (not interpolated)",
468
+ default=10,
469
+ dest="video_length",
470
+ )
471
+ vq_parser.add_argument(
472
+ "-ofps",
473
+ "--output_video_fps",
474
+ type=float,
475
+ help="Create an interpolated video (Nvidia GPU only) with this fps (min 10. best set to 30 or 60)",
476
+ default=30,
477
+ dest="output_video_fps",
478
+ )
479
+ vq_parser.add_argument(
480
+ "-ifps",
481
+ "--input_video_fps",
482
+ type=float,
483
+ help="When creating an interpolated video, use this as the input fps to interpolate from (>0 & <ofps)",
484
+ default=15,
485
+ dest="input_video_fps",
486
+ )
487
+ vq_parser.add_argument(
488
+ "-d",
489
+ "--deterministic",
490
+ action="store_true",
491
+ help="Enable cudnn.deterministic?",
492
+ dest="cudnn_determinism",
493
+ )
494
+ vq_parser.add_argument(
495
+ "-aug",
496
+ "--augments",
497
+ nargs="+",
498
+ action="append",
499
+ type=str,
500
+ choices=["Ji", "Sh", "Gn", "Pe", "Ro", "Af", "Et", "Ts", "Cr", "Er", "Re"],
501
+ help="Enabled augments (latest vut method only)",
502
+ default=[["Af", "Pe", "Ji", "Er"]],
503
+ dest="augments",
504
+ )
505
+ vq_parser.add_argument(
506
+ "-vsd",
507
+ "--video_style_dir",
508
+ type=str,
509
+ help="Directory with video frames to style",
510
+ default=None,
511
+ dest="video_style_dir",
512
+ )
513
+ vq_parser.add_argument(
514
+ "-cd",
515
+ "--cuda_device",
516
+ type=str,
517
+ help="Cuda device to use",
518
+ default="cuda:0",
519
+ dest="cuda_device",
520
+ )
521
+
522
+ # Execute the parse_args() method
523
+ args = vq_parser.parse_args("")
524
+ return args
525
+
526
+
527
+ def load_vqgan_model(config_path, checkpoint_path):
528
+ config = OmegaConf.load(config_path)
529
+ # config.model.target == 'taming.models.vqgan.VQModel':
530
+ model = vqgan.VQModel(**config.model.params)
531
+ model.eval().requires_grad_(False)
532
+ model.init_from_ckpt(checkpoint_path)
533
+ del model.loss
534
+ return model
535
+
536
+
537
+ class MakeCutouts(nn.Module):
538
+ def __init__(self, cut_size, cutn, args, cut_pow=1.0):
539
+ super().__init__()
540
+ self.cut_size = cut_size
541
+ self.cutn = cutn
542
+ self.cut_pow = cut_pow # not used with pooling
543
+
544
+ # Pick your own augments & their order
545
+ augment_list = []
546
+ for item in args.augments[0]:
547
+ if item == "Ji":
548
+ augment_list.append(
549
+ K.ColorJitter(
550
+ brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.7
551
+ )
552
+ )
553
+ elif item == "Sh":
554
+ augment_list.append(K.RandomSharpness(sharpness=0.3, p=0.5))
555
+ elif item == "Gn":
556
+ augment_list.append(K.RandomGaussianNoise(mean=0.0, std=1.0, p=0.5))
557
+ elif item == "Pe":
558
+ augment_list.append(K.RandomPerspective(distortion_scale=0.7, p=0.7))
559
+ elif item == "Ro":
560
+ augment_list.append(K.RandomRotation(degrees=15, p=0.7))
561
+ elif item == "Af":
562
+ augment_list.append(
563
+ K.RandomAffine(
564
+ degrees=15,
565
+ translate=0.1,
566
+ shear=5,
567
+ p=0.7,
568
+ padding_mode="zeros",
569
+ keepdim=True,
570
+ )
571
+ ) # border, reflection, zeros
572
+ elif item == "Et":
573
+ augment_list.append(K.RandomElasticTransform(p=0.7))
574
+ elif item == "Ts":
575
+ augment_list.append(
576
+ K.RandomThinPlateSpline(scale=0.8, same_on_batch=True, p=0.7)
577
+ )
578
+ elif item == "Cr":
579
+ augment_list.append(
580
+ K.RandomCrop(
581
+ size=(self.cut_size, self.cut_size),
582
+ pad_if_needed=True,
583
+ padding_mode="reflect",
584
+ p=0.5,
585
+ )
586
+ )
587
+ elif item == "Er":
588
+ augment_list.append(
589
+ K.RandomErasing(
590
+ scale=(0.1, 0.4),
591
+ ratio=(0.3, 1 / 0.3),
592
+ same_on_batch=True,
593
+ p=0.7,
594
+ )
595
+ )
596
+ elif item == "Re":
597
+ augment_list.append(
598
+ K.RandomResizedCrop(
599
+ size=(self.cut_size, self.cut_size),
600
+ scale=(0.1, 1),
601
+ ratio=(0.75, 1.333),
602
+ cropping_mode="resample",
603
+ p=0.5,
604
+ )
605
+ )
606
+
607
+ self.augs = nn.Sequential(*augment_list)
608
+ self.noise_fac = 0.1
609
+ # self.noise_fac = False
610
+
611
+ # Uncomment if you like seeing the list ;)
612
+ # print(augment_list)
613
+
614
+ # Pooling
615
+ self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
616
+ self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
617
+
618
+ def forward(self, input):
619
+ cutouts = []
620
+
621
+ for _ in range(self.cutn):
622
+ # Use Pooling
623
+ cutout = (self.av_pool(input) + self.max_pool(input)) / 2
624
+ cutouts.append(cutout)
625
+
626
+ batch = self.augs(torch.cat(cutouts, dim=0))
627
+
628
+ if self.noise_fac:
629
+ facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
630
+ batch = batch + facs * torch.randn_like(batch)
631
+ return batch
632
+
633
+
634
+ def get_opt(opt_name, opt_lr, z):
635
+ if opt_name == "Adam":
636
+ opt = optim.Adam([z], lr=opt_lr) # LR=0.1 (Default)
637
+ elif opt_name == "AdamW":
638
+ opt = optim.AdamW([z], lr=opt_lr)
639
+ elif opt_name == "Adagrad":
640
+ opt = optim.Adagrad([z], lr=opt_lr)
641
+ elif opt_name == "Adamax":
642
+ opt = optim.Adamax([z], lr=opt_lr)
643
+ elif opt_name == "DiffGrad":
644
+ opt = DiffGrad(
645
+ [z], lr=opt_lr, eps=1e-9, weight_decay=1e-9
646
+ ) # NR: Playing for reasons
647
+ elif opt_name == "AdamP":
648
+ opt = AdamP([z], lr=opt_lr)
649
+ elif opt_name == "RAdam":
650
+ opt = RAdam([z], lr=opt_lr)
651
+ elif opt_name == "RMSprop":
652
+ opt = optim.RMSprop([z], lr=opt_lr)
653
+ else:
654
+ print("Unknown optimiser. Are choices broken?")
655
+ opt = optim.Adam([z], lr=opt_lr)
656
+ return opt
657
+
658
+
659
+ def ascend_txt(i, z, perceptor, args, model, make_cutouts, pMs):
660
+ normalize = transforms.Normalize(
661
+ mean=[0.48145466, 0.4578275, 0.40821073],
662
+ std=[0.26862954, 0.26130258, 0.27577711],
663
+ )
664
+ out = synth(z, model)
665
+ iii = perceptor.encode_image(normalize(make_cutouts(out))).float()
666
+
667
+ result = []
668
+
669
+ if args.init_weight:
670
+ # result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)
671
+ result.append(
672
+ F.mse_loss(z, torch.zeros_like(z_orig))
673
+ * ((1 / torch.tensor(i * 2 + 1)) * args.init_weight)
674
+ / 2
675
+ )
676
+
677
+ for prompt in pMs:
678
+ result.append(prompt(iii))
679
+
680
+ if args.make_video:
681
+ img = np.array(
682
+ out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8)
683
+ )[:, :, :]
684
+ img = np.transpose(img, (1, 2, 0))
685
+ imageio.imwrite("steps/" + str(i) + ".png", np.array(img))
686
+
687
+ return result
688
+
689
+
690
+ def synth(z, model):
691
+ # gumbel is False
692
+ z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(
693
+ 3, 1
694
+ )
695
+ return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)
696
+
697
+
698
+ def vector_quantize(x, codebook):
699
+ d = (
700
+ x.pow(2).sum(dim=-1, keepdim=True)
701
+ + codebook.pow(2).sum(dim=1)
702
+ - 2 * x @ codebook.T
703
+ )
704
+ indices = d.argmin(-1)
705
+ x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
706
+ return replace_grad(x_q, x)
707
+
708
+
709
+ def split_prompt(prompt):
710
+ vals = prompt.rsplit(":", 2)
711
+ vals = vals + ["", "1", "-inf"][len(vals) :]
712
+ return vals[0], float(vals[1]), float(vals[2])
713
+
714
+
715
+ class Prompt(nn.Module):
716
+ def __init__(self, embed, weight=1.0, stop=float("-inf")):
717
+ super().__init__()
718
+ self.register_buffer("embed", embed)
719
+ self.register_buffer("weight", torch.as_tensor(weight))
720
+ self.register_buffer("stop", torch.as_tensor(stop))
721
+
722
+ def forward(self, input):
723
+ input_normed = F.normalize(input.unsqueeze(1), dim=2)
724
+ embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
725
+ dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
726
+ dists = dists * self.weight.sign()
727
+ return (
728
+ self.weight.abs()
729
+ * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
730
+ )