apolinario commited on
Commit
077fc3d
1 Parent(s): 51df617

try vqgan on this space

Browse files
Files changed (2) hide show
  1. app.py +2352 -7
  2. requirements.txt +29 -1
app.py CHANGED
@@ -1,11 +1,2356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- is_cuda = torch.cuda.is_available()
5
- def greet(name):
6
- if is_cuda:
7
- return "Hello cuda" + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  else:
9
- return "Hello ooops" + name + "!!"
10
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  iface.launch()
 
1
+ import sys
2
+ import argparse
3
+ import math
4
+ from pathlib import Path
5
+ import sys
6
+ import pandas as pd
7
+ from base64 import b64encode
8
+ from omegaconf import OmegaConf
9
+ from PIL import Image
10
+ from taming.models import cond_transformer, vqgan
11
+ import torch
12
+ from os.path import exists as path_exists
13
+
14
+ torch.cuda.empty_cache()
15
+ from torch import nn
16
+ import torch.optim as optim
17
+ from torch import optim
18
+ from torch.nn import functional as F
19
+ from torchvision import transforms
20
+ from torchvision.transforms import functional as TF
21
+ import torchvision.transforms as T
22
+ from git.repo.base import Repo
23
+
24
+ if not (path_exists(f"CLIP")):
25
+ Repo.clone_from("https://github.com/openai/CLIP", "CLIP")
26
+
27
+ from CLIP import clip
28
  import gradio as gr
