Spaces:
Runtime error
Runtime error
import torch | |
torch.hub.download_url_to_file('http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.yaml', 'vqgan_imagenet_f16_16384.yaml') | |
torch.hub.download_url_to_file('http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.ckpt', 'vqgan_imagenet_f16_16384.ckpt') | |
import argparse | |
import math | |
from pathlib import Path | |
import sys | |
sys.path.insert(1, './taming-transformers') | |
#from IPython import display | |
from base64 import b64encode | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from taming.models import cond_transformer, vqgan | |
import taming.modules | |
from torch import nn, optim | |
from torch.nn import functional as F | |
from torchvision import transforms | |
from torchvision.transforms import functional as TF | |
from tqdm.notebook import tqdm | |
from CLIP import clip | |
import kornia.augmentation as K | |
import numpy as np | |
import imageio | |
from PIL import ImageFile, Image | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
import gradio as gr | |
torch.hub.download_url_to_file('https://i.imgur.com/WEHmKef.jpg', 'gpu.jpg') | |
def sinc(x): | |
return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])) | |
def lanczos(x, a): | |
cond = torch.logical_and(-a < x, x < a) | |
out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([])) | |
return out / out.sum() | |
def ramp(ratio, width): | |
n = math.ceil(width / ratio + 1) | |
out = torch.empty([n]) | |
cur = 0 | |
for i in range(out.shape[0]): | |
out[i] = cur | |
cur += ratio | |
return torch.cat([-out[1:].flip([0]), out])[1:-1] | |
def resample(input, size, align_corners=True): | |
n, c, h, w = input.shape | |
dh, dw = size | |
input = input.view([n * c, 1, h, w]) | |
if dh < h: | |
kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype) | |
pad_h = (kernel_h.shape[0] - 1) // 2 | |
input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect') | |
input = F.conv2d(input, kernel_h[None, None, :, None]) | |
if dw < w: | |
kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype) | |
pad_w = (kernel_w.shape[0] - 1) // 2 | |
input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect') | |
input = F.conv2d(input, kernel_w[None, None, None, :]) | |
input = input.view([n, c, h, w]) | |
return F.interpolate(input, size, mode='bicubic', align_corners=align_corners) | |
class ReplaceGrad(torch.autograd.Function): | |
def forward(ctx, x_forward, x_backward): | |
ctx.shape = x_backward.shape | |
return x_forward | |
def backward(ctx, grad_in): | |
return None, grad_in.sum_to_size(ctx.shape) | |
replace_grad = ReplaceGrad.apply | |
class ClampWithGrad(torch.autograd.Function): | |
def forward(ctx, input, min, max): | |
ctx.min = min | |
ctx.max = max | |
ctx.save_for_backward(input) | |
return input.clamp(min, max) | |
def backward(ctx, grad_in): | |
input, = ctx.saved_tensors | |
return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None | |
clamp_with_grad = ClampWithGrad.apply | |
def vector_quantize(x, codebook): | |
d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T | |
indices = d.argmin(-1) | |
x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook | |
return replace_grad(x_q, x) | |
class Prompt(nn.Module): | |
def __init__(self, embed, weight=1., stop=float('-inf')): | |
super().__init__() | |
self.register_buffer('embed', embed) | |
self.register_buffer('weight', torch.as_tensor(weight)) | |
self.register_buffer('stop', torch.as_tensor(stop)) | |
def forward(self, input): | |
input_normed = F.normalize(input.unsqueeze(1), dim=2) | |
embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2) | |
dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) | |
dists = dists * self.weight.sign() | |
return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean() | |
def parse_prompt(prompt): | |
vals = prompt.rsplit(':', 2) | |
vals = vals + ['', '1', '-inf'][len(vals):] | |
return vals[0], float(vals[1]), float(vals[2]) | |
class MakeCutouts(nn.Module): | |
def __init__(self, cut_size, cutn, cut_pow=1.): | |
super().__init__() | |
self.cut_size = cut_size | |
self.cutn = cutn | |
self.cut_pow = cut_pow | |
self.augs = nn.Sequential( | |
# K.RandomHorizontalFlip(p=0.5), | |
# K.RandomVerticalFlip(p=0.5), | |
# K.RandomSolarize(0.01, 0.01, p=0.7), | |
# K.RandomSharpness(0.3,p=0.4), | |
# K.RandomResizedCrop(size=(self.cut_size,self.cut_size), scale=(0.1,1), ratio=(0.75,1.333), cropping_mode='resample', p=0.5), | |
# K.RandomCrop(size=(self.cut_size,self.cut_size), p=0.5), | |
K.RandomAffine(degrees=15, translate=0.1, p=0.7, padding_mode='border'), | |
K.RandomPerspective(0.7,p=0.7), | |
K.ColorJitter(hue=0.1, saturation=0.1, p=0.7), | |
K.RandomErasing((.1, .4), (.3, 1/.3), same_on_batch=True, p=0.7), | |
) | |
self.noise_fac = 0.1 | |
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size)) | |
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size)) | |
def forward(self, input): | |
sideY, sideX = input.shape[2:4] | |
max_size = min(sideX, sideY) | |
min_size = min(sideX, sideY, self.cut_size) | |
cutouts = [] | |
for _ in range(self.cutn): | |
# size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) | |
# offsetx = torch.randint(0, sideX - size + 1, ()) | |
# offsety = torch.randint(0, sideY - size + 1, ()) | |
# cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] | |
# cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) | |
# cutout = transforms.Resize(size=(self.cut_size, self.cut_size))(input) | |
cutout = (self.av_pool(input) + self.max_pool(input))/2 | |
cutouts.append(cutout) | |
batch = self.augs(torch.cat(cutouts, dim=0)) | |
if self.noise_fac: | |
facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) | |
batch = batch + facs * torch.randn_like(batch) | |
return batch | |
def load_vqgan_model(config_path, checkpoint_path): | |
config = OmegaConf.load(config_path) | |
if config.model.target == 'taming.models.vqgan.VQModel': | |
model = vqgan.VQModel(**config.model.params) | |
model.eval().requires_grad_(False) | |
model.init_from_ckpt(checkpoint_path) | |
elif config.model.target == 'taming.models.vqgan.GumbelVQ': | |
model = vqgan.GumbelVQ(**config.model.params) | |
model.eval().requires_grad_(False) | |
model.init_from_ckpt(checkpoint_path) | |
elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer': | |
parent_model = cond_transformer.Net2NetTransformer(**config.model.params) | |
parent_model.eval().requires_grad_(False) | |
parent_model.init_from_ckpt(checkpoint_path) | |
model = parent_model.first_stage_model | |
else: | |
raise ValueError(f'unknown model type: {config.model.target}') | |
del model.loss | |
return model | |
def resize_image(image, out_size): | |
ratio = image.size[0] / image.size[1] | |
area = min(image.size[0] * image.size[1], out_size[0] * out_size[1]) | |
size = round((area * ratio)**0.5), round((area / ratio)**0.5) | |
return image.resize(size, Image.LANCZOS) | |
model_name = "vqgan_imagenet_f16_16384" | |
images_interval = 50 | |
width = 280 | |
height = 280 | |
init_image = "" | |
seed = 42 | |
args = argparse.Namespace( | |
noise_prompt_seeds=[], | |
noise_prompt_weights=[], | |
size=[width, height], | |
init_image=init_image, | |
init_weight=0., | |
clip_model='ViT-B/32', | |
vqgan_config=f'{model_name}.yaml', | |
vqgan_checkpoint=f'{model_name}.ckpt', | |
step_size=0.15, | |
cutn=4, | |
cut_pow=1., | |
display_freq=images_interval, | |
seed=seed, | |
) | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
print('Using device:', device) | |
model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device) | |
perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device) | |
def inference(text, seed, step_size): | |
texts = text | |
target_images = "" | |
max_iterations = 100 | |
model_names={"vqgan_imagenet_f16_16384": 'ImageNet 16384',"vqgan_imagenet_f16_1024":"ImageNet 1024", 'vqgan_openimages_f16_8192':'OpenImages 8912', | |
"wikiart_1024":"WikiArt 1024", "wikiart_16384":"WikiArt 16384", "coco":"COCO-Stuff", "faceshq":"FacesHQ", "sflckr":"S-FLCKR"} | |
name_model = model_names[model_name] | |
if target_images == "None" or not target_images: | |
target_images = [] | |
else: | |
target_images = target_images.split("|") | |
target_images = [image.strip() for image in target_images] | |
texts = [phrase.strip() for phrase in texts.split("|")] | |
if texts == ['']: | |
texts = [] | |
from urllib.request import urlopen | |
if texts: | |
print('Using texts:', texts) | |
if target_images: | |
print('Using image prompts:', target_images) | |
if seed is None or seed == -1: | |
seed = torch.seed() | |
else: | |
seed = seed | |
torch.manual_seed(seed) | |
print('Using seed:', seed) | |
# clock=deepcopy(perceptor.visual.positional_embedding.data) | |
# perceptor.visual.positional_embedding.data = clock/clock.max() | |
# perceptor.visual.positional_embedding.data=clamp_with_grad(clock,0,1) | |
cut_size = perceptor.visual.input_resolution | |
f = 2**(model.decoder.num_resolutions - 1) | |
make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow) | |
toksX, toksY = args.size[0] // f, args.size[1] // f | |
sideX, sideY = toksX * f, toksY * f | |
if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt': | |
e_dim = 256 | |
n_toks = model.quantize.n_embed | |
z_min = model.quantize.embed.weight.min(dim=0).values[None, :, None, None] | |
z_max = model.quantize.embed.weight.max(dim=0).values[None, :, None, None] | |
else: | |
e_dim = model.quantize.e_dim | |
n_toks = model.quantize.n_e | |
z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None] | |
z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None] | |
# z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None] | |
# z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None] | |
# normalize_imagenet = transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
# std=[0.229, 0.224, 0.225]) | |
if args.init_image: | |
if 'http' in args.init_image: | |
img = Image.open(urlopen(args.init_image)) | |
else: | |
img = Image.open(args.init_image) | |
pil_image = img.convert('RGB') | |
pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS) | |
pil_tensor = TF.to_tensor(pil_image) | |
z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1) | |
else: | |
one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float() | |
# z = one_hot @ model.quantize.embedding.weight | |
if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt': | |
z = one_hot @ model.quantize.embed.weight | |
else: | |
z = one_hot @ model.quantize.embedding.weight | |
z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2) | |
z = torch.rand_like(z)*2 | |
z_orig = z.clone() | |
z.requires_grad_(True) | |
opt = optim.Adam([z], lr=step_size) | |
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], | |
std=[0.26862954, 0.26130258, 0.27577711]) | |
pMs = [] | |
for prompt in texts: | |
txt, weight, stop = parse_prompt(prompt) | |
embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float() | |
pMs.append(Prompt(embed, weight, stop).to(device)) | |
for prompt in target_images: | |
path, weight, stop = parse_prompt(prompt) | |
img = Image.open(path) | |
pil_image = img.convert('RGB') | |
img = resize_image(pil_image, (sideX, sideY)) | |
batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device)) | |
embed = perceptor.encode_image(normalize(batch)).float() | |
pMs.append(Prompt(embed, weight, stop).to(device)) | |
for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights): | |
gen = torch.Generator().manual_seed(seed) | |
embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen) | |
pMs.append(Prompt(embed, weight).to(device)) | |
def synth(z): | |
if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt': | |
z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim(3, 1) | |
else: | |
z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1) | |
return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1) | |
def checkin(i, losses): | |
losses_str = ', '.join(f'{loss.item():g}' for loss in losses) | |
tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}') | |
out = synth(z) | |
#TF.to_pil_image(out[0].cpu()).save('progress.png') | |
#display.display(display.Image('progress.png')) | |
def ascend_txt(): | |
# global i | |
out = synth(z) | |
iii = perceptor.encode_image(normalize(make_cutouts(out))).float() | |
result = [] | |
if args.init_weight: | |
# result.append(F.mse_loss(z, z_orig) * args.init_weight / 2) | |
result.append(F.mse_loss(z, torch.zeros_like(z_orig)) * ((1/torch.tensor(i*2 + 1))*args.init_weight) / 2) | |
for prompt in pMs: | |
result.append(prompt(iii)) | |
img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:] | |
img = np.transpose(img, (1, 2, 0)) | |
#imageio.imwrite('./steps/' + str(i) + '.png', np.array(img)) | |
return result, np.array(img) | |
def train(i): | |
opt.zero_grad() | |
lossAll, image = ascend_txt() | |
if i % args.display_freq == 0: | |
checkin(i, lossAll) | |
loss = sum(lossAll) | |
loss.backward() | |
opt.step() | |
with torch.no_grad(): | |
z.copy_(z.maximum(z_min).minimum(z_max)) | |
return image | |
i = 0 | |
try: | |
with tqdm() as pbar: | |
while True: | |
image = train(i) | |
if i == max_iterations: | |
break | |
i += 1 | |
pbar.update() | |
except KeyboardInterrupt: | |
pass | |
return image | |
inferences_running = 0 | |
def load_image( infilename ) : | |
img = Image.open( infilename ) | |
img.load() | |
data = np.asarray( img, dtype="int32" ) | |
return data | |
def throttled_inference(text, seed, step_size): | |
global inferences_running | |
current = inferences_running | |
if current >= 2: | |
print(f"Rejected inference when we already had {current} running") | |
return load_image("./gpu.jpg") | |
print(f"Inference starting when we already had {current} running") | |
inferences_running += 1 | |
try: | |
return inference(text, seed, step_size) | |
finally: | |
print("Inference finished") | |
inferences_running -= 1 | |
title = "VQGAN + CLIP" | |
description = "Gradio demo for VQGAN + CLIP. To use it, simply add your text, or click one of the examples to load them. Read more at the links below. Please click submit only once. Results will show up in under a minute." | |
article = "<p style='text-align: center'>Originally made by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). The original BigGAN+CLIP method was by https://twitter.com/advadnoun. Added some explanations and modifications by Eleiber#8347, pooling trick by Crimeacs#8222 (https://twitter.com/EarthML1) and the GUI was made with the help of Abulafia#3734. | <a href='https://colab.research.google.com/drive/1ZAus_gn2RhTZWzOWUpPERNC0Q8OhZRTZ'>Colab</a> | <a href='https://github.com/CompVis/taming-transformers'>Taming Transformers Github Repo</a> | <a href='https://github.com/openai/CLIP'>CLIP Github Repo</a> | Special thanks to BoneAmputee (https://twitter.com/BoneAmputee) for suggestions and advice</p>" | |
gr.Interface( | |
throttled_inference, | |
[gr.inputs.Textbox(label="Input"), | |
gr.inputs.Number(default=42, label="seed"), | |
gr.inputs.Slider(minimum=0.1, maximum=0.9, default=0.23, label='step size') | |
], | |
gr.outputs.Image(type="numpy", label="Output"), | |
title=title, | |
description=description, | |
article=article, | |
examples=[ | |
['a garden by james gurney',42,0.23], | |
['coral reef city artstationHQ',1000,0.6], | |
['a cabin in the mountains unreal engine',98,0.3] | |
] | |
).launch(debug=True) |