Runtime error
Runtime error
File size: 11,592 Bytes
ca389f6 cde81bb 0958aff d61c863 cde81bb 06db207 cde81bb 06db207 3576c12 14e4fc4 9507462 593d5d0 06db207 cde81bb 06db207 cde81bb 06db207 cde81bb 3ab7435 c7307af 06db207 641acff f5dff55 cde81bb f5dff55 3ab7435 06db207 f5dff55 45c31f1 f5dff55 cde81bb f619e7d f829d0d 641acff 06db207 c6b75cf d30b2a2 06db207 f829d0d 06db207 cde81bb 06db207 cde81bb 06db207 1f3e0d6 06db207 f619e7d 3ab7435 075cc29 369e3bd bd66b34 1f3e0d6 c7307af 2c83e65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import os
import sys
os.system("pip install gradio==2.8.0b5")
import gradio as gr
os.system('git clone')
os.system('git clone')
os.system('pip install -e ./CLIP')
os.system('pip install -e ./guided-diffusion')
os.system('pip install lpips')
os.system("curl -OL ''")
import io
import math
import sys
import lpips
from PIL import Image
import requests
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm
import clip
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
import numpy as np
import imageio
torch.hub.download_url_to_file('', 'coralreef.jpeg')
def fetch(url_or_path):
if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
r = requests.get(url_or_path)
fd = io.BytesIO()
return fd
return open(url_or_path, 'rb')
def parse_prompt(prompt):
if prompt.startswith('http://') or prompt.startswith('https://'):
vals = prompt.rsplit(':', 2)
vals = [vals[0] + ':' + vals[1], *vals[2:]]
vals = prompt.rsplit(':', 1)
vals = vals + ['', '1'][len(vals):]
return vals[0], float(vals[1])
class MakeCutouts(nn.Module):
def __init__(self, cut_size, cutn, cut_pow=1.):
self.cut_size = cut_size
self.cutn = cutn
self.cut_pow = cut_pow
def forward(self, input):
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
min_size = min(sideX, sideY, self.cut_size)
cutouts = []
for _ in range(self.cutn):
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
def spherical_dist_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
def tv_loss(input):
"""L2 total variation loss, as in Mahendran et al."""
input = F.pad(input, (0, 1, 0, 1), 'replicate')
x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
return (x_diff**2 + y_diff**2).mean([1, 2, 3])
def range_loss(input):
return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])
def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompts,timestep_respacing, cutn):
# Model settings
model_config = model_and_diffusion_defaults()
'attention_resolutions': '32, 16, 8',
'class_cond': False,
'diffusion_steps': 1000,
'rescale_timesteps': True,
'timestep_respacing': str(timestep_respacing), # Modify this value to decrease the number of
# timesteps.
'image_size': 256,
'learn_sigma': True,
'noise_schedule': 'linear',
'num_channels': 256,
'num_head_channels': 64,
'num_res_blocks': 2,
'resblock_updown': True,
'use_fp16': True,
'use_scale_shift_norm': True,
# Load models
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
model, diffusion = create_model_and_diffusion(**model_config)
model.load_state_dict(torch.load('', map_location='cpu'))
for name, param in model.named_parameters():
if 'qkv' in name or 'norm' in name or 'proj' in name:
if model_config['use_fp16']:
clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device)
clip_size = clip_model.visual.input_resolution
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
lpips_model = lpips.LPIPS(net='vgg').to(device)
#def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompt):
all_frames = []
prompts = [text]
if image_prompts:
image_prompts = []
image_prompts = []
batch_size = 1
clip_guidance_scale = clip_guidance_scale # Controls how much the image should look like the prompt.
tv_scale = tv_scale # Controls the smoothness of the final output.
range_scale = range_scale # Controls how far out of range RGB values are allowed to be.
cutn = cutn
n_batches = 1
if init_image:
init_image =
init_image = None # This can be an URL or Colab local path and must be in quotes.
skip_timesteps = skip_timesteps # This needs to be between approx. 200 and 500 when using an init image.
# Higher values make the output look more like the init.
init_scale = init_scale # This enhances the effect of the init image, a good value is 1000.
seed = seed
if seed is not None:
make_cutouts = MakeCutouts(clip_size, cutn)
side_x = side_y = model_config['image_size']
target_embeds, weights = [], []
for prompt in prompts:
txt, weight = parse_prompt(prompt)
for prompt in image_prompts:
path, weight = parse_prompt(prompt)
img ='RGB')
img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS)
batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
embed = clip_model.encode_image(normalize(batch)).float()
weights.extend([weight / cutn] * cutn)
target_embeds =
weights = torch.tensor(weights, device=device)
if weights.sum().abs() < 1e-3:
raise RuntimeError('The weights must not sum to 0.')
weights /= weights.sum().abs()
init = None
if init_image is not None:
init ='RGB')
init = init.resize((side_x, side_y), Image.LANCZOS)
init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
cur_t = None
def cond_fn(x, t, y=None):
with torch.enable_grad():
x = x.detach().requires_grad_()
n = x.shape[0]
my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t
out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y})
fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
x_in = out['pred_xstart'] * fac + x * (1 - fac)
clip_in = normalize(make_cutouts(x_in.add(1).div(2)))
image_embeds = clip_model.encode_image(clip_in).float()
dists = spherical_dist_loss(image_embeds.unsqueeze(1), target_embeds.unsqueeze(0))
dists = dists.view([cutn, n, -1])
losses = dists.mul(weights).sum(2).mean(0)
tv_losses = tv_loss(x_in)
range_losses = range_loss(out['pred_xstart'])
loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale
if init is not None and init_scale:
init_losses = lpips_model(x_in, init)
loss = loss + init_losses.sum() * init_scale
return -torch.autograd.grad(loss, x)[0]
if model_config['timestep_respacing'].startswith('ddim'):
sample_fn = diffusion.ddim_sample_loop_progressive
sample_fn = diffusion.p_sample_loop_progressive
for i in range(n_batches):
cur_t = diffusion.num_timesteps - skip_timesteps - 1
samples = sample_fn(
(batch_size, 3, side_y, side_x),
for j, sample in enumerate(samples):
cur_t -= 1
if j % 1 == 0 or cur_t == -1:
for k, image in enumerate(sample['pred_xstart']):
#filename = f'progress_{i * batch_size + k:05}.png'
img = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))
tqdm.write(f'Batch {i}, step {j}, output {k}:')
writer = imageio.get_writer('video.mp4', fps=5)
for im in all_frames:
return img, 'video.mp4'
title = "CLIP Guided Diffusion HQ"
description = "Gradio demo for CLIP Guided Diffusion. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'> By Katherine Crowson (, It uses OpenAI's 256x256 unconditional ImageNet diffusion model ( together with CLIP ( to connect text prompts with images. | <a href='' target='_blank'>Colab</a></p>"
iface = gr.Interface(inference, inputs=["text",gr.inputs.Image(type="file", label='initial image (optional)', optional=True),gr.inputs.Slider(minimum=0, maximum=45, step=1, default=10, label="skip_timesteps"), gr.inputs.Slider(minimum=0, maximum=3000, step=1, default=600, label="clip guidance scale (Controls how much the image should look like the prompt)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="tv_scale (Controls the smoothness of the final output)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="range_scale (Controls how far out of range RGB values are allowed to be)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="init_scale (This enhances the effect of the init image)"), gr.inputs.Number(default=0, label="Seed"), gr.inputs.Image(type="file", label='image prompt (optional)', optional=True), gr.inputs.Slider(minimum=50, maximum=500, step=1, default=50, label="timestep respacing"),gr.inputs.Slider(minimum=1, maximum=64, step=1, default=32, label="cutn")], outputs=["image","video"], title=title, description=description, article=article, examples=[["coral reef city by artistation artists", "coralreef.jpeg", 0, 1000, 150, 50, 0, 0, "coralreef.jpeg", 90, 32]])
iface.launch(enable_queue=True,cache_examples=True) |