stylegan3_clip / app.py
Ahsen Khaliq
Update app.py
9b12215
raw
history blame
6.52 kB
import os
os.system("pip install --upgrade torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html")
os.system("git clone https://github.com/NVlabs/stylegan3")
os.system("git clone https://github.com/openai/CLIP")
os.system("pip install -e ./CLIP")
os.system("pip install einops ninja")
import sys
sys.path.append('./CLIP')
sys.path.append('./stylegan3')
import io
import os, time
import pickle
import shutil
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import requests
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import clip
from tqdm.notebook import tqdm
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from einops import rearrange
device = torch.device('cuda:0')
def fetch(url_or_path):
if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
r = requests.get(url_or_path)
r.raise_for_status()
fd = io.BytesIO()
fd.write(r.content)
fd.seek(0)
return fd
return open(url_or_path, 'rb')
def fetch_model(url_or_path):
basename = os.path.basename(url_or_path)
if os.path.exists(basename):
return basename
else:
os.system("wget -c '{url_or_path}'")
return basename
def norm1(prompt):
"Normalize to the unit sphere."
return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt()
def spherical_dist_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
class MakeCutouts(torch.nn.Module):
def __init__(self, cut_size, cutn, cut_pow=1.):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
self.cut_pow = cut_pow
def forward(self, input):
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
min_size = min(sideX, sideY, self.cut_size)
cutouts = []
for _ in range(self.cutn):
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
return torch.cat(cutouts)
make_cutouts = MakeCutouts(224, 32, 0.5)
def embed_image(image):
n = image.shape[0]
cutouts = make_cutouts(image)
embeds = clip_model.embed_cutout(cutouts)
embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n)
return embeds
def embed_url(url):
image = Image.open(fetch(url)).convert('RGB')
return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)
class CLIP(object):
def __init__(self):
clip_model = "ViT-B/32"
self.model, _ = clip.load(clip_model)
self.model = self.model.requires_grad_(False)
self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
@torch.no_grad()
def embed_text(self, prompt):
"Normalized clip text embedding."
return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float())
def embed_cutout(self, image):
"Normalized clip image embedding."
return norm1(self.model.encode_image(self.normalize(image)))
clip_model = CLIP()
# Load stylegan model
base_url = "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/"
model_name = "stylegan3-t-ffhqu-1024x1024.pkl"
#model_name = "stylegan3-r-metfacesu-1024x1024.pkl"
#model_name = "stylegan3-t-afhqv2-512x512.pkl"
network_url = base_url + model_name
os.system("wget -c https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhqu-1024x1024.pkl")
with open('stylegan3-t-ffhqu-1024x1024.pkl', 'rb') as fp:
G = pickle.load(fp)['G_ema'].to(device)
zs = torch.randn([10000, G.mapping.z_dim], device=device)
w_stds = G.mapping(zs, None).std(0)
def inference(text):
target = clip_model.embed_text(text)
steps = 600
seed = 2
tf = Compose([
Resize(224),
lambda x: torch.clamp((x+1)/2,min=0,max=1),
])
torch.manual_seed(seed)
timestring = time.strftime('%Y%m%d%H%M%S')
with torch.no_grad():
qs = []
losses = []
for _ in range(8):
q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds
images = G.synthesis(q * w_stds + G.mapping.w_avg)
embeds = embed_image(images.add(1).div(2))
loss = spherical_dist_loss(embeds, target).mean(0)
i = torch.argmin(loss)
qs.append(q[i])
losses.append(loss[i])
qs = torch.stack(qs)
losses = torch.stack(losses)
print(losses)
print(losses.shape, qs.shape)
i = torch.argmin(losses)
q = qs[i].unsqueeze(0)
q.requires_grad_()
q_ema = q
opt = torch.optim.AdamW([q], lr=0.03, betas=(0.0,0.999))
loop = tqdm(range(steps))
for i in loop:
opt.zero_grad()
w = q * w_stds
image = G.synthesis(w + G.mapping.w_avg, noise_mode='const')
embed = embed_image(image.add(1).div(2))
loss = spherical_dist_loss(embed, target).mean()
loss.backward()
opt.step()
loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())
q_ema = q_ema * 0.9 + q * 0.1
image = G.synthesis(q_ema * w_stds + G.mapping.w_avg, noise_mode='const')
if i % 10 == 0:
display(TF.to_pil_image(tf(image)[0]))
pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1))
#os.makedirs(f'samples/{timestring}', exist_ok=True)
#pil_image.save(f'samples/{timestring}/{i:04}.jpg')
return pil_image
title = "StyleGAN+CLIP_with_Latent_Bootstraping"
description = "Gradio demo for StyleGAN+CLIP_with_Latent_Bootstraping. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'>colab by https://twitter.com/EricHallahan <a href='https://colab.research.google.com/drive/1br7GP_D6XCgulxPTAFhwGaV-ijFe084X' target='_blank'>Colab</a></p>"
examples = [['elon musk']]
gr.Interface(
inference,
"text",
gr.outputs.Image(type="pil", label="Output"),
title=title,
description=description,
article=article,
enable_queue=True,
examples=examples
).launch(debug=True)