import os os.system("pip install --upgrade torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html") os.system("git clone https://github.com/NVlabs/stylegan3") os.system("git clone https://github.com/openai/CLIP") os.system("pip install -e ./CLIP") os.system("pip install einops ninja scipy numpy Pillow tqdm") import sys sys.path.append('./CLIP') sys.path.append('./stylegan3') import io import os, time import pickle import shutil import numpy as np from PIL import Image import torch import torch.nn.functional as F import requests import torchvision.transforms as transforms import torchvision.transforms.functional as TF import clip from tqdm.notebook import tqdm from torchvision.transforms import Compose, Resize, ToTensor, Normalize from einops import rearrange device = torch.device('cuda:0') def fetch(url_or_path): if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): r = requests.get(url_or_path) r.raise_for_status() fd = io.BytesIO() fd.write(r.content) fd.seek(0) return fd return open(url_or_path, 'rb') def fetch_model(url_or_path): basename = os.path.basename(url_or_path) if os.path.exists(basename): return basename else: os.system("wget -c '{url_or_path}'") return basename def norm1(prompt): "Normalize to the unit sphere." return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt() def spherical_dist_loss(x, y): x = F.normalize(x, dim=-1) y = F.normalize(y, dim=-1) return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) class MakeCutouts(torch.nn.Module): def __init__(self, cut_size, cutn, cut_pow=1.): super().__init__() self.cut_size = cut_size self.cutn = cutn self.cut_pow = cut_pow def forward(self, input): sideY, sideX = input.shape[2:4] max_size = min(sideX, sideY) min_size = min(sideX, sideY, self.cut_size) cutouts = [] for _ in range(self.cutn): size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) offsetx = torch.randint(0, sideX - size + 1, ()) offsety = torch.randint(0, sideY - size + 1, ()) cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) return torch.cat(cutouts) make_cutouts = MakeCutouts(224, 32, 0.5) def embed_image(image): n = image.shape[0] cutouts = make_cutouts(image) embeds = clip_model.embed_cutout(cutouts) embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n) return embeds def embed_url(url): image = Image.open(fetch(url)).convert('RGB') return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0) class CLIP(object): def __init__(self): clip_model = "ViT-B/32" self.model, _ = clip.load(clip_model) self.model = self.model.requires_grad_(False) self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) @torch.no_grad() def embed_text(self, prompt): "Normalized clip text embedding." return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float()) def embed_cutout(self, image): "Normalized clip image embedding." return norm1(self.model.encode_image(self.normalize(image))) clip_model = CLIP() # Load stylegan model base_url = "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/" model_name = "stylegan3-t-ffhqu-1024x1024.pkl" #model_name = "stylegan3-r-metfacesu-1024x1024.pkl" #model_name = "stylegan3-t-afhqv2-512x512.pkl" network_url = base_url + model_name os.system("wget -c https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhqu-1024x1024.pkl") with open('stylegan3-t-ffhqu-1024x1024.pkl', 'rb') as fp: G = pickle.load(fp)['G_ema'].to(device) zs = torch.randn([10000, G.mapping.z_dim], device=device) w_stds = G.mapping(zs, None).std(0) def inference(text): target = clip_model.embed_text(text) steps = 600 seed = 2 tf = Compose([ Resize(224), lambda x: torch.clamp((x+1)/2,min=0,max=1), ]) torch.manual_seed(seed) timestring = time.strftime('%Y%m%d%H%M%S') with torch.no_grad(): qs = [] losses = [] for _ in range(8): q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds images = G.synthesis(q * w_stds + G.mapping.w_avg) embeds = embed_image(images.add(1).div(2)) loss = spherical_dist_loss(embeds, target).mean(0) i = torch.argmin(loss) qs.append(q[i]) losses.append(loss[i]) qs = torch.stack(qs) losses = torch.stack(losses) print(losses) print(losses.shape, qs.shape) i = torch.argmin(losses) q = qs[i].unsqueeze(0) q.requires_grad_() q_ema = q opt = torch.optim.AdamW([q], lr=0.03, betas=(0.0,0.999)) loop = tqdm(range(steps)) for i in loop: opt.zero_grad() w = q * w_stds image = G.synthesis(w + G.mapping.w_avg, noise_mode='const') embed = embed_image(image.add(1).div(2)) loss = spherical_dist_loss(embed, target).mean() loss.backward() opt.step() loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item()) q_ema = q_ema * 0.9 + q * 0.1 image = G.synthesis(q_ema * w_stds + G.mapping.w_avg, noise_mode='const') if i % 10 == 0: display(TF.to_pil_image(tf(image)[0])) pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1)) #os.makedirs(f'samples/{timestring}', exist_ok=True) #pil_image.save(f'samples/{timestring}/{i:04}.jpg') return pil_image title = "StyleGAN+CLIP_with_Latent_Bootstraping" description = "Gradio demo for StyleGAN+CLIP_with_Latent_Bootstraping. To use it, simply add your text, or click one of the examples to load them. Read more at the links below." article = "

colab by https://twitter.com/EricHallahan Colab

" examples = [['elon musk']] gr.Interface( inference, "text", gr.outputs.Image(type="pil", label="Output"), title=title, description=description, article=article, enable_queue=True, examples=examples ).launch(debug=True)