29
+ import kornia.augmentation as K
30
+ import numpy as np
31
+ import subprocess
32
+ import imageio
33
+ from PIL import ImageFile, Image
34
+ import time
35
+ import base64
36
+
37
+ import hashlib
38
+ from PIL.PngImagePlugin import PngImageFile, PngInfo
39
+ import json
40
+ import urllib.request
41
+ from random import randint
42
+ from pathvalidate import sanitize_filename
43
+ from huggingface_hub import hf_hub_download
44
+ import shortuuid
45
+
46
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
47
+ print("Using device:", device)
48
+
49
+ vqgan_model = hf_hub_download(repo_id="boris/vqgan_f16_16384", filename="model.ckpt")
50
+ vqgan_config = hf_hub_download(repo_id="boris/vqgan_f16_16384", filename="config.yaml")
51
+
52
+ def load_vqgan_model(config_path, checkpoint_path):
53
+ config = OmegaConf.load(config_path)
54
+ if config.model.target == "taming.models.vqgan.VQModel":
55
+ model = vqgan.VQModel(**config.model.params)
56
+ model.eval().requires_grad_(False)
57
+ model.init_from_ckpt(checkpoint_path)
58
+ elif config.model.target == "taming.models.cond_transformer.Net2NetTransformer":
59
+ parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
60
+ parent_model.eval().requires_grad_(False)
61
+ parent_model.init_from_ckpt(checkpoint_path)
62
+ model = parent_model.first_stage_model
63
+ elif config.model.target == "taming.models.vqgan.GumbelVQ":
64
+ model = vqgan.GumbelVQ(**config.model.params)
65
+ # print(config.model.params)
66
+ model.eval().requires_grad_(False)
67
+ model.init_from_ckpt(checkpoint_path)
68
+ else:
69
+ raise ValueError(f"unknown model type: {config.model.target}")
70
+ del model.loss
71
+ return model
72
+ model = load_vqgan_model(vqgan_config, vqgan_model).to(device)
73
+ perceptor = (
74
+ clip.load("ViT-B/32", jit=False)[0]
75
+ .eval()
76
+ .requires_grad_(False)
77
+ .to(device)
78
+ )
79
+ def run_all(user_input, num_steps, flavor, template, width, height):
80
+ import random
81
+ #if uploaded_file is not None:
82
+ #uploaded_folder = f"{DefaultPaths.root_path}/uploaded"
83
+ #if not path_exists(uploaded_folder):
84
+ # os.makedirs(uploaded_folder)
85
+ #image_data = uploaded_file.read()
86
+ #f = open(f"{uploaded_folder}/{uploaded_file.name}", "wb")
87
+ #f.write(image_data)
88
+ #f.close()
89
+ #image_path = f"{uploaded_folder}/{uploaded_file.name}"
90
+ #pass
91
+ #else:
92
+ image_path = None
93
+ url = shortuuid.uuid()
94
+ args2 = argparse.Namespace(
95
+ prompt=user_input,
96
+ seed=int(random.randint(0, 2147483647)),
97
+ sizex=width,
98
+ sizey=height,
99
+ flavor=flavor,
100
+ iterations=num_steps,
101
+ mse=True,
102
+ update=100,
103
+ template=template,
104
+ vqgan_model='ImageNet 16384',
105
+ seed_image=image_path,
106
+ image_file=f"{url}.png",
107
+ #frame_dir=intermediary_folder,
108
+ )
109
+ if args2.seed is not None:
110
+ import torch
111
+
112
+ import numpy as np
113
+
114
+ np.random.seed(args2.seed)
115
+ import random
116
+
117
+ random.seed(args2.seed)
118
+ # next line forces deterministic random values, but causes other issues with resampling (uncomment to see)
119
+ torch.manual_seed(args2.seed)
120
+ torch.cuda.manual_seed(args2.seed)
121
+ torch.cuda.manual_seed_all(args2.seed)
122
+ torch.backends.cudnn.deterministic = True
123
+ torch.backends.cudnn.benchmark = False
124
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
125
+ print("Using device:", device)
126
+
127
+ def noise_gen(shape, octaves=5):
128
+ n, c, h, w = shape
129
+ noise = torch.zeros([n, c, 1, 1])
130
+ max_octaves = min(octaves, math.log(h) / math.log(2), math.log(w) / math.log(2))
131
+ for i in reversed(range(max_octaves)):
132
+ h_cur, w_cur = h // 2**i, w // 2**i
133
+ noise = F.interpolate(
134
+ noise, (h_cur, w_cur), mode="bicubic", align_corners=False
135
+ )
136
+ noise += torch.randn([n, c, h_cur, w_cur]) / 5
137
+ return noise
138
+
139
+ def sinc(x):
140
+ return torch.where(
141
+ x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])
142
+ )
143
+
144
+ def lanczos(x, a):
145
+ cond = torch.logical_and(-a < x, x < a)
146
+ out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([]))
147
+ return out / out.sum()
148
+
149
+ def ramp(ratio, width):
150
+ n = math.ceil(width / ratio + 1)
151
+ out = torch.empty([n])
152
+ cur = 0
153
+ for i in range(out.shape[0]):
154
+ out[i] = cur
155
+ cur += ratio
156
+ return torch.cat([-out[1:].flip([0]), out])[1:-1]
157
+
158
+ def resample(input, size, align_corners=True):
159
+ n, c, h, w = input.shape
160
+ dh, dw = size
161
+
162
+ input = input.view([n * c, 1, h, w])
163
+
164
+ if dh < h:
165
+ kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
166
+ pad_h = (kernel_h.shape[0] - 1) // 2
167
+ input = F.pad(input, (0, 0, pad_h, pad_h), "reflect")
168
+ input = F.conv2d(input, kernel_h[None, None, :, None])
169
+
170
+ if dw < w:
171
+ kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
172
+ pad_w = (kernel_w.shape[0] - 1) // 2
173
+ input = F.pad(input, (pad_w, pad_w, 0, 0), "reflect")
174
+ input = F.conv2d(input, kernel_w[None, None, None, :])
175
+
176
+ input = input.view([n, c, h, w])
177
+ return F.interpolate(input, size, mode="bicubic", align_corners=align_corners)
178
+
179
+ def lerp(a, b, f):
180
+ return (a * (1.0 - f)) + (b * f)
181
+
182
+ class ReplaceGrad(torch.autograd.Function):
183
+ @staticmethod
184
+ def forward(ctx, x_forward, x_backward):
185
+ ctx.shape = x_backward.shape
186
+ return x_forward
187
+
188
+ @staticmethod
189
+ def backward(ctx, grad_in):
190
+ return None, grad_in.sum_to_size(ctx.shape)
191
+
192
+ replace_grad = ReplaceGrad.apply
193
+
194
+ class ClampWithGrad(torch.autograd.Function):
195
+ @staticmethod
196
+ def forward(ctx, input, min, max):
197
+ ctx.min = min
198
+ ctx.max = max
199
+ ctx.save_for_backward(input)
200
+ return input.clamp(min, max)
201
+
202
+ @staticmethod
203
+ def backward(ctx, grad_in):
204
+ (input,) = ctx.saved_tensors
205
+ return (
206
+ grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0),
207
+ None,
208
+ None,
209
+ )
210
+
211
+ clamp_with_grad = ClampWithGrad.apply
212
+
213
+ def vector_quantize(x, codebook):
214
+ d = (
215
+ x.pow(2).sum(dim=-1, keepdim=True)
216
+ + codebook.pow(2).sum(dim=1)
217
+ - 2 * x @ codebook.T
218
+ )
219
+ indices = d.argmin(-1)
220
+ x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
221
+ return replace_grad(x_q, x)
222
+
223
+ class Prompt(nn.Module):
224
+ def __init__(self, embed, weight=1.0, stop=float("-inf")):
225
+ super().__init__()
226
+ self.register_buffer("embed", embed)
227
+ self.register_buffer("weight", torch.as_tensor(weight))
228
+ self.register_buffer("stop", torch.as_tensor(stop))
229
+
230
+ def forward(self, input):
231
+ input_normed = F.normalize(input.unsqueeze(1), dim=2)
232
+ embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
233
+ dists = (
234
+ input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
235
+ )
236
+ dists = dists * self.weight.sign()
237
+ return (
238
+ self.weight.abs()
239
+ * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
240
+ )
241
+
242
+ def parse_prompt(prompt):
243
+ if prompt.startswith("http://") or prompt.startswith("https://"):
244
+ vals = prompt.rsplit(":", 1)
245
+ vals = [vals[0] + ":" + vals[1], *vals[2:]]
246
+ else:
247
+ vals = prompt.rsplit(":", 1)
248
+ vals = vals + ["", "1", "-inf"][len(vals) :]
249
+ return vals[0], float(vals[1]), float(vals[2])
250
+
251
+ def one_sided_clip_loss(input, target, labels=None, logit_scale=100):
252
+ input_normed = F.normalize(input, dim=-1)
253
+ target_normed = F.normalize(target, dim=-1)
254
+ logits = input_normed @ target_normed.T * logit_scale
255
+ if labels is None:
256
+ labels = torch.arange(len(input), device=logits.device)
257
+ return F.cross_entropy(logits, labels)
258
+
259
+ class EMATensor(nn.Module):
260
+ """implmeneted by Katherine Crowson"""
261
+
262
+ def __init__(self, tensor, decay):
263
+ super().__init__()
264
+ self.tensor = nn.Parameter(tensor)
265
+ self.register_buffer("biased", torch.zeros_like(tensor))
266
+ self.register_buffer("average", torch.zeros_like(tensor))
267
+ self.decay = decay
268
+ self.register_buffer("accum", torch.tensor(1.0))
269
+ self.update()
270
+
271
+ @torch.no_grad()
272
+ def update(self):
273
+ if not self.training:
274
+ raise RuntimeError("update() should only be called during training")
275
+
276
+ self.accum *= self.decay
277
+ self.biased.mul_(self.decay)
278
+ self.biased.add_((1 - self.decay) * self.tensor)
279
+ self.average.copy_(self.biased)
280
+ self.average.div_(1 - self.accum)
281
+
282
+ def forward(self):
283
+ if self.training:
284
+ return self.tensor
285
+ return self.average
286
+
287
+ class MakeCutoutsCustom(nn.Module):
288
+ def __init__(self, cut_size, cutn, cut_pow, augs):
289
+ super().__init__()
290
+ self.cut_size = cut_size
291
+ # tqdm.write(f"cut size: {self.cut_size}")
292
+ self.cutn = cutn
293
+ self.cut_pow = cut_pow
294
+ self.noise_fac = 0.1
295
+ self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
296
+ self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
297
+ self.augs = nn.Sequential(
298
+ K.RandomHorizontalFlip(p=Random_Horizontal_Flip),
299
+ K.RandomSharpness(Random_Sharpness, p=Random_Sharpness_P),
300
+ K.RandomGaussianBlur(
301
+ (Random_Gaussian_Blur),
302
+ (Random_Gaussian_Blur_W, Random_Gaussian_Blur_W),
303
+ p=Random_Gaussian_Blur_P,
304
+ ),
305
+ K.RandomGaussianNoise(p=Random_Gaussian_Noise_P),
306
+ K.RandomElasticTransform(
307
+ kernel_size=(
308
+ Random_Elastic_Transform_Kernel_Size_W,
309
+ Random_Elastic_Transform_Kernel_Size_H,
310
+ ),
311
+ sigma=(Random_Elastic_Transform_Sigma),
312
+ p=Random_Elastic_Transform_P,
313
+ ),
314
+ K.RandomAffine(
315
+ degrees=Random_Affine_Degrees,
316
+ translate=Random_Affine_Translate,
317
+ p=Random_Affine_P,
318
+ padding_mode="border",
319
+ ),
320
+ K.RandomPerspective(Random_Perspective, p=Random_Perspective_P),
321
+ K.ColorJitter(
322
+ hue=Color_Jitter_Hue,
323
+ saturation=Color_Jitter_Saturation,
324
+ p=Color_Jitter_P,
325
+ ),
326
+ )
327
+ # K.RandomErasing((0.1, 0.7), (0.3, 1/0.4), same_on_batch=True, p=0.2),)
328
+
329
+ def set_cut_pow(self, cut_pow):
330
+ self.cut_pow = cut_pow
331
+
332
+ def forward(self, input):
333
+ sideY, sideX = input.shape[2:4]
334
+ max_size = min(sideX, sideY)
335
+ min_size = min(sideX, sideY, self.cut_size)
336
+ cutouts = []
337
+ cutouts_full = []
338
+ noise_fac = 0.1
339
+
340
+ min_size_width = min(sideX, sideY)
341
+ lower_bound = float(self.cut_size / min_size_width)
342
+
343
+ for ii in range(self.cutn):
344
+
345
+ # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
346
+ randsize = (
347
+ torch.zeros(
348
+ 1,
349
+ )
350
+ .normal_(mean=0.8, std=0.3)
351
+ .clip(lower_bound, 1.0)
352
+ )
353
+ size_mult = randsize**self.cut_pow
354
+ size = int(
355
+ min_size_width * (size_mult.clip(lower_bound, 1.0))
356
+ ) # replace .5 with a result for 224 the default large size is .95
357
+ # size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95
358
+
359
+ offsetx = torch.randint(0, sideX - size + 1, ())
360
+ offsety = torch.randint(0, sideY - size + 1, ())
361
+ cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
362
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
363
+
364
+ cutouts = torch.cat(cutouts, dim=0)
365
+ cutouts = clamp_with_grad(cutouts, 0, 1)
366
+
367
+ # if args.use_augs:
368
+ cutouts = self.augs(cutouts)
369
+ if self.noise_fac:
370
+ facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
371
+ 0, self.noise_fac
372
+ )
373
+ cutouts = cutouts + facs * torch.randn_like(cutouts)
374
+ return cutouts
375
+
376
+ class MakeCutoutsJuu(nn.Module):
377
+ def __init__(self, cut_size, cutn, cut_pow, augs):
378
+ super().__init__()
379
+ self.cut_size = cut_size
380
+ self.cutn = cutn
381
+ self.cut_pow = cut_pow
382
+ self.augs = nn.Sequential(
383
+ # K.RandomGaussianNoise(mean=0.0, std=0.5, p=0.1),
384
+ K.RandomHorizontalFlip(p=0.5),
385
+ K.RandomSharpness(0.3, p=0.4),
386
+ K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"),
387
+ K.RandomPerspective(0.2, p=0.4),
388
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
389
+ K.RandomGrayscale(p=0.1),
390
+ )
391
+ self.noise_fac = 0.1
392
+
393
+ def forward(self, input):
394
+ sideY, sideX = input.shape[2:4]
395
+ max_size = min(sideX, sideY)
396
+ min_size = min(sideX, sideY, self.cut_size)
397
+ cutouts = []
398
+ for _ in range(self.cutn):
399
+ size = int(
400
+ torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size
401
+ )
402
+ offsetx = torch.randint(0, sideX - size + 1, ())
403
+ offsety = torch.randint(0, sideY - size + 1, ())
404
+ cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
405
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
406
+ batch = self.augs(torch.cat(cutouts, dim=0))
407
+ if self.noise_fac:
408
+ facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
409
+ batch = batch + facs * torch.randn_like(batch)
410
+ return batch
411
+
412
+ class MakeCutoutsMoth(nn.Module):
413
+ def __init__(self, cut_size, cutn, cut_pow, augs, skip_augs=False):
414
+ super().__init__()
415
+ self.cut_size = cut_size
416
+ self.cutn = cutn
417
+ self.cut_pow = cut_pow
418
+ self.skip_augs = skip_augs
419
+ self.augs = T.Compose(
420
+ [
421
+ T.RandomHorizontalFlip(p=0.5),
422
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
423
+ T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
424
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
425
+ T.RandomPerspective(distortion_scale=0.4, p=0.7),
426
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
427
+ T.RandomGrayscale(p=0.15),
428
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
429
+ # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
430
+ ]
431
+ )
432
+
433
+ def forward(self, input):
434
+ input = T.Pad(input.shape[2] // 4, fill=0)(input)
435
+ sideY, sideX = input.shape[2:4]
436
+ max_size = min(sideX, sideY)
437
+
438
+ cutouts = []
439
+ for ch in range(cutn):
440
+ if ch > cutn - cutn // 4:
441
+ cutout = input.clone()
442
+ else:
443
+ size = int(
444
+ max_size
445
+ * torch.zeros(
446
+ 1,
447
+ )
448
+ .normal_(mean=0.8, std=0.3)
449
+ .clip(float(self.cut_size / max_size), 1.0)
450
+ )
451
+ offsetx = torch.randint(0, abs(sideX - size + 1), ())
452
+ offsety = torch.randint(0, abs(sideY - size + 1), ())
453
+ cutout = input[
454
+ :, :, offsety : offsety + size, offsetx : offsetx + size
455
+ ]
456
+
457
+ if not self.skip_augs:
458
+ cutout = self.augs(cutout)
459
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
460
+ del cutout
461
+
462
+ cutouts = torch.cat(cutouts, dim=0)
463
+ return cutouts
464
+
465
+ class MakeCutoutsAaron(nn.Module):
466
+ def __init__(self, cut_size, cutn, cut_pow, augs):
467
+ super().__init__()
468
+ self.cut_size = cut_size
469
+ self.cutn = cutn
470
+ self.cut_pow = cut_pow
471
+ self.augs = augs
472
+ self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
473
+ self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
474
+
475
+ def set_cut_pow(self, cut_pow):
476
+ self.cut_pow = cut_pow
477
+
478
+ def forward(self, input):
479
+ sideY, sideX = input.shape[2:4]
480
+ max_size = min(sideX, sideY)
481
+ min_size = min(sideX, sideY, self.cut_size)
482
+ cutouts = []
483
+ cutouts_full = []
484
+
485
+ min_size_width = min(sideX, sideY)
486
+ lower_bound = float(self.cut_size / min_size_width)
487
+
488
+ for ii in range(self.cutn):
489
+ size = int(
490
+ min_size_width
491
+ * torch.zeros(
492
+ 1,
493
+ )
494
+ .normal_(mean=0.8, std=0.3)
495
+ .clip(lower_bound, 1.0)
496
+ ) # replace .5 with a result for 224 the default large size is .95
497
+
498
+ offsetx = torch.randint(0, sideX - size + 1, ())
499
+ offsety = torch.randint(0, sideY - size + 1, ())
500
+ cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
501
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
502
+
503
+ cutouts = torch.cat(cutouts, dim=0)
504
+
505
+ return clamp_with_grad(cutouts, 0, 1)
506
+
507
+ class MakeCutoutsCumin(nn.Module):
508
+ # from https://colab.research.google.com/drive/1ZAus_gn2RhTZWzOWUpPERNC0Q8OhZRTZ
509
+ def __init__(self, cut_size, cutn, cut_pow, augs):
510
+ super().__init__()
511
+ self.cut_size = cut_size
512
+ # tqdm.write(f"cut size: {self.cut_size}")
513
+ self.cutn = cutn
514
+ self.cut_pow = cut_pow
515
+ self.noise_fac = 0.1
516
+ self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
517
+ self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
518
+ self.augs = nn.Sequential(
519
+ # K.RandomHorizontalFlip(p=0.5),
520
+ # K.RandomSharpness(0.3,p=0.4),
521
+ # K.RandomGaussianBlur((3,3),(10.5,10.5),p=0.2),
522
+ # K.RandomGaussianNoise(p=0.5),
523
+ # K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
524
+ K.RandomAffine(degrees=15, translate=0.1, p=0.7, padding_mode="border"),
525
+ K.RandomPerspective(0.7, p=0.7),
526
+ K.ColorJitter(hue=0.1, saturation=0.1, p=0.7),
527
+ K.RandomErasing((0.1, 0.4), (0.3, 1 / 0.3), same_on_batch=True, p=0.7),
528
+ )
529
+
530
+ def set_cut_pow(self, cut_pow):
531
+ self.cut_pow = cut_pow
532
+
533
+ def forward(self, input):
534
+ sideY, sideX = input.shape[2:4]
535
+ max_size = min(sideX, sideY)
536
+ min_size = min(sideX, sideY, self.cut_size)
537
+ cutouts = []
538
+ cutouts_full = []
539
+ noise_fac = 0.1
540
+
541
+ min_size_width = min(sideX, sideY)
542
+ lower_bound = float(self.cut_size / min_size_width)
543
+
544
+ for ii in range(self.cutn):
545
+
546
+ # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
547
+ randsize = (
548
+ torch.zeros(
549
+ 1,
550
+ )
551
+ .normal_(mean=0.8, std=0.3)
552
+ .clip(lower_bound, 1.0)
553
+ )
554
+ size_mult = randsize**self.cut_pow
555
+ size = int(
556
+ min_size_width * (size_mult.clip(lower_bound, 1.0))
557
+ ) # replace .5 with a result for 224 the default large size is .95
558
+ # size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95
559
+
560
+ offsetx = torch.randint(0, sideX - size + 1, ())
561
+ offsety = torch.randint(0, sideY - size + 1, ())
562
+ cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
563
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
564
+
565
+ cutouts = torch.cat(cutouts, dim=0)
566
+ cutouts = clamp_with_grad(cutouts, 0, 1)
567
+
568
+ # if args.use_augs:
569
+ cutouts = self.augs(cutouts)
570
+ if self.noise_fac:
571
+ facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
572
+ 0, self.noise_fac
573
+ )
574
+ cutouts = cutouts + facs * torch.randn_like(cutouts)
575
+ return cutouts
576
+
577
+ class MakeCutoutsHolywater(nn.Module):
578
+ def __init__(self, cut_size, cutn, cut_pow, augs):
579
+ super().__init__()
580
+ self.cut_size = cut_size
581
+ # tqdm.write(f"cut size: {self.cut_size}")
582
+ self.cutn = cutn
583
+ self.cut_pow = cut_pow
584
+ self.noise_fac = 0.1
585
+ self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
586
+ self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
587
+ self.augs = nn.Sequential(
588
+ # K.RandomGaussianNoise(mean=0.0, std=0.5, p=0.1),
589
+ K.RandomHorizontalFlip(p=0.5),
590
+ K.RandomSharpness(0.3, p=0.4),
591
+ K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"),
592
+ K.RandomPerspective(0.2, p=0.4),
593
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
594
+ K.RandomGrayscale(p=0.1),
595
+ )
596
+
597
+ def set_cut_pow(self, cut_pow):
598
+ self.cut_pow = cut_pow
599
+
600
+ def forward(self, input):
601
+ sideY, sideX = input.shape[2:4]
602
+ max_size = min(sideX, sideY)
603
+ min_size = min(sideX, sideY, self.cut_size)
604
+ cutouts = []
605
+ cutouts_full = []
606
+ noise_fac = 0.1
607
+ min_size_width = min(sideX, sideY)
608
+ lower_bound = float(self.cut_size / min_size_width)
609
+
610
+ for ii in range(self.cutn):
611
+ size = int(
612
+ torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size
613
+ )
614
+ randsize = (
615
+ torch.zeros(
616
+ 1,
617
+ )
618
+ .normal_(mean=0.8, std=0.3)
619
+ .clip(lower_bound, 1.0)
620
+ )
621
+ size_mult = randsize**self.cut_pow * ii + size
622
+ size1 = int(
623
+ (min_size_width) * (size_mult.clip(lower_bound, 1.0))
624
+ ) # replace .5 with a result for 224 the default large size is .95
625
+ size2 = int(
626
+ (min_size_width)
627
+ * torch.zeros(
628
+ 1,
629
+ )
630
+ .normal_(mean=0.9, std=0.3)
631
+ .clip(lower_bound, 0.95)
632
+ ) # replace .5 with a result for 224 the default large size is .95
633
+ offsetx = torch.randint(0, sideX - size1 + 1, ())
634
+ offsety = torch.randint(0, sideY - size2 + 1, ())
635
+ cutout = input[
636
+ :, :, offsety : offsety + size2 + ii, offsetx : offsetx + size1 + ii
637
+ ]
638
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
639
+
640
+ cutouts = torch.cat(cutouts, dim=0)
641
+ cutouts = clamp_with_grad(cutouts, 0, 1)
642
+ cutouts = self.augs(cutouts)
643
+ facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
644
+ 0, self.noise_fac
645
+ )
646
+ cutouts = cutouts + facs * torch.randn_like(cutouts)
647
+ return cutouts
648
+
649
+ class MakeCutoutsOldHolywater(nn.Module):
650
+ def __init__(self, cut_size, cutn, cut_pow, augs):
651
+ super().__init__()
652
+ self.cut_size = cut_size
653
+ # tqdm.write(f"cut size: {self.cut_size}")
654
+ self.cutn = cutn
655
+ self.cut_pow = cut_pow
656
+ self.noise_fac = 0.1
657
+ self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
658
+ self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
659
+ self.augs = nn.Sequential(
660
+ # K.RandomHorizontalFlip(p=0.5),
661
+ # K.RandomSharpness(0.3,p=0.4),
662
+ # K.RandomGaussianBlur((3,3),(10.5,10.5),p=0.2),
663
+ # K.RandomGaussianNoise(p=0.5),
664
+ # K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
665
+ K.RandomAffine(
666
+ degrees=180, translate=0.5, p=0.2, padding_mode="border"
667
+ ),
668
+ K.RandomPerspective(0.6, p=0.9),
669
+ K.ColorJitter(hue=0.03, saturation=0.01, p=0.1),
670
+ K.RandomErasing((0.1, 0.7), (0.3, 1 / 0.4), same_on_batch=True, p=0.2),
671
+ )
672
+
673
+ def set_cut_pow(self, cut_pow):
674
+ self.cut_pow = cut_pow
675
+
676
+ def forward(self, input):
677
+ sideY, sideX = input.shape[2:4]
678
+ max_size = min(sideX, sideY)
679
+ min_size = min(sideX, sideY, self.cut_size)
680
+ cutouts = []
681
+ cutouts_full = []
682
+ noise_fac = 0.1
683
+
684
+ min_size_width = min(sideX, sideY)
685
+ lower_bound = float(self.cut_size / min_size_width)
686
+
687
+ for ii in range(self.cutn):
688
+
689
+ # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
690
+ randsize = (
691
+ torch.zeros(
692
+ 1,
693
+ )
694
+ .normal_(mean=0.8, std=0.3)
695
+ .clip(lower_bound, 1.0)
696
+ )
697
+ size_mult = randsize**self.cut_pow
698
+ size = int(
699
+ min_size_width * (size_mult.clip(lower_bound, 1.0))
700
+ ) # replace .5 with a result for 224 the default large size is .95
701
+ # size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95
702
+
703
+ offsetx = torch.randint(0, sideX - size + 1, ())
704
+ offsety = torch.randint(0, sideY - size + 1, ())
705
+ cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
706
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
707
+
708
+ cutouts = torch.cat(cutouts, dim=0)
709
+ cutouts = clamp_with_grad(cutouts, 0, 1)
710
+
711
+ # if args.use_augs:
712
+ cutouts = self.augs(cutouts)
713
+ if self.noise_fac:
714
+ facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
715
+ 0, self.noise_fac
716
+ )
717
+ cutouts = cutouts + facs * torch.randn_like(cutouts)
718
+ return cutouts
719
+
720
+ class MakeCutoutsGinger(nn.Module):
721
+ def __init__(self, cut_size, cutn, cut_pow, augs):
722
+ super().__init__()
723
+ self.cut_size = cut_size
724
+ # tqdm.write(f"cut size: {self.cut_size}")
725
+ self.cutn = cutn
726
+ self.cut_pow = cut_pow
727
+ self.noise_fac = 0.1
728
+ self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
729
+ self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
730
+ self.augs = augs
731
+ """
732
+ nn.Sequential(
733
+ K.RandomHorizontalFlip(p=0.5),
734
+ K.RandomSharpness(0.3,p=0.4),
735
+ K.RandomGaussianBlur((3,3),(10.5,10.5),p=0.2),
736
+ K.RandomGaussianNoise(p=0.5),
737
+ K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
738
+ K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'), # padding_mode=2
739
+ K.RandomPerspective(0.2,p=0.4, ),
740
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),)
741
+ """
742
+
743
+ def set_cut_pow(self, cut_pow):
744
+ self.cut_pow = cut_pow
745
+
746
+ def forward(self, input):
747
+ sideY, sideX = input.shape[2:4]
748
+ max_size = min(sideX, sideY)
749
+ min_size = min(sideX, sideY, self.cut_size)
750
+ cutouts = []
751
+ cutouts_full = []
752
+ noise_fac = 0.1
753
+
754
+ min_size_width = min(sideX, sideY)
755
+ lower_bound = float(self.cut_size / min_size_width)
756
+
757
+ for ii in range(self.cutn):
758
+
759
+ # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
760
+ randsize = (
761
+ torch.zeros(
762
+ 1,
763
+ )
764
+ .normal_(mean=0.8, std=0.3)
765
+ .clip(lower_bound, 1.0)
766
+ )
767
+ size_mult = randsize**self.cut_pow
768
+ size = int(
769
+ min_size_width * (size_mult.clip(lower_bound, 1.0))
770
+ ) # replace .5 with a result for 224 the default large size is .95
771
+ # size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95
772
+
773
+ offsetx = torch.randint(0, sideX - size + 1, ())
774
+ offsety = torch.randint(0, sideY - size + 1, ())
775
+ cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
776
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
777
+
778
+ cutouts = torch.cat(cutouts, dim=0)
779
+ cutouts = clamp_with_grad(cutouts, 0, 1)
780
+
781
+ # if args.use_augs:
782
+ cutouts = self.augs(cutouts)
783
+ if self.noise_fac:
784
+ facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
785
+ 0, self.noise_fac
786
+ )
787
+ cutouts = cutouts + facs * torch.randn_like(cutouts)
788
+ return cutouts
789
+
790
+ class MakeCutoutsZynth(nn.Module):
791
+ def __init__(self, cut_size, cutn, cut_pow, augs):
792
+ super().__init__()
793
+ self.cut_size = cut_size
794
+ # tqdm.write(f"cut size: {self.cut_size}")
795
+ self.cutn = cutn
796
+ self.cut_pow = cut_pow
797
+ self.noise_fac = 0.1
798
+ self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
799
+ self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
800
+ self.augs = nn.Sequential(
801
+ K.RandomHorizontalFlip(p=0.5),
802
+ # K.RandomSolarize(0.01, 0.01, p=0.7),
803
+ K.RandomSharpness(0.3, p=0.4),
804
+ K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"),
805
+ K.RandomPerspective(0.2, p=0.4),
806
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
807
+ )
808
+
809
+ def set_cut_pow(self, cut_pow):
810
+ self.cut_pow = cut_pow
811
+
812
+ def forward(self, input):
813
+ sideY, sideX = input.shape[2:4]
814
+ max_size = min(sideX, sideY)
815
+ min_size = min(sideX, sideY, self.cut_size)
816
+ cutouts = []
817
+ cutouts_full = []
818
+ noise_fac = 0.1
819
+
820
+ min_size_width = min(sideX, sideY)
821
+ lower_bound = float(self.cut_size / min_size_width)
822
+
823
+ for ii in range(self.cutn):
824
+
825
+ # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
826
+ randsize = (
827
+ torch.zeros(
828
+ 1,
829
+ )
830
+ .normal_(mean=0.8, std=0.3)
831
+ .clip(lower_bound, 1.0)
832
+ )
833
+ size_mult = randsize**self.cut_pow
834
+ size = int(
835
+ min_size_width * (size_mult.clip(lower_bound, 1.0))
836
+ ) # replace .5 with a result for 224 the default large size is .95
837
+ # size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95
838
+
839
+ offsetx = torch.randint(0, sideX - size + 1, ())
840
+ offsety = torch.randint(0, sideY - size + 1, ())
841
+ cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
842
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
843
+
844
+ cutouts = torch.cat(cutouts, dim=0)
845
+ cutouts = clamp_with_grad(cutouts, 0, 1)
846
+
847
+ # if args.use_augs:
848
+ cutouts = self.augs(cutouts)
849
+ if self.noise_fac:
850
+ facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
851
+ 0, self.noise_fac
852
+ )
853
+ cutouts = cutouts + facs * torch.randn_like(cutouts)
854
+ return cutouts
855
+
856
+ class MakeCutoutsWyvern(nn.Module):
857
+ def __init__(self, cut_size, cutn, cut_pow, augs):
858
+ super().__init__()
859
+ self.cut_size = cut_size
860
+ # tqdm.write(f"cut size: {self.cut_size}")
861
+ self.cutn = cutn
862
+ self.cut_pow = cut_pow
863
+ self.noise_fac = 0.1
864
+ self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
865
+ self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
866
+ self.augs = augs
867
+
868
+ def forward(self, input):
869
+ sideY, sideX = input.shape[2:4]
870
+ max_size = min(sideX, sideY)
871
+ min_size = min(sideX, sideY, self.cut_size)
872
+ cutouts = []
873
+ for _ in range(self.cutn):
874
+ size = int(
875
+ torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size
876
+ )
877
+ offsetx = torch.randint(0, sideX - size + 1, ())
878
+ offsety = torch.randint(0, sideY - size + 1, ())
879
+ cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
880
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
881
+ return clamp_with_grad(torch.cat(cutouts, dim=0), 0, 1)
882
+
883
+
884
+ import PIL
885
+
886
+ def resize_image(image, out_size):
887
+ ratio = image.size[0] / image.size[1]
888
+ area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
889
+ size = round((area * ratio) ** 0.5), round((area / ratio) ** 0.5)
890
+ return image.resize(size, PIL.Image.LANCZOS)
891
+
892
+ class GaussianBlur2d(nn.Module):
893
+ def __init__(self, sigma, window=0, mode="reflect", value=0):
894
+ super().__init__()
895
+ self.mode = mode
896
+ self.value = value
897
+ if not window:
898
+ window = max(math.ceil((sigma * 6 + 1) / 2) * 2 - 1, 3)
899
+ if sigma:
900
+ kernel = torch.exp(
901
+ -((torch.arange(window) - window // 2) ** 2) / 2 / sigma**2
902
+ )
903
+ kernel /= kernel.sum()
904
+ else:
905
+ kernel = torch.ones([1])
906
+ self.register_buffer("kernel", kernel)
907
+
908
+ def forward(self, input):
909
+ n, c, h, w = input.shape
910
+ input = input.view([n * c, 1, h, w])
911
+ start_pad = (self.kernel.shape[0] - 1) // 2
912
+ end_pad = self.kernel.shape[0] // 2
913
+ input = F.pad(
914
+ input, (start_pad, end_pad, start_pad, end_pad), self.mode, self.value
915
+ )
916
+ input = F.conv2d(input, self.kernel[None, None, None, :])
917
+ input = F.conv2d(input, self.kernel[None, None, :, None])
918
+ return input.view([n, c, h, w])
919
+
920
+ BUF_SIZE = 65536
921
+
922
+ def get_digest(path, alg=hashlib.sha256):
923
+ hash = alg()
924
+ # print(path)
925
+ with open(path, "rb") as fp:
926
+ while True:
927
+ data = fp.read(BUF_SIZE)
928
+ if not data:
929
+ break
930
+ hash.update(data)
931
+ return b64encode(hash.digest()).decode("utf-8")
932
+
933
+ flavordict = {
934
+ "cumin": MakeCutoutsCumin,
935
+ "holywater": MakeCutoutsHolywater,
936
+ "old_holywater": MakeCutoutsOldHolywater,
937
+ "ginger": MakeCutoutsGinger,
938
+ "zynth": MakeCutoutsZynth,
939
+ "wyvern": MakeCutoutsWyvern,
940
+ "aaron": MakeCutoutsAaron,
941
+ "moth": MakeCutoutsMoth,
942
+ "juu": MakeCutoutsJuu,
943
+ "custom": MakeCutoutsCustom,
944
+ }
945
+
946
+ @torch.jit.script
947
+ def gelu_impl(x):
948
+ """OpenAI's gelu implementation."""
949
+ return (
950
+ 0.5
951
+ * x
952
+ * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
953
+ )
954
+
955
+ def gelu(x):
956
+ return gelu_impl(x)
957
+
958
+ class MSEDecayLoss(nn.Module):
959
+ def __init__(self, init_weight, mse_decay_rate, mse_epoches, mse_quantize):
960
+ super().__init__()
961
+
962
+ self.init_weight = init_weight
963
+ self.has_init_image = False
964
+ self.mse_decay = init_weight / mse_epoches if init_weight else 0
965
+ self.mse_decay_rate = mse_decay_rate
966
+ self.mse_weight = init_weight
967
+ self.mse_epoches = mse_epoches
968
+ self.mse_quantize = mse_quantize
969
+
970
+ @torch.no_grad()
971
+ def set_target(self, z_tensor, model):
972
+ z_tensor = z_tensor.detach().clone()
973
+ if self.mse_quantize:
974
+ z_tensor = vector_quantize(
975
+ z_tensor.movedim(1, 3), model.quantize.embedding.weight
976
+ ).movedim(
977
+ 3, 1
978
+ ) # z.average
979
+ self.z_orig = z_tensor
980
+
981
+ def forward(self, i, z):
982
+ if self.is_active(i):
983
+ return F.mse_loss(z, self.z_orig) * self.mse_weight / 2
984
+ return 0
985
+
986
+ def is_active(self, i):
987
+ if not self.init_weight:
988
+ return False
989
+ if i <= self.mse_decay_rate and not self.has_init_image:
990
+ return False
991
+ return True
992
+
993
+ @torch.no_grad()
994
+ def step(self, i):
995
+
996
+ if (
997
+ i % self.mse_decay_rate == 0
998
+ and i != 0
999
+ and i < self.mse_decay_rate * self.mse_epoches
1000
+ ):
1001
+
1002
+ if (
1003
+ self.mse_weight - self.mse_decay > 0
1004
+ and self.mse_weight - self.mse_decay >= self.mse_decay
1005
+ ):
1006
+ self.mse_weight -= self.mse_decay
1007
+ else:
1008
+ self.mse_weight = 0
1009
+ # print(f"updated mse weight: {self.mse_weight}")
1010
+
1011
+ return True
1012
+
1013
+ return False
1014
+
1015
+ class TVLoss(nn.Module):
1016
+ def forward(self, input):
1017
+ input = F.pad(input, (0, 1, 0, 1), "replicate")
1018
+ x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
1019
+ y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
1020
+ diff = x_diff**2 + y_diff**2 + 1e-8
1021
+ return diff.mean(dim=1).sqrt().mean()
1022
+
1023
+ class MultiClipLoss(nn.Module):
1024
+ def __init__(
1025
+ self, clip_models, text_prompt, cutn, cut_pow=1.0, clip_weight=1.0
1026
+ ):
1027
+ super().__init__()
1028
+
1029
+ # Load Clip
1030
+ self.perceptors = []
1031
+ for cm in clip_models:
1032
+ sys.stdout.write(f"Loading {cm[0]} ...\n")
1033
+ sys.stdout.flush()
1034
+ c = (
1035
+ clip.load(cm[0], jit=False)[0]
1036
+ .eval()
1037
+ .requires_grad_(False)
1038
+ .to(device)
1039
+ )
1040
+ self.perceptors.append(
1041
+ {
1042
+ "res": c.visual.input_resolution,
1043
+ "perceptor": c,
1044
+ "weight": cm[1],
1045
+ "prompts": [],
1046
+ }
1047
+ )
1048
+ self.perceptors.sort(key=lambda e: e["res"], reverse=True)
1049
+
1050
+ # Make Cutouts
1051
+ self.max_cut_size = self.perceptors[0]["res"]
1052
+ # self.make_cuts = flavordict[flavor](self.max_cut_size, cutn, cut_pow)
1053
+ # cutouts = flavordict[flavor](self.max_cut_size, cutn, cut_pow=cut_pow, augs=args.augs)
1054
+
1055
+ # Get Prompt Embedings
1056
+ # texts = [phrase.strip() for phrase in text_prompt.split("|")]
1057
+ # if text_prompt == ['']:
1058
+ # texts = []
1059
+ texts = text_prompt
1060
+ self.pMs = []
1061
+ for prompt in texts:
1062
+ txt, weight, stop = parse_prompt(prompt)
1063
+ clip_token = clip.tokenize(txt).to(device)
1064
+ for p in self.perceptors:
1065
+ embed = p["perceptor"].encode_text(clip_token).float()
1066
+ embed_normed = F.normalize(embed.unsqueeze(0), dim=2)
1067
+ p["prompts"].append(
1068
+ {
1069
+ "embed_normed": embed_normed,
1070
+ "weight": torch.as_tensor(weight, device=device),
1071
+ "stop": torch.as_tensor(stop, device=device),
1072
+ }
1073
+ )
1074
+
1075
+ # Prep Augments
1076
+ self.normalize = transforms.Normalize(
1077
+ mean=[0.48145466, 0.4578275, 0.40821073],
1078
+ std=[0.26862954, 0.26130258, 0.27577711],
1079
+ )
1080
+
1081
+ self.augs = nn.Sequential(
1082
+ K.RandomHorizontalFlip(p=0.5),
1083
+ K.RandomSharpness(0.3, p=0.1),
1084
+ K.RandomAffine(
1085
+ degrees=30, translate=0.1, p=0.8, padding_mode="border"
1086
+ ), # padding_mode=2
1087
+ K.RandomPerspective(
1088
+ 0.2,
1089
+ p=0.4,
1090
+ ),
1091
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
1092
+ K.RandomGrayscale(p=0.15),
1093
+ )
1094
+ self.noise_fac = 0.1
1095
+
1096
+ self.clip_weight = clip_weight
1097
+
1098
+ def prepare_cuts(self, img):
1099
+ cutouts = self.make_cuts(img)
1100
+ cutouts = self.augs(cutouts)
1101
+ if self.noise_fac:
1102
+ facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
1103
+ 0, self.noise_fac
1104
+ )
1105
+ cutouts = cutouts + facs * torch.randn_like(cutouts)
1106
+ cutouts = self.normalize(cutouts)
1107
+ return cutouts
1108
+
1109
+ def forward(self, i, img):
1110
+ cutouts = checkpoint(self.prepare_cuts, img)
1111
+ loss = []
1112
+
1113
+ current_cuts = cutouts
1114
+ currentres = self.max_cut_size
1115
+ for p in self.perceptors:
1116
+ if currentres != p["res"]:
1117
+ current_cuts = resample(cutouts, (p["res"], p["res"]))
1118
+ currentres = p["res"]
1119
+
1120
+ iii = p["perceptor"].encode_image(current_cuts).float()
1121
+ input_normed = F.normalize(iii.unsqueeze(1), dim=2)
1122
+ for prompt in p["prompts"]:
1123
+ dists = (
1124
+ input_normed.sub(prompt["embed_normed"])
1125
+ .norm(dim=2)
1126
+ .div(2)
1127
+ .arcsin()
1128
+ .pow(2)
1129
+ .mul(2)
1130
+ )
1131
+ dists = dists * prompt["weight"].sign()
1132
+ l = (
1133
+ prompt["weight"].abs()
1134
+ * replace_grad(
1135
+ dists, torch.maximum(dists, prompt["stop"])
1136
+ ).mean()
1137
+ )
1138
+ loss.append(l * p["weight"])
1139
+
1140
+ return loss
1141
+
1142
+ class ModelHost:
1143
+ def __init__(self, args):
1144
+ self.args = args
1145
+ self.model, self.perceptor = None, None
1146
+ self.make_cutouts = None
1147
+ self.alt_make_cutouts = None
1148
+ self.imageSize = None
1149
+ self.prompts = None
1150
+ self.opt = None
1151
+ self.normalize = None
1152
+ self.z, self.z_orig, self.z_min, self.z_max = None, None, None, None
1153
+ self.metadata = None
1154
+ self.mse_weight = 0
1155
+ self.normal_flip_optim = None
1156
+ self.usealtprompts = False
1157
+
1158
+ def setup_metadata(self, seed):
1159
+ metadata = {k: v for k, v in vars(self.args).items()}
1160
+ del metadata["max_iterations"]
1161
+ del metadata["display_freq"]
1162
+ metadata["seed"] = seed
1163
+ if metadata["init_image"]:
1164
+ path = metadata["init_image"]
1165
+ digest = get_digest(path)
1166
+ metadata["init_image"] = (path, digest)
1167
+ if metadata["image_prompts"]:
1168
+ prompts = []
1169
+ for prompt in metadata["image_prompts"]:
1170
+ path = prompt
1171
+ digest = get_digest(path)
1172
+ prompts.append((path, digest))
1173
+ metadata["image_prompts"] = prompts
1174
+ self.metadata = metadata
1175
+
1176
+ def setup_model(self, x):
1177
+ i = x
1178
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1179
+
1180
+ #perceptor = (
1181
+ # clip.load(args.clip_model, jit=False)[0]
1182
+ # .eval()
1183
+ # .requires_grad_(False)
1184
+ # .to(device)
1185
+ #)
1186
 
1187
+ cut_size = perceptor.visual.input_resolution
1188
+
1189
+ if self.args.is_gumbel:
1190
+ e_dim = model.quantize.embedding_dim
1191
+ else:
1192
+ e_dim = model.quantize.e_dim
1193
+
1194
+ f = 2 ** (model.decoder.num_resolutions - 1)
1195
+
1196
+ make_cutouts = flavordict[flavor](
1197
+ cut_size, args.mse_cutn, cut_pow=args.mse_cut_pow, augs=args.augs
1198
+ )
1199
+
1200
+ # make_cutouts = MakeCutouts(cut_size, args.mse_cutn, cut_pow=args.mse_cut_pow,augs=args.augs)
1201
+ if args.altprompts:
1202
+ self.usealtprompts = True
1203
+ self.alt_make_cutouts = flavordict[flavor](
1204
+ cut_size,
1205
+ args.mse_cutn,
1206
+ cut_pow=args.alt_mse_cut_pow,
1207
+ augs=args.altaugs,
1208
+ )
1209
+ # self.alt_make_cutouts = MakeCutouts(cut_size, args.mse_cutn, cut_pow=args.alt_mse_cut_pow,augs=args.altaugs)
1210
+
1211
+ if self.args.is_gumbel:
1212
+ n_toks = model.quantize.n_embed
1213
+ else:
1214
+ n_toks = model.quantize.n_e
1215
+
1216
+ toksX, toksY = args.size[0] // f, args.size[1] // f
1217
+ sideX, sideY = toksX * f, toksY * f
1218
+
1219
+ if self.args.is_gumbel:
1220
+ z_min = model.quantize.embed.weight.min(dim=0).values[
1221
+ None, :, None, None
1222
+ ]
1223
+ z_max = model.quantize.embed.weight.max(dim=0).values[
1224
+ None, :, None, None
1225
+ ]
1226
+ else:
1227
+ z_min = model.quantize.embedding.weight.min(dim=0).values[
1228
+ None, :, None, None
1229
+ ]
1230
+ z_max = model.quantize.embedding.weight.max(dim=0).values[
1231
+ None, :, None, None
1232
+ ]
1233
+
1234
+ from PIL import Image
1235
+ import cv2
1236
+
1237
+ # -------
1238
+ working_dir = self.args.folder_name
1239
+
1240
+ if self.args.init_image != "":
1241
+ img_0 = cv2.imread(init_image)
1242
+ z, *_ = model.encode(
1243
+ TF.to_tensor(img_0).to(device).unsqueeze(0) * 2 - 1
1244
+ )
1245
+ elif not os.path.isfile(f"{working_dir}/steps/{i:04d}.png"):
1246
+ one_hot = F.one_hot(
1247
+ torch.randint(n_toks, [toksY * toksX], device=device), n_toks
1248
+ ).float()
1249
+ if self.args.is_gumbel:
1250
+ z = one_hot @ model.quantize.embed.weight
1251
+ else:
1252
+ z = one_hot @ model.quantize.embedding.weight
1253
+ z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
1254
+ else:
1255
+ center = (1 * img_0.shape[1] // 2, 1 * img_0.shape[0] // 2)
1256
+ trans_mat = np.float32([[1, 0, 10], [0, 1, 10]])
1257
+ rot_mat = cv2.getRotationMatrix2D(center, 10, 20)
1258
+
1259
+ trans_mat = np.vstack([trans_mat, [0, 0, 1]])
1260
+ rot_mat = np.vstack([rot_mat, [0, 0, 1]])
1261
+ transformation_matrix = np.matmul(rot_mat, trans_mat)
1262
+
1263
+ img_0 = cv2.warpPerspective(
1264
+ img_0,
1265
+ transformation_matrix,
1266
+ (img_0.shape[1], img_0.shape[0]),
1267
+ borderMode=cv2.BORDER_WRAP,
1268
+ )
1269
+ z, *_ = model.encode(
1270
+ TF.to_tensor(img_0).to(device).unsqueeze(0) * 2 - 1
1271
+ )
1272
+
1273
+ def save_output(i, img, suffix="zoomed"):
1274
+ filename = f"{working_dir}/steps/{i:04}{'_' + suffix if suffix else ''}.png"
1275
+ imageio.imwrite(filename, np.array(img))
1276
+
1277
+ save_output(i, img_0)
1278
+ # -------
1279
+ if args.init_image:
1280
+ pil_image = Image.open(args.init_image).convert("RGB")
1281
+ pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
1282
+ z, *_ = model.encode(
1283
+ TF.to_tensor(pil_image).to(device).unsqueeze(0) * 2 - 1
1284
+ )
1285
+ else:
1286
+ one_hot = F.one_hot(
1287
+ torch.randint(n_toks, [toksY * toksX], device=device), n_toks
1288
+ ).float()
1289
+ if self.args.is_gumbel:
1290
+ z = one_hot @ model.quantize.embed.weight
1291
+ else:
1292
+ z = one_hot @ model.quantize.embedding.weight
1293
+ z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
1294
+ z = EMATensor(z, args.ema_val)
1295
+
1296
+ if args.mse_with_zeros and not args.init_image:
1297
+ z_orig = torch.zeros_like(z.tensor)
1298
+ else:
1299
+ z_orig = z.tensor.clone()
1300
+ z.requires_grad_(True)
1301
+ # opt = optim.AdamW(z.parameters(), lr=args.mse_step_size, weight_decay=0.00000000)
1302
+ print("Step size inside:", args.step_size)
1303
+ if self.normal_flip_optim == True:
1304
+ if randint(1, 2) == 1:
1305
+ opt = torch.optim.AdamW(
1306
+ z.parameters(), lr=args.step_size, weight_decay=0.00000000
1307
+ )
1308
+ # opt = Ranger21(z.parameters(), lr=args.step_size, weight_decay=0.00000000)
1309
+ else:
1310
+ opt = optim.DiffGrad(
1311
+ z.parameters(), lr=args.step_size, weight_decay=0.00000000
1312
+ )
1313
+ else:
1314
+ opt = torch.optim.AdamW(
1315
+ z.parameters(), lr=args.step_size, weight_decay=0.00000000
1316
+ )
1317
+
1318
+ self.cur_step_size = args.mse_step_size
1319
+
1320
+ normalize = transforms.Normalize(
1321
+ mean=[0.48145466, 0.4578275, 0.40821073],
1322
+ std=[0.26862954, 0.26130258, 0.27577711],
1323
+ )
1324
+
1325
+ pMs = []
1326
+ altpMs = []
1327
+
1328
+ for prompt in args.prompts:
1329
+ txt, weight, stop = parse_prompt(prompt)
1330
+ embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
1331
+ pMs.append(Prompt(embed, weight, stop).to(device))
1332
+
1333
+ for prompt in args.altprompts:
1334
+ txt, weight, stop = parse_prompt(prompt)
1335
+ embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
1336
+ altpMs.append(Prompt(embed, weight, stop).to(device))
1337
+
1338
+ from PIL import Image
1339
+
1340
+ for prompt in args.image_prompts:
1341
+ path, weight, stop = parse_prompt(prompt)
1342
+ img = resize_image(Image.open(path).convert("RGB"), (sideX, sideY))
1343
+ batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
1344
+ embed = perceptor.encode_image(normalize(batch)).float()
1345
+ pMs.append(Prompt(embed, weight, stop).to(device))
1346
+
1347
+ for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
1348
+ gen = torch.Generator().manual_seed(seed)
1349
+ embed = torch.empty([1, perceptor.visual.output_dim]).normal_(
1350
+ generator=gen
1351
+ )
1352
+ pMs.append(Prompt(embed, weight).to(device))
1353
+ if self.usealtprompts:
1354
+ altpMs.append(Prompt(embed, weight).to(device))
1355
+
1356
+ self.model, self.perceptor = model, perceptor
1357
+ self.make_cutouts = make_cutouts
1358
+ self.imageSize = (sideX, sideY)
1359
+ self.prompts = pMs
1360
+ self.altprompts = altpMs
1361
+ self.opt = opt
1362
+ self.normalize = normalize
1363
+ self.z, self.z_orig, self.z_min, self.z_max = z, z_orig, z_min, z_max
1364
+ self.setup_metadata(args2.seed)
1365
+ self.mse_weight = self.args.init_weight
1366
+
1367
+ def synth(self, z):
1368
+ if self.args.is_gumbel:
1369
+ z_q = vector_quantize(
1370
+ z.movedim(1, 3), self.model.quantize.embed.weight
1371
+ ).movedim(3, 1)
1372
+ else:
1373
+ z_q = vector_quantize(
1374
+ z.movedim(1, 3), self.model.quantize.embedding.weight
1375
+ ).movedim(3, 1)
1376
+ return clamp_with_grad(self.model.decode(z_q).add(1).div(2), 0, 1)
1377
+
1378
+ def add_metadata(self, path, i):
1379
+ imfile = PngImageFile(path)
1380
+ meta = PngInfo()
1381
+ step_meta = {"iterations": i}
1382
+ step_meta.update(self.metadata)
1383
+ # meta.add_itxt('vqgan-params', json.dumps(step_meta), zip=True)
1384
+ imfile.save(path, pnginfo=meta)
1385
+ # Hey you. This one's for Glooperpogger#7353 on Discord (Gloop has a gun), they are a nice snek
1386
+
1387
+ @torch.no_grad()
1388
+ def checkin(self, i, losses, x):
1389
+ out = self.synth(self.z.average)
1390
+
1391
+ batchpath = "./"
1392
+ TF.to_pil_image(out[0].cpu()).save(args2.image_file)
1393
+
1394
+ def unique_index(self, batchpath):
1395
+ i = 0
1396
+ while i < 10000:
1397
+ if os.path.isfile(batchpath + "/" + str(i) + ".png"):
1398
+ i = i + 1
1399
+ else:
1400
+ return batchpath + "/" + str(i) + ".png"
1401
+
1402
+ def ascend_txt(self, i):
1403
+ out = self.synth(self.z.tensor)
1404
+ iii = self.perceptor.encode_image(
1405
+ self.normalize(self.make_cutouts(out))
1406
+ ).float()
1407
+
1408
+ result = []
1409
+ if self.args.init_weight and self.mse_weight > 0:
1410
+ result.append(
1411
+ F.mse_loss(self.z.tensor, self.z_orig) * self.mse_weight / 2
1412
+ )
1413
+
1414
+ for prompt in self.prompts:
1415
+ result.append(prompt(iii))
1416
+
1417
+ if self.usealtprompts:
1418
+ iii = self.perceptor.encode_image(
1419
+ self.normalize(self.alt_make_cutouts(out))
1420
+ ).float()
1421
+ for prompt in self.altprompts:
1422
+ result.append(prompt(iii))
1423
+
1424
+ return result
1425
+
1426
+ def train(self, i, x):
1427
+ self.opt.zero_grad()
1428
+ mse_decay = self.args.mse_decay
1429
+ mse_decay_rate = self.args.mse_decay_rate
1430
+ lossAll = self.ascend_txt(i)
1431
+
1432
+ sys.stdout.write("Iteration {}".format(i) + "\n")
1433
+ sys.stdout.flush()
1434
+ if i % (args2.iterations-2) == 0:
1435
+ self.checkin(i, lossAll, x)
1436
+
1437
+ loss = sum(lossAll)
1438
+ loss.backward()
1439
+ self.opt.step()
1440
+ with torch.no_grad():
1441
+ if (
1442
+ self.mse_weight > 0
1443
+ and self.args.init_weight
1444
+ and i > 0
1445
+ and i % mse_decay_rate == 0
1446
+ ):
1447
+ if self.args.is_gumbel:
1448
+ self.z_orig = vector_quantize(
1449
+ self.z.average.movedim(1, 3),
1450
+ self.model.quantize.embed.weight,
1451
+ ).movedim(3, 1)
1452
+ else:
1453
+ self.z_orig = vector_quantize(
1454
+ self.z.average.movedim(1, 3),
1455
+ self.model.quantize.embedding.weight,
1456
+ ).movedim(3, 1)
1457
+ if self.mse_weight - mse_decay > 0:
1458
+ self.mse_weight = self.mse_weight - mse_decay
1459
+ # print(f"updated mse weight: {self.mse_weight}")
1460
+ else:
1461
+ self.mse_weight = 0
1462
+ self.make_cutouts = flavordict[flavor](
1463
+ self.perceptor.visual.input_resolution,
1464
+ args.cutn,
1465
+ cut_pow=args.cut_pow,
1466
+ augs=args.augs,
1467
+ )
1468
+ if self.usealtprompts:
1469
+ self.alt_make_cutouts = flavordict[flavor](
1470
+ self.perceptor.visual.input_resolution,
1471
+ args.cutn,
1472
+ cut_pow=args.alt_cut_pow,
1473
+ augs=args.altaugs,
1474
+ )
1475
+ self.z = EMATensor(self.z.average, args.ema_val)
1476
+ self.new_step_size = args.step_size
1477
+ self.opt = torch.optim.AdamW(
1478
+ self.z.parameters(),
1479
+ lr=args.step_size,
1480
+ weight_decay=0.00000000,
1481
+ )
1482
+ # print(f"updated mse weight: {self.mse_weight}")
1483
+ if i > args.mse_end:
1484
+ if (
1485
+ args.step_size != args.final_step_size
1486
+ and args.max_iterations > 0
1487
+ ):
1488
+ progress = (i - args.mse_end) / (args.max_iterations)
1489
+ self.cur_step_size = lerp(step_size, final_step_size, progress)
1490
+ for g in self.opt.param_groups:
1491
+ g["lr"] = self.cur_step_size
1492
+
1493
+ def run(self, x):
1494
+ j = 0
1495
+ try:
1496
+ print("Step size: ", args.step_size)
1497
+ print("Step MSE size: ", args.mse_step_size)
1498
+ before_start_time = time.perf_counter()
1499
+ total_steps = int(args.max_iterations + args.mse_end) - 1
1500
+ for _ in range(total_steps):
1501
+ self.train(j, x)
1502
+ if j > 0 and j % args.mse_decay_rate == 0 and self.mse_weight > 0:
1503
+ self.z = EMATensor(self.z.average, args.ema_val)
1504
+ self.opt = torch.optim.AdamW(
1505
+ self.z.parameters(),
1506
+ lr=args.mse_step_size,
1507
+ weight_decay=0.00000000,
1508
+ )
1509
+ if j >= total_steps:
1510
+ break
1511
+ self.z.update()
1512
+ j += 1
1513
+ time_past_seconds = time.perf_counter() - before_start_time
1514
+ iterations_per_second = j / time_past_seconds
1515
+ time_left = (total_steps - j) / iterations_per_second
1516
+ percentage = round((j / (total_steps + 1)) * 100)
1517
+
1518
+ import shutil
1519
+ import os
1520
+
1521
+ #image_data = Image.open(args2.image_file)
1522
+ #os.remove(args2.image_file)
1523
+ #return(image_data)
1524
+
1525
+ except KeyboardInterrupt:
1526
+ pass
1527
+
1528
+ def add_noise(img):
1529
+
1530
+ # Getting the dimensions of the image
1531
+ row, col = img.shape
1532
+
1533
+ # Randomly pick some pixels in the
1534
+ # image for coloring them white
1535
+ # Pick a random number between 300 and 10000
1536
+ number_of_pixels = random.randint(300, 10000)
1537
+ for i in range(number_of_pixels):
1538
+
1539
+ # Pick a random y coordinate
1540
+ y_coord = random.randint(0, row - 1)
1541
+
1542
+ # Pick a random x coordinate
1543
+ x_coord = random.randint(0, col - 1)
1544
+
1545
+ # Color that pixel to white
1546
+ img[y_coord][x_coord] = 255
1547
+
1548
+ # Randomly pick some pixels in
1549
+ # the image for coloring them black
1550
+ # Pick a random number between 300 and 10000
1551
+ number_of_pixels = random.randint(300, 10000)
1552
+ for i in range(number_of_pixels):
1553
+
1554
+ # Pick a random y coordinate
1555
+ y_coord = random.randint(0, row - 1)
1556
+
1557
+ # Pick a random x coordinate
1558
+ x_coord = random.randint(0, col - 1)
1559
+
1560
+ # Color that pixel to black
1561
+ img[y_coord][x_coord] = 0
1562
+
1563
+ return img
1564
+
1565
+ import io
1566
+ import base64
1567
+
1568
+ def image_to_data_url(img, ext):
1569
+ img_byte_arr = io.BytesIO()
1570
+ img.save(img_byte_arr, format=ext)
1571
+ img_byte_arr = img_byte_arr.getvalue()
1572
+ # ext = filename.split('.')[-1]
1573
+ prefix = f"data:image/{ext};base64,"
1574
+ return prefix + base64.b64encode(img_byte_arr).decode("utf-8")
1575
+
1576
+ import torch
1577
+ import math
1578
+
1579
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1580
+
1581
+ def rand_perlin_2d(
1582
+ shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3
1583
+ ):
1584
+ delta = (res[0] / shape[0], res[1] / shape[1])
1585
+ d = (shape[0] // res[0], shape[1] // res[1])
1586
+
1587
+ grid = (
1588
+ torch.stack(
1589
+ torch.meshgrid(
1590
+ torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])
1591
+ ),
1592
+ dim=-1,
1593
+ )
1594
+ % 1
1595
+ )
1596
+ angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1)
1597
+ gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
1598
+
1599
+ tile_grads = (
1600
+ lambda slice1, slice2: gradients[
1601
+ slice1[0] : slice1[1], slice2[0] : slice2[1]
1602
+ ]
1603
+ .repeat_interleave(d[0], 0)
1604
+ .repeat_interleave(d[1], 1)
1605
+ )
1606
+ dot = lambda grad, shift: (
1607
+ torch.stack(
1608
+ (
1609
+ grid[: shape[0], : shape[1], 0] + shift[0],
1610
+ grid[: shape[0], : shape[1], 1] + shift[1],
1611
+ ),
1612
+ dim=-1,
1613
+ )
1614
+ * grad[: shape[0], : shape[1]]
1615
+ ).sum(dim=-1)
1616
+
1617
+ n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
1618
+ n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
1619
+ n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
1620
+ n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
1621
+ t = fade(grid[: shape[0], : shape[1]])
1622
+ return math.sqrt(2) * torch.lerp(
1623
+ torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]
1624
+ )
1625
+
1626
+ def rand_perlin_2d_octaves(desired_shape, octaves=1, persistence=0.5):
1627
+ shape = torch.tensor(desired_shape)
1628
+ shape = 2 ** torch.ceil(torch.log2(shape))
1629
+ shape = shape.type(torch.int)
1630
+
1631
+ max_octaves = int(
1632
+ min(
1633
+ octaves,
1634
+ math.log(shape[0]) / math.log(2),
1635
+ math.log(shape[1]) / math.log(2),
1636
+ )
1637
+ )
1638
+ res = torch.floor(shape / 2**max_octaves).type(torch.int)
1639
+
1640
+ noise = torch.zeros(list(shape))
1641
+ frequency = 1
1642
+ amplitude = 1
1643
+ for _ in range(max_octaves):
1644
+ noise += amplitude * rand_perlin_2d(
1645
+ shape, (frequency * res[0], frequency * res[1])
1646
+ )
1647
+ frequency *= 2
1648
+ amplitude *= persistence
1649
+
1650
+ return noise[: desired_shape[0], : desired_shape[1]]
1651
+
1652
+ def rand_perlin_rgb(desired_shape, amp=0.1, octaves=6):
1653
+ r = rand_perlin_2d_octaves(desired_shape, octaves)
1654
+ g = rand_perlin_2d_octaves(desired_shape, octaves)
1655
+ b = rand_perlin_2d_octaves(desired_shape, octaves)
1656
+ rgb = (torch.stack((r, g, b)) * amp + 1) * 0.5
1657
+ return rgb.unsqueeze(0).clip(0, 1).to(device)
1658
+
1659
+ def pyramid_noise_gen(shape, octaves=5, decay=1.0):
1660
+ n, c, h, w = shape
1661
+ noise = torch.zeros([n, c, 1, 1])
1662
+ max_octaves = int(min(math.log(h) / math.log(2), math.log(w) / math.log(2)))
1663
+ if octaves is not None and 0 < octaves:
1664
+ max_octaves = min(octaves, max_octaves)
1665
+ for i in reversed(range(max_octaves)):
1666
+ h_cur, w_cur = h // 2**i, w // 2**i
1667
+ noise = F.interpolate(
1668
+ noise, (h_cur, w_cur), mode="bicubic", align_corners=False
1669
+ )
1670
+ noise += (torch.randn([n, c, h_cur, w_cur]) / max_octaves) * decay ** (
1671
+ max_octaves - (i + 1)
1672
+ )
1673
+ return noise
1674
+
1675
+ def rand_z(model, toksX, toksY):
1676
+ e_dim = model.quantize.e_dim
1677
+ n_toks = model.quantize.n_e
1678
+ z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
1679
+ z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
1680
+
1681
+ one_hot = F.one_hot(
1682
+ torch.randint(n_toks, [toksY * toksX], device=device), n_toks
1683
+ ).float()
1684
+ z = one_hot @ model.quantize.embedding.weight
1685
+ z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
1686
+
1687
+ return z
1688
+
1689
+ def make_rand_init(
1690
+ mode,
1691
+ model,
1692
+ perlin_octaves,
1693
+ perlin_weight,
1694
+ pyramid_octaves,
1695
+ pyramid_decay,
1696
+ toksX,
1697
+ toksY,
1698
+ f,
1699
+ ):
1700
+
1701
+ if mode == "VQGAN ZRand":
1702
+ return rand_z(model, toksX, toksY)
1703
+ elif mode == "Perlin Noise":
1704
+ rand_init = rand_perlin_rgb(
1705
+ (toksY * f, toksX * f), perlin_weight, perlin_octaves
1706
+ )
1707
+ z, *_ = model.encode(rand_init * 2 - 1)
1708
+ return z
1709
+ elif mode == "Pyramid Noise":
1710
+ rand_init = pyramid_noise_gen(
1711
+ (1, 3, toksY * f, toksX * f), pyramid_octaves, pyramid_decay
1712
+ ).to(device)
1713
+ rand_init = (rand_init * 0.5 + 0.5).clip(0, 1)
1714
+ z, *_ = model.encode(rand_init * 2 - 1)
1715
+ return z
1716
+
1717
+ ##################### JUICY MESS ###################################
1718
+ import os
1719
+
1720
+ imagenet_1024 = False # @param {type:"boolean"}
1721
+ imagenet_16384 = True # @param {type:"boolean"}
1722
+ gumbel_8192 = False # @param {type:"boolean"}
1723
+ sber_gumbel = False # @param {type:"boolean"}
1724
+ # imagenet_cin = False #@param {type:"boolean"}
1725
+ coco = False # @param {type:"boolean"}
1726
+ coco_1stage = False # @param {type:"boolean"}
1727
+ faceshq = False # @param {type:"boolean"}
1728
+ wikiart_1024 = False # @param {type:"boolean"}
1729
+ wikiart_16384 = False # @param {type:"boolean"}
1730
+ wikiart_7mil = False # @param {type:"boolean"}
1731
+ sflckr = False # @param {type:"boolean"}
1732
+
1733
+ ##@markdown Experimental models (won't probably work, if you know how to make them work, go ahead :D):
1734
+ # celebahq = False #@param {type:"boolean"}
1735
+ # ade20k = False #@param {type:"boolean"}
1736
+ # drin = False #@param {type:"boolean"}
1737
+ # gumbel = False #@param {type:"boolean"}
1738
+ # gumbel_8192 = False #@param {type:"boolean"}
1739
+
1740
+ # Configure and run the model"""
1741
+
1742
+ # Commented out IPython magic to ensure Python compatibility.
1743
+ # @title <font color="lightgreen" size="+3">←</font> <font size="+2">🏃‍♂️</font> **Configure & Run** <font size="+2">🏃‍♂️</font>
1744
+
1745
+ import os
1746
+ import random
1747
+ import cv2
1748
+
1749
+ # from google.colab import drive
1750
+ from PIL import Image
1751
+ from importlib import reload
1752
+
1753
+ reload(PIL.TiffTags)
1754
+ # %cd /content/
1755
+ # @markdown >`prompts` is the list of prompts to give to the AI, separated by `|`. With more than one, it will attempt to mix them together. You can add weights to different parts of the prompt by adding a `p:x` at the end of a prompt (before a `|`) where `p` is the prompt and `x` is the weight.
1756
+
1757
+ # prompts = "A fantasy landscape, by Greg Rutkowski. A lush mountain.:1 | Trending on ArtStation, unreal engine. 4K HD, realism.:0.63" #@param {type:"string"}
1758
+
1759
+ prompts = args2.prompt
1760
+
1761
+ width = args2.sizex # @param {type:"number"}
1762
+ height = args2.sizey # @param {type:"number"}
1763
+
1764
+ # model = "ImageNet 16384" #@param ['ImageNet 16384', 'ImageNet 1024', "Gumbel 8192", "Sber Gumbel", 'WikiArt 1024', 'WikiArt 16384', 'WikiArt 7mil', 'COCO-Stuff', 'COCO 1 Stage', 'FacesHQ', 'S-FLCKR']
1765
+ #model = args2.vqgan_model
1766
+
1767
+ #if model == "Gumbel 8192" or model == "Sber Gumbel":
1768
+ # is_gumbel = True
1769
+ #else:
1770
+ # is_gumbel = False
1771
+ is_gumbel = False
1772
+ ##@markdown The flavor effects the output greatly. Each has it's own characteristics and depending on what you choose, you'll get a widely different result with the same prompt and seed. Ginger is the default, nothing special. Cumin results more of a painting, while Holywater makes everythng super funky and/or colorful. Custom is a custom flavor, use the utilities above.
1773
+ # Type "old_holywater" to use the old holywater flavor from Hypertron V1
1774
+ flavor = (
1775
+ args2.flavor
1776
+ ) #'ginger' #@param ["ginger", "cumin", "holywater", "zynth", "wyvern", "aaron", "moth", "juu", "custom"]
1777
+ template = (
1778
+ args2.template
1779
+ ) # @param ["none", "----------Parameter Tweaking----------", "Balanced", "Detailed", "Consistent Creativity", "Realistic", "Smooth", "Subtle MSE", "Hyper Fast Results", "----------Complete Overhaul----------", "flag", "planet", "creature", "human", "----------Sizes----------", "Size: Square", "Size: Landscape", "Size: Poster", "----------Prompt Modifiers----------", "Better - Fast", "Better - Slow", "Movie Poster", "Negative Prompt", "Better Quality"]
1780
+ ##@markdown To use initial or target images, upload it on the left in the file browser. You can also use previous outputs by putting its path below, e.g. `batch_01/0.png`. If your previous output is saved to drive, you can use the checkbox so you don't have to type the whole path.
1781
+ init = "default noise" # @param ["default noise", "image", "random image", "salt and pepper noise", "salt and pepper noise on init image"]
1782
+
1783
+ if args2.seed_image is None:
1784
+ init_image = "" # args2.seed_image #""#@param {type:"string"}
1785
+ else:
1786
+ init_image = args2.seed_image # ""#@param {type:"string"}
1787
+
1788
+ if init == "random image":
1789
+ url = (
1790
+ "https://picsum.photos/"
1791
+ + str(width)
1792
+ + "/"
1793
+ + str(height)
1794
+ + "?blur="
1795
+ + str(random.randrange(5, 10))
1796
+ )
1797
+ urllib.request.urlretrieve(url, "Init_Img/Image.png")
1798
+ init_image = "Init_Img/Image.png"
1799
+ elif init == "random image clear":
1800
+ url = "https://source.unsplash.com/random/" + str(width) + "x" + str(height)
1801
+ urllib.request.urlretrieve(url, "Init_Img/Image.png")
1802
+ init_image = "Init_Img/Image.png"
1803
+ elif init == "random image clear 2":
1804
+ url = "https://loremflickr.com/" + str(width) + "/" + str(height)
1805
+ urllib.request.urlretrieve(url, "Init_Img/Image.png")
1806
+ init_image = "Init_Img/Image.png"
1807
+ elif init == "salt and pepper noise":
1808
+ urllib.request.urlretrieve(
1809
+ "https://i.stack.imgur.com/olrL8.png", "Init_Img/Image.png"
1810
+ )
1811
+ import cv2
1812
+
1813
+ img = cv2.imread("Init_Img/Image.png", 0)
1814
+ cv2.imwrite("Init_Img/Image.png", add_noise(img))
1815
+ init_image = "Init_Img/Image.png"
1816
+ elif init == "salt and pepper noise on init image":
1817
+ img = cv2.imread(init_image, 0)
1818
+ cv2.imwrite("Init_Img/Image.png", add_noise(img))
1819
+ init_image = "Init_Img/Image.png"
1820
+ elif init == "perlin noise":
1821
+ # For some reason Colab started crashing from this
1822
+ import noise
1823
+ import numpy as np
1824
+ from PIL import Image
1825
+
1826
+ shape = (width, height)
1827
+ scale = 100
1828
+ octaves = 6
1829
+ persistence = 0.5
1830
+ lacunarity = 2.0
1831
+ seed = np.random.randint(0, 100000)
1832
+ world = np.zeros(shape)
1833
+ for i in range(shape[0]):
1834
+ for j in range(shape[1]):
1835
+ world[i][j] = noise.pnoise2(
1836
+ i / scale,
1837
+ j / scale,
1838
+ octaves=octaves,
1839
+ persistence=persistence,
1840
+ lacunarity=lacunarity,
1841
+ repeatx=1024,
1842
+ repeaty=1024,
1843
+ base=seed,
1844
+ )
1845
+ Image.fromarray(prep_world(world)).convert("L").save("Init_Img/Image.png")
1846
+ init_image = "Init_Img/Image.png"
1847
+ elif init == "black and white":
1848
+ url = "https://www.random.org/bitmaps/?format=png&width=300&height=300&zoom=1"
1849
+ urllib.request.urlretrieve(url, "Init_Img/Image.png")
1850
+ init_image = "Init_Img/Image.png"
1851
+
1852
+ seed = args2.seed # @param {type:"number"}
1853
+ # @markdown >iterations excludes iterations spent during the mse phase, if it is being used. The total iterations will be more if `mse_decay_rate` is more than 0.
1854
+ iterations = args2.iterations # @param {type:"number"}
1855
+ transparent_png = False # @param {type:"boolean"}
1856
+
1857
+ # @markdown <font size="+3">⚠</font> **ADVANCED SETTINGS** <font size="+3">⚠</font>
1858
+ # @markdown ---
1859
+ # @markdown ---
1860
+
1861
+ # @markdown >If you want to make multiple images with different prompts, use this. Seperate different prompts for different images with a `~` (example: `prompt1~prompt1~prompt3`). Iter is the iterations you want each image to run for. If you use MSE, I'd type a pretty low number (about 10).
1862
+ multiple_prompt_batches = False # @param {type:"boolean"}
1863
+ multiple_prompt_batches_iter = 300 # @param {type:"number"}
1864
+
1865
+ # @markdown >`folder_name` is the name of the folder you want to output your result(s) to. Previous outputs will NOT be overwritten. By default, it will be saved to the colab's root folder, but the `save_to_drive` checkbox will save it to `MyDrive\VQGAN_Output` instead.
1866
+ folder_name = "" # @param {type:"string"}
1867
+ save_to_drive = False # @param {type:"boolean"}
1868
+ prompt_experiment = "None" # @param ['None', 'Fever Dream', 'Philipuss’s Basement', 'Vivid Turmoil', 'Mad Dad', 'Platinum', 'Negative Energy']
1869
+ if prompt_experiment == "Fever Dream":
1870
+ prompts = "<|startoftext|>" + prompts + "<|endoftext|>"
1871
+ elif prompt_experiment == "Vivid Turmoil":
1872
+ prompts = prompts.replace(" ", "¡")
1873
+ prompts = "¬" + prompts + "®"
1874
+ elif prompt_experiment == "Mad Dad":
1875
+ prompts = prompts.replace(" ", "\\s+")
1876
+ elif prompt_experiment == "Platinum":
1877
+ prompts = "~!" + prompts + "!~"
1878
+ prompts = prompts.replace(" ", "</w>")
1879
+ elif prompt_experiment == "Philipuss’s Basement":
1880
+ prompts = "<|startoftext|>" + prompts
1881
+ prompts = prompts.replace(" ", "<|endoftext|><|startoftext|>")
1882
+ elif prompt_experiment == "Lowercase":
1883
+ prompts = prompts.lower()
1884
+
1885
+
1886
+ # @markdown >Target images work like prompts, write the name of the image. You can add multiple target images by seperating them with a `|`.
1887
+ target_images = "" # @param {type:"string"}
1888
+
1889
+ # @markdown ><font size="+2">☢</font> Advanced values. Values of cut_pow below 1 prioritize structure over detail, and vice versa for above 1. Step_size affects how wild the change between iterations is, and if final_step_size is not 0, step_size will interpolate towards it over time.
1890
+ # @markdown >Cutn affects on 'Creativity': less cutout will lead to more random/creative results, sometimes barely readable, while higher values (90+) lead to very stable, photo-like outputs
1891
+ cutn = 130 # @param {type:"number"}
1892
+ cut_pow = 1 # @param {type:"number"}
1893
+ # @markdown >Step_size is like weirdness. Lower: more accurate/realistic, slower; Higher: less accurate/more funky, faster.
1894
+ step_size = 0.1 # @param {type:"number"}
1895
+ # @markdown >Start_step_size is a temporary step_size that will be active only in the first 10 iterations. It (sometimes) helps with speed. If it's set to 0, it won't be used.
1896
+ start_step_size = 0 # @param {type:"number"}
1897
+ # @markdown >Final_step_size is a goal step_size which the AI will try and reach. If set to 0, it won't be used.
1898
+ final_step_size = 0 # @param {type:"number"}
1899
+ if start_step_size <= 0:
1900
+ start_step_size = step_size
1901
+ if final_step_size <= 0:
1902
+ final_step_size = step_size
1903
+
1904
+ # @markdown ---
1905
+
1906
+ # @markdown >EMA maintains a moving average of trained parameters. The number below is the rate of decay (higher means slower).
1907
+ ema_val = 0.98 # @param {type:"number"}
1908
+
1909
+ # @markdown >If you want to keep starting from the same point, set `gen_seed` to a positive number. `-1` will make it random every time.
1910
+ gen_seed = -1 # @param {type:'number'}
1911
+
1912
+ init_image_in_drive = False # @param {type:"boolean"}
1913
+ if init_image_in_drive and init_image:
1914
+ init_image = "/content/drive/MyDrive/VQGAN_Output/" + init_image
1915
+
1916
+ images_interval = args2.update # @param {type:"number"}
1917
+
1918
+ # I think you should give "Free Thoughts on the Proceedings of the Continental Congress" a read, really funny and actually well-written, Hamilton presented it in a bad light IMO.
1919
+
1920
+ batch_size = 1 # @param {type:"number"}
1921
+
1922
+ # @markdown ---
1923
+
1924
+ # @markdown <font size="+1">🔮</font> **MSE Regulization** <font size="+1">🔮</font>
1925
+ # Based off of this notebook: https://colab.research.google.com/drive/1gFn9u3oPOgsNzJWEFmdK-N9h_y65b8fj?usp=sharing - already in credits
1926
+ use_mse = args2.mse # @param {type:"boolean"}
1927
+ mse_images_interval = images_interval
1928
+ mse_init_weight = 0.2 # @param {type:"number"}
1929
+ mse_decay_rate = 160 # @param {type:"number"}
1930
+ mse_epoches = 10 # @param {type:"number"}
1931
+ ##@param {type:"number"}
1932
+
1933
+ # @markdown >Overwrites the usual values during the mse phase if included. If any value is 0, its normal counterpart is used instead.
1934
+ mse_with_zeros = True # @param {type:"boolean"}
1935
+ mse_step_size = 0.87 # @param {type:"number"}
1936
+ mse_cutn = 42 # @param {type:"number"}
1937
+ mse_cut_pow = 0.75 # @param {type:"number"}
1938
+
1939
+ # @markdown >normal_flip_optim flips between two optimizers during the normal (not MSE) phase. It can improve quality, but it's kind of experimental, use at your own risk.
1940
+ normal_flip_optim = True # @param {type:"boolean"}
1941
+ ##@markdown >Adding some TV may make the image blurrier but also helps to get rid of noise. A good value to try might be 0.1.
1942
+ # tv_weight = 0.1 #@param {type:'number'}
1943
+ # @markdown ---
1944
+
1945
+ # @markdown >`altprompts` is a set of prompts that take in a different augmentation pipeline, and can have their own cut_pow. At the moment, the default "alt augment" settings flip the picture cutouts upside down before evaluating. This can be good for optical illusion images. If either cut_pow value is 0, it will use the same value as the normal prompts.
1946
+ altprompts = "" # @param {type:"string"}
1947
+ altprompt_mode = "flipped"
1948
+ ##@param ["normal" , "flipped", "sideways"]
1949
+ alt_cut_pow = 0 # @param {type:"number"}
1950
+ alt_mse_cut_pow = 0 # @param {type:"number"}
1951
+ # altprompt_type = "upside-down" #@param ['upside-down', 'as']
1952
+
1953
+ ##@markdown ---
1954
+ ##@markdown <font size="+1">💫</font> **Zooming and Moving** <font size="+1">💫</font>
1955
+ zoom = False
1956
+ ##@param {type:"boolean"}
1957
+ zoom_speed = 100
1958
+ ##@param {type:"number"}
1959
+ zoom_frequency = 20
1960
+ ##@param {type:"number"}
1961
+
1962
+ # @markdown ---
1963
+ # @markdown On an unrelated note, if you get any errors while running this, restart the runtime and run the first cell again. If that doesn't work either, message me on Discord (Philipuss#4066).
1964
+
1965
+ model_names = {
1966
+ "vqgan_imagenet_f16_16384": "vqgan_imagenet_f16_16384",
1967
+ "ImageNet 1024": "vqgan_imagenet_f16_1024",
1968
+ "Gumbel 8192": "gumbel_8192",
1969
+ "Sber Gumbel": "sber_gumbel",
1970
+ "imagenet_cin": "imagenet_cin",
1971
+ "WikiArt 1024": "wikiart_1024",
1972
+ "WikiArt 16384": "wikiart_16384",
1973
+ "COCO-Stuff": "coco",
1974
+ "FacesHQ": "faceshq",
1975
+ "S-FLCKR": "sflckr",
1976
+ "WikiArt 7mil": "wikiart_7mil",
1977
+ "COCO 1 Stage": "coco_1stage",
1978
+ }
1979
+
1980
+ if template == "Better - Fast":
1981
+ prompts = prompts + ". Detailed artwork. ArtStationHQ. unreal engine. 4K HD."
1982
+ elif template == "Better - Slow":
1983
+ prompts = (
1984
+ prompts
1985
+ + ". Detailed artwork. Trending on ArtStation. unreal engine. | Rendered in Maya. "
1986
+ + prompts
1987
+ + ". 4K HD."
1988
+ )
1989
+ elif template == "Movie Poster":
1990
+ prompts = prompts + ". Movie poster. Rendered in unreal engine. ArtStationHQ."
1991
+ width = 400
1992
+ height = 592
1993
+ elif template == "flag":
1994
+ prompts = (
1995
+ "A photo of a flag of the country "
1996
+ + prompts
1997
+ + " | Flag of "
1998
+ + prompts
1999
+ + ". White background."
2000
+ )
2001
+ # import cv2
2002
+ # img = cv2.imread('templates/flag.png', 0)
2003
+ # cv2.imwrite('templates/final_flag.png', add_noise(img))
2004
+ init_image = "templates/flag.png"
2005
+ transparent_png = True
2006
+ elif template == "planet":
2007
+ import cv2
2008
+
2009
+ img = cv2.imread("templates/planet.png", 0)
2010
+ cv2.imwrite("templates/final_planet.png", add_noise(img))
2011
+ prompts = (
2012
+ "A photo of the planet "
2013
+ + prompts
2014
+ + ". Planet in the middle with black background. | The planet of "
2015
+ + prompts
2016
+ + ". Photo of a planet. Black background. Trending on ArtStation. | Colorful."
2017
+ )
2018
+ init_image = "templates/final_planet.png"
2019
+ elif template == "creature":
2020
+ # import cv2
2021
+ # img = cv2.imread('templates/planet.png', 0)
2022
+ # cv2.imwrite('templates/final_planet.png', add_noise(img))
2023
+ prompts = (
2024
+ "A photo of a creature with "
2025
+ + prompts
2026
+ + ". Animal in the middle with white background. | The creature has "
2027
+ + prompts
2028
+ + ". Photo of a creature/animal. White background. Detailed image of a creature. | White background."
2029
+ )
2030
+ init_image = "templates/creature.png"
2031
+ # transparent_png = True
2032
+ elif template == "Detailed":
2033
+ prompts = (
2034
+ prompts
2035
+ + ", by Puer Udger. Detailed artwork, trending on artstation. 4K HD, realism."
2036
+ )
2037
+ flavor = "cumin"
2038
+ elif template == "human":
2039
+ init_image = "/content/templates/human.png"
2040
+ elif template == "Realistic":
2041
+ cutn = 200
2042
+ step_size = 0.03
2043
+ cut_pow = 0.2
2044
+ flavor = "holywater"
2045
+ elif template == "Consistent Creativity":
2046
+ flavor = "cumin"
2047
+ cut_pow = 0.01
2048
+ cutn = 136
2049
+ step_size = 0.08
2050
+ mse_step_size = 0.41
2051
+ mse_cut_pow = 0.3
2052
+ ema_val = 0.99
2053
+ normal_flip_optim = False
2054
+ elif template == "Smooth":
2055
+ flavor = "wyvern"
2056
+ step_size = 0.10
2057
+ cutn = 120
2058
+ normal_flip_optim = False
2059
+ tv_weight = 10
2060
+ elif template == "Subtle MSE":
2061
+ mse_init_weight = 0.07
2062
+ mse_decay_rate = 130
2063
+ mse_step_size = 0.2
2064
+ mse_cutn = 100
2065
+ mse_cut_pow = 0.6
2066
+ elif template == "Balanced":
2067
+ cutn = 130
2068
+ cut_pow = 1
2069
+ step_size = 0.16
2070
+ final_step_size = 0
2071
+ ema_val = 0.98
2072
+ mse_init_weight = 0.2
2073
+ mse_decay_rate = 130
2074
+ mse_with_zeros = True
2075
+ mse_step_size = 0.9
2076
+ mse_cutn = 50
2077
+ mse_cut_pow = 0.8
2078
+ normal_flip_optim = True
2079
+ elif template == "Size: Square":
2080
+ width = 450
2081
+ height = 450
2082
+ elif template == "Size: Landscape":
2083
+ width = 480
2084
+ height = 336
2085
+ elif template == "Size: Poster":
2086
+ width = 336
2087
+ height = 480
2088
+ elif template == "Negative Prompt":
2089
+ prompts = prompts.replace(":", ":-")
2090
+ prompts = prompts.replace(":--", ":")
2091
+ elif template == "Hyper Fast Results":
2092
+ step_size = 1
2093
+ ema_val = 0.3
2094
+ cutn = 30
2095
+ elif template == "Better Quality":
2096
+ prompts = (
2097
+ prompts + ":1 | Watermark, blurry, cropped, confusing, cut, incoherent:-1"
2098
+ )
2099
+
2100
+ mse_decay = 0
2101
+
2102
+ if use_mse == False:
2103
+ mse_init_weight = 0.0
2104
+ else:
2105
+ mse_decay = mse_init_weight / mse_epoches
2106
+
2107
+
2108
+ if seed == -1:
2109
+ seed = None
2110
+ if init_image == "None":
2111
+ init_image = None
2112
+ if target_images == "None" or not target_images:
2113
+ target_images = []
2114
+ else:
2115
+ target_images = target_images.split("|")
2116
+ target_images = [image.strip() for image in target_images]
2117
+
2118
+ prompts = [phrase.strip() for phrase in prompts.split("|")]
2119
+ if prompts == [""]:
2120
+ prompts = []
2121
+
2122
+ altprompts = [phrase.strip() for phrase in altprompts.split("|")]
2123
+ if altprompts == [""]:
2124
+ altprompts = []
2125
+
2126
+ if mse_images_interval == 0:
2127
+ mse_images_interval = images_interval
2128
+ if mse_step_size == 0:
2129
+ mse_step_size = step_size
2130
+ if mse_cutn == 0:
2131
+ mse_cutn = cutn
2132
+ if mse_cut_pow == 0:
2133
+ mse_cut_pow = cut_pow
2134
+ if alt_cut_pow == 0:
2135
+ alt_cut_pow = cut_pow
2136
+ if alt_mse_cut_pow == 0:
2137
+ alt_mse_cut_pow = mse_cut_pow
2138
+
2139
+ augs = nn.Sequential(
2140
+ K.RandomHorizontalFlip(p=0.5),
2141
+ K.RandomSharpness(0.3, p=0.4),
2142
+ K.RandomGaussianBlur((3, 3), (4.5, 4.5), p=0.3),
2143
+ # K.RandomGaussianNoise(p=0.5),
2144
+ # K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
2145
+ K.RandomAffine(
2146
+ degrees=30, translate=0.1, p=0.8, padding_mode="border"
2147
+ ), # padding_mode=2
2148
+ K.RandomPerspective(
2149
+ 0.2,
2150
+ p=0.4,
2151
+ ),
2152
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
2153
+ K.RandomGrayscale(p=0.1),
2154
+ )
2155
+
2156
+ if altprompt_mode == "normal":
2157
+ altaugs = nn.Sequential(
2158
+ K.RandomRotation(degrees=90.0, return_transform=True),
2159
+ K.RandomHorizontalFlip(p=0.5),
2160
+ K.RandomSharpness(0.3, p=0.4),
2161
+ K.RandomGaussianBlur((3, 3), (4.5, 4.5), p=0.3),
2162
+ # K.RandomGaussianNoise(p=0.5),
2163
+ # K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
2164
+ K.RandomAffine(
2165
+ degrees=30, translate=0.1, p=0.8, padding_mode="border"
2166
+ ), # padding_mode=2
2167
+ K.RandomPerspective(
2168
+ 0.2,
2169
+ p=0.4,
2170
+ ),
2171
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
2172
+ K.RandomGrayscale(p=0.1),
2173
+ )
2174
+ elif altprompt_mode == "flipped":
2175
+ altaugs = nn.Sequential(
2176
+ K.RandomHorizontalFlip(p=0.5),
2177
+ # K.RandomRotation(degrees=90.0),
2178
+ K.RandomVerticalFlip(p=1),
2179
+ K.RandomSharpness(0.3, p=0.4),
2180
+ K.RandomGaussianBlur((3, 3), (4.5, 4.5), p=0.3),
2181
+ # K.RandomGaussianNoise(p=0.5),
2182
+ # K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
2183
+ K.RandomAffine(
2184
+ degrees=30, translate=0.1, p=0.8, padding_mode="border"
2185
+ ), # padding_mode=2
2186
+ K.RandomPerspective(
2187
+ 0.2,
2188
+ p=0.4,
2189
+ ),
2190
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
2191
+ K.RandomGrayscale(p=0.1),
2192
+ )
2193
+ elif altprompt_mode == "sideways":
2194
+ altaugs = nn.Sequential(
2195
+ K.RandomHorizontalFlip(p=0.5),
2196
+ # K.RandomRotation(degrees=90.0),
2197
+ K.RandomVerticalFlip(p=1),
2198
+ K.RandomSharpness(0.3, p=0.4),
2199
+ K.RandomGaussianBlur((3, 3), (4.5, 4.5), p=0.3),
2200
+ # K.RandomGaussianNoise(p=0.5),
2201
+ # K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
2202
+ K.RandomAffine(
2203
+ degrees=30, translate=0.1, p=0.8, padding_mode="border"
2204
+ ), # padding_mode=2
2205
+ K.RandomPerspective(
2206
+ 0.2,
2207
+ p=0.4,
2208
+ ),
2209
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
2210
+ K.RandomGrayscale(p=0.1),
2211
+ )
2212
+
2213
+ if multiple_prompt_batches:
2214
+ prompts_all = str(prompts).split("~")
2215
+ else:
2216
+ prompts_all = prompts
2217
+ multiple_prompt_batches_iter = iterations
2218
+
2219
+ if multiple_prompt_batches:
2220
+ mtpl_prmpts_btchs = len(prompts_all)
2221
  else:
