Axolotlily commited on
Commit
2a5b13c
·
1 Parent(s): d3dc3d3

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +395 -0
main.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import random
4
+
5
+ from vqgan_clip.grad import *
6
+ from vqgan_clip.helpers import *
7
+ from vqgan_clip.inits import *
8
+ from vqgan_clip.masking import *
9
+ from vqgan_clip.optimizers import *
10
+
11
+ from urllib.request import urlopen
12
+ from tqdm import tqdm
13
+ import sys
14
+ import os
15
+
16
+ from omegaconf import OmegaConf
17
+
18
+ from taming.models import cond_transformer, vqgan
19
+
20
+ import torch
21
+ from torch import nn, optim
22
+ from torch.nn import functional as F
23
+ from torchvision import transforms
24
+ from torchvision.transforms import functional as TF
25
+ from torch.cuda import get_device_properties
26
+ torch.backends.cudnn.benchmark = False
27
+
28
+ from torch_optimizer import DiffGrad, AdamP, RAdam
29
+
30
+ import clip
31
+ import kornia.augmentation as K
32
+ import numpy as np
33
+ import imageio
34
+
35
+ from PIL import ImageFile, Image, PngImagePlugin, ImageChops
36
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
37
+
38
+ from subprocess import Popen, PIPE
39
+ import re
40
+ from packaging import version
41
+
42
+ # Supress warnings
43
+ import warnings
44
+ warnings.filterwarnings('ignore')
45
+
46
+ # Check for GPU and reduce the default image size if low VRAM
47
+ default_image_size = 512 # >8GB VRAM
48
+ if not torch.cuda.is_available():
49
+ default_image_size = 256 # no GPU found
50
+ elif get_device_properties(0).total_memory <= 2 ** 33: # 2 ** 33 = 8,589,934,592 bytes = 8 GB
51
+ default_image_size = 318 # <8GB VRAM
52
+
53
+ def parse():
54
+
55
+ vq_parser = argparse.ArgumentParser(description='Image generation using VQGAN+CLIP')
56
+
57
+ vq_parser.add_argument("-aug", "--augments", nargs='+', action='append', type=str, choices=['Hf','Ji','Sh','Pe','Ro','Af','Et','Ts','Er'],
58
+ help="Enabled augments (latest vut method only)", default=[['Hf','Af', 'Pe', 'Ji', 'Er']], dest='augments')
59
+ vq_parser.add_argument("-cd", "--cuda_device", type=str, help="Cuda device to use", default="cuda:0", dest='cuda_device')
60
+ vq_parser.add_argument("-ckpt", "--vqgan_checkpoint", type=str, help="VQGAN checkpoint", default=f'checkpoints/vqgan_imagenet_f16_16384.ckpt',
61
+ dest='vqgan_checkpoint')
62
+ vq_parser.add_argument("-conf", "--vqgan_config", type=str, help="VQGAN config", default=f'checkpoints/vqgan_imagenet_f16_16384.yaml', dest='vqgan_config')
63
+ vq_parser.add_argument("-cpe", "--change_prompt_every", type=int, help="Prompt change frequency", default=0, dest='prompt_frequency')
64
+ vq_parser.add_argument("-cutm", "--cut_method", type=str, help="Cut method", choices=['original','latest'],
65
+ default='latest', dest='cut_method')
66
+ vq_parser.add_argument("-cutp", "--cut_power", type=float, help="Cut power", default=1., dest='cut_pow')
67
+ vq_parser.add_argument("-cuts", "--num_cuts", type=int, help="Number of cuts", default=32, dest='cutn')
68
+ vq_parser.add_argument("-d", "--deterministic", action='store_true', help="Enable cudnn.deterministic?", dest='cudnn_determinism')
69
+ vq_parser.add_argument("-i", "--iterations", type=int, help="Number of iterations", default=500, dest='max_iterations')
70
+ vq_parser.add_argument("-ifps", "--input_video_fps", type=float,
71
+ help="When creating an interpolated video, use this as the input fps to interpolate from (>0 & <ofps)", default=15,
72
+ dest='input_video_fps')
73
+ vq_parser.add_argument("-ii", "--init_image", type=str, help="Initial image", default=None, dest='init_image')
74
+ vq_parser.add_argument("-in", "--init_noise", type=str, help="Initial noise image (pixels or gradient)", default=None, dest='init_noise')
75
+ vq_parser.add_argument("-ip", "--image_prompts", type=str, help="Image prompts / target image", default=[], dest='image_prompts')
76
+ vq_parser.add_argument("-iw", "--init_weight", type=float, help="Initial weight", default=0., dest='init_weight')
77
+ vq_parser.add_argument("-lr", "--learning_rate", type=float, help="Learning rate", default=0.1, dest='step_size')
78
+ vq_parser.add_argument("-m", "--clip_model", type=str, help="CLIP model (e.g. ViT-B/32, ViT-B/16)", default='ViT-B/32', dest='clip_model')
79
+ vq_parser.add_argument("-nps", "--noise_prompt_seeds", nargs="*", type=int, help="Noise prompt seeds", default=[], dest='noise_prompt_seeds')
80
+ vq_parser.add_argument("-npw", "--noise_prompt_weights", nargs="*", type=float, help="Noise prompt weights", default=[], dest='noise_prompt_weights')
81
+ vq_parser.add_argument("-o", "--output", type=str, help="Output filename", default="output.png", dest='output')
82
+ vq_parser.add_argument("-ofps", "--output_video_fps", type=float,
83
+ help="Create an interpolated video (Nvidia GPU only) with this fps (min 10. best set to 30 or 60)", default=0, dest='output_video_fps')
84
+ vq_parser.add_argument("-opt", "--optimiser", type=str, help="Optimiser", choices=['Adam','AdamW','Adagrad','Adamax','DiffGrad','AdamP','RAdam','RMSprop'],
85
+ default='Adam', dest='optimiser')
86
+ vq_parser.add_argument("-p", "--prompts", type=str, help="Text prompts", default=None, dest='prompts')
87
+ vq_parser.add_argument("-s", "--size", nargs=2, type=int, help="Image size (width height) (default: %(default)s)",
88
+ default=[default_image_size, default_image_size], dest='size')
89
+ vq_parser.add_argument("-sd", "--seed", type=int, help="Seed", default=None, dest='seed')
90
+ vq_parser.add_argument("-se", "--save_every", type=int, help="Save image iterations", default=50, dest='display_freq')
91
+ vq_parser.add_argument("-vid", "--video", action='store_true', help="Create video frames?", dest='make_video')
92
+ vq_parser.add_argument("-vl", "--video_length", type=float, help="Video length in seconds (not interpolated)", default=10, dest='video_length')
93
+ vq_parser.add_argument("-vsd", "--video_style_dir", type=str, help="Directory with video frames to style", default=None, dest='video_style_dir')
94
+ vq_parser.add_argument("-zs", "--zoom_start", type=int, help="Zoom start iteration", default=0, dest='zoom_start')
95
+ vq_parser.add_argument("-zsc", "--zoom_scale", type=float, help="Zoom scale %", default=0.99, dest='zoom_scale')
96
+ vq_parser.add_argument("-zse", "--zoom_save_every", type=int, help="Save zoom image iterations", default=10, dest='zoom_frequency')
97
+ vq_parser.add_argument("-zsx", "--zoom_shift_x", type=int, help="Zoom shift x (left/right) amount in pixels", default=0, dest='zoom_shift_x')
98
+ vq_parser.add_argument("-zsy", "--zoom_shift_y", type=int, help="Zoom shift y (up/down) amount in pixels", default=0, dest='zoom_shift_y')
99
+ vq_parser.add_argument("-zvid", "--zoom_video", action='store_true', help="Create zoom video?", dest='make_zoom_video')
100
+
101
+ args = vq_parser.parse_args()
102
+
103
+ if not args.prompts and not args.image_prompts:
104
+ raise Exception("You must supply a text or image prompt")
105
+
106
+ torch.backends.cudnn.deterministic = args.cudnn_determinism
107
+
108
+ # Split text prompts using the pipe character (weights are split later)
109
+ if args.prompts:
110
+ # For stories, there will be many phrases
111
+ story_phrases = [phrase.strip() for phrase in args.prompts.split("^")]
112
+
113
+ # Make a list of all phrases
114
+ all_phrases = []
115
+ for phrase in story_phrases:
116
+ all_phrases.append(phrase.split("|"))
117
+
118
+ # First phrase
119
+ args.prompts = all_phrases[0]
120
+
121
+ # Split target images using the pipe character (weights are split later)
122
+ if args.image_prompts:
123
+ args.image_prompts = args.image_prompts.split("|")
124
+ args.image_prompts = [image.strip() for image in args.image_prompts]
125
+
126
+ if args.make_video and args.make_zoom_video:
127
+ print("Warning: Make video and make zoom video are mutually exclusive.")
128
+ args.make_video = False
129
+
130
+ # Make video steps directory
131
+ if args.make_video or args.make_zoom_video:
132
+ if not os.path.exists('steps'):
133
+ os.mkdir('steps')
134
+
135
+ return args
136
+
137
+ class Prompt(nn.Module):
138
+ def __init__(self, embed, weight=1., stop=float('-inf')):
139
+ super().__init__()
140
+ self.register_buffer('embed', embed)
141
+ self.register_buffer('weight', torch.as_tensor(weight))
142
+ self.register_buffer('stop', torch.as_tensor(stop))
143
+
144
+ def forward(self, input):
145
+ input_normed = F.normalize(input.unsqueeze(1), dim=2)
146
+ embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
147
+ dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
148
+ dists = dists * self.weight.sign()
149
+ return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
150
+
151
+
152
+ #NR: Split prompts and weights
153
+ def split_prompt(prompt):
154
+ vals = prompt.rsplit(':', 2)
155
+ vals = vals + ['', '1', '-inf'][len(vals):]
156
+ return vals[0], float(vals[1]), float(vals[2])
157
+
158
+
159
+ def load_vqgan_model(config_path, checkpoint_path):
160
+ global gumbel
161
+ gumbel = False
162
+ config = OmegaConf.load(config_path)
163
+ if config.model.target == 'taming.models.vqgan.VQModel':
164
+ model = vqgan.VQModel(**config.model.params)
165
+ model.eval().requires_grad_(False)
166
+ model.init_from_ckpt(checkpoint_path)
167
+ elif config.model.target == 'taming.models.vqgan.GumbelVQ':
168
+ model = vqgan.GumbelVQ(**config.model.params)
169
+ model.eval().requires_grad_(False)
170
+ model.init_from_ckpt(checkpoint_path)
171
+ gumbel = True
172
+ elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':
173
+ parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
174
+ parent_model.eval().requires_grad_(False)
175
+ parent_model.init_from_ckpt(checkpoint_path)
176
+ model = parent_model.first_stage_model
177
+ else:
178
+ raise ValueError(f'unknown model type: {config.model.target}')
179
+ del model.loss
180
+ return model
181
+
182
+
183
+ # Vector quantize
184
+ def synth(z):
185
+ if gumbel:
186
+ z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim(3, 1)
187
+ else:
188
+ z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)
189
+ return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)
190
+
191
+
192
+ @torch.inference_mode()
193
+ def checkin(i, losses):
194
+ losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
195
+ tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')
196
+ out = synth(z)
197
+ info = PngImagePlugin.PngInfo()
198
+ info.add_text('comment', f'{args.prompts}')
199
+ TF.to_pil_image(out[0].cpu()).save(args.output, pnginfo=info)
200
+
201
+
202
+ def ascend_txt():
203
+ global i
204
+ out = synth(z)
205
+ iii = perceptor.encode_image(normalize(make_cutouts(out))).float()
206
+
207
+ result = []
208
+
209
+ if args.init_weight:
210
+ # result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)
211
+ result.append(F.mse_loss(z, torch.zeros_like(z_orig)) * ((1/torch.tensor(i*2 + 1))*args.init_weight) / 2)
212
+
213
+ for prompt in pMs:
214
+ result.append(prompt(iii))
215
+
216
+ if args.make_video:
217
+ img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]
218
+ img = np.transpose(img, (1, 2, 0))
219
+ imageio.imwrite('./steps/' + str(i) + '.png', np.array(img))
220
+
221
+ return result # return loss
222
+
223
+
224
+ def train(i):
225
+ opt.zero_grad(set_to_none=True)
226
+ lossAll = ascend_txt()
227
+
228
+ if i % args.display_freq == 0:
229
+ checkin(i, lossAll)
230
+
231
+ loss = sum(lossAll)
232
+ loss.backward()
233
+ opt.step()
234
+
235
+ #with torch.no_grad():
236
+ with torch.inference_mode():
237
+ z.copy_(z.maximum(z_min).minimum(z_max))
238
+
239
+
240
+ if __name__ == '__main__':
241
+
242
+ args = parse()
243
+
244
+ # Do it
245
+ device = torch.device(args.cuda_device)
246
+ model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
247
+ jit = True if version.parse(torch.__version__) < version.parse('1.8.0') else False
248
+ perceptor = clip.load(args.clip_model, jit=jit)[0].eval().requires_grad_(False).to(device)
249
+
250
+
251
+ cut_size = perceptor.visual.input_resolution
252
+ f = 2**(model.decoder.num_resolutions - 1)
253
+
254
+ # Cutout class options:
255
+ # 'latest','original','updated' or 'updatedpooling'
256
+ if args.cut_method == 'latest':
257
+ make_cutouts = MakeCutouts(args, cut_size, args.cutn)
258
+ elif args.cut_method == 'original':
259
+ make_cutouts = MakeCutoutsOrig(args, cut_size, args.cutn)
260
+
261
+ toksX, toksY = args.size[0] // f, args.size[1] // f
262
+ sideX, sideY = toksX * f, toksY * f
263
+
264
+ # Gumbel or not?
265
+ if gumbel:
266
+ e_dim = 256
267
+ n_toks = model.quantize.n_embed
268
+ z_min = model.quantize.embed.weight.min(dim=0).values[None, :, None, None]
269
+ z_max = model.quantize.embed.weight.max(dim=0).values[None, :, None, None]
270
+ else:
271
+ e_dim = model.quantize.e_dim
272
+ n_toks = model.quantize.n_e
273
+ z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
274
+ z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
275
+
276
+
277
+ if args.init_image:
278
+ if 'http' in args.init_image:
279
+ img = Image.open(urlopen(args.init_image))
280
+ else:
281
+ img = Image.open(args.init_image)
282
+ pil_image = img.convert('RGB')
283
+ pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
284
+ pil_tensor = TF.to_tensor(pil_image)
285
+ z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)
286
+ elif args.init_noise == 'pixels':
287
+ img = random_noise_image(args.size[0], args.size[1])
288
+ pil_image = img.convert('RGB')
289
+ pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
290
+ pil_tensor = TF.to_tensor(pil_image)
291
+ z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)
292
+ elif args.init_noise == 'gradient':
293
+ img = random_gradient_image(args.size[0], args.size[1])
294
+ pil_image = img.convert('RGB')
295
+ pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
296
+ pil_tensor = TF.to_tensor(pil_image)
297
+ z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)
298
+ else:
299
+ one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float()
300
+ # z = one_hot @ model.quantize.embedding.weight
301
+ if gumbel:
302
+ z = one_hot @ model.quantize.embed.weight
303
+ else:
304
+ z = one_hot @ model.quantize.embedding.weight
305
+
306
+ z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
307
+ #z = torch.rand_like(z)*2 # NR: check
308
+
309
+ z_orig = z.clone()
310
+ z.requires_grad_(True)
311
+ pMs = []
312
+ normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
313
+ std=[0.26862954, 0.26130258, 0.27577711])
314
+
315
+
316
+ # CLIP tokenize/encode
317
+ if args.prompts:
318
+ for prompt in args.prompts:
319
+ txt, weight, stop = split_prompt(prompt)
320
+ embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
321
+ pMs.append(Prompt(embed, weight, stop).to(device))
322
+
323
+
324
+ for prompt in args.image_prompts:
325
+ path, weight, stop = split_prompt(prompt)
326
+ img = Image.open(path)
327
+ pil_image = img.convert('RGB')
328
+ img = resize_image(pil_image, (sideX, sideY))
329
+ batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
330
+ embed = perceptor.encode_image(normalize(batch)).float()
331
+ pMs.append(Prompt(embed, weight, stop).to(device))
332
+
333
+ for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
334
+ gen = torch.Generator().manual_seed(seed)
335
+ embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
336
+ pMs.append(Prompt(embed, weight).to(device))
337
+
338
+
339
+ # Set the optimiser
340
+ opt, z = get_opt(args.optimiser, z, args.step_size)
341
+
342
+
343
+ # Output for the user
344
+ print('Using device:', device)
345
+ print('Optimising using:', args.optimiser)
346
+
347
+ if args.prompts:
348
+ print('Using text prompts:', args.prompts)
349
+ if args.image_prompts:
350
+ print('Using image prompts:', args.image_prompts)
351
+ if args.init_image:
352
+ print('Using initial image:', args.init_image)
353
+ if args.noise_prompt_weights:
354
+ print('Noise prompt weights:', args.noise_prompt_weights)
355
+
356
+
357
+ if args.seed is None:
358
+ seed = torch.seed()
359
+ else:
360
+ seed = args.seed
361
+ torch.manual_seed(seed)
362
+ print('Using seed:', seed)
363
+
364
+
365
+ i = 0 # Iteration counter
366
+ j = 0 # Zoom video frame counter
367
+ p = 1 # Phrase counter
368
+ smoother = 0 # Smoother counter
369
+ this_video_frame = 0 # for video styling
370
+
371
+ with tqdm() as pbar:
372
+ while i < args.max_iterations:
373
+ # Change text prompt
374
+ if args.prompt_frequency > 0:
375
+ if i % args.prompt_frequency == 0 and i > 0:
376
+ # In case there aren't enough phrases, just loop
377
+ if p >= len(all_phrases):
378
+ p = 0
379
+
380
+ pMs = []
381
+ args.prompts = all_phrases[p]
382
+
383
+ # Show user we're changing prompt
384
+ print(args.prompts)
385
+
386
+ for prompt in args.prompts:
387
+ txt, weight, stop = split_prompt(prompt)
388
+ embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
389
+ pMs.append(Prompt(embed, weight, stop).to(device))
390
+ p += 1
391
+ train(i)
392
+ i += 1
393
+ pbar.update()
394
+
395
+ print("done")