akhaliq's picture
akhaliq HF staff
cache and queue enabled in spaces by default
ab18492
raw history blame
No virus
11.5 kB
import os
import sys
import gradio as gr
os.system('git clone https://github.com/openai/CLIP')
os.system('git clone https://github.com/crowsonkb/guided-diffusion')
os.system('pip install -e ./CLIP')
os.system('pip install -e ./guided-diffusion')
os.system('pip install lpips')
os.system("curl -OL 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'")
import io
import math
import sys
import lpips
from PIL import Image
import requests
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm
sys.path.append('./CLIP')
sys.path.append('./guided-diffusion')
import clip
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
import numpy as np
import imageio
torch.hub.download_url_to_file('https://images.pexels.com/photos/68767/divers-underwater-ocean-swim-68767.jpeg', 'coralreef.jpeg')
def fetch(url_or_path):
if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
r = requests.get(url_or_path)
r.raise_for_status()
fd = io.BytesIO()
fd.write(r.content)
fd.seek(0)
return fd
return open(url_or_path, 'rb')
def parse_prompt(prompt):
if prompt.startswith('http://') or prompt.startswith('https://'):
vals = prompt.rsplit(':', 2)
vals = [vals[0] + ':' + vals[1], *vals[2:]]
else:
vals = prompt.rsplit(':', 1)
vals = vals + ['', '1'][len(vals):]
return vals[0], float(vals[1])
class MakeCutouts(nn.Module):
def __init__(self, cut_size, cutn, cut_pow=1.):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
self.cut_pow = cut_pow
def forward(self, input):
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
min_size = min(sideX, sideY, self.cut_size)
cutouts = []
for _ in range(self.cutn):
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
return torch.cat(cutouts)
def spherical_dist_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
def tv_loss(input):
"""L2 total variation loss, as in Mahendran et al."""
input = F.pad(input, (0, 1, 0, 1), 'replicate')
x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
return (x_diff**2 + y_diff**2).mean([1, 2, 3])
def range_loss(input):
return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])
def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompts,timestep_respacing, cutn):
# Model settings
model_config = model_and_diffusion_defaults()
model_config.update({
'attention_resolutions': '32, 16, 8',
'class_cond': False,
'diffusion_steps': 1000,
'rescale_timesteps': True,
'timestep_respacing': str(timestep_respacing), # Modify this value to decrease the number of
# timesteps.
'image_size': 256,
'learn_sigma': True,
'noise_schedule': 'linear',
'num_channels': 256,
'num_head_channels': 64,
'num_res_blocks': 2,
'resblock_updown': True,
'use_fp16': True,
'use_scale_shift_norm': True,
})
# Load models
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
model, diffusion = create_model_and_diffusion(**model_config)
model.load_state_dict(torch.load('256x256_diffusion_uncond.pt', map_location='cpu'))
model.requires_grad_(False).eval().to(device)
for name, param in model.named_parameters():
if 'qkv' in name or 'norm' in name or 'proj' in name:
param.requires_grad_()
if model_config['use_fp16']:
model.convert_to_fp16()
clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device)
clip_size = clip_model.visual.input_resolution
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
lpips_model = lpips.LPIPS(net='vgg').to(device)
#def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompt):
all_frames = []
prompts = [text]
if image_prompts:
image_prompts = [image_prompts.name]
else:
image_prompts = []
batch_size = 1
clip_guidance_scale = clip_guidance_scale # Controls how much the image should look like the prompt.
tv_scale = tv_scale # Controls the smoothness of the final output.
range_scale = range_scale # Controls how far out of range RGB values are allowed to be.
cutn = cutn
n_batches = 1
if init_image:
init_image = init_image.name
else:
init_image = None # This can be an URL or Colab local path and must be in quotes.
skip_timesteps = skip_timesteps # This needs to be between approx. 200 and 500 when using an init image.
# Higher values make the output look more like the init.
init_scale = init_scale # This enhances the effect of the init image, a good value is 1000.
seed = seed
if seed is not None:
torch.manual_seed(seed)
make_cutouts = MakeCutouts(clip_size, cutn)
side_x = side_y = model_config['image_size']
target_embeds, weights = [], []
for prompt in prompts:
txt, weight = parse_prompt(prompt)
target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
weights.append(weight)
for prompt in image_prompts:
path, weight = parse_prompt(prompt)
img = Image.open(fetch(path)).convert('RGB')
img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS)
batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
embed = clip_model.encode_image(normalize(batch)).float()
target_embeds.append(embed)
weights.extend([weight / cutn] * cutn)
target_embeds = torch.cat(target_embeds)
weights = torch.tensor(weights, device=device)
if weights.sum().abs() < 1e-3:
raise RuntimeError('The weights must not sum to 0.')
weights /= weights.sum().abs()
init = None
if init_image is not None:
init = Image.open(fetch(init_image)).convert('RGB')
init = init.resize((side_x, side_y), Image.LANCZOS)
init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
cur_t = None
def cond_fn(x, t, y=None):
with torch.enable_grad():
x = x.detach().requires_grad_()
n = x.shape[0]
my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t
out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y})
fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
x_in = out['pred_xstart'] * fac + x * (1 - fac)
clip_in = normalize(make_cutouts(x_in.add(1).div(2)))
image_embeds = clip_model.encode_image(clip_in).float()
dists = spherical_dist_loss(image_embeds.unsqueeze(1), target_embeds.unsqueeze(0))
dists = dists.view([cutn, n, -1])
losses = dists.mul(weights).sum(2).mean(0)
tv_losses = tv_loss(x_in)
range_losses = range_loss(out['pred_xstart'])
loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale
if init is not None and init_scale:
init_losses = lpips_model(x_in, init)
loss = loss + init_losses.sum() * init_scale
return -torch.autograd.grad(loss, x)[0]
if model_config['timestep_respacing'].startswith('ddim'):
sample_fn = diffusion.ddim_sample_loop_progressive
else:
sample_fn = diffusion.p_sample_loop_progressive
for i in range(n_batches):
cur_t = diffusion.num_timesteps - skip_timesteps - 1
samples = sample_fn(
model,
(batch_size, 3, side_y, side_x),
clip_denoised=False,
model_kwargs={},
cond_fn=cond_fn,
progress=True,
skip_timesteps=skip_timesteps,
init_image=init,
randomize_class=True,
)
for j, sample in enumerate(samples):
cur_t -= 1
if j % 1 == 0 or cur_t == -1:
print()
for k, image in enumerate(sample['pred_xstart']):
#filename = f'progress_{i * batch_size + k:05}.png'
img = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))
all_frames.append(img)
tqdm.write(f'Batch {i}, step {j}, output {k}:')
#display.display(display.Image(filename))
writer = imageio.get_writer('video.mp4', fps=5)
for im in all_frames:
writer.append_data(np.array(im))
writer.close()
return img, 'video.mp4'
title = "CLIP Guided Diffusion HQ"
description = "Gradio demo for CLIP Guided Diffusion. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'> By Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). It uses OpenAI's 256x256 unconditional ImageNet diffusion model (https://github.com/openai/guided-diffusion) together with CLIP (https://github.com/openai/CLIP) to connect text prompts with images. | <a href='https://colab.research.google.com/drive/12a_Wrfi2_gwwAuN3VvMTwVMz9TfqctNj' target='_blank'>Colab</a></p>"
iface = gr.Interface(inference, inputs=["text",gr.inputs.Image(type="file", label='initial image (optional)', optional=True),gr.inputs.Slider(minimum=0, maximum=45, step=1, default=10, label="skip_timesteps"), gr.inputs.Slider(minimum=0, maximum=3000, step=1, default=600, label="clip guidance scale (Controls how much the image should look like the prompt)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="tv_scale (Controls the smoothness of the final output)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="range_scale (Controls how far out of range RGB values are allowed to be)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="init_scale (This enhances the effect of the init image)"), gr.inputs.Number(default=0, label="Seed"), gr.inputs.Image(type="file", label='image prompt (optional)', optional=True), gr.inputs.Slider(minimum=50, maximum=500, step=1, default=50, label="timestep respacing"),gr.inputs.Slider(minimum=1, maximum=64, step=1, default=32, label="cutn")], outputs=["image","video"], title=title, description=description, article=article, examples=[["coral reef city by artistation artists", "coralreef.jpeg", 0, 1000, 150, 50, 0, 0, "coralreef.jpeg", 90, 32]])
iface.launch()