import os os.system('pip install gradio --upgrade') os.system('pip freeze') import torch torch.hub.download_url_to_file('https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1', 'vqgan_imagenet_f16_16384.yaml') torch.hub.download_url_to_file('https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fckpts%2Flast.ckpt&dl=1', 'vqgan_imagenet_f16_16384.ckpt') import argparse import math from pathlib import Path import sys sys.path.insert(1, './taming-transformers') # from IPython import display from base64 import b64encode from omegaconf import OmegaConf from PIL import Image from taming.models import cond_transformer, vqgan import taming.modules from torch import nn, optim from torch.nn import functional as F from torchvision import transforms from torchvision.transforms import functional as TF from tqdm.notebook import tqdm from CLIP import clip import kornia.augmentation as K import numpy as np import imageio from PIL import ImageFile, Image ImageFile.LOAD_TRUNCATED_IMAGES = True import gradio as gr import nvidia_smi nvidia_smi.nvmlInit() handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0) # card id 0 hardcoded here, there is also a call to get all available card ids, so we could iterate torch.hub.download_url_to_file('https://images.pexels.com/photos/158028/bellingrath-gardens-alabama-landscape-scenic-158028.jpeg', 'garden.jpeg') torch.hub.download_url_to_file('https://images.pexels.com/photos/68767/divers-underwater-ocean-swim-68767.jpeg', 'coralreef.jpeg') torch.hub.download_url_to_file('https://images.pexels.com/photos/803975/pexels-photo-803975.jpeg', 'cabin.jpeg') def sinc(x): return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])) def lanczos(x, a): cond = torch.logical_and(-a < x, x < a) out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([])) return out / out.sum() def ramp(ratio, width): n = math.ceil(width / ratio + 1) out = torch.empty([n]) cur = 0 for i in range(out.shape[0]): out[i] = cur cur += ratio return torch.cat([-out[1:].flip([0]), out])[1:-1] def resample(input, size, align_corners=True): n, c, h, w = input.shape dh, dw = size input = input.view([n * c, 1, h, w]) if dh < h: kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype) pad_h = (kernel_h.shape[0] - 1) // 2 input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect') input = F.conv2d(input, kernel_h[None, None, :, None]) if dw < w: kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype) pad_w = (kernel_w.shape[0] - 1) // 2 input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect') input = F.conv2d(input, kernel_w[None, None, None, :]) input = input.view([n, c, h, w]) return F.interpolate(input, size, mode='bicubic', align_corners=align_corners) class ReplaceGrad(torch.autograd.Function): @staticmethod def forward(ctx, x_forward, x_backward): ctx.shape = x_backward.shape return x_forward @staticmethod def backward(ctx, grad_in): return None, grad_in.sum_to_size(ctx.shape) replace_grad = ReplaceGrad.apply class ClampWithGrad(torch.autograd.Function): @staticmethod def forward(ctx, input, min, max): ctx.min = min ctx.max = max ctx.save_for_backward(input) return input.clamp(min, max) @staticmethod def backward(ctx, grad_in): input, = ctx.saved_tensors return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None clamp_with_grad = ClampWithGrad.apply def vector_quantize(x, codebook): d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T indices = d.argmin(-1) x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook return replace_grad(x_q, x) class Prompt(nn.Module): def __init__(self, embed, weight=1., stop=float('-inf')): super().__init__() self.register_buffer('embed', embed) self.register_buffer('weight', torch.as_tensor(weight)) self.register_buffer('stop', torch.as_tensor(stop)) def forward(self, input): input_normed = F.normalize(input.unsqueeze(1), dim=2) embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2) dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) dists = dists * self.weight.sign() return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean() def parse_prompt(prompt): vals = prompt.rsplit(':', 2) vals = vals + ['', '1', '-inf'][len(vals):] return vals[0], float(vals[1]), float(vals[2]) class MakeCutouts(nn.Module): def __init__(self, cut_size, cutn, cut_pow=1.): super().__init__() self.cut_size = cut_size self.cutn = cutn self.cut_pow = cut_pow self.augs = nn.Sequential( # K.RandomHorizontalFlip(p=0.5), # K.RandomVerticalFlip(p=0.5), # K.RandomSolarize(0.01, 0.01, p=0.7), # K.RandomSharpness(0.3,p=0.4), # K.RandomResizedCrop(size=(self.cut_size,self.cut_size), scale=(0.1,1), ratio=(0.75,1.333), cropping_mode='resample', p=0.5), # K.RandomCrop(size=(self.cut_size,self.cut_size), p=0.5), K.RandomAffine(degrees=15, translate=0.1, p=0.7, padding_mode='border'), K.RandomPerspective(0.7,p=0.7), K.ColorJitter(hue=0.1, saturation=0.1, p=0.7), K.RandomErasing((.1, .4), (.3, 1/.3), same_on_batch=True, p=0.7), ) self.noise_fac = 0.1 self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size)) self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size)) def forward(self, input): sideY, sideX = input.shape[2:4] max_size = min(sideX, sideY) min_size = min(sideX, sideY, self.cut_size) cutouts = [] for _ in range(self.cutn): # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) # offsetx = torch.randint(0, sideX - size + 1, ()) # offsety = torch.randint(0, sideY - size + 1, ()) # cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] # cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) # cutout = transforms.Resize(size=(self.cut_size, self.cut_size))(input) cutout = (self.av_pool(input) + self.max_pool(input))/2 cutouts.append(cutout) batch = self.augs(torch.cat(cutouts, dim=0)) if self.noise_fac: facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) batch = batch + facs * torch.randn_like(batch) return batch def load_vqgan_model(config_path, checkpoint_path): config = OmegaConf.load(config_path) if config.model.target == 'taming.models.vqgan.VQModel': model = vqgan.VQModel(**config.model.params) model.eval().requires_grad_(False) model.init_from_ckpt(checkpoint_path) elif config.model.target == 'taming.models.vqgan.GumbelVQ': model = vqgan.GumbelVQ(**config.model.params) model.eval().requires_grad_(False) model.init_from_ckpt(checkpoint_path) elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer': parent_model = cond_transformer.Net2NetTransformer(**config.model.params) parent_model.eval().requires_grad_(False) parent_model.init_from_ckpt(checkpoint_path) model = parent_model.first_stage_model else: raise ValueError(f'unknown model type: {config.model.target}') del model.loss return model def resize_image(image, out_size): ratio = image.size[0] / image.size[1] area = min(image.size[0] * image.size[1], out_size[0] * out_size[1]) size = round((area * ratio)**0.5), round((area / ratio)**0.5) return image.resize(size, Image.LANCZOS) model_name = "vqgan_imagenet_f16_16384" images_interval = 50 width = 280 height = 280 init_image = "" seed = 42 args = argparse.Namespace( noise_prompt_seeds=[], noise_prompt_weights=[], size=[width, height], init_image=init_image, init_weight=0., clip_model='ViT-B/32', vqgan_config=f'{model_name}.yaml', vqgan_checkpoint=f'{model_name}.ckpt', step_size=0.15, cutn=4, cut_pow=1., display_freq=images_interval, seed=seed, ) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') print('Using device:', device) model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device) perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device) def inference(text, seed, step_size, max_iterations, width, height, init_image, init_weight): all_frames = [] size=[width, height] texts = text init_weight=init_weight if init_image != None init_image = init_image.name init_image = "" target_images = "" max_iterations = max_iterations model_names={"vqgan_imagenet_f16_16384": 'ImageNet 16384',"vqgan_imagenet_f16_1024":"ImageNet 1024", 'vqgan_openimages_f16_8192':'OpenImages 8912', "wikiart_1024":"WikiArt 1024", "wikiart_16384":"WikiArt 16384", "coco":"COCO-Stuff", "faceshq":"FacesHQ", "sflckr":"S-FLCKR"} name_model = model_names[model_name] if target_images == "None" or not target_images: target_images = [] else: target_images = target_images.split("|") target_images = [image.strip() for image in target_images] texts = [phrase.strip() for phrase in texts.split("|")] if texts == ['']: texts = [] from urllib.request import urlopen if texts: print('Using texts:', texts) if target_images: print('Using image prompts:', target_images) if seed is None or seed == -1: seed = torch.seed() else: seed = seed torch.manual_seed(seed) print('Using seed:', seed) # clock=deepcopy(perceptor.visual.positional_embedding.data) # perceptor.visual.positional_embedding.data = clock/clock.max() # perceptor.visual.positional_embedding.data=clamp_with_grad(clock,0,1) cut_size = perceptor.visual.input_resolution f = 2**(model.decoder.num_resolutions - 1) make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow) toksX, toksY = size[0] // f, size[1] // f sideX, sideY = toksX * f, toksY * f if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt': e_dim = 256 n_toks = model.quantize.n_embed z_min = model.quantize.embed.weight.min(dim=0).values[None, :, None, None] z_max = model.quantize.embed.weight.max(dim=0).values[None, :, None, None] else: e_dim = model.quantize.e_dim n_toks = model.quantize.n_e z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None] z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None] # z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None] # z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None] # normalize_imagenet = transforms.Normalize(mean=[0.485, 0.456, 0.406], # std=[0.229, 0.224, 0.225]) if init_image: if 'http' in init_image: img = Image.open(urlopen(init_image)) else: img = Image.open(init_image) pil_image = img.convert('RGB') pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS) pil_tensor = TF.to_tensor(pil_image) z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1) else: one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float() # z = one_hot @ model.quantize.embedding.weight if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt': z = one_hot @ model.quantize.embed.weight else: z = one_hot @ model.quantize.embedding.weight z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2) z = torch.rand_like(z)*2 z_orig = z.clone() z.requires_grad_(True) opt = optim.Adam([z], lr=step_size) normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) pMs = [] for prompt in texts: txt, weight, stop = parse_prompt(prompt) embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float() pMs.append(Prompt(embed, weight, stop).to(device)) for prompt in target_images: path, weight, stop = parse_prompt(prompt) img = Image.open(path) pil_image = img.convert('RGB') img = resize_image(pil_image, (sideX, sideY)) batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device)) embed = perceptor.encode_image(normalize(batch)).float() pMs.append(Prompt(embed, weight, stop).to(device)) for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights): gen = torch.Generator().manual_seed(seed) embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen) pMs.append(Prompt(embed, weight).to(device)) def synth(z): if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt': z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim(3, 1) else: z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1) return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1) @torch.no_grad() def checkin(i, losses): losses_str = ', '.join(f'{loss.item():g}' for loss in losses) tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}') out = synth(z) # TF.to_pil_image(out[0].cpu()).save('progress.png') # display.display(display.Image('progress.png')) res = nvidia_smi.nvmlDeviceGetUtilizationRates(handle) print(f'gpu: {res.gpu}%, gpu-mem: {res.memory}%') def ascend_txt(): # global i out = synth(z) iii = perceptor.encode_image(normalize(make_cutouts(out))).float() result = [] if init_weight: result.append(F.mse_loss(z, z_orig) * init_weight / 2) #result.append(F.mse_loss(z, torch.zeros_like(z_orig)) * ((1/torch.tensor(i*2 + 1))*init_weight) / 2) for prompt in pMs: result.append(prompt(iii)) img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:] img = np.transpose(img, (1, 2, 0)) # imageio.imwrite('./steps/' + str(i) + '.png', np.array(img)) img = Image.fromarray(img).convert('RGB') all_frames.append(img) return result, np.array(img) def train(i): opt.zero_grad() lossAll, image = ascend_txt() if i % args.display_freq == 0: checkin(i, lossAll) loss = sum(lossAll) loss.backward() opt.step() with torch.no_grad(): z.copy_(z.maximum(z_min).minimum(z_max)) return image i = 0 try: with tqdm() as pbar: while True: image = train(i) if i == max_iterations: break i += 1 pbar.update() except KeyboardInterrupt: pass all_frames[0].save('out.gif', save_all=True, append_images=all_frames[1:], optimize=False, duration=80, loop=0) return image, 'out.gif' def load_image( infilename ) : img = Image.open( infilename ) img.load() data = np.asarray( img, dtype="int32" ) return data title = "VQGAN + CLIP" description = "Gradio demo for VQGAN + CLIP. To use it, simply add your text, or click one of the examples to load them. Read more at the links below. Please click submit only once. Results will show up in under a minute." article = "

