VQGAN_CLIP / app.py
Ahsen Khaliq
2 jobs
67bd1fa
raw
history blame
16.8 kB
import torch
torch.hub.download_url_to_file('http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.yaml', 'vqgan_imagenet_f16_16384.yaml')
torch.hub.download_url_to_file('http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.ckpt', 'vqgan_imagenet_f16_16384.ckpt')
import argparse
import math
from pathlib import Path
import sys
sys.path.insert(1, './taming-transformers')
#from IPython import display
from base64 import b64encode
from omegaconf import OmegaConf
from PIL import Image
from taming.models import cond_transformer, vqgan
import taming.modules
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm
from CLIP import clip
import kornia.augmentation as K
import numpy as np
import imageio
from PIL import ImageFile, Image
ImageFile.LOAD_TRUNCATED_IMAGES = True
import gradio as gr
torch.hub.download_url_to_file('https://i.imgur.com/WEHmKef.jpg', 'gpu.jpg')
def sinc(x):
return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
def lanczos(x, a):
cond = torch.logical_and(-a < x, x < a)
out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
return out / out.sum()
def ramp(ratio, width):
n = math.ceil(width / ratio + 1)
out = torch.empty([n])
cur = 0
for i in range(out.shape[0]):
out[i] = cur
cur += ratio
return torch.cat([-out[1:].flip([0]), out])[1:-1]
def resample(input, size, align_corners=True):
n, c, h, w = input.shape
dh, dw = size
input = input.view([n * c, 1, h, w])
if dh < h:
kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
pad_h = (kernel_h.shape[0] - 1) // 2
input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
input = F.conv2d(input, kernel_h[None, None, :, None])
if dw < w:
kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
pad_w = (kernel_w.shape[0] - 1) // 2
input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
input = F.conv2d(input, kernel_w[None, None, None, :])
input = input.view([n, c, h, w])
return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
class ReplaceGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, x_forward, x_backward):
ctx.shape = x_backward.shape
return x_forward
@staticmethod
def backward(ctx, grad_in):
return None, grad_in.sum_to_size(ctx.shape)
replace_grad = ReplaceGrad.apply
class ClampWithGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, input, min, max):
ctx.min = min
ctx.max = max
ctx.save_for_backward(input)
return input.clamp(min, max)
@staticmethod
def backward(ctx, grad_in):
input, = ctx.saved_tensors
return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None
clamp_with_grad = ClampWithGrad.apply
def vector_quantize(x, codebook):
d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
indices = d.argmin(-1)
x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
return replace_grad(x_q, x)
class Prompt(nn.Module):
def __init__(self, embed, weight=1., stop=float('-inf')):
super().__init__()
self.register_buffer('embed', embed)
self.register_buffer('weight', torch.as_tensor(weight))
self.register_buffer('stop', torch.as_tensor(stop))
def forward(self, input):
input_normed = F.normalize(input.unsqueeze(1), dim=2)
embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
dists = dists * self.weight.sign()
return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
def parse_prompt(prompt):
vals = prompt.rsplit(':', 2)
vals = vals + ['', '1', '-inf'][len(vals):]
return vals[0], float(vals[1]), float(vals[2])
class MakeCutouts(nn.Module):
def __init__(self, cut_size, cutn, cut_pow=1.):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
self.cut_pow = cut_pow
self.augs = nn.Sequential(
# K.RandomHorizontalFlip(p=0.5),
# K.RandomVerticalFlip(p=0.5),
# K.RandomSolarize(0.01, 0.01, p=0.7),
# K.RandomSharpness(0.3,p=0.4),
# K.RandomResizedCrop(size=(self.cut_size,self.cut_size), scale=(0.1,1), ratio=(0.75,1.333), cropping_mode='resample', p=0.5),
# K.RandomCrop(size=(self.cut_size,self.cut_size), p=0.5),
K.RandomAffine(degrees=15, translate=0.1, p=0.7, padding_mode='border'),
K.RandomPerspective(0.7,p=0.7),
K.ColorJitter(hue=0.1, saturation=0.1, p=0.7),
K.RandomErasing((.1, .4), (.3, 1/.3), same_on_batch=True, p=0.7),
)
self.noise_fac = 0.1
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
def forward(self, input):
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
min_size = min(sideX, sideY, self.cut_size)
cutouts = []
for _ in range(self.cutn):
# size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
# offsetx = torch.randint(0, sideX - size + 1, ())
# offsety = torch.randint(0, sideY - size + 1, ())
# cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
# cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
# cutout = transforms.Resize(size=(self.cut_size, self.cut_size))(input)
cutout = (self.av_pool(input) + self.max_pool(input))/2
cutouts.append(cutout)
batch = self.augs(torch.cat(cutouts, dim=0))
if self.noise_fac:
facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
batch = batch + facs * torch.randn_like(batch)
return batch
def load_vqgan_model(config_path, checkpoint_path):
config = OmegaConf.load(config_path)
if config.model.target == 'taming.models.vqgan.VQModel':
model = vqgan.VQModel(**config.model.params)
model.eval().requires_grad_(False)
model.init_from_ckpt(checkpoint_path)
elif config.model.target == 'taming.models.vqgan.GumbelVQ':
model = vqgan.GumbelVQ(**config.model.params)
model.eval().requires_grad_(False)
model.init_from_ckpt(checkpoint_path)
elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':
parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
parent_model.eval().requires_grad_(False)
parent_model.init_from_ckpt(checkpoint_path)
model = parent_model.first_stage_model
else:
raise ValueError(f'unknown model type: {config.model.target}')
del model.loss
return model
def resize_image(image, out_size):
ratio = image.size[0] / image.size[1]
area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
size = round((area * ratio)**0.5), round((area / ratio)**0.5)
return image.resize(size, Image.LANCZOS)
model_name = "vqgan_imagenet_f16_16384"
images_interval = 50
width = 280
height = 280
init_image = ""
seed = 42
args = argparse.Namespace(
noise_prompt_seeds=[],
noise_prompt_weights=[],
size=[width, height],
init_image=init_image,
init_weight=0.,
clip_model='ViT-B/32',
vqgan_config=f'{model_name}.yaml',
vqgan_checkpoint=f'{model_name}.ckpt',
step_size=0.15,
cutn=4,
cut_pow=1.,
display_freq=images_interval,
seed=seed,
)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)
def inference(text, seed, step_size):
texts = text
target_images = ""
max_iterations = 100
model_names={"vqgan_imagenet_f16_16384": 'ImageNet 16384',"vqgan_imagenet_f16_1024":"ImageNet 1024", 'vqgan_openimages_f16_8192':'OpenImages 8912',
"wikiart_1024":"WikiArt 1024", "wikiart_16384":"WikiArt 16384", "coco":"COCO-Stuff", "faceshq":"FacesHQ", "sflckr":"S-FLCKR"}
name_model = model_names[model_name]
if target_images == "None" or not target_images:
target_images = []
else:
target_images = target_images.split("|")
target_images = [image.strip() for image in target_images]
texts = [phrase.strip() for phrase in texts.split("|")]
if texts == ['']:
texts = []
from urllib.request import urlopen
if texts:
print('Using texts:', texts)
if target_images:
print('Using image prompts:', target_images)
if seed is None or seed == -1:
seed = torch.seed()
else:
seed = seed
torch.manual_seed(seed)
print('Using seed:', seed)
# clock=deepcopy(perceptor.visual.positional_embedding.data)
# perceptor.visual.positional_embedding.data = clock/clock.max()
# perceptor.visual.positional_embedding.data=clamp_with_grad(clock,0,1)
cut_size = perceptor.visual.input_resolution
f = 2**(model.decoder.num_resolutions - 1)
make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)
toksX, toksY = args.size[0] // f, args.size[1] // f
sideX, sideY = toksX * f, toksY * f
if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':
e_dim = 256
n_toks = model.quantize.n_embed
z_min = model.quantize.embed.weight.min(dim=0).values[None, :, None, None]
z_max = model.quantize.embed.weight.max(dim=0).values[None, :, None, None]
else:
e_dim = model.quantize.e_dim
n_toks = model.quantize.n_e
z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
# z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
# z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
# normalize_imagenet = transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225])
if args.init_image:
if 'http' in args.init_image:
img = Image.open(urlopen(args.init_image))
else:
img = Image.open(args.init_image)
pil_image = img.convert('RGB')
pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
pil_tensor = TF.to_tensor(pil_image)
z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)
else:
one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float()
# z = one_hot @ model.quantize.embedding.weight
if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':
z = one_hot @ model.quantize.embed.weight
else:
z = one_hot @ model.quantize.embedding.weight
z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
z = torch.rand_like(z)*2
z_orig = z.clone()
z.requires_grad_(True)
opt = optim.Adam([z], lr=step_size)
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
pMs = []
for prompt in texts:
txt, weight, stop = parse_prompt(prompt)
embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
pMs.append(Prompt(embed, weight, stop).to(device))
for prompt in target_images:
path, weight, stop = parse_prompt(prompt)
img = Image.open(path)
pil_image = img.convert('RGB')
img = resize_image(pil_image, (sideX, sideY))
batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
embed = perceptor.encode_image(normalize(batch)).float()
pMs.append(Prompt(embed, weight, stop).to(device))
for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
gen = torch.Generator().manual_seed(seed)
embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
pMs.append(Prompt(embed, weight).to(device))
def synth(z):
if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':
z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim(3, 1)
else:
z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)
return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)
@torch.no_grad()
def checkin(i, losses):
losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')
out = synth(z)
#TF.to_pil_image(out[0].cpu()).save('progress.png')
#display.display(display.Image('progress.png'))
def ascend_txt():
# global i
out = synth(z)
iii = perceptor.encode_image(normalize(make_cutouts(out))).float()
result = []
if args.init_weight:
# result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)
result.append(F.mse_loss(z, torch.zeros_like(z_orig)) * ((1/torch.tensor(i*2 + 1))*args.init_weight) / 2)
for prompt in pMs:
result.append(prompt(iii))
img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]
img = np.transpose(img, (1, 2, 0))
#imageio.imwrite('./steps/' + str(i) + '.png', np.array(img))
return result, np.array(img)
def train(i):
opt.zero_grad()
lossAll, image = ascend_txt()
if i % args.display_freq == 0:
checkin(i, lossAll)
loss = sum(lossAll)
loss.backward()
opt.step()
with torch.no_grad():
z.copy_(z.maximum(z_min).minimum(z_max))
return image
i = 0
try:
with tqdm() as pbar:
while True:
image = train(i)
if i == max_iterations:
break
i += 1
pbar.update()
except KeyboardInterrupt:
pass
return image
inferences_running = 0
def load_image( infilename ) :
img = Image.open( infilename )
img.load()
data = np.asarray( img, dtype="int32" )
return data
def throttled_inference(text, seed, step_size):
global inferences_running
current = inferences_running
if current >= 2:
print(f"Rejected inference when we already had {current} running")
return load_image("./gpu.jpg")
print(f"Inference starting when we already had {current} running")
inferences_running += 1
try:
return inference(text, seed, step_size)
finally:
print("Inference finished")
inferences_running -= 1
title = "VQGAN + CLIP"
description = "Gradio demo for VQGAN + CLIP. To use it, simply add your text, or click one of the examples to load them. Read more at the links below. Please click submit only once. Results will show up in under a minute."
article = "<p style='text-align: center'>Originally made by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). The original BigGAN+CLIP method was by https://twitter.com/advadnoun. Added some explanations and modifications by Eleiber#8347, pooling trick by Crimeacs#8222 (https://twitter.com/EarthML1) and the GUI was made with the help of Abulafia#3734. | <a href='https://colab.research.google.com/drive/1ZAus_gn2RhTZWzOWUpPERNC0Q8OhZRTZ'>Colab</a> | <a href='https://github.com/CompVis/taming-transformers'>Taming Transformers Github Repo</a> | <a href='https://github.com/openai/CLIP'>CLIP Github Repo</a> | Special thanks to BoneAmputee (https://twitter.com/BoneAmputee) for suggestions and advice</p>"
gr.Interface(
throttled_inference,
[gr.inputs.Textbox(label="Input"),
gr.inputs.Number(default=42, label="seed"),
gr.inputs.Slider(minimum=0.1, maximum=0.9, default=0.23, label='step size')
],
gr.outputs.Image(type="numpy", label="Output"),
title=title,
description=description,
article=article,
examples=[
['a garden by james gurney',42,0.23],
['coral reef city artstationHQ',1000,0.6],
['a cabin in the mountains unreal engine',98,0.3]
]
).launch(debug=True)