Spaces:
Runtime error
Runtime error
File size: 18,637 Bytes
eea2605 e829f30 eea2605 e9cd1e9 e984b5c 799297b e984b5c a87d0e8 2a759eb e984b5c d1adddc 6e45c49 d1adddc 2a2b7e0 a528f1f 1d40feb d1adddc e984b5c 2a759eb ec77de0 1880628 ec77de0 3c93d7b 821b58d e984b5c d1adddc 9a34aef e984b5c c62bf4d e984b5c c62bf4d e984b5c 2a2b7e0 408fc5d e984b5c 408fc5d e984b5c 408fc5d e984b5c c62bf4d e984b5c 9a34aef e984b5c 9a34aef e984b5c 2a759eb a87d0e8 e984b5c 2a759eb 1de7080 e984b5c 8828289 2041e1f e984b5c 2041e1f e984b5c 2041e1f e984b5c 2041e1f e984b5c 73529df df3c3ca 73529df 1d19779 2397336 e984b5c dabf774 c62bf4d e984b5c eea2605 fe3881c c62bf4d 0f59c7a a4a0791 24c84ca 2a759eb 2a2b7e0 c3250ef c62bf4d 73529df e984b5c b8fa7a6 562bf67 b8fa7a6 c9e86ad 10b0dff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 |
import os
os.system("pip install gradio==2.8.0b5")
os.system('pip freeze')
import torch
torch.hub.download_url_to_file('https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1', 'vqgan_imagenet_f16_16384.yaml')
torch.hub.download_url_to_file('https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fckpts%2Flast.ckpt&dl=1', 'vqgan_imagenet_f16_16384.ckpt')
import argparse
import math
from pathlib import Path
import sys
sys.path.insert(1, './taming-transformers')
from base64 import b64encode
from omegaconf import OmegaConf
from PIL import Image
from taming.models import cond_transformer, vqgan
import taming.modules
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm
from CLIP import clip
import kornia.augmentation as K
import numpy as np
import imageio
from PIL import ImageFile, Image
ImageFile.LOAD_TRUNCATED_IMAGES = True
import gradio as gr
import nvidia_smi
nvidia_smi.nvmlInit()
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
# card id 0 hardcoded here, there is also a call to get all available card ids, so we could iterate
torch.hub.download_url_to_file('https://images.pexels.com/photos/158028/bellingrath-gardens-alabama-landscape-scenic-158028.jpeg', 'garden.jpeg')
torch.hub.download_url_to_file('https://images.pexels.com/photos/68767/divers-underwater-ocean-swim-68767.jpeg', 'coralreef.jpeg')
torch.hub.download_url_to_file('https://images.pexels.com/photos/803975/pexels-photo-803975.jpeg', 'cabin.jpeg')
def sinc(x):
return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
def lanczos(x, a):
cond = torch.logical_and(-a < x, x < a)
out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
return out / out.sum()
def ramp(ratio, width):
n = math.ceil(width / ratio + 1)
out = torch.empty([n])
cur = 0
for i in range(out.shape[0]):
out[i] = cur
cur += ratio
return torch.cat([-out[1:].flip([0]), out])[1:-1]
def resample(input, size, align_corners=True):
n, c, h, w = input.shape
dh, dw = size
input = input.view([n * c, 1, h, w])
if dh < h:
kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
pad_h = (kernel_h.shape[0] - 1) // 2
input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
input = F.conv2d(input, kernel_h[None, None, :, None])
if dw < w:
kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
pad_w = (kernel_w.shape[0] - 1) // 2
input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
input = F.conv2d(input, kernel_w[None, None, None, :])
input = input.view([n, c, h, w])
return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
class ReplaceGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, x_forward, x_backward):
ctx.shape = x_backward.shape
return x_forward
@staticmethod
def backward(ctx, grad_in):
return None, grad_in.sum_to_size(ctx.shape)
replace_grad = ReplaceGrad.apply
class ClampWithGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, input, min, max):
ctx.min = min
ctx.max = max
ctx.save_for_backward(input)
return input.clamp(min, max)
@staticmethod
def backward(ctx, grad_in):
input, = ctx.saved_tensors
return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None
clamp_with_grad = ClampWithGrad.apply
def vector_quantize(x, codebook):
d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
indices = d.argmin(-1)
x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
return replace_grad(x_q, x)
class Prompt(nn.Module):
def __init__(self, embed, weight=1., stop=float('-inf')):
super().__init__()
self.register_buffer('embed', embed)
self.register_buffer('weight', torch.as_tensor(weight))
self.register_buffer('stop', torch.as_tensor(stop))
def forward(self, input):
input_normed = F.normalize(input.unsqueeze(1), dim=2)
embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
dists = dists * self.weight.sign()
return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
def parse_prompt(prompt):
vals = prompt.rsplit(':', 2)
vals = vals + ['', '1', '-inf'][len(vals):]
return vals[0], float(vals[1]), float(vals[2])
class MakeCutouts(nn.Module):
def __init__(self, cut_size, cutn, cut_pow=1.):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
self.cut_pow = cut_pow
self.augs = nn.Sequential(
# K.RandomHorizontalFlip(p=0.5),
# K.RandomVerticalFlip(p=0.5),
# K.RandomSolarize(0.01, 0.01, p=0.7),
# K.RandomSharpness(0.3,p=0.4),
# K.RandomResizedCrop(size=(self.cut_size,self.cut_size), scale=(0.1,1), ratio=(0.75,1.333), cropping_mode='resample', p=0.5),
# K.RandomCrop(size=(self.cut_size,self.cut_size), p=0.5),
K.RandomAffine(degrees=15, translate=0.1, p=0.7, padding_mode='border'),
K.RandomPerspective(0.7,p=0.7),
K.ColorJitter(hue=0.1, saturation=0.1, p=0.7),
K.RandomErasing((.1, .4), (.3, 1/.3), same_on_batch=True, p=0.7),
)
self.noise_fac = 0.1
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
def forward(self, input):
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
min_size = min(sideX, sideY, self.cut_size)
cutouts = []
for _ in range(self.cutn):
# size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
# offsetx = torch.randint(0, sideX - size + 1, ())
# offsety = torch.randint(0, sideY - size + 1, ())
# cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
# cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
# cutout = transforms.Resize(size=(self.cut_size, self.cut_size))(input)
cutout = (self.av_pool(input) + self.max_pool(input))/2
cutouts.append(cutout)
batch = self.augs(torch.cat(cutouts, dim=0))
if self.noise_fac:
facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
batch = batch + facs * torch.randn_like(batch)
return batch
def load_vqgan_model(config_path, checkpoint_path):
config = OmegaConf.load(config_path)
if config.model.target == 'taming.models.vqgan.VQModel':
model = vqgan.VQModel(**config.model.params)
model.eval().requires_grad_(False)
model.init_from_ckpt(checkpoint_path)
elif config.model.target == 'taming.models.vqgan.GumbelVQ':
model = vqgan.GumbelVQ(**config.model.params)
model.eval().requires_grad_(False)
model.init_from_ckpt(checkpoint_path)
elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':
parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
parent_model.eval().requires_grad_(False)
parent_model.init_from_ckpt(checkpoint_path)
model = parent_model.first_stage_model
else:
raise ValueError(f'unknown model type: {config.model.target}')
del model.loss
return model
def resize_image(image, out_size):
ratio = image.size[0] / image.size[1]
area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
size = round((area * ratio)**0.5), round((area / ratio)**0.5)
return image.resize(size, Image.LANCZOS)
model_name = "vqgan_imagenet_f16_16384"
images_interval = 50
width = 280
height = 280
init_image = ""
seed = 42
args = argparse.Namespace(
noise_prompt_seeds=[],
noise_prompt_weights=[],
size=[width, height],
init_image=init_image,
init_weight=0.,
clip_model='ViT-B/32',
vqgan_config=f'{model_name}.yaml',
vqgan_checkpoint=f'{model_name}.ckpt',
step_size=0.15,
cutn=4,
cut_pow=1.,
display_freq=images_interval,
seed=seed,
)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)
def inference(text, seed, step_size, max_iterations, width, height, init_image, init_weight, target_images, cutn, cut_pow):
torch.cuda.empty_cache()
torch.cuda.memory_summary(device=None, abbreviated=False)
all_frames = []
size=[width, height]
texts = text
init_weight=init_weight
if init_image:
init_image = init_image.name
else:
init_image = ""
if target_images:
target_images = target_images.name
else:
target_images = ""
max_iterations = max_iterations
model_names={"vqgan_imagenet_f16_16384": 'ImageNet 16384',"vqgan_imagenet_f16_1024":"ImageNet 1024", 'vqgan_openimages_f16_8192':'OpenImages 8912',
"wikiart_1024":"WikiArt 1024", "wikiart_16384":"WikiArt 16384", "coco":"COCO-Stuff", "faceshq":"FacesHQ", "sflckr":"S-FLCKR"}
name_model = model_names[model_name]
if target_images == "None" or not target_images:
target_images = []
else:
target_images = target_images.split("|")
target_images = [image.strip() for image in target_images]
texts = [phrase.strip() for phrase in texts.split("|")]
if texts == ['']:
texts = []
from urllib.request import urlopen
if texts:
print('Using texts:', texts)
if target_images:
print('Using image prompts:', target_images)
if seed is None or seed == -1:
seed = torch.seed()
else:
seed = seed
torch.manual_seed(seed)
print('Using seed:', seed)
# clock=deepcopy(perceptor.visual.positional_embedding.data)
# perceptor.visual.positional_embedding.data = clock/clock.max()
# perceptor.visual.positional_embedding.data=clamp_with_grad(clock,0,1)
cut_size = perceptor.visual.input_resolution
f = 2**(model.decoder.num_resolutions - 1)
make_cutouts = MakeCutouts(cut_size, cutn, cut_pow=cut_pow)
toksX, toksY = size[0] // f, size[1] // f
sideX, sideY = toksX * f, toksY * f
if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':
e_dim = 256
n_toks = model.quantize.n_embed
z_min = model.quantize.embed.weight.min(dim=0).values[None, :, None, None]
z_max = model.quantize.embed.weight.max(dim=0).values[None, :, None, None]
else:
e_dim = model.quantize.e_dim
n_toks = model.quantize.n_e
z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
# z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
# z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
# normalize_imagenet = transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225])
if init_image:
if 'http' in init_image:
img = Image.open(urlopen(init_image))
else:
img = Image.open(init_image)
pil_image = img.convert('RGB')
pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
pil_tensor = TF.to_tensor(pil_image)
z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)
else:
one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float()
# z = one_hot @ model.quantize.embedding.weight
if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':
z = one_hot @ model.quantize.embed.weight
else:
z = one_hot @ model.quantize.embedding.weight
z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
z = torch.rand_like(z)*2
z_orig = z.clone()
z.requires_grad_(True)
opt = optim.Adam([z], lr=step_size)
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
pMs = []
for prompt in texts:
txt, weight, stop = parse_prompt(prompt)
embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
pMs.append(Prompt(embed, weight, stop).to(device))
for prompt in target_images:
path, weight, stop = parse_prompt(prompt)
img = Image.open(path)
pil_image = img.convert('RGB')
img = resize_image(pil_image, (sideX, sideY))
batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
embed = perceptor.encode_image(normalize(batch)).float()
pMs.append(Prompt(embed, weight, stop).to(device))
for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
gen = torch.Generator().manual_seed(seed)
embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
pMs.append(Prompt(embed, weight).to(device))
def synth(z):
if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':
z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim(3, 1)
else:
z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)
return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)
@torch.no_grad()
def checkin(i, losses):
losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')
out = synth(z)
# TF.to_pil_image(out[0].cpu()).save('progress.png')
# display.display(display.Image('progress.png'))
res = nvidia_smi.nvmlDeviceGetUtilizationRates(handle)
print(f'gpu: {res.gpu}%, gpu-mem: {res.memory}%')
def ascend_txt():
# global i
out = synth(z)
iii = perceptor.encode_image(normalize(make_cutouts(out))).float()
result = []
if init_weight:
result.append(F.mse_loss(z, z_orig) * init_weight / 2)
#result.append(F.mse_loss(z, torch.zeros_like(z_orig)) * ((1/torch.tensor(i*2 + 1))*init_weight) / 2)
for prompt in pMs:
result.append(prompt(iii))
img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]
img = np.transpose(img, (1, 2, 0))
# imageio.imwrite('./steps/' + str(i) + '.png', np.array(img))
img = Image.fromarray(img).convert('RGB')
all_frames.append(img)
return result, np.array(img)
def train(i):
opt.zero_grad()
lossAll, image = ascend_txt()
if i % args.display_freq == 0:
checkin(i, lossAll)
loss = sum(lossAll)
loss.backward()
opt.step()
with torch.no_grad():
z.copy_(z.maximum(z_min).minimum(z_max))
return image
i = 0
try:
with tqdm() as pbar:
while True:
image = train(i)
if i == max_iterations:
break
i += 1
pbar.update()
except KeyboardInterrupt:
pass
writer = imageio.get_writer('test.mp4', fps=20)
for im in all_frames:
writer.append_data(np.array(im))
writer.close()
# all_frames[0].save('out.gif',
# save_all=True, append_images=all_frames[1:], optimize=False, duration=80, loop=0)
return image, 'test.mp4'
def load_image( infilename ) :
img = Image.open( infilename )
img.load()
data = np.asarray( img, dtype="int32" )
return data
title = "VQGAN + CLIP"
description = "Gradio demo for VQGAN + CLIP. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'>Originally made by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). The original BigGAN+CLIP method was by https://twitter.com/advadnoun. Added some explanations and modifications by Eleiber#8347, pooling trick by Crimeacs#8222 (https://twitter.com/EarthML1) and the GUI was made with the help of Abulafia#3734. | <a href='https://colab.research.google.com/drive/1ZAus_gn2RhTZWzOWUpPERNC0Q8OhZRTZ'>Colab</a> | <a href='https://github.com/CompVis/taming-transformers'>Taming Transformers Github Repo</a> | <a href='https://github.com/openai/CLIP'>CLIP Github Repo</a> | Special thanks to BoneAmputee (https://twitter.com/BoneAmputee) for suggestions and advice</p>"
gr.Interface(
inference,
[gr.inputs.Textbox(label="Text Input"),
gr.inputs.Number(default=42, label="seed"),
gr.inputs.Slider(minimum=0.1, maximum=0.9, default=0.6, label='step size'),
gr.inputs.Slider(minimum=1, maximum=500, default=100, label='max iterations', step=1),
gr.inputs.Slider(minimum=200, maximum=600, default=256, label='width', step=1),
gr.inputs.Slider(minimum=200, maximum=600, default=256, label='height', step=1),
gr.inputs.Image(type="file", label="Initial Image (Optional)", optional=True),
gr.inputs.Slider(minimum=0.0, maximum=15.0, default=0.0, label='Initial Weight', step=1.0),
gr.inputs.Image(type="file", label="Target Image (Optional)", optional=True),
gr.inputs.Slider(minimum=1, maximum=40, default=1, label='cutn', step=1),
gr.inputs.Slider(minimum=1.0, maximum=40.0, default=1.0, label='cut_pow', step=1.0)
],
[gr.outputs.Image(type="numpy", label="Output Image"),gr.outputs.Video(label="Output Video")],
title=title,
description=description,
article=article,
examples=[
['a garden by james gurney',42,0.6, 100, 256, 256, 'garden.jpeg', 0.0, 'garden.jpeg',1,1.0],
['coral reef city artstationHQ',1000,0.6, 110, 200, 200, 'coralreef.jpeg', 0.0, 'coralreef.jpeg',1,1.0],
['a cabin in the mountains unreal engine',98,0.6, 120, 280, 280, 'cabin.jpeg', 0.0, 'cabin.jpeg',1,1.0]
]
).launch(enable_queue=True)
|