apolinario commited on
Commit
ddad699
1 Parent(s): b18ff48

Initial attempt

Browse files
Files changed (4) hide show
  1. .gitignore +7 -0
  2. app.py +282 -8
  3. packages.txt +1 -0
  4. requirements.txt +5 -1
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ gradio_queue.db
2
+ gradio_queue.db-journal
3
+ stylegan_xl
4
+ samples
5
+ flagged
6
+ *.pkl
7
+ *.mp4
app.py CHANGED
@@ -1,11 +1,285 @@
1
  import gradio as gr
2
- import torch
 
3
 
4
- is_cuda = torch.cuda.is_available()
5
- def greet(name):
6
- if is_cuda:
7
- return "Hello cuda" + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  else:
9
- return "Hello ooops" + name + "!!"
10
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
11
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from git.repo.base import Repo
3
+ from os.path import exists as path_exists
4
 
5
+ if not (path_exists(f"stylegan_xl")):
6
+ Repo.clone_from("https://github.com/autonomousvision/stylegan_xl", "stylegan_xl")
7
+
8
+ import sys
9
+ sys.path.append('./CLIP')
10
+ sys.path.append('./stylegan_xl')
11
+
12
+ import io
13
+ import os, time, glob
14
+ import pickle
15
+ import shutil
16
+ import numpy as np
17
+ from PIL import Image
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import requests
21
+ import torchvision.transforms as transforms
22
+ import torchvision.transforms.functional as TF
23
+ import clip
24
+ import unicodedata
25
+ import re
26
+ from tqdm.notebook import tqdm
27
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize
28
+ from IPython.display import display
29
+ from einops import rearrange
30
+ import dnnlib
31
+ import legacy
32
+ import subprocess
33
+
34
+ torch.cuda.empty_cache()
35
+ device = torch.device('cuda:0')
36
+ print('Using device:', device, file=sys.stderr)
37
+
38
+ def fetch(url_or_path):
39
+ if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
40
+ r = requests.get(url_or_path)
41
+ r.raise_for_status()
42
+ fd = io.BytesIO()
43
+ fd.write(r.content)
44
+ fd.seek(0)
45
+ return fd
46
+ return open(url_or_path, 'rb')
47
+
48
+ def fetch_model(url_or_path,network_name):
49
+ torch.hub.download_url_to_file(f'{url_or_path}',f'./{network_name}')
50
+
51
+ def slugify(value, allow_unicode=False):
52
+ """
53
+ Taken from https://github.com/django/django/blob/master/django/utils/text.py
54
+ Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
55
+ dashes to single dashes. Remove characters that aren't alphanumerics,
56
+ underscores, or hyphens. Convert to lowercase. Also strip leading and
57
+ trailing whitespace, dashes, and underscores.
58
+ """
59
+ value = str(value)
60
+ if allow_unicode:
61
+ value = unicodedata.normalize('NFKC', value)
62
  else:
