Spaces:
Runtime error
Runtime error
apolinario
commited on
Commit
•
ddad699
1
Parent(s):
b18ff48
Initial attempt
Browse files- .gitignore +7 -0
- app.py +282 -8
- packages.txt +1 -0
- requirements.txt +5 -1
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio_queue.db
|
2 |
+
gradio_queue.db-journal
|
3 |
+
stylegan_xl
|
4 |
+
samples
|
5 |
+
flagged
|
6 |
+
*.pkl
|
7 |
+
*.mp4
|
app.py
CHANGED
@@ -1,11 +1,285 @@
|
|
1 |
import gradio as gr
|
2 |
-
import
|
|
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
else:
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from git.repo.base import Repo
|
3 |
+
from os.path import exists as path_exists
|
4 |
|
5 |
+
if not (path_exists(f"stylegan_xl")):
|
6 |
+
Repo.clone_from("https://github.com/autonomousvision/stylegan_xl", "stylegan_xl")
|
7 |
+
|
8 |
+
import sys
|
9 |
+
sys.path.append('./CLIP')
|
10 |
+
sys.path.append('./stylegan_xl')
|
11 |
+
|
12 |
+
import io
|
13 |
+
import os, time, glob
|
14 |
+
import pickle
|
15 |
+
import shutil
|
16 |
+
import numpy as np
|
17 |
+
from PIL import Image
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
import requests
|
21 |
+
import torchvision.transforms as transforms
|
22 |
+
import torchvision.transforms.functional as TF
|
23 |
+
import clip
|
24 |
+
import unicodedata
|
25 |
+
import re
|
26 |
+
from tqdm.notebook import tqdm
|
27 |
+
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
|
28 |
+
from IPython.display import display
|
29 |
+
from einops import rearrange
|
30 |
+
import dnnlib
|
31 |
+
import legacy
|
32 |
+
import subprocess
|
33 |
+
|
34 |
+
torch.cuda.empty_cache()
|
35 |
+
device = torch.device('cuda:0')
|
36 |
+
print('Using device:', device, file=sys.stderr)
|
37 |
+
|
38 |
+
def fetch(url_or_path):
|
39 |
+
if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
|
40 |
+
r = requests.get(url_or_path)
|
41 |
+
r.raise_for_status()
|
42 |
+
fd = io.BytesIO()
|
43 |
+
fd.write(r.content)
|
44 |
+
fd.seek(0)
|
45 |
+
return fd
|
46 |
+
return open(url_or_path, 'rb')
|
47 |
+
|
48 |
+
def fetch_model(url_or_path,network_name):
|
49 |
+
torch.hub.download_url_to_file(f'{url_or_path}',f'./{network_name}')
|
50 |
+
|
51 |
+
def slugify(value, allow_unicode=False):
|
52 |
+
"""
|
53 |
+
Taken from https://github.com/django/django/blob/master/django/utils/text.py
|
54 |
+
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
|
55 |
+
dashes to single dashes. Remove characters that aren't alphanumerics,
|
56 |
+
underscores, or hyphens. Convert to lowercase. Also strip leading and
|
57 |
+
trailing whitespace, dashes, and underscores.
|
58 |
+
"""
|
59 |
+
value = str(value)
|
60 |
+
if allow_unicode:
|
61 |
+
value = unicodedata.normalize('NFKC', value)
|
62 |
else:
|
63 |
+
value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
|
64 |
+
value = re.sub(r'[^\w\s-]', '', value.lower())
|
65 |
+
return re.sub(r'[-\s]+', '-', value).strip('-_')
|
66 |
+
|
67 |
+
def norm1(prompt):
|
68 |
+
"Normalize to the unit sphere."
|
69 |
+
return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt()
|
70 |
+
|
71 |
+
def spherical_dist_loss(x, y):
|
72 |
+
x = F.normalize(x, dim=-1)
|
73 |
+
y = F.normalize(y, dim=-1)
|
74 |
+
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
|
75 |
+
|
76 |
+
def prompts_dist_loss(x, targets, loss):
|
77 |
+
if len(targets) == 1: # Keeps consitent results vs previous method for single objective guidance
|
78 |
+
return loss(x, targets[0])
|
79 |
+
distances = [loss(x, target) for target in targets]
|
80 |
+
return torch.stack(distances, dim=-1).sum(dim=-1)
|
81 |
+
|
82 |
+
class MakeCutouts(torch.nn.Module):
|
83 |
+
def __init__(self, cut_size, cutn, cut_pow=1.):
|
84 |
+
super().__init__()
|
85 |
+
self.cut_size = cut_size
|
86 |
+
self.cutn = cutn
|
87 |
+
self.cut_pow = cut_pow
|
88 |
+
|
89 |
+
def forward(self, input):
|
90 |
+
sideY, sideX = input.shape[2:4]
|
91 |
+
max_size = min(sideX, sideY)
|
92 |
+
min_size = min(sideX, sideY, self.cut_size)
|
93 |
+
cutouts = []
|
94 |
+
for _ in range(self.cutn):
|
95 |
+
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
|
96 |
+
offsetx = torch.randint(0, sideX - size + 1, ())
|
97 |
+
offsety = torch.randint(0, sideY - size + 1, ())
|
98 |
+
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
|
99 |
+
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
|
100 |
+
return torch.cat(cutouts)
|
101 |
+
|
102 |
+
make_cutouts = MakeCutouts(224, 32, 0.5)
|
103 |
+
|
104 |
+
def embed_image(image):
|
105 |
+
n = image.shape[0]
|
106 |
+
cutouts = make_cutouts(image)
|
107 |
+
embeds = clip_model.embed_cutout(cutouts)
|
108 |
+
embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n)
|
109 |
+
return embeds
|
110 |
+
|
111 |
+
def embed_url(url):
|
112 |
+
image = Image.open(fetch(url)).convert('RGB')
|
113 |
+
return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)
|
114 |
+
|
115 |
+
class CLIP(object):
|
116 |
+
def __init__(self):
|
117 |
+
clip_model = "ViT-B/16"
|
118 |
+
self.model, _ = clip.load(clip_model)
|
119 |
+
self.model = self.model.requires_grad_(False)
|
120 |
+
self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
121 |
+
std=[0.26862954, 0.26130258, 0.27577711])
|
122 |
+
|
123 |
+
@torch.no_grad()
|
124 |
+
def embed_text(self, prompt):
|
125 |
+
"Normalized clip text embedding."
|
126 |
+
return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float())
|
127 |
+
|
128 |
+
def embed_cutout(self, image):
|
129 |
+
"Normalized clip image embedding."
|
130 |
+
return norm1(self.model.encode_image(self.normalize(image)))
|
131 |
+
|
132 |
+
clip_model = CLIP()
|
133 |
+
|
134 |
+
#@markdown #**Model selection** 🎭
|
135 |
+
|
136 |
+
Models = ["Imagenet256", "Imagenet512", "Imagenet1024", "Pokemon", "FFHQ"]
|
137 |
+
|
138 |
+
#@markdown ---
|
139 |
+
|
140 |
+
network_url = {
|
141 |
+
"Imagenet256":"https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet256.pkl",
|
142 |
+
"Imagenet512": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet512.pkl",
|
143 |
+
"Imagenet1024": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet1024.pkl",
|
144 |
+
"Pokemon": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl",
|
145 |
+
"FFHQ": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/ffhq256.pkl"
|
146 |
+
}
|
147 |
+
|
148 |
+
for Model in Models:
|
149 |
+
network_name = network_url[Model].split("/")[-1]
|
150 |
+
if not (path_exists(network_name)):
|
151 |
+
fetch_model(network_url[Model],network_name)
|
152 |
+
|
153 |
+
def load_current_model(current_model="Imagenet256.pkl"):
|
154 |
+
with dnnlib.util.open_url(current_model) as f:
|
155 |
+
G = legacy.load_network_pkl(f)['G_ema'].to(device)
|
156 |
+
|
157 |
+
zs = torch.randn([10000, G.mapping.z_dim], device=device)
|
158 |
+
cs = torch.zeros([10000, G.mapping.c_dim], device=device)
|
159 |
+
for i in range(cs.shape[0]):
|
160 |
+
cs[i,i//10]=1
|
161 |
+
w_stds = G.mapping(zs, cs)
|
162 |
+
w_stds = w_stds.reshape(10, 1000, G.num_ws, -1)
|
163 |
+
w_stds=w_stds.std(0).mean(0)[0]
|
164 |
+
w_all_classes_avg = G.mapping.w_avg.mean(0)
|
165 |
+
return(G,w_stds,w_all_classes_avg)
|
166 |
+
|
167 |
+
G, w_stds, w_all_classes_avg = load_current_model()
|
168 |
+
print(w_stds)
|
169 |
+
previousModel = 'imagenet256'
|
170 |
+
def run(prompt,steps,model):
|
171 |
+
global G, w_stds, w_all_classes_avg, previousModel
|
172 |
+
if(model == 'imagenet256' and previousModel != 'imagenet256'):
|
173 |
+
G, w_stds, w_all_classes_avg = load_current_model('imagenet256.pkl')
|
174 |
+
if(model == 'imagenet512' and previousModel != 'imagenet512'):
|
175 |
+
G, w_stds, w_all_classes_avg = load_current_model('imagenet512.pkl')
|
176 |
+
elif(model=='imagenet1024' and previousModel != 'imagenet1024'):
|
177 |
+
G, w_stds, w_all_classes_avg = load_current_model('imagenet1024.pkl')
|
178 |
+
elif(model=='pokemon256' and previousModel != 'pokemon256'):
|
179 |
+
G, w_stds, w_all_classes_avg = load_current_model('pokemon256.pkl')
|
180 |
+
elif(model=='ffhq256' and previousModel != 'ffhq256'):
|
181 |
+
G, w_stds, w_all_classes_avg = load_current_model('ffhq256.pkl')
|
182 |
+
previousModel = model
|
183 |
+
|
184 |
+
texts = prompt
|
185 |
+
steps = steps
|
186 |
+
seed = -1 # @param {type:"number"}
|
187 |
+
|
188 |
+
# @markdown ---
|
189 |
+
|
190 |
+
if seed == -1:
|
191 |
+
seed = np.random.randint(0, 9e9)
|
192 |
+
print(f"Your random seed is: {seed}")
|
193 |
+
|
194 |
+
texts = [frase.strip() for frase in texts.split("|") if frase]
|
195 |
+
|
196 |
+
targets = [clip_model.embed_text(text) for text in texts]
|
197 |
+
|
198 |
+
tf = Compose(
|
199 |
+
[
|
200 |
+
# Resize(224),
|
201 |
+
lambda x: torch.clamp((x + 1) / 2, min=0, max=1),
|
202 |
+
]
|
203 |
+
)
|
204 |
+
|
205 |
+
initial_batch = 4 # actually that will be multiplied by initial_image_steps
|
206 |
+
initial_image_steps = 32
|
207 |
+
|
208 |
+
def get_image(timestring):
|
209 |
+
os.makedirs(f"samples/{timestring}", exist_ok=True)
|
210 |
+
torch.manual_seed(seed)
|
211 |
+
with torch.no_grad():
|
212 |
+
qs = []
|
213 |
+
losses = []
|
214 |
+
for _ in range(initial_image_steps):
|
215 |
+
a = torch.randn([initial_batch, 512], device=device) * 0.4 + w_stds * 0.4
|
216 |
+
q = (a - w_all_classes_avg) / w_stds
|
217 |
+
images = G.synthesis(
|
218 |
+
(q * w_stds + w_all_classes_avg).unsqueeze(1).repeat([1, G.num_ws, 1])
|
219 |
+
)
|
220 |
+
embeds = embed_image(images.add(1).div(2))
|
221 |
+
loss = prompts_dist_loss(embeds, targets, spherical_dist_loss).mean(0)
|
222 |
+
i = torch.argmin(loss)
|
223 |
+
qs.append(q[i])
|
224 |
+
losses.append(loss[i])
|
225 |
+
qs = torch.stack(qs)
|
226 |
+
losses = torch.stack(losses)
|
227 |
+
i = torch.argmin(losses)
|
228 |
+
q = qs[i].unsqueeze(0).repeat([G.num_ws, 1]).requires_grad_()
|
229 |
+
|
230 |
+
# Sampling loop
|
231 |
+
q_ema = q
|
232 |
+
print(q.shape)
|
233 |
+
opt = torch.optim.AdamW([q], lr=0.05, betas=(0.0, 0.999), weight_decay=0.025)
|
234 |
+
loop = tqdm(range(steps))
|
235 |
+
for i in loop:
|
236 |
+
opt.zero_grad()
|
237 |
+
w = q * w_stds
|
238 |
+
image = G.synthesis((q * w_stds + w_all_classes_avg)[None], noise_mode="const")
|
239 |
+
embed = embed_image(image.add(1).div(2))
|
240 |
+
loss = prompts_dist_loss(embed, targets, spherical_dist_loss).mean()
|
241 |
+
loss.backward()
|
242 |
+
opt.step()
|
243 |
+
loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())
|
244 |
+
|
245 |
+
q_ema = q_ema * 0.98 + q * 0.02
|
246 |
+
image = G.synthesis(
|
247 |
+
(q_ema * w_stds + w_all_classes_avg)[None], noise_mode="const"
|
248 |
+
)
|
249 |
+
|
250 |
+
pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0, 1))
|
251 |
+
pil_image.save(f"samples/{timestring}/{i:04}.jpg")
|
252 |
+
|
253 |
+
if (i+1) % steps == 0:
|
254 |
+
#/usr/bin/
|
255 |
+
subprocess.call(['ffmpeg', '-r', '60', '-i', f'samples/{timestring}/%04d.jpg', '-vcodec', 'libx264', '-crf','18','-pix_fmt','yuv420p', f'{timestring}.mp4'])
|
256 |
+
shutil.rmtree(f"samples/{timestring}")
|
257 |
+
pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0, 1))
|
258 |
+
return(pil_image, f'{timestring}.mp4')
|
259 |
+
|
260 |
+
try:
|
261 |
+
timestring = time.strftime("%Y%m%d%H%M%S")
|
262 |
+
image,video = get_image(timestring)
|
263 |
+
return([image,video])
|
264 |
+
except KeyboardInterrupt:
|
265 |
+
pass
|
266 |
+
|
267 |
+
image = gr.outputs.Image(type="pil", label="Your imge")
|
268 |
+
video = gr.outputs.Video(type="mp4", label="Your video")
|
269 |
+
css = ".output-image{height: 528px !important},.output-video{height: 528px !important}"
|
270 |
+
iface = gr.Interface(fn=run, inputs=[
|
271 |
+
gr.inputs.Textbox(label="Prompt",default="chalk pastel drawing of a dog wearing a funny hat"),
|
272 |
+
gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=300,maximum=500,minimum=10,step=1),
|
273 |
+
#gr.inputs.Radio(label="Aspect Ratio", choices=["Square", "Horizontal", "Vertical"],default="Horizontal"),
|
274 |
+
gr.inputs.Dropdown(label="Model", choices=["imagenet256","imagenet512","imagenet1024","Pokemon256", "ffhq256"], default="imagenet256")
|
275 |
+
#gr.inputs.Radio(label="Height", choices=[32,64,128,256,512],default=256),
|
276 |
+
#gr.inputs.Slider(label="Images - How many images you wish to generate", default=2, step=1, minimum=1, maximum=4),
|
277 |
+
#gr.inputs.Slider(label="Diversity scale - How different from one another you wish the images to be",default=5.0, minimum=1.0, maximum=15.0),
|
278 |
+
#gr.inputs.Slider(label="ETA - between 0 and 1. Lower values can provide better quality, higher values can be more diverse",default=0.0,minimum=0.0, maximum=1.0,step=0.1),
|
279 |
+
],
|
280 |
+
outputs=[image,video],
|
281 |
+
css=css,
|
282 |
+
title="Generate images from text with StyleGAN XL + CLIP",
|
283 |
+
description="<div>By typing a prompt and pressing submit you can generate images based on this prompt. <a href='https://github.com/CompVis/latent-diffusion' target='_blank'>ruDALLE</a> is an open source text-to-image model, this Arbitrary Aspect ration implementation was created by <a href='https://github.com/shonenkov-AI' target='_blank'>Alex Shonenkov</a><br>This UI to the model was assembled by <a style='color: rgb(245, 158, 11);font-weight:bold' href='https://twitter.com/multimodalart' target='_blank'>@multimodalart</a></div>",
|
284 |
+
article="<h4 style='font-size: 110%;margin-top:.5em'>Biases acknowledgment</h4><div>Despite how impressive being able to turn text into image is, beware to the fact that this model may output content that reinforces or exarcbates societal biases. According to the <a href='https://arxiv.org/abs/2112.10752' target='_blank'>Latent Diffusion paper</a>:<i> \"Deep learning modules tend to reproduce or exacerbate biases that are already present in the data\"</i>. The models are meant to be used for research purposes, such as this one.</div><h4 style='font-size: 110%;margin-top:1em'>Who owns the images produced by this demo?</h4><div>Definetly not me! Probably you do. I say probably because the Copyright discussion about AI generated art is ongoing. So <a href='https://www.theverge.com/2022/2/21/22944335/us-copyright-office-reject-ai-generated-art-recent-entrance-to-paradise' target='_blank'>it may be the case that everything produced here falls automatically into the public domain</a>. But in any case it is either yours or is in the public domain.</div>")
|
285 |
+
iface.launch(enable_queue=True)
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
ffmpeg
|
requirements.txt
CHANGED
@@ -1 +1,5 @@
|
|
1 |
-
torch
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
-e git+https://github.com/openai/CLIP.git#egg=CLIP
|
3 |
+
einops
|
4 |
+
ninja
|
5 |
+
timm==0.4.12
|