import os os.system("git clone https://github.com/NVlabs/stylegan2-ada-pytorch") os.chdir('stylegan2-ada-pytorch') os.system("pwd") os.system("""pip install --upgrade https://github.com/podgorskiy/dnnlib/releases/download/0.0.1/dnnlib-0.0.1-py3-none-any.whl numpy tqdm Pillow torch-utils==0.0.7 torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html ftfy regex git+https://github.com/openai/CLIP.git ninja git+https://github.com/geoopt/geoopt.git gdown exrex torchtext==0.10.0""") import os import pickle import numpy as np import PIL import torch import dnnlib import clip import exrex from tqdm.notebook import tqdm from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize network_pkl = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl" if not os.path.isfile(os.path.basename(network_pkl)): os.system("wget https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl") cuda_available = torch.cuda.is_available() device = torch.device('cuda' if cuda_available else 'cpu') # Load StyleGAN with open(os.path.basename(network_pkl), 'rb') as f: # If legacy pkl then convert before loading. try: G = pickle.load(f)['G_ema'].to(device) except ModuleNotFoundError: import legacy G = legacy.load_network_pkl(f)['G_ema'].to(device) clip_model = "ViT-B/32" model, preprocess = clip.load(clip_model) os.chdir("..") if not os.path.exists('CLIP_vecs.npy'): os.system("gdown https://drive.google.com/u/0/uc?id=1R2Ra6Bf7IKwM2eZMKwFvyjkWRyOulQaP&export=download") if os.path.exists('CLIP_vecs.npy'): CLIP_vecs = torch.from_numpy(np.load('CLIP_vecs.npy')) seeded_z = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.w_dim) for seed in range(CLIP_vecs.shape[0])])) def spherical_avg(p, w=None, tol=1e-6): """Applies a weighted spherical average as described in the paper `Spherical Averages and Applications to Spherical Splines and Interpolation `__ . Args: p (torch.Tensor): Input vectors w (torch.Tensor, optional): Weights for averaging. tol (float, optional): The desired tolerance of the output. Default: 1e-6 """ from geoopt import Sphere sphere = Sphere() if w is None: w = p.new_ones([p.shape[0]]) assert p.ndim == 2 and w.ndim == 1 and len(p) == len(w) w = w / w.sum() p = sphere.projx(p) q = sphere.projx(p.mul(w.unsqueeze(1)).sum(dim=0)) while True: q_new = sphere.retr(q, sphere.logmap(q, p).mul(w.unsqueeze(1)).sum(dim=0)) norm = torch.linalg.vector_norm(q.sub(q_new)) q = q_new if norm <= tol: break return q def inference(text): prompt = text prompt_preview = False continue_opt = False iterations = 20 k = 18 if not continue_opt: augmented_prompts = list(exrex.generate(prompt)) assert len(augmented_prompts)<=32 augmented_prompts, polarities = list(map(lambda x: x.replace('~',''), augmented_prompts)), list(map(lambda x: x.__contains__('~'), augmented_prompts)) with torch.no_grad(): # Encode strings to features text_features = model.encode_text(clip.tokenize(augmented_prompts).to(device)).cpu().to(torch.float32)*torch.tensor(list(map(lambda x: -1 if x else 1,polarities))).unsqueeze(1).expand(-1,512) # If we have more than one feature vector use their spherical average instead if text_features.shape[0]>1: text_features = spherical_avg(text_features).unsqueeze(0) # Use the vector table if it exists, fallback on w_avg if not if os.path.exists('/content/CLIP_vecs.npy'): tmp = torch.nn.functional.cosine_similarity(CLIP_vecs,text_features.cpu()) tmp, indexes = torch.topk(tmp,k,dim=0) tmp = torch.softmax(tmp/0.01,dim=-1) ws = G.mapping((seeded_z[indexes]).reshape(-1,G.w_dim).to(device), c=None).cpu() found_w = torch.sum(ws*tmp.unsqueeze(1).unsqueeze(2),dim=0).unsqueeze(0) else: found_w = torch.zeros(1,18,512) # Prepare for gradient decent found_w = found_w.to(device)-G.mapping.w_avg text_features = text_features.to(device) found_w.requires_grad = True # Adapted preprocessing routine for connecting StyleGAN to CLIP stylegan_transform = Compose([ Resize(224), lambda x: torch.clamp((x+1)/2,min=0,max=1), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) if not continue_opt: optimizer = torch.optim.Adam((found_w,),0.01,betas=(0,0.999)) for i in tqdm(range(iterations)): optimizer.zero_grad() img = G.synthesis(found_w+G.mapping.w_avg, noise_mode='const', force_fp32=not cuda_available) img = stylegan_transform(img) image_features = model.encode_image(img) loss = -torch.nn.functional.cosine_similarity(image_features,text_features) loss.backward() optimizer.step() img = G.synthesis(found_w+G.mapping.w_avg, noise_mode='const', force_fp32=not cuda_available) return PIL.Image.fromarray((img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)[0].cpu().numpy(), 'RGB') title = "StyleGAN+CLIP_with_Latent_Bootstraping" description = "demo for Anime2Sketch. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." article = "

Adversarial Open Domain Adaption for Sketch-to-Photo Synthesis | Github Repo

" gr.Interface( inference, "text", gr.outputs.Image(type="pil", label="Output"), title=title, description=description, article=article, enable_queue=True ).launch(debug=True)