2222
+ mtpl_prmpts_btchs = 1
2223
+
2224
+ # print(mtpl_prmpts_btchs)
2225
+
2226
+ steps_path = "./"
2227
+ zoom_path = "./"
2228
+
2229
+ path = "./"
2230
+
2231
+ iterations = multiple_prompt_batches_iter
2232
+
2233
+ for pr in range(0, mtpl_prmpts_btchs):
2234
+ # print(prompts_all[pr].replace('[\'', '').replace('\']', ''))
2235
+ if multiple_prompt_batches:
2236
+ prompts = prompts_all[pr].replace("['", "").replace("']", "")
2237
+
2238
+ if zoom:
2239
+ mdf_iter = round(iterations / zoom_frequency)
2240
+ else:
2241
+ mdf_iter = 2
2242
+ zoom_frequency = iterations
2243
+
2244
+ for iter in range(1, mdf_iter):
2245
+ if zoom:
2246
+ if iter != 0:
2247
+ image = Image.open("progress.png")
2248
+ area = (0, 0, width - zoom_speed, height - zoom_speed)
2249
+ cropped_img = image.crop(area)
2250
+ cropped_img.show()
2251
+
2252
+ new_image = cropped_img.resize((width, height))
2253
+ new_image.save("zoom.png")
2254
+ init_image = "zoom.png"
2255
+
2256
+ args = argparse.Namespace(
2257
+ prompts=prompts,
2258
+ altprompts=altprompts,
2259
+ image_prompts=target_images,
2260
+ noise_prompt_seeds=[],
2261
+ noise_prompt_weights=[],
2262
+ size=[width, height],
2263
+ init_image=init_image,
2264
+ png=transparent_png,
2265
+ init_weight=mse_init_weight,
2266
+ #vqgan_model=model_names[model],
2267
+ step_size=step_size,
2268
+ start_step_size=start_step_size,
2269
+ final_step_size=final_step_size,
2270
+ cutn=cutn,
2271
+ cut_pow=cut_pow,
2272
+ mse_cutn=mse_cutn,
2273
+ mse_cut_pow=mse_cut_pow,
2274
+ mse_step_size=mse_step_size,
2275
+ display_freq=images_interval,
2276
+ mse_display_freq=mse_images_interval,
2277
+ max_iterations=zoom_frequency,
2278
+ mse_end=0,
2279
+ seed=seed,
2280
+ folder_name=folder_name,
2281
+ save_to_drive=save_to_drive,
2282
+ mse_decay_rate=mse_decay_rate,
2283
+ mse_decay=mse_decay,
2284
+ mse_with_zeros=mse_with_zeros,
2285
+ normal_flip_optim=normal_flip_optim,
2286
+ ema_val=ema_val,
2287
+ augs=augs,
2288
+ altaugs=altaugs,
2289
+ alt_cut_pow=alt_cut_pow,
2290
+ alt_mse_cut_pow=alt_mse_cut_pow,
2291
+ is_gumbel=is_gumbel,
2292
+ gen_seed=gen_seed,
2293
+ )
2294
+ mh = ModelHost(args)
2295
+ x = 0
2296
+
2297
+ #for x in range(batch_size):
2298
+ mh.setup_model(x)
2299
+ mh.run(x)
2300
+ image_data = Image.open(args2.image_file)
2301
+ os.remove(args2.image_file)
2302
+ return(image_data)
2303
+ #return(last_iter)
2304
+ #x = x + 1
2305
+
2306
+ if zoom:
2307
+ files = os.listdir(steps_path)
2308
+ for index, file in enumerate(files):
2309
+ os.rename(
2310
+ os.path.join(steps_path, file),
2311
+ os.path.join(
2312
+ steps_path,
2313
+ "".join([str(index + 1 + zoom_frequency * iter), ".png"]),
2314
+ ),
2315
+ )
2316
+ index = index + 1
2317
+
2318
+ from pathlib import Path
2319
+ import shutil
2320
+
2321
+ src_path = steps_path
2322
+ trg_path = zoom_path
2323
+
2324
+ for src_file in range(1, mdf_iter):
2325
+ shutil.move(os.path.join(src_path, src_file), trg_path)
2326
+
2327
+ ##################### START GRADIO HERE ############################
2328
+ image = gr.outputs.Image(type="pil", label="Your result")
2329
+ #def cvt_2_base64(file_name):
2330
+ # with open(file_name , "rb") as image_file :
2331
+ # data = base64.b64encode(image_file.read())
2332
+ # return data.decode('utf-8')
2333
+ #base64image = "data:image/jpg;base64,"+cvt_2_base64('flavors.jpg')
2334
+ #markdown = gr.Markdown("<img src='"+base64image+"' />")
2335
+ #def test(raw_input):
2336
+ # pass
2337
+ #setattr(markdown, "requires_permissions", False)
2338
+ #setattr(markdown, "label", "Flavors")
2339
+ #setattr(markdown, "preprocess", test)
2340
+ iface = gr.Interface(
2341
+ fn=run_all,
2342
+ inputs=[
2343
+ gr.inputs.Textbox(label="Prompt - try adding increments to your prompt such as 'oil on canvas', 'a painting', 'a book cover'",default="chalk pastel drawing of a dog wearing a funny hat"),
2344
+ gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=50,maximum=250,minimum=1,step=1),
2345
+ gr.inputs.Dropdown(label="Flavor",choices=["ginger", "cumin", "holywater", "zynth", "wyvern", "aaron", "moth", "juu", "custom"]),
2346
+ #markdown,
2347
+ gr.inputs.Dropdown(label="Style",choices=["Default","Balanced","Detailed","Consistent Creativity","Realistic","Smooth","Subtle MSE","Hyper Fast Results"],default="Hyper Fast Results"),
2348
+ gr.inputs.Radio(label="Width", choices=[32,64,128,256,512],default=512),
2349
+ gr.inputs.Radio(label="Height", choices=[32,64,128,256,512],default=512),
2350
+ ],
2351
+ outputs=image,
2352
+ title="Generate images from text with VQGAN+CLIP (Hypertron v2)",
2353
+ #description="<div>By typing a prompt and pressing submit you can generate images based on this prompt. <a href='https://github.com/CompVis/latent-diffusion' target='_blank'>Latent Diffusion</a> is a text-to-image model created by <a href='https://github.com/CompVis' target='_blank'>CompVis</a>, trained on the <a href='https://laion.ai/laion-400-open-dataset/'>LAION-400M dataset.</a><br>This UI to the model was assembled by <a style='color: rgb(245, 158, 11);font-weight:bold' href='https://twitter.com/multimodalart' target='_blank'>@multimodalart</a></div>",
2354
+ #article="<h4 style='font-size: 110%;margin-top:.5em'>Biases acknowledgment</h4><div>Despite how impressive being able to turn text into image is, beware to the fact that this model may output content that reinforces or exarcbates societal biases. According to the <a href='https://arxiv.org/abs/2112.10752' target='_blank'>Latent Diffusion paper</a>:<i> \"Deep learning modules tend to reproduce or exacerbate biases that are already present in the data\"</i>. The model was trained on an unfiltered version the LAION-400M dataset, which scrapped non-curated image-text-pairs from the internet (the exception being the the removal of illegal content) and is meant to be used for research purposes, such as this one. <a href='https://laion.ai/laion-400-open-dataset/' target='_blank'>You can read more on LAION's website</a></div><h4 style='font-size: 110%;margin-top:1em'>Who owns the images produced by this demo?</h4><div>Definetly not me! Probably you do. I say probably because the Copyright discussion about AI generated art is ongoing. So <a href='https://www.theverge.com/2022/2/21/22944335/us-copyright-office-reject-ai-generated-art-recent-entrance-to-paradise' target='_blank'>it may be the case that everything produced here falls automatically into the public domain</a>. But in any case it is either yours or is in the public domain.</div>"
2355
+ )
2356
  iface.launch()
requirements.txt CHANGED
@@ -1 +1,29 @@
1
- torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -e git+https://github.com/CompVis/taming-transformers.git#egg=taming-transformers
2
+ -e git+https://github.com/openai/CLIP/#egg=CLIP
3
+ gitpython
4
+ ftfy
5
+ regex
6
+ pandas
7
+ omegaconf
8
+ pytorch-lightning
9
+ torch-fidelity
10
+ transformers
11
+ einops
12
+ noise
13
+ gputil
14
+ gradio
15
+ torch
16
+ numpy
17
+ tqdm
18
+ torchvision
19
+ Pillow
20
+ autokeras
21
+ huggingface_hub
22
+ kornia
23
+ imageio
24
+ pathvalidate
25
+ stegano
26
+ imgtag
27
+ timm
28
+ python-xmp-toolkit
29
+ shortuuid