Originally made by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). The original BigGAN+CLIP method was by https://twitter.com/advadnoun. Added some explanations and modifications by Eleiber#8347, pooling trick by Crimeacs#8222 (https://twitter.com/EarthML1) and the GUI was made with the help of Abulafia#3734. | Colab | Taming Transformers Github Repo | CLIP Github Repo | Special thanks to BoneAmputee (https://twitter.com/BoneAmputee) for suggestions and advice

" gr.Interface( inference, [gr.inputs.Textbox(label="Text Input"), gr.inputs.Number(default=42, label="seed"), gr.inputs.Slider(minimum=0.1, maximum=0.9, default=0.17, label='step size'), gr.inputs.Slider(minimum=25, maximum=400, default=80, label='max iterations', step=1), gr.inputs.Slider(minimum=200, maximum=500, default=256, label='width', step=1), gr.inputs.Slider(minimum=200, maximum=500, default=256, label='height', step=1), gr.inputs.Image(type="file", label="Initial Image", optional=True), gr.inputs.Slider(minimum=0.0, maximum=15.0, default=0.0, label='Initial Weight', step=1.0), ], [gr.outputs.Image(type="numpy", label="Output Image"),gr.outputs.Image(type="file", label="Output GIF")], title=title, description=description, article=article, examples=[ ['a garden by james gurney',42,0.16, 100, 256, 256, 'garden.jpeg', 0.0 ], ['coral reef city artstationHQ',1000,0.6, 110, 200, 200, 'coralreef.jpeg', 0.0], ['a cabin in the mountains unreal engine',98,0.3, 120, 280, 280, 'cabin.jpeg', 0.0] ], enable_queue=True ).launch(debug=True)