Ahsen Khaliq commited on
Commit
1dca94e
1 Parent(s): a7aebee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -115
app.py CHANGED
@@ -1,135 +1,186 @@
1
  import os
2
- os.system("""pip install --upgrade https://github.com/podgorskiy/dnnlib/releases/download/0.0.1/dnnlib-0.0.1-py3-none-any.whl numpy tqdm Pillow torch-utils==0.0.7 torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html ftfy regex git+https://github.com/openai/CLIP.git ninja git+https://github.com/geoopt/geoopt.git gdown exrex torchtext==0.10.0""")
3
- os.system("nvidia-smi")
4
- import gradio as gr
 
 
 
 
 
 
 
 
 
5
  import pickle
 
6
  import numpy as np
7
- import PIL
8
  import torch
9
- import dnnlib
 
 
 
10
  import clip
11
- import exrex
12
  from tqdm.notebook import tqdm
13
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
 
14
 
15
- network_pkl = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl"
16
- if not os.path.isfile(os.path.basename(network_pkl)):
17
- os.system("wget https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl")
18
 
19
- cuda_available = torch.cuda.is_available()
20
- device = torch.device('cuda' if cuda_available else 'cpu')
21
 
22
- # Load StyleGAN
23
- with open(os.path.basename(network_pkl), 'rb') as f:
24
- # If legacy pkl then convert before loading.
25
- try:
26
- G = pickle.load(f)['G_ema'].to(device)
27
- except ModuleNotFoundError:
28
- import legacy
29
- G = legacy.load_network_pkl(f)['G_ema'].to(device)
 
30
 
31
- clip_model = "ViT-B/32"
32
- model, preprocess = clip.load(clip_model)
 
 
 
 
 
33
 
34
- os.system("pwd")
 
 
35
 
36
- if not os.path.exists('CLIP_vecs.npy'):
37
- os.system("wget https://www.dropbox.com/s/seqev3lvy6e6dz6/CLIP_vecs.npy")
 
 
38
 
39
- os.system("ls")
40
- if os.path.exists('CLIP_vecs.npy'):
41
- CLIP_vecs = torch.from_numpy(np.load('CLIP_vecs.npy'))
42
- seeded_z = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.w_dim) for seed in range(CLIP_vecs.shape[0])]))
43
-
44
- os.system("ls")
45
 
46
- def spherical_avg(p, w=None, tol=1e-6):
47
- """Applies a weighted spherical average as described in the paper
48
- `Spherical Averages and Applications to Spherical Splines and
49
- Interpolation <http://math.ucsd.edu/~sbuss/ResearchWeb/spheremean>`__ .
50
-
51
- Args:
52
- p (torch.Tensor): Input vectors
53
- w (torch.Tensor, optional): Weights for averaging.
54
- tol (float, optional): The desired tolerance of the output.
55
- Default: 1e-6
56
- """
57
- from geoopt import Sphere
58
- sphere = Sphere()
59
- if w is None:
60
- w = p.new_ones([p.shape[0]])
61
- assert p.ndim == 2 and w.ndim == 1 and len(p) == len(w)
62
- w = w / w.sum()
63
- p = sphere.projx(p)
64
- q = sphere.projx(p.mul(w.unsqueeze(1)).sum(dim=0))
65
- while True:
66
- q_new = sphere.retr(q, sphere.logmap(q, p).mul(w.unsqueeze(1)).sum(dim=0))
67
- norm = torch.linalg.vector_norm(q.sub(q_new))
68
- q = q_new
69
- if norm <= tol:
70
- break
71
- return q
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
 
 
73
 
74
  def inference(text):