63
+ value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
64
+ value = re.sub(r'[^\w\s-]', '', value.lower())
65
+ return re.sub(r'[-\s]+', '-', value).strip('-_')
66
+
67
+ def norm1(prompt):
68
+ "Normalize to the unit sphere."
69
+ return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt()
70
+
71
+ def spherical_dist_loss(x, y):
72
+ x = F.normalize(x, dim=-1)
73
+ y = F.normalize(y, dim=-1)
74
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
75
+
76
+ def prompts_dist_loss(x, targets, loss):
77
+ if len(targets) == 1: # Keeps consitent results vs previous method for single objective guidance
78
+ return loss(x, targets[0])
79
+ distances = [loss(x, target) for target in targets]
80
+ return torch.stack(distances, dim=-1).sum(dim=-1)
81
+
82
+ class MakeCutouts(torch.nn.Module):
83
+ def __init__(self, cut_size, cutn, cut_pow=1.):
84
+ super().__init__()
85
+ self.cut_size = cut_size
86
+ self.cutn = cutn
87
+ self.cut_pow = cut_pow
88
+
89
+ def forward(self, input):
90
+ sideY, sideX = input.shape[2:4]
91
+ max_size = min(sideX, sideY)
92
+ min_size = min(sideX, sideY, self.cut_size)
93
+ cutouts = []
94
+ for _ in range(self.cutn):
95
+ size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
96
+ offsetx = torch.randint(0, sideX - size + 1, ())
97
+ offsety = torch.randint(0, sideY - size + 1, ())
98
+ cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
99
+ cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
100
+ return torch.cat(cutouts)
101
+
102
+ make_cutouts = MakeCutouts(224, 32, 0.5)
103
+
104
+ def embed_image(image):
105
+ n = image.shape[0]
106
+ cutouts = make_cutouts(image)
107
+ embeds = clip_model.embed_cutout(cutouts)
108
+ embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n)
109
+ return embeds
110
+
111
+ def embed_url(url):
112
+ image = Image.open(fetch(url)).convert('RGB')
113
+ return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)
114
+
115
+ class CLIP(object):
116
+ def __init__(self):
117
+ clip_model = "ViT-B/16"
118
+ self.model, _ = clip.load(clip_model)
119
+ self.model = self.model.requires_grad_(False)
120
+ self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
121
+ std=[0.26862954, 0.26130258, 0.27577711])
122
+
123
+ @torch.no_grad()
124
+ def embed_text(self, prompt):
125
+ "Normalized clip text embedding."
126
+ return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float())
127
+
128
+ def embed_cutout(self, image):
129
+ "Normalized clip image embedding."
130
+ return norm1(self.model.encode_image(self.normalize(image)))
131
+
132
+ clip_model = CLIP()
133
+
134
+ #@markdown #**Model selection** 🎭
135
+
136
+ Models = ["Imagenet256", "Imagenet512", "Imagenet1024", "Pokemon", "FFHQ"]
137
+
138
+ #@markdown ---
139
+
140
+ network_url = {
141
+ "Imagenet256":"https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet256.pkl",
142
+ "Imagenet512": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet512.pkl",
143
+ "Imagenet1024": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet1024.pkl",
144
+ "Pokemon": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl",
145
+ "FFHQ": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/ffhq256.pkl"
146
+ }
147
+
148
+ for Model in Models:
149
+ network_name = network_url[Model].split("/")[-1]
150
+ if not (path_exists(network_name)):
151
+ fetch_model(network_url[Model],network_name)
152
+
153
+ def load_current_model(current_model="Imagenet256.pkl"):
154
+ with dnnlib.util.open_url(current_model) as f:
155
+ G = legacy.load_network_pkl(f)['G_ema'].to(device)
156
+
157
+ zs = torch.randn([10000, G.mapping.z_dim], device=device)
158
+ cs = torch.zeros([10000, G.mapping.c_dim], device=device)
159
+ for i in range(cs.shape[0]):
160
+ cs[i,i//10]=1
161
+ w_stds = G.mapping(zs, cs)
162
+ w_stds = w_stds.reshape(10, 1000, G.num_ws, -1)
163
+ w_stds=w_stds.std(0).mean(0)[0]
164
+ w_all_classes_avg = G.mapping.w_avg.mean(0)
165
+ return(G,w_stds,w_all_classes_avg)
166
+
167
+ G, w_stds, w_all_classes_avg = load_current_model()
168
+ print(w_stds)
169
+ previousModel = 'imagenet256'
170
+ def run(prompt,steps,model):
171
+ global G, w_stds, w_all_classes_avg, previousModel
172
+ if(model == 'imagenet256' and previousModel != 'imagenet256'):
173
+ G, w_stds, w_all_classes_avg = load_current_model('imagenet256.pkl')
174
+ if(model == 'imagenet512' and previousModel != 'imagenet512'):
175
+ G, w_stds, w_all_classes_avg = load_current_model('imagenet512.pkl')
176
+ elif(model=='imagenet1024' and previousModel != 'imagenet1024'):
177
+ G, w_stds, w_all_classes_avg = load_current_model('imagenet1024.pkl')
178
+ elif(model=='pokemon256' and previousModel != 'pokemon256'):
179
+ G, w_stds, w_all_classes_avg = load_current_model('pokemon256.pkl')
180
+ elif(model=='ffhq256' and previousModel != 'ffhq256'):
181
+ G, w_stds, w_all_classes_avg = load_current_model('ffhq256.pkl')
182
+ previousModel = model
183
+
184
+ texts = prompt
185
+ steps = steps
186
+ seed = -1 # @param {type:"number"}
187
+
188
+ # @markdown ---
189
+
190
+ if seed == -1:
191
+ seed = np.random.randint(0, 9e9)
192
+ print(f"Your random seed is: {seed}")
193
+
194
+ texts = [frase.strip() for frase in texts.split("|") if frase]
195
+
196
+ targets = [clip_model.embed_text(text) for text in texts]
197
+
198
+ tf = Compose(
199
+ [
200
+ # Resize(224),
201
+ lambda x: torch.clamp((x + 1) / 2, min=0, max=1),
202
+ ]
203
+ )
204
+
205
+ initial_batch = 4 # actually that will be multiplied by initial_image_steps
206
+ initial_image_steps = 32
207
+
208
+ def get_image(timestring):
209
+ os.makedirs(f"samples/{timestring}", exist_ok=True)
210
+ torch.manual_seed(seed)
211
+ with torch.no_grad():
212
+ qs = []
213
+ losses = []
214
+ for _ in range(initial_image_steps):
215
+ a = torch.randn([initial_batch, 512], device=device) * 0.4 + w_stds * 0.4
216
+ q = (a - w_all_classes_avg) / w_stds
217
+ images = G.synthesis(
218
+ (q * w_stds + w_all_classes_avg).unsqueeze(1).repeat([1, G.num_ws, 1])
219
+ )
220
+ embeds = embed_image(images.add(1).div(2))
221
+ loss = prompts_dist_loss(embeds, targets, spherical_dist_loss).mean(0)
222
+ i = torch.argmin(loss)
223
+ qs.append(q[i])
224
+ losses.append(loss[i])
225
+ qs = torch.stack(qs)
226
+ losses = torch.stack(losses)
227
+ i = torch.argmin(losses)
228
+ q = qs[i].unsqueeze(0).repeat([G.num_ws, 1]).requires_grad_()
229
+
230
+ # Sampling loop
231
+ q_ema = q
232
+ print(q.shape)
233
+ opt = torch.optim.AdamW([q], lr=0.05, betas=(0.0, 0.999), weight_decay=0.025)
234
+ loop = tqdm(range(steps))
235
+ for i in loop:
236
+ opt.zero_grad()
237
+ w = q * w_stds
238
+ image = G.synthesis((q * w_stds + w_all_classes_avg)[None], noise_mode="const")
239
+ embed = embed_image(image.add(1).div(2))
240
+ loss = prompts_dist_loss(embed, targets, spherical_dist_loss).mean()
241
+ loss.backward()
242
+ opt.step()
243
+ loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())
244
+
245
+ q_ema = q_ema * 0.98 + q * 0.02
246
+ image = G.synthesis(
247
+ (q_ema * w_stds + w_all_classes_avg)[None], noise_mode="const"
248
+ )
249
+
250
+ pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0, 1))
251
+ pil_image.save(f"samples/{timestring}/{i:04}.jpg")
252
+
253
+ if (i+1) % steps == 0:
254
+ #/usr/bin/
255
+ subprocess.call(['ffmpeg', '-r', '60', '-i', f'samples/{timestring}/%04d.jpg', '-vcodec', 'libx264', '-crf','18','-pix_fmt','yuv420p', f'{timestring}.mp4'])
256
+ shutil.rmtree(f"samples/{timestring}")
257
+ pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0, 1))
258
+ return(pil_image, f'{timestring}.mp4')
259
+
260
+ try:
261
+ timestring = time.strftime("%Y%m%d%H%M%S")
262
+ image,video = get_image(timestring)
263
+ return([image,video])
264
+ except KeyboardInterrupt:
265
+ pass
266
+
267
+ image = gr.outputs.Image(type="pil", label="Your imge")
268
+ video = gr.outputs.Video(type="mp4", label="Your video")
269
+ css = ".output-image{height: 528px !important},.output-video{height: 528px !important}"
270
+ iface = gr.Interface(fn=run, inputs=[
271
+ gr.inputs.Textbox(label="Prompt",default="chalk pastel drawing of a dog wearing a funny hat"),
272
+ gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=300,maximum=500,minimum=10,step=1),
273
+ #gr.inputs.Radio(label="Aspect Ratio", choices=["Square", "Horizontal", "Vertical"],default="Horizontal"),
274
+ gr.inputs.Dropdown(label="Model", choices=["imagenet256","imagenet512","imagenet1024","Pokemon256", "ffhq256"], default="imagenet256")
275
+ #gr.inputs.Radio(label="Height", choices=[32,64,128,256,512],default=256),
276
+ #gr.inputs.Slider(label="Images - How many images you wish to generate", default=2, step=1, minimum=1, maximum=4),
277
+ #gr.inputs.Slider(label="Diversity scale - How different from one another you wish the images to be",default=5.0, minimum=1.0, maximum=15.0),
278
+ #gr.inputs.Slider(label="ETA - between 0 and 1. Lower values can provide better quality, higher values can be more diverse",default=0.0,minimum=0.0, maximum=1.0,step=0.1),
279
+ ],
280
+ outputs=[image,video],
281
+ css=css,
282
+ title="Generate images from text with StyleGAN XL + CLIP",
283
+ description="<div>By typing a prompt and pressing submit you can generate images based on this prompt. <a href='https://github.com/CompVis/latent-diffusion' target='_blank'>ruDALLE</a> is an open source text-to-image model, this Arbitrary Aspect ration implementation was created by <a href='https://github.com/shonenkov-AI' target='_blank'>Alex Shonenkov</a><br>This UI to the model was assembled by <a style='color: rgb(245, 158, 11);font-weight:bold' href='https://twitter.com/multimodalart' target='_blank'>@multimodalart</a></div>",
284
+ article="<h4 style='font-size: 110%;margin-top:.5em'>Biases acknowledgment</h4><div>Despite how impressive being able to turn text into image is, beware to the fact that this model may output content that reinforces or exarcbates societal biases. According to the <a href='https://arxiv.org/abs/2112.10752' target='_blank'>Latent Diffusion paper</a>:<i> \"Deep learning modules tend to reproduce or exacerbate biases that are already present in the data\"</i>. The models are meant to be used for research purposes, such as this one.</div><h4 style='font-size: 110%;margin-top:1em'>Who owns the images produced by this demo?</h4><div>Definetly not me! Probably you do. I say probably because the Copyright discussion about AI generated art is ongoing. So <a href='https://www.theverge.com/2022/2/21/22944335/us-copyright-office-reject-ai-generated-art-recent-entrance-to-paradise' target='_blank'>it may be the case that everything produced here falls automatically into the public domain</a>. But in any case it is either yours or is in the public domain.</div>")
285
+ iface.launch(enable_queue=True)
packages.txt ADDED
@@ -0,0 +1 @@
 
1
+ ffmpeg
requirements.txt CHANGED
@@ -1 +1,5 @@
1
- torch
 
 
 
 
1
+ torch
2
+ -e git+https://github.com/openai/CLIP.git#egg=CLIP
3
+ einops
4
+ ninja
5
+ timm==0.4.12