Spaces:
Runtime error
Runtime error
Ahsen Khaliq
commited on
Commit
•
1dca94e
1
Parent(s):
a7aebee
Update app.py
Browse files
app.py
CHANGED
@@ -1,135 +1,186 @@
|
|
1 |
import os
|
2 |
-
os.system("
|
3 |
-
os.system("
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import pickle
|
|
|
6 |
import numpy as np
|
7 |
-
import
|
8 |
import torch
|
9 |
-
import
|
|
|
|
|
|
|
10 |
import clip
|
11 |
-
import exrex
|
12 |
from tqdm.notebook import tqdm
|
13 |
-
from torchvision.transforms import Compose, Resize,
|
|
|
14 |
|
15 |
-
|
16 |
-
if not os.path.isfile(os.path.basename(network_pkl)):
|
17 |
-
os.system("wget https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl")
|
18 |
|
19 |
-
cuda_available = torch.cuda.is_available()
|
20 |
-
device = torch.device('cuda' if cuda_available else 'cpu')
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
30 |
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
|
|
|
|
|
35 |
|
36 |
-
|
37 |
-
|
|
|
|
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
|
46 |
-
def
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
|
|
|
|
73 |
|
74 |
def inference(text):
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
img = G.synthesis(found_w+G.mapping.w_avg, noise_mode='const', force_fp32=not cuda_available)
|
132 |
-
return PIL.Image.fromarray((img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)[0].cpu().numpy(), 'RGB')
|
133 |
|
134 |
|
135 |
|
|
|
1 |
import os
|
2 |
+
os.system("pip install --upgrade torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html")
|
3 |
+
os.system("git clone https://github.com/NVlabs/stylegan3")
|
4 |
+
os.system("git clone https://github.com/openai/CLIP")
|
5 |
+
os.system("pip install -e ./CLIP")
|
6 |
+
os.system("pip install einops ninja")
|
7 |
+
|
8 |
+
import sys
|
9 |
+
sys.path.append('./CLIP')
|
10 |
+
sys.path.append('./stylegan3')
|
11 |
+
|
12 |
+
import io
|
13 |
+
import os, time
|
14 |
import pickle
|
15 |
+
import shutil
|
16 |
import numpy as np
|
17 |
+
from PIL import Image
|
18 |
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
import requests
|
21 |
+
import torchvision.transforms as transforms
|
22 |
+
import torchvision.transforms.functional as TF
|
23 |
import clip
|
|
|
24 |
from tqdm.notebook import tqdm
|
25 |
+
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
|
26 |
+
from einops import rearrange
|
27 |
|
28 |
+
device = torch.device('cuda:0')
|
|
|
|
|
29 |
|
|
|
|
|
30 |
|
31 |
+
def fetch(url_or_path):
|
32 |
+
if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
|
33 |
+
r = requests.get(url_or_path)
|
34 |
+
r.raise_for_status()
|
35 |
+
fd = io.BytesIO()
|
36 |
+
fd.write(r.content)
|
37 |
+
fd.seek(0)
|
38 |
+
return fd
|
39 |
+
return open(url_or_path, 'rb')
|
40 |
|
41 |
+
def fetch_model(url_or_path):
|
42 |
+
basename = os.path.basename(url_or_path)
|
43 |
+
if os.path.exists(basename):
|
44 |
+
return basename
|
45 |
+
else:
|
46 |
+
!wget -c '{url_or_path}'
|
47 |
+
return basename
|
48 |
|
49 |
+
def norm1(prompt):
|
50 |
+
"Normalize to the unit sphere."
|
51 |
+
return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt()
|
52 |
|
53 |
+
def spherical_dist_loss(x, y):
|
54 |
+
x = F.normalize(x, dim=-1)
|
55 |
+
y = F.normalize(y, dim=-1)
|
56 |
+
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
|
57 |
|
58 |
+
class MakeCutouts(torch.nn.Module):
|
59 |
+
def __init__(self, cut_size, cutn, cut_pow=1.):
|
60 |
+
super().__init__()
|
61 |
+
self.cut_size = cut_size
|
62 |
+
self.cutn = cutn
|
63 |
+
self.cut_pow = cut_pow
|
64 |
|
65 |
+
def forward(self, input):
|
66 |
+
sideY, sideX = input.shape[2:4]
|
67 |
+
max_size = min(sideX, sideY)
|
68 |
+
min_size = min(sideX, sideY, self.cut_size)
|
69 |
+
cutouts = []
|
70 |
+
for _ in range(self.cutn):
|
71 |
+
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
|
72 |
+
offsetx = torch.randint(0, sideX - size + 1, ())
|
73 |
+
offsety = torch.randint(0, sideY - size + 1, ())
|
74 |
+
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
|
75 |
+
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
|
76 |
+
return torch.cat(cutouts)
|
77 |
+
|
78 |
+
make_cutouts = MakeCutouts(224, 32, 0.5)
|
79 |
+
|
80 |
+
def embed_image(image):
|
81 |
+
n = image.shape[0]
|
82 |
+
cutouts = make_cutouts(image)
|
83 |
+
embeds = clip_model.embed_cutout(cutouts)
|
84 |
+
embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n)
|
85 |
+
return embeds
|
86 |
+
|
87 |
+
def embed_url(url):
|
88 |
+
image = Image.open(fetch(url)).convert('RGB')
|
89 |
+
return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)
|
90 |
+
|
91 |
+
class CLIP(object):
|
92 |
+
def __init__(self):
|
93 |
+
clip_model = "ViT-B/32"
|
94 |
+
self.model, _ = clip.load(clip_model)
|
95 |
+
self.model = self.model.requires_grad_(False)
|
96 |
+
self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
97 |
+
std=[0.26862954, 0.26130258, 0.27577711])
|
98 |
+
|
99 |
+
@torch.no_grad()
|
100 |
+
def embed_text(self, prompt):
|
101 |
+
"Normalized clip text embedding."
|
102 |
+
return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float())
|
103 |
+
|
104 |
+
def embed_cutout(self, image):
|
105 |
+
"Normalized clip image embedding."
|
106 |
+
return norm1(self.model.encode_image(self.normalize(image)))
|
107 |
+
|
108 |
+
clip_model = CLIP()
|
109 |
+
|
110 |
+
# Load stylegan model
|
111 |
+
|
112 |
+
base_url = "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/"
|
113 |
+
model_name = "stylegan3-t-ffhqu-1024x1024.pkl"
|
114 |
+
#model_name = "stylegan3-r-metfacesu-1024x1024.pkl"
|
115 |
+
#model_name = "stylegan3-t-afhqv2-512x512.pkl"
|
116 |
+
network_url = base_url + model_name
|
117 |
+
|
118 |
+
with open(fetch_model(network_url), 'rb') as fp:
|
119 |
+
G = pickle.load(fp)['G_ema'].to(device)
|
120 |
+
|
121 |
+
zs = torch.randn([10000, G.mapping.z_dim], device=device)
|
122 |
+
w_stds = G.mapping(zs, None).std(0)
|
123 |
|
124 |
+
|
125 |
+
|
126 |
|
127 |
def inference(text):
|
128 |
+
target = clip_model.embed_text(text)
|
129 |
+
|
130 |
+
steps = 600
|
131 |
+
seed = 2
|
132 |
+
|
133 |
+
tf = Compose([
|
134 |
+
Resize(224),
|
135 |
+
lambda x: torch.clamp((x+1)/2,min=0,max=1),
|
136 |
+
])
|
137 |
+
|
138 |
+
torch.manual_seed(seed)
|
139 |
+
timestring = time.strftime('%Y%m%d%H%M%S')
|
140 |
+
|
141 |
+
with torch.no_grad():
|
142 |
+
qs = []
|
143 |
+
losses = []
|
144 |
+
for _ in range(8):
|
145 |
+
q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds
|
146 |
+
images = G.synthesis(q * w_stds + G.mapping.w_avg)
|
147 |
+
embeds = embed_image(images.add(1).div(2))
|
148 |
+
loss = spherical_dist_loss(embeds, target).mean(0)
|
149 |
+
i = torch.argmin(loss)
|
150 |
+
qs.append(q[i])
|
151 |
+
losses.append(loss[i])
|
152 |
+
qs = torch.stack(qs)
|
153 |
+
losses = torch.stack(losses)
|
154 |
+
print(losses)
|
155 |
+
print(losses.shape, qs.shape)
|
156 |
+
i = torch.argmin(losses)
|
157 |
+
q = qs[i].unsqueeze(0)
|
158 |
+
|
159 |
+
q.requires_grad_()
|
160 |
+
|
161 |
+
q_ema = q
|
162 |
+
opt = torch.optim.AdamW([q], lr=0.03, betas=(0.0,0.999))
|
163 |
+
loop = tqdm(range(steps))
|
164 |
+
for i in loop:
|
165 |
+
opt.zero_grad()
|
166 |
+
w = q * w_stds
|
167 |
+
image = G.synthesis(w + G.mapping.w_avg, noise_mode='const')
|
168 |
+
embed = embed_image(image.add(1).div(2))
|
169 |
+
loss = spherical_dist_loss(embed, target).mean()
|
170 |
+
loss.backward()
|
171 |
+
opt.step()
|
172 |
+
loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())
|
173 |
+
|
174 |
+
q_ema = q_ema * 0.9 + q * 0.1
|
175 |
+
image = G.synthesis(q_ema * w_stds + G.mapping.w_avg, noise_mode='const')
|
176 |
+
|
177 |
+
if i % 10 == 0:
|
178 |
+
display(TF.to_pil_image(tf(image)[0]))
|
179 |
+
pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1))
|
180 |
+
#os.makedirs(f'samples/{timestring}', exist_ok=True)
|
181 |
+
#pil_image.save(f'samples/{timestring}/{i:04}.jpg')
|
182 |
+
|
183 |
+
return pil_image
|
|
|
|
|
184 |
|
185 |
|
186 |
|