75
- prompt = text
76
- prompt_preview = False
77
- continue_opt = False
78
- iterations = 20
79
- k = 18
80
-
81
- if not continue_opt:
82
- augmented_prompts = list(exrex.generate(prompt))
83
- if not len(augmented_prompts)<=32:
84
- return PIL.Image.new(mode="RGB", size=(200, 200),color = (255,255,255))
85
- augmented_prompts, polarities = list(map(lambda x: x.replace('~',''), augmented_prompts)), list(map(lambda x: x.__contains__('~'), augmented_prompts))
86
-
87
-
88
-
89
- with torch.no_grad():
90
- # Encode strings to features
91
- text_features = model.encode_text(clip.tokenize(augmented_prompts).to(device)).cpu().to(torch.float32)*torch.tensor(list(map(lambda x: -1 if x else 1,polarities))).unsqueeze(1).expand(-1,512)
92
-
93
- # If we have more than one feature vector use their spherical average instead
94
- if text_features.shape[0]>1:
95
- text_features = spherical_avg(text_features).unsqueeze(0)
96
-
97
- # Use the vector table if it exists, fallback on w_avg if not
98
- if os.path.exists('CLIP_vecs.npy'):
99
- tmp = torch.nn.functional.cosine_similarity(CLIP_vecs,text_features.cpu())
100
- tmp, indexes = torch.topk(tmp,k,dim=0)
101
- tmp = torch.softmax(tmp/0.01,dim=-1)
102
- ws = G.mapping((seeded_z[indexes]).reshape(-1,G.w_dim).to(device), c=None).cpu()
103
- found_w = torch.sum(ws*tmp.unsqueeze(1).unsqueeze(2),dim=0).unsqueeze(0)
104
- found_w = found_w.to(device)-G.mapping.w_avg
105
- else:
106
- found_w = torch.zeros(1,18,512, device=device)
107
-
108
- # Prepare for gradient decent
109
- text_features = text_features.to(device)
110
- found_w.requires_grad = True
111
-
112
- # Adapted preprocessing routine for connecting StyleGAN to CLIP
113
- stylegan_transform = Compose([
114
- Resize(224),
115
- lambda x: torch.clamp((x+1)/2,min=0,max=1),
116
- Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
117
- ])
118
-
119
- if not continue_opt:
120
- optimizer = torch.optim.Adam((found_w,),0.01,betas=(0,0.999))
121
-
122
- for i in tqdm(range(iterations)):
123
- optimizer.zero_grad()
124
- img = G.synthesis(found_w+G.mapping.w_avg, noise_mode='const', force_fp32=not cuda_available)
125
- img = stylegan_transform(img)
126
- image_features = model.encode_image(img)
127
- loss = -torch.nn.functional.cosine_similarity(image_features,text_features)
128
- loss.backward()
129
- optimizer.step()
130
-
131
- img = G.synthesis(found_w+G.mapping.w_avg, noise_mode='const', force_fp32=not cuda_available)
132
- return PIL.Image.fromarray((img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)[0].cpu().numpy(), 'RGB')
133
 
134
 
135
 
 
1
  import os
2
+ os.system("pip install --upgrade torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html")
3
+ os.system("git clone https://github.com/NVlabs/stylegan3")
4
+ os.system("git clone https://github.com/openai/CLIP")
5
+ os.system("pip install -e ./CLIP")
6
+ os.system("pip install einops ninja")
7
+
8
+ import sys
9
+ sys.path.append('./CLIP')
10
+ sys.path.append('./stylegan3')
11
+
12
+ import io
13
+ import os, time
14
  import pickle
15
+ import shutil
16
  import numpy as np
17
+ from PIL import Image
18
  import torch
19
+ import torch.nn.functional as F
20
+ import requests
21
+ import torchvision.transforms as transforms
22
+ import torchvision.transforms.functional as TF
23
  import clip
 
24
  from tqdm.notebook import tqdm
25
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize
26
+ from einops import rearrange
27
 
28
+ device = torch.device('cuda:0')
 
 
29
 
 
 
30
 
31
+ def fetch(url_or_path):
32
+ if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
33
+ r = requests.get(url_or_path)
34
+ r.raise_for_status()
35
+ fd = io.BytesIO()
36
+ fd.write(r.content)
37
+ fd.seek(0)
38
+ return fd
39
+ return open(url_or_path, 'rb')
40
 
41
+ def fetch_model(url_or_path):
42
+ basename = os.path.basename(url_or_path)
43
+ if os.path.exists(basename):
44
+ return basename
45
+ else:
46
+ !wget -c '{url_or_path}'
47
+ return basename
48
 
49
+ def norm1(prompt):
50
+ "Normalize to the unit sphere."
51
+ return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt()
52
 
53
+ def spherical_dist_loss(x, y):
54
+ x = F.normalize(x, dim=-1)
55
+ y = F.normalize(y, dim=-1)
56
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
57
 
58
+ class MakeCutouts(torch.nn.Module):
59
+ def __init__(self, cut_size, cutn, cut_pow=1.):
60
+ super().__init__()
61
+ self.cut_size = cut_size
62
+ self.cutn = cutn
63
+ self.cut_pow = cut_pow
64
 
