AlexKM commited on
Commit
711d859
1 Parent(s): d7da47d

Upload generate.py

Browse files
Files changed (1) hide show
  1. generate.py +990 -0
generate.py ADDED
@@ -0,0 +1,990 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Originally made by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings)
2
+ # The original BigGAN+CLIP method was by https://twitter.com/advadnoun
3
+
4
+ import argparse
5
+ import math
6
+ import random
7
+ # from email.policy import default
8
+ from urllib.request import urlopen
9
+ from tqdm import tqdm
10
+ import sys
11
+ import os
12
+
13
+ # pip install taming-transformers doesn't work with Gumbel, but does not yet work with coco etc
14
+ # appending the path does work with Gumbel, but gives ModuleNotFoundError: No module named 'transformers' for coco etc
15
+ sys.path.append('taming-transformers')
16
+
17
+ from omegaconf import OmegaConf
18
+ from taming.models import cond_transformer, vqgan
19
+ #import taming.modules
20
+
21
+ import torch
22
+ from torch import nn, optim
23
+ from torch.nn import functional as F
24
+ from torchvision import transforms
25
+ from torchvision.transforms import functional as TF
26
+ from torch.cuda import get_device_properties
27
+ torch.backends.cudnn.benchmark = False # NR: True is a bit faster, but can lead to OOM. False is more deterministic.
28
+ #torch.use_deterministic_algorithms(True) # NR: grid_sampler_2d_backward_cuda does not have a deterministic implementation
29
+
30
+ from torch_optimizer import DiffGrad, AdamP
31
+
32
+ from CLIP import clip
33
+ import kornia.augmentation as K
34
+ import numpy as np
35
+ import imageio
36
+
37
+ from PIL import ImageFile, Image, PngImagePlugin, ImageChops
38
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
39
+
40
+ from subprocess import Popen, PIPE
41
+ import re
42
+
43
+ # Supress warnings
44
+ import warnings
45
+ warnings.filterwarnings('ignore')
46
+
47
+
48
+ # Check for GPU and reduce the default image size if low VRAM
49
+ default_image_size = 512 # >8GB VRAM
50
+ if not torch.cuda.is_available():
51
+ default_image_size = 256 # no GPU found
52
+ elif get_device_properties(0).total_memory <= 2 ** 33: # 2 ** 33 = 8,589,934,592 bytes = 8 GB
53
+ default_image_size = 304 # <8GB VRAM
54
+
55
+ # Create the parser
56
+ vq_parser = argparse.ArgumentParser(description='Image generation using VQGAN+CLIP')
57
+
58
+ # Add the arguments
59
+ vq_parser.add_argument("-p", "--prompts", type=str, help="Text prompts", default=None, dest='prompts')
60
+ vq_parser.add_argument("-ip", "--image_prompts", type=str, help="Image prompts / target image", default=[], dest='image_prompts')
61
+ vq_parser.add_argument("-i", "--iterations", type=int, help="Number of iterations", default=500, dest='max_iterations')
62
+ vq_parser.add_argument("-se", "--save_every", type=int, help="Save image iterations", default=50, dest='display_freq')
63
+ vq_parser.add_argument("-s", "--size", nargs=2, type=int, help="Image size (width height) (default: %(default)s)", default=[default_image_size,default_image_size], dest='size')
64
+ vq_parser.add_argument("-ii", "--init_image", type=str, help="Initial image", default=None, dest='init_image')
65
+ vq_parser.add_argument("-in", "--init_noise", type=str, help="Initial noise image (pixels or gradient)", default=None, dest='init_noise')
66
+ vq_parser.add_argument("-iw", "--init_weight", type=float, help="Initial weight", default=0., dest='init_weight')
67
+ vq_parser.add_argument("-m", "--clip_model", type=str, help="CLIP model (e.g. ViT-B/32, ViT-B/16)", default='ViT-B/32', dest='clip_model')
68
+ vq_parser.add_argument("-conf", "--vqgan_config", type=str, help="VQGAN config", default=f'checkpoints/vqgan_imagenet_f16_16384.yaml', dest='vqgan_config')
69
+ vq_parser.add_argument("-ckpt", "--vqgan_checkpoint", type=str, help="VQGAN checkpoint", default=f'checkpoints/vqgan_imagenet_f16_16384.ckpt', dest='vqgan_checkpoint')
70
+ vq_parser.add_argument("-nps", "--noise_prompt_seeds", nargs="*", type=int, help="Noise prompt seeds", default=[], dest='noise_prompt_seeds')
71
+ vq_parser.add_argument("-npw", "--noise_prompt_weights", nargs="*", type=float, help="Noise prompt weights", default=[], dest='noise_prompt_weights')
72
+ vq_parser.add_argument("-lr", "--learning_rate", type=float, help="Learning rate", default=0.1, dest='step_size')
73
+ vq_parser.add_argument("-cutm", "--cut_method", type=str, help="Cut method", choices=['original','updated','nrupdated','updatedpooling','latest'], default='latest', dest='cut_method')
74
+ vq_parser.add_argument("-cuts", "--num_cuts", type=int, help="Number of cuts", default=32, dest='cutn')
75
+ vq_parser.add_argument("-cutp", "--cut_power", type=float, help="Cut power", default=1., dest='cut_pow')
76
+ vq_parser.add_argument("-sd", "--seed", type=int, help="Seed", default=None, dest='seed')
77
+ vq_parser.add_argument("-opt", "--optimiser", type=str, help="Optimiser", choices=['Adam','AdamW','Adagrad','Adamax','DiffGrad','AdamP','RAdam','RMSprop'], default='Adam', dest='optimiser')
78
+ vq_parser.add_argument("-o", "--output", type=str, help="Output image filename", default="output.png", dest='output')
79
+ vq_parser.add_argument("-vid", "--video", action='store_true', help="Create video frames?", dest='make_video')
80
+ vq_parser.add_argument("-zvid", "--zoom_video", action='store_true', help="Create zoom video?", dest='make_zoom_video')
81
+ vq_parser.add_argument("-zs", "--zoom_start", type=int, help="Zoom start iteration", default=0, dest='zoom_start')
82
+ vq_parser.add_argument("-zse", "--zoom_save_every", type=int, help="Save zoom image iterations", default=10, dest='zoom_frequency')
83
+ vq_parser.add_argument("-zsc", "--zoom_scale", type=float, help="Zoom scale %%", default=0.99, dest='zoom_scale')
84
+ vq_parser.add_argument("-zsx", "--zoom_shift_x", type=int, help="Zoom shift x (left/right) amount in pixels", default=0, dest='zoom_shift_x')
85
+ vq_parser.add_argument("-zsy", "--zoom_shift_y", type=int, help="Zoom shift y (up/down) amount in pixels", default=0, dest='zoom_shift_y')
86
+ vq_parser.add_argument("-cpe", "--change_prompt_every", type=int, help="Prompt change frequency", default=0, dest='prompt_frequency')
87
+ vq_parser.add_argument("-vl", "--video_length", type=float, help="Video length in seconds (not interpolated)", default=10, dest='video_length')
88
+ vq_parser.add_argument("-ofps", "--output_video_fps", type=float, help="Create an interpolated video (Nvidia GPU only) with this fps (min 10. best set to 30 or 60)", default=0, dest='output_video_fps')
89
+ vq_parser.add_argument("-ifps", "--input_video_fps", type=float, help="When creating an interpolated video, use this as the input fps to interpolate from (>0 & <ofps)", default=15, dest='input_video_fps')
90
+ vq_parser.add_argument("-d", "--deterministic", action='store_true', help="Enable cudnn.deterministic?", dest='cudnn_determinism')
91
+ vq_parser.add_argument("-aug", "--augments", nargs='+', action='append', type=str, choices=['Ji','Sh','Gn','Pe','Ro','Af','Et','Ts','Cr','Er','Re'], help="Enabled augments (latest vut method only)", default=[], dest='augments')
92
+ vq_parser.add_argument("-vsd", "--video_style_dir", type=str, help="Directory with video frames to style", default=None, dest='video_style_dir')
93
+ vq_parser.add_argument("-cd", "--cuda_device", type=str, help="Cuda device to use", default="cuda:0", dest='cuda_device')
94
+
95
+
96
+ # Execute the parse_args() method
97
+ args = vq_parser.parse_args()
98
+
99
+ if not args.prompts and not args.image_prompts:
100
+ args.prompts = "A cute, smiling, Nerdy Rodent"
101
+
102
+ if args.cudnn_determinism:
103
+ torch.backends.cudnn.deterministic = True
104
+
105
+ if not args.augments:
106
+ args.augments = [['Af', 'Pe', 'Ji', 'Er']]
107
+
108
+ # Split text prompts using the pipe character (weights are split later)
109
+ if args.prompts:
110
+ # For stories, there will be many phrases
111
+ story_phrases = [phrase.strip() for phrase in args.prompts.split("^")]
112
+
113
+ # Make a list of all phrases
114
+ all_phrases = []
115
+ for phrase in story_phrases:
116
+ all_phrases.append(phrase.split("|"))
117
+
118
+ # First phrase
119
+ args.prompts = all_phrases[0]
120
+
121
+ # Split target images using the pipe character (weights are split later)
122
+ if args.image_prompts:
123
+ args.image_prompts = args.image_prompts.split("|")
124
+ args.image_prompts = [image.strip() for image in args.image_prompts]
125
+
126
+ if args.make_video and args.make_zoom_video:
127
+ print("Warning: Make video and make zoom video are mutually exclusive.")
128
+ args.make_video = False
129
+
130
+ # Make video steps directory
131
+ if args.make_video or args.make_zoom_video:
132
+ if not os.path.exists('steps'):
133
+ os.mkdir('steps')
134
+
135
+ # Fallback to CPU if CUDA is not found and make sure GPU video rendering is also disabled
136
+ # NB. May not work for AMD cards?
137
+ if not args.cuda_device == 'cpu' and not torch.cuda.is_available():
138
+ args.cuda_device = 'cpu'
139
+ args.video_fps = 0
140
+ print("Warning: No GPU found! Using the CPU instead. The iterations will be slow.")
141
+ print("Perhaps CUDA/ROCm or the right pytorch version is not properly installed?")
142
+
143
+ # If a video_style_dir has been, then create a list of all the images
144
+ if args.video_style_dir:
145
+ print("Locating video frames...")
146
+ video_frame_list = []
147
+ for entry in os.scandir(args.video_style_dir):
148
+ if (entry.path.endswith(".jpg")
149
+ or entry.path.endswith(".png")) and entry.is_file():
150
+ video_frame_list.append(entry.path)
151
+
152
+ # Reset a few options - same filename, different directory
153
+ if not os.path.exists('steps'):
154
+ os.mkdir('steps')
155
+
156
+ args.init_image = video_frame_list[0]
157
+ filename = os.path.basename(args.init_image)
158
+ cwd = os.getcwd()
159
+ args.output = os.path.join(cwd, "steps", filename)
160
+ num_video_frames = len(video_frame_list) # for video styling
161
+
162
+
163
+ # Various functions and classes
164
+ def sinc(x):
165
+ return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
166
+
167
+
168
+ def lanczos(x, a):
169
+ cond = torch.logical_and(-a < x, x < a)
170
+ out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
171
+ return out / out.sum()
172
+
173
+
174
+ def ramp(ratio, width):
175
+ n = math.ceil(width / ratio + 1)
176
+ out = torch.empty([n])
177
+ cur = 0
178
+ for i in range(out.shape[0]):
179
+ out[i] = cur
180
+ cur += ratio
181
+ return torch.cat([-out[1:].flip([0]), out])[1:-1]
182
+
183
+
184
+ # For zoom video
185
+ def zoom_at(img, x, y, zoom):
186
+ w, h = img.size
187
+ zoom2 = zoom * 2
188
+ img = img.crop((x - w / zoom2, y - h / zoom2,
189
+ x + w / zoom2, y + h / zoom2))
190
+ return img.resize((w, h), Image.LANCZOS)
191
+
192
+
193
+ # NR: Testing with different intital images
194
+ def random_noise_image(w,h):
195
+ random_image = Image.fromarray(np.random.randint(0,255,(w,h,3),dtype=np.dtype('uint8')))
196
+ return random_image
197
+
198
+
199
+ # create initial gradient image
200
+ def gradient_2d(start, stop, width, height, is_horizontal):
201
+ if is_horizontal:
202
+ return np.tile(np.linspace(start, stop, width), (height, 1))
203
+ else:
204
+ return np.tile(np.linspace(start, stop, height), (width, 1)).T
205
+
206
+
207
+ def gradient_3d(width, height, start_list, stop_list, is_horizontal_list):
208
+ result = np.zeros((height, width, len(start_list)), dtype=float)
209
+
210
+ for i, (start, stop, is_horizontal) in enumerate(zip(start_list, stop_list, is_horizontal_list)):
211
+ result[:, :, i] = gradient_2d(start, stop, width, height, is_horizontal)
212
+
213
+ return result
214
+
215
+
216
+ def random_gradient_image(w,h):
217
+ array = gradient_3d(w, h, (0, 0, np.random.randint(0,255)), (np.random.randint(1,255), np.random.randint(2,255), np.random.randint(3,128)), (True, False, False))
218
+ random_image = Image.fromarray(np.uint8(array))
219
+ return random_image
220
+
221
+
222
+ # Used in older MakeCutouts
223
+ def resample(input, size, align_corners=True):
224
+ n, c, h, w = input.shape
225
+ dh, dw = size
226
+
227
+ input = input.view([n * c, 1, h, w])
228
+
229
+ if dh < h:
230
+ kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
231
+ pad_h = (kernel_h.shape[0] - 1) // 2
232
+ input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
233
+ input = F.conv2d(input, kernel_h[None, None, :, None])
234
+
235
+ if dw < w:
236
+ kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
237
+ pad_w = (kernel_w.shape[0] - 1) // 2
238
+ input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
239
+ input = F.conv2d(input, kernel_w[None, None, None, :])
240
+
241
+ input = input.view([n, c, h, w])
242
+ return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
243
+
244
+
245
+ class ReplaceGrad(torch.autograd.Function):
246
+ @staticmethod
247
+ def forward(ctx, x_forward, x_backward):
248
+ ctx.shape = x_backward.shape
249
+ return x_forward
250
+
251
+ @staticmethod
252
+ def backward(ctx, grad_in):
253
+ return None, grad_in.sum_to_size(ctx.shape)
254
+
255
+ replace_grad = ReplaceGrad.apply
256
+
257
+
258
+ class ClampWithGrad(torch.autograd.Function):
259
+ @staticmethod
260
+ def forward(ctx, input, min, max):
261
+ ctx.min = min
262
+ ctx.max = max
263
+ ctx.save_for_backward(input)
264
+ return input.clamp(min, max)
265
+
266
+ @staticmethod
267
+ def backward(ctx, grad_in):
268
+ input, = ctx.saved_tensors
269
+ return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None
270
+
271
+ clamp_with_grad = ClampWithGrad.apply
272
+
273
+
274
+ def vector_quantize(x, codebook):
275
+ d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
276
+ indices = d.argmin(-1)
277
+ x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
278
+ return replace_grad(x_q, x)
279
+
280
+
281
+ class Prompt(nn.Module):
282
+ def __init__(self, embed, weight=1., stop=float('-inf')):
283
+ super().__init__()
284
+ self.register_buffer('embed', embed)
285
+ self.register_buffer('weight', torch.as_tensor(weight))
286
+ self.register_buffer('stop', torch.as_tensor(stop))
287
+
288
+ def forward(self, input):
289
+ input_normed = F.normalize(input.unsqueeze(1), dim=2)
290
+ embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
291
+ dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
292
+ dists = dists * self.weight.sign()
293
+ return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
294
+
295
+
296
+ #NR: Split prompts and weights
297
+ def split_prompt(prompt):
298
+ vals = prompt.rsplit(':', 2)
299
+ vals = vals + ['', '1', '-inf'][len(vals):]
300
+ return vals[0], float(vals[1]), float(vals[2])
301
+
302
+
303
+ class MakeCutouts(nn.Module):
304
+ def __init__(self, cut_size, cutn, cut_pow=1.):
305
+ super().__init__()
306
+ self.cut_size = cut_size
307
+ self.cutn = cutn
308
+ self.cut_pow = cut_pow # not used with pooling
309
+
310
+ # Pick your own augments & their order
311
+ augment_list = []
312
+ for item in args.augments[0]:
313
+ if item == 'Ji':
314
+ augment_list.append(K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.7))
315
+ elif item == 'Sh':
316
+ augment_list.append(K.RandomSharpness(sharpness=0.3, p=0.5))
317
+ elif item == 'Gn':
318
+ augment_list.append(K.RandomGaussianNoise(mean=0.0, std=1., p=0.5))
319
+ elif item == 'Pe':
320
+ augment_list.append(K.RandomPerspective(distortion_scale=0.7, p=0.7))
321
+ elif item == 'Ro':
322
+ augment_list.append(K.RandomRotation(degrees=15, p=0.7))
323
+ elif item == 'Af':
324
+ augment_list.append(K.RandomAffine(degrees=15, translate=0.1, shear=5, p=0.7, padding_mode='zeros', keepdim=True)) # border, reflection, zeros
325
+ elif item == 'Et':
326
+ augment_list.append(K.RandomElasticTransform(p=0.7))
327
+ elif item == 'Ts':
328
+ augment_list.append(K.RandomThinPlateSpline(scale=0.8, same_on_batch=True, p=0.7))
329
+ elif item == 'Cr':
330
+ augment_list.append(K.RandomCrop(size=(self.cut_size,self.cut_size), pad_if_needed=True, padding_mode='reflect', p=0.5))
331
+ elif item == 'Er':
332
+ augment_list.append(K.RandomErasing(scale=(.1, .4), ratio=(.3, 1/.3), same_on_batch=True, p=0.7))
333
+ elif item == 'Re':
334
+ augment_list.append(K.RandomResizedCrop(size=(self.cut_size,self.cut_size), scale=(0.1,1), ratio=(0.75,1.333), cropping_mode='resample', p=0.5))
335
+
336
+ self.augs = nn.Sequential(*augment_list)
337
+ self.noise_fac = 0.1
338
+ # self.noise_fac = False
339
+
340
+ # Uncomment if you like seeing the list ;)
341
+ # print(augment_list)
342
+
343
+ # Pooling
344
+ self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
345
+ self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
346
+
347
+ def forward(self, input):
348
+ cutouts = []
349
+
350
+ for _ in range(self.cutn):
351
+ # Use Pooling
352
+ cutout = (self.av_pool(input) + self.max_pool(input))/2
353
+ cutouts.append(cutout)
354
+
355
+ batch = self.augs(torch.cat(cutouts, dim=0))
356
+
357
+ if self.noise_fac:
358
+ facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
359
+ batch = batch + facs * torch.randn_like(batch)
360
+ return batch
361
+
362
+
363
+ # An updated version with Kornia augments and pooling (where my version started):
364
+ class MakeCutoutsPoolingUpdate(nn.Module):
365
+ def __init__(self, cut_size, cutn, cut_pow=1.):
366
+ super().__init__()
367
+ self.cut_size = cut_size
368
+ self.cutn = cutn
369
+ self.cut_pow = cut_pow # Not used with pooling
370
+
371
+ self.augs = nn.Sequential(
372
+ K.RandomAffine(degrees=15, translate=0.1, p=0.7, padding_mode='border'),
373
+ K.RandomPerspective(0.7,p=0.7),
374
+ K.ColorJitter(hue=0.1, saturation=0.1, p=0.7),
375
+ K.RandomErasing((.1, .4), (.3, 1/.3), same_on_batch=True, p=0.7),
376
+ )
377
+
378
+ self.noise_fac = 0.1
379
+ self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
380
+ self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
381
+
382
+ def forward(self, input):
383
+ sideY, sideX = input.shape[2:4]
384
+ max_size = min(sideX, sideY)
385
+ min_size = min(sideX, sideY, self.cut_size)
386
+ cutouts = []
387
+
388
+ for _ in range(self.cutn):
389
+ cutout = (self.av_pool(input) + self.max_pool(input))/2
390
+ cutouts.append(cutout)
391
+
392
+ batch = self.augs(torch.cat(cutouts, dim=0))
393
+
394
+ if self.noise_fac:
395
+ facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
396
+ batch = batch + facs * torch.randn_like(batch)
397
+ return batch
398
+
399
+
400
+ # An Nerdy updated version with selectable Kornia augments, but no pooling:
401
+ class MakeCutoutsNRUpdate(nn.Module):
402
+ def __init__(self, cut_size, cutn, cut_pow=1.):
403
+ super().__init__()
404
+ self.cut_size = cut_size
405
+ self.cutn = cutn
406
+ self.cut_pow = cut_pow
407
+ self.noise_fac = 0.1
408
+
409
+ # Pick your own augments & their order
410
+ augment_list = []
411
+ for item in args.augments[0]:
412
+ if item == 'Ji':
413
+ augment_list.append(K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.7))
414
+ elif item == 'Sh':
415
+ augment_list.append(K.RandomSharpness(sharpness=0.3, p=0.5))
416
+ elif item == 'Gn':
417
+ augment_list.append(K.RandomGaussianNoise(mean=0.0, std=1., p=0.5))
418
+ elif item == 'Pe':
419
+ augment_list.append(K.RandomPerspective(distortion_scale=0.5, p=0.7))
420
+ elif item == 'Ro':
421
+ augment_list.append(K.RandomRotation(degrees=15, p=0.7))
422
+ elif item == 'Af':
423
+ augment_list.append(K.RandomAffine(degrees=30, translate=0.1, shear=5, p=0.7, padding_mode='zeros', keepdim=True)) # border, reflection, zeros
424
+ elif item == 'Et':
425
+ augment_list.append(K.RandomElasticTransform(p=0.7))
426
+ elif item == 'Ts':
427
+ augment_list.append(K.RandomThinPlateSpline(scale=0.8, same_on_batch=True, p=0.7))
428
+ elif item == 'Cr':
429
+ augment_list.append(K.RandomCrop(size=(self.cut_size,self.cut_size), pad_if_needed=True, padding_mode='reflect', p=0.5))
430
+ elif item == 'Er':
431
+ augment_list.append(K.RandomErasing(scale=(.1, .4), ratio=(.3, 1/.3), same_on_batch=True, p=0.7))
432
+ elif item == 'Re':
433
+ augment_list.append(K.RandomResizedCrop(size=(self.cut_size,self.cut_size), scale=(0.1,1), ratio=(0.75,1.333), cropping_mode='resample', p=0.5))
434
+
435
+ self.augs = nn.Sequential(*augment_list)
436
+
437
+
438
+ def forward(self, input):
439
+ sideY, sideX = input.shape[2:4]
440
+ max_size = min(sideX, sideY)
441
+ min_size = min(sideX, sideY, self.cut_size)
442
+ cutouts = []
443
+ for _ in range(self.cutn):
444
+ size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
445
+ offsetx = torch.randint(0, sideX - size + 1, ())
446
+ offsety = torch.randint(0, sideY - size + 1, ())
447
+ cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
448
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
449
+ batch = self.augs(torch.cat(cutouts, dim=0))
450
+ if self.noise_fac:
451
+ facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
452
+ batch = batch + facs * torch.randn_like(batch)
453
+ return batch
454
+
455
+
456
+ # An updated version with Kornia augments, but no pooling:
457
+ class MakeCutoutsUpdate(nn.Module):
458
+ def __init__(self, cut_size, cutn, cut_pow=1.):
459
+ super().__init__()
460
+ self.cut_size = cut_size
461
+ self.cutn = cutn
462
+ self.cut_pow = cut_pow
463
+ self.augs = nn.Sequential(
464
+ K.RandomHorizontalFlip(p=0.5),
465
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
466
+ # K.RandomSolarize(0.01, 0.01, p=0.7),
467
+ K.RandomSharpness(0.3,p=0.4),
468
+ K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),
469
+ K.RandomPerspective(0.2,p=0.4),)
470
+ self.noise_fac = 0.1
471
+
472
+
473
+ def forward(self, input):
474
+ sideY, sideX = input.shape[2:4]
475
+ max_size = min(sideX, sideY)
476
+ min_size = min(sideX, sideY, self.cut_size)
477
+ cutouts = []
478
+ for _ in range(self.cutn):
479
+ size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
480
+ offsetx = torch.randint(0, sideX - size + 1, ())
481
+ offsety = torch.randint(0, sideY - size + 1, ())
482
+ cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
483
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
484
+ batch = self.augs(torch.cat(cutouts, dim=0))
485
+ if self.noise_fac:
486
+ facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
487
+ batch = batch + facs * torch.randn_like(batch)
488
+ return batch
489
+
490
+
491
+ # This is the original version (No pooling)
492
+ class MakeCutoutsOrig(nn.Module):
493
+ def __init__(self, cut_size, cutn, cut_pow=1.):
494
+ super().__init__()
495
+ self.cut_size = cut_size
496
+ self.cutn = cutn
497
+ self.cut_pow = cut_pow
498
+
499
+ def forward(self, input):
500
+ sideY, sideX = input.shape[2:4]
501
+ max_size = min(sideX, sideY)
502
+ min_size = min(sideX, sideY, self.cut_size)
503
+ cutouts = []
504
+ for _ in range(self.cutn):
505
+ size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
506
+ offsetx = torch.randint(0, sideX - size + 1, ())
507
+ offsety = torch.randint(0, sideY - size + 1, ())
508
+ cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
509
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
510
+ return clamp_with_grad(torch.cat(cutouts, dim=0), 0, 1)
511
+
512
+
513
+ def load_vqgan_model(config_path, checkpoint_path):
514
+ global gumbel
515
+ gumbel = False
516
+ config = OmegaConf.load(config_path)
517
+ if config.model.target == 'taming.models.vqgan.VQModel':
518
+ model = vqgan.VQModel(**config.model.params)
519
+ model.eval().requires_grad_(False)
520
+ model.init_from_ckpt(checkpoint_path)
521
+ elif config.model.target == 'taming.models.vqgan.GumbelVQ':
522
+ model = vqgan.GumbelVQ(**config.model.params)
523
+ model.eval().requires_grad_(False)
524
+ model.init_from_ckpt(checkpoint_path)
525
+ gumbel = True
526
+ elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':
527
+ parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
528
+ parent_model.eval().requires_grad_(False)
529
+ parent_model.init_from_ckpt(checkpoint_path)
530
+ model = parent_model.first_stage_model
531
+ else:
532
+ raise ValueError(f'unknown model type: {config.model.target}')
533
+ del model.loss
534
+ return model
535
+
536
+
537
+ def resize_image(image, out_size):
538
+ ratio = image.size[0] / image.size[1]
539
+ area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
540
+ size = round((area * ratio)**0.5), round((area / ratio)**0.5)
541
+ return image.resize(size, Image.LANCZOS)
542
+
543
+
544
+ # Do it
545
+ device = torch.device(args.cuda_device)
546
+ model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
547
+ jit = True if "1.7.1" in torch.__version__ else False
548
+ perceptor = clip.load(args.clip_model, jit=jit)[0].eval().requires_grad_(False).to(device)
549
+
550
+ # clock=deepcopy(perceptor.visual.positional_embedding.data)
551
+ # perceptor.visual.positional_embedding.data = clock/clock.max()
552
+ # perceptor.visual.positional_embedding.data=clamp_with_grad(clock,0,1)
553
+
554
+ cut_size = perceptor.visual.input_resolution
555
+ f = 2**(model.decoder.num_resolutions - 1)
556
+
557
+ # Cutout class options:
558
+ # 'latest','original','updated' or 'updatedpooling'
559
+ if args.cut_method == 'latest':
560
+ make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)
561
+ elif args.cut_method == 'original':
562
+ make_cutouts = MakeCutoutsOrig(cut_size, args.cutn, cut_pow=args.cut_pow)
563
+ elif args.cut_method == 'updated':
564
+ make_cutouts = MakeCutoutsUpdate(cut_size, args.cutn, cut_pow=args.cut_pow)
565
+ elif args.cut_method == 'nrupdated':
566
+ make_cutouts = MakeCutoutsNRUpdate(cut_size, args.cutn, cut_pow=args.cut_pow)
567
+ else:
568
+ make_cutouts = MakeCutoutsPoolingUpdate(cut_size, args.cutn, cut_pow=args.cut_pow)
569
+
570
+ toksX, toksY = args.size[0] // f, args.size[1] // f
571
+ sideX, sideY = toksX * f, toksY * f
572
+
573
+ # Gumbel or not?
574
+ if gumbel:
575
+ e_dim = 256
576
+ n_toks = model.quantize.n_embed
577
+ z_min = model.quantize.embed.weight.min(dim=0).values[None, :, None, None]
578
+ z_max = model.quantize.embed.weight.max(dim=0).values[None, :, None, None]
579
+ else:
580
+ e_dim = model.quantize.e_dim
581
+ n_toks = model.quantize.n_e
582
+ z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
583
+ z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
584
+
585
+
586
+ if args.init_image:
587
+ if 'http' in args.init_image:
588
+ img = Image.open(urlopen(args.init_image))
589
+ else:
590
+ img = Image.open(args.init_image)
591
+ pil_image = img.convert('RGB')
592
+ pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
593
+ pil_tensor = TF.to_tensor(pil_image)
594
+ z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)
595
+ elif args.init_noise == 'pixels':
596
+ img = random_noise_image(args.size[0], args.size[1])
597
+ pil_image = img.convert('RGB')
598
+ pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
599
+ pil_tensor = TF.to_tensor(pil_image)
600
+ z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)
601
+ elif args.init_noise == 'gradient':
602
+ img = random_gradient_image(args.size[0], args.size[1])
603
+ pil_image = img.convert('RGB')
604
+ pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
605
+ pil_tensor = TF.to_tensor(pil_image)
606
+ z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)
607
+ else:
608
+ one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float()
609
+ # z = one_hot @ model.quantize.embedding.weight
610
+ if gumbel:
611
+ z = one_hot @ model.quantize.embed.weight
612
+ else:
613
+ z = one_hot @ model.quantize.embedding.weight
614
+
615
+ z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
616
+ #z = torch.rand_like(z)*2 # NR: check
617
+
618
+ z_orig = z.clone()
619
+ z.requires_grad_(True)
620
+
621
+ pMs = []
622
+ normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
623
+ std=[0.26862954, 0.26130258, 0.27577711])
624
+
625
+ # From imagenet - Which is better?
626
+ #normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
627
+ # std=[0.229, 0.224, 0.225])
628
+
629
+ # CLIP tokenize/encode
630
+ if args.prompts:
631
+ for prompt in args.prompts:
632
+ txt, weight, stop = split_prompt(prompt)
633
+ embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
634
+ pMs.append(Prompt(embed, weight, stop).to(device))
635
+
636
+ for prompt in args.image_prompts:
637
+ path, weight, stop = split_prompt(prompt)
638
+ img = Image.open(path)
639
+ pil_image = img.convert('RGB')
640
+ img = resize_image(pil_image, (sideX, sideY))
641
+ batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
642
+ embed = perceptor.encode_image(normalize(batch)).float()
643
+ pMs.append(Prompt(embed, weight, stop).to(device))
644
+
645
+ for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
646
+ gen = torch.Generator().manual_seed(seed)
647
+ embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
648
+ pMs.append(Prompt(embed, weight).to(device))
649
+
650
+
651
+ # Set the optimiser
652
+ def get_opt(opt_name, opt_lr):
653
+ if opt_name == "Adam":
654
+ opt = optim.Adam([z], lr=opt_lr) # LR=0.1 (Default)
655
+ elif opt_name == "AdamW":
656
+ opt = optim.AdamW([z], lr=opt_lr)
657
+ elif opt_name == "Adagrad":
658
+ opt = optim.Adagrad([z], lr=opt_lr)
659
+ elif opt_name == "Adamax":
660
+ opt = optim.Adamax([z], lr=opt_lr)
661
+ elif opt_name == "DiffGrad":
662
+ opt = DiffGrad([z], lr=opt_lr, eps=1e-9, weight_decay=1e-9) # NR: Playing for reasons
663
+ elif opt_name == "AdamP":
664
+ opt = AdamP([z], lr=opt_lr)
665
+ elif opt_name == "RAdam":
666
+ opt = optim.RAdam([z], lr=opt_lr)
667
+ elif opt_name == "RMSprop":
668
+ opt = optim.RMSprop([z], lr=opt_lr)
669
+ else:
670
+ print("Unknown optimiser. Are choices broken?")
671
+ opt = optim.Adam([z], lr=opt_lr)
672
+ return opt
673
+
674
+ opt = get_opt(args.optimiser, args.step_size)
675
+
676
+
677
+ # Output for the user
678
+ print('Using device:', device)
679
+ print('Optimising using:', args.optimiser)
680
+
681
+ if args.prompts:
682
+ print('Using text prompts:', args.prompts)
683
+ if args.image_prompts:
684
+ print('Using image prompts:', args.image_prompts)
685
+ if args.init_image:
686
+ print('Using initial image:', args.init_image)
687
+ if args.noise_prompt_weights:
688
+ print('Noise prompt weights:', args.noise_prompt_weights)
689
+
690
+
691
+ if args.seed is None:
692
+ seed = torch.seed()
693
+ else:
694
+ seed = args.seed
695
+ torch.manual_seed(seed)
696
+ print('Using seed:', seed)
697
+
698
+
699
+ # Vector quantize
700
+ def synth(z):
701
+ if gumbel:
702
+ z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim(3, 1)
703
+ else:
704
+ z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)
705
+ return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)
706
+
707
+
708
+ #@torch.no_grad()
709
+ @torch.inference_mode()
710
+ def checkin(i, losses):
711
+ losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
712
+ tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')
713
+ out = synth(z)
714
+ info = PngImagePlugin.PngInfo()
715
+ info.add_text('comment', f'{args.prompts}')
716
+ TF.to_pil_image(out[0].cpu()).save(args.output, pnginfo=info)
717
+
718
+
719
+ def ascend_txt():
720
+ global i
721
+ out = synth(z)
722
+ iii = perceptor.encode_image(normalize(make_cutouts(out))).float()
723
+
724
+ result = []
725
+
726
+ if args.init_weight:
727
+ # result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)
728
+ result.append(F.mse_loss(z, torch.zeros_like(z_orig)) * ((1/torch.tensor(i*2 + 1))*args.init_weight) / 2)
729
+
730
+ for prompt in pMs:
731
+ result.append(prompt(iii))
732
+
733
+ if args.make_video:
734
+ img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]
735
+ img = np.transpose(img, (1, 2, 0))
736
+ imageio.imwrite('./steps/' + str(i) + '.png', np.array(img))
737
+
738
+ return result # return loss
739
+
740
+
741
+ def train(i):
742
+ opt.zero_grad(set_to_none=True)
743
+ lossAll = ascend_txt()
744
+
745
+ if i % args.display_freq == 0:
746
+ checkin(i, lossAll)
747
+
748
+ loss = sum(lossAll)
749
+ loss.backward()
750
+ opt.step()
751
+
752
+ #with torch.no_grad():
753
+ with torch.inference_mode():
754
+ z.copy_(z.maximum(z_min).minimum(z_max))
755
+
756
+
757
+
758
+ i = 0 # Iteration counter
759
+ j = 0 # Zoom video frame counter
760
+ p = 1 # Phrase counter
761
+ smoother = 0 # Smoother counter
762
+ this_video_frame = 0 # for video styling
763
+
764
+ # Messing with learning rate / optimisers
765
+ #variable_lr = args.step_size
766
+ #optimiser_list = [['Adam',0.075],['AdamW',0.125],['Adagrad',0.2],['Adamax',0.125],['DiffGrad',0.075],['RAdam',0.125],['RMSprop',0.02]]
767
+
768
+ # Do it
769
+ try:
770
+ with tqdm() as pbar:
771
+ while True:
772
+ # Change generated image
773
+ if args.make_zoom_video:
774
+ if i % args.zoom_frequency == 0:
775
+ out = synth(z)
776
+
777
+ # Save image
778
+ img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]
779
+ img = np.transpose(img, (1, 2, 0))
780
+ imageio.imwrite('./steps/' + str(j) + '.png', np.array(img))
781
+
782
+ # Time to start zooming?
783
+ if args.zoom_start <= i:
784
+ # Convert z back into a Pil image
785
+ #pil_image = TF.to_pil_image(out[0].cpu())
786
+
787
+ # Convert NP to Pil image
788
+ pil_image = Image.fromarray(np.array(img).astype('uint8'), 'RGB')
789
+
790
+ # Zoom
791
+ if args.zoom_scale != 1:
792
+ pil_image_zoom = zoom_at(pil_image, sideX/2, sideY/2, args.zoom_scale)
793
+ else:
794
+ pil_image_zoom = pil_image
795
+
796
+ # Shift - https://pillow.readthedocs.io/en/latest/reference/ImageChops.html
797
+ if args.zoom_shift_x or args.zoom_shift_y:
798
+ # This one wraps the image
799
+ pil_image_zoom = ImageChops.offset(pil_image_zoom, args.zoom_shift_x, args.zoom_shift_y)
800
+
801
+ # Convert image back to a tensor again
802
+ pil_tensor = TF.to_tensor(pil_image_zoom)
803
+
804
+ # Re-encode
805
+ z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)
806
+ z_orig = z.clone()
807
+ z.requires_grad_(True)
808
+
809
+ # Re-create optimiser
810
+ opt = get_opt(args.optimiser, args.step_size)
811
+
812
+ # Next
813
+ j += 1
814
+
815
+ # Change text prompt
816
+ if args.prompt_frequency > 0:
817
+ if i % args.prompt_frequency == 0 and i > 0:
818
+ # In case there aren't enough phrases, just loop
819
+ if p >= len(all_phrases):
820
+ p = 0
821
+
822
+ pMs = []
823
+ args.prompts = all_phrases[p]
824
+
825
+ # Show user we're changing prompt
826
+ print(args.prompts)
827
+
828
+ for prompt in args.prompts:
829
+ txt, weight, stop = split_prompt(prompt)
830
+ embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
831
+ pMs.append(Prompt(embed, weight, stop).to(device))
832
+
833
+ '''
834
+ # Smooth test
835
+ smoother = args.zoom_frequency * 15 # smoothing over x frames
836
+ variable_lr = args.step_size * 0.25
837
+ opt = get_opt(args.optimiser, variable_lr)
838
+ '''
839
+
840
+ p += 1
841
+
842
+ '''
843
+ if smoother > 0:
844
+ if smoother == 1:
845
+ opt = get_opt(args.optimiser, args.step_size)
846
+ smoother -= 1
847
+ '''
848
+
849
+ '''
850
+ # Messing with learning rate / optimisers
851
+ if i % 225 == 0 and i > 0:
852
+ variable_optimiser_item = random.choice(optimiser_list)
853
+ variable_optimiser = variable_optimiser_item[0]
854
+ variable_lr = variable_optimiser_item[1]
855
+
856
+ opt = get_opt(variable_optimiser, variable_lr)
857
+ print("New opt: %s, lr= %f" %(variable_optimiser,variable_lr))
858
+ '''
859
+
860
+
861
+ # Training time
862
+ train(i)
863
+
864
+ # Ready to stop yet?
865
+ if i == args.max_iterations:
866
+ if not args.video_style_dir:
867
+ # we're done
868
+ break
869
+ else:
870
+ if this_video_frame == (num_video_frames - 1):
871
+ # we're done
872
+ make_styled_video = True
873
+ break
874
+ else:
875
+ # Next video frame
876
+ this_video_frame += 1
877
+
878
+ # Reset the iteration count
879
+ i = -1
880
+ pbar.reset()
881
+
882
+ # Load the next frame, reset a few options - same filename, different directory
883
+ args.init_image = video_frame_list[this_video_frame]
884
+ print("Next frame: ", args.init_image)
885
+
886
+ if args.seed is None:
887
+ seed = torch.seed()
888
+ else:
889
+ seed = args.seed
890
+ torch.manual_seed(seed)
891
+ print("Seed: ", seed)
892
+
893
+ filename = os.path.basename(args.init_image)
894
+ args.output = os.path.join(cwd, "steps", filename)
895
+
896
+ # Load and resize image
897
+ img = Image.open(args.init_image)
898
+ pil_image = img.convert('RGB')
899
+ pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
900
+ pil_tensor = TF.to_tensor(pil_image)
901
+
902
+ # Re-encode
903
+ z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)
904
+ z_orig = z.clone()
905
+ z.requires_grad_(True)
906
+
907
+ # Re-create optimiser
908
+ opt = get_opt(args.optimiser, args.step_size)
909
+
910
+ i += 1
911
+ pbar.update()
912
+ except KeyboardInterrupt:
913
+ pass
914
+
915
+ # All done :)
916
+
917
+ # Video generation
918
+ if args.make_video or args.make_zoom_video:
919
+ init_frame = 1 # Initial video frame
920
+ if args.make_zoom_video:
921
+ last_frame = j
922
+ else:
923
+ last_frame = i # This will raise an error if that number of frames does not exist.
924
+
925
+ length = args.video_length # Desired time of the video in seconds
926
+
927
+ min_fps = 10
928
+ max_fps = 60
929
+
930
+ total_frames = last_frame-init_frame
931
+
932
+ frames = []
933
+ tqdm.write('Generating video...')
934
+ for i in range(init_frame,last_frame):
935
+ temp = Image.open("./steps/"+ str(i) +'.png')
936
+ keep = temp.copy()
937
+ frames.append(keep)
938
+ temp.close()
939
+
940
+ if args.output_video_fps > 9:
941
+ # Hardware encoding and video frame interpolation
942
+ print("Creating interpolated frames...")
943
+ ffmpeg_filter = f"minterpolate='mi_mode=mci:me=hexbs:me_mode=bidir:mc_mode=aobmc:vsbmc=1:mb_size=8:search_param=32:fps={args.output_video_fps}'"
944
+ output_file = re.compile('\.png$').sub('.mp4', args.output)
945
+ try:
946
+ p = Popen(['ffmpeg',
947
+ '-y',
948
+ '-f', 'image2pipe',
949
+ '-vcodec', 'png',
950
+ '-r', str(args.input_video_fps),
951
+ '-i',
952
+ '-',
953
+ '-b:v', '10M',
954
+ '-vcodec', 'h264_nvenc',
955
+ '-pix_fmt', 'yuv420p',
956
+ '-strict', '-2',
957
+ '-filter:v', f'{ffmpeg_filter}',
958
+ '-metadata', f'comment={args.prompts}',
959
+ output_file], stdin=PIPE)
960
+ except FileNotFoundError:
961
+ print("ffmpeg command failed - check your installation")
962
+ for im in tqdm(frames):
963
+ im.save(p.stdin, 'PNG')
964
+ p.stdin.close()
965
+ p.wait()
966
+ else:
967
+ # CPU
968
+ fps = np.clip(total_frames/length,min_fps,max_fps)
969
+ output_file = re.compile('\.png$').sub('.mp4', args.output)
970
+ try:
971
+ p = Popen(['ffmpeg',
972
+ '-y',
973
+ '-f', 'image2pipe',
974
+ '-vcodec', 'png',
975
+ '-r', str(fps),
976
+ '-i',
977
+ '-',
978
+ '-vcodec', 'libx264',
979
+ '-r', str(fps),
980
+ '-pix_fmt', 'yuv420p',
981
+ '-crf', '17',
982
+ '-preset', 'veryslow',
983
+ '-metadata', f'comment={args.prompts}',
984
+ output_file], stdin=PIPE)
985
+ except FileNotFoundError:
986
+ print("ffmpeg command failed - check your installation")
987
+ for im in tqdm(frames):
988
+ im.save(p.stdin, 'PNG')
989
+ p.stdin.close()
990
+ p.wait()