65
+ def forward(self, input):
66
+ sideY, sideX = input.shape[2:4]
67
+ max_size = min(sideX, sideY)
68
+ min_size = min(sideX, sideY, self.cut_size)
69
+ cutouts = []
70
+ for _ in range(self.cutn):
71
+ size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
72
+ offsetx = torch.randint(0, sideX - size + 1, ())
73
+ offsety = torch.randint(0, sideY - size + 1, ())
74
+ cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
75
+ cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
76
+ return torch.cat(cutouts)
77
+
78
+ make_cutouts = MakeCutouts(224, 32, 0.5)
79
+
80
+ def embed_image(image):
81
+ n = image.shape[0]
82
+ cutouts = make_cutouts(image)
83
+ embeds = clip_model.embed_cutout(cutouts)
84
+ embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n)
85
+ return embeds
86
+
87
+ def embed_url(url):
88
+ image = Image.open(fetch(url)).convert('RGB')
89
+ return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)
90
+
91
+ class CLIP(object):
92
+ def __init__(self):
93
+ clip_model = "ViT-B/32"
94
+ self.model, _ = clip.load(clip_model)
95
+ self.model = self.model.requires_grad_(False)
96
+ self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
97
+ std=[0.26862954, 0.26130258, 0.27577711])
98
+
99
+ @torch.no_grad()
100
+ def embed_text(self, prompt):
101
+ "Normalized clip text embedding."
102
+ return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float())
103
+
104
+ def embed_cutout(self, image):
105
+ "Normalized clip image embedding."
106
+ return norm1(self.model.encode_image(self.normalize(image)))
107
+
108
+ clip_model = CLIP()
109
+
110
+ # Load stylegan model
111
+
112
+ base_url = "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/"
113
+ model_name = "stylegan3-t-ffhqu-1024x1024.pkl"
114
+ #model_name = "stylegan3-r-metfacesu-1024x1024.pkl"
115
+ #model_name = "stylegan3-t-afhqv2-512x512.pkl"
116
+ network_url = base_url + model_name
117
+
118
+ with open(fetch_model(network_url), 'rb') as fp:
119
+ G = pickle.load(fp)['G_ema'].to(device)
120
+
121
+ zs = torch.randn([10000, G.mapping.z_dim], device=device)
122
+ w_stds = G.mapping(zs, None).std(0)
123
 
124
+
125
+
126
 
127
  def inference(text):
128
+ target = clip_model.embed_text(text)
129
+
130
+ steps = 600
131
+ seed = 2
132
+
133
+ tf = Compose([
134
+ Resize(224),
135
+ lambda x: torch.clamp((x+1)/2,min=0,max=1),
136
+ ])
137
+
138
+ torch.manual_seed(seed)
139
+ timestring = time.strftime('%Y%m%d%H%M%S')
140
+
141
+ with torch.no_grad():
142
+ qs = []
143
+ losses = []
144
+ for _ in range(8):
145
+ q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds
146
+ images = G.synthesis(q * w_stds + G.mapping.w_avg)
147
+ embeds = embed_image(images.add(1).div(2))
148
+ loss = spherical_dist_loss(embeds, target).mean(0)
149
+ i = torch.argmin(loss)
150
+ qs.append(q[i])
151
+ losses.append(loss[i])
152
+ qs = torch.stack(qs)
153
+ losses = torch.stack(losses)
154
+ print(losses)
155
+ print(losses.shape, qs.shape)
156
+ i = torch.argmin(losses)
157
+ q = qs[i].unsqueeze(0)
158
+
159
+ q.requires_grad_()
160
+
161
+ q_ema = q
162
+ opt = torch.optim.AdamW([q], lr=0.03, betas=(0.0,0.999))
163
+ loop = tqdm(range(steps))
164
+ for i in loop:
165
+ opt.zero_grad()
166
+ w = q * w_stds
167
+ image = G.synthesis(w + G.mapping.w_avg, noise_mode='const')
168
+ embed = embed_image(image.add(1).div(2))
169
+ loss = spherical_dist_loss(embed, target).mean()
170
+ loss.backward()
171
+ opt.step()
172
+ loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())
173
+
174
+ q_ema = q_ema * 0.9 + q * 0.1
175
+ image = G.synthesis(q_ema * w_stds + G.mapping.w_avg, noise_mode='const')
176
+
177
+ if i % 10 == 0:
178
+ display(TF.to_pil_image(tf(image)[0]))
179
+ pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1))
180
+ #os.makedirs(f'samples/{timestring}', exist_ok=True)
181
+ #pil_image.save(f'samples/{timestring}/{i:04}.jpg')
182
+
183
+ return pil_image
 
 
184
 
185
 
186