Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import gradio as gr
|
4 |
+
os.system('git clone https://github.com/openai/CLIP')
|
5 |
+
os.system('git clone https://github.com/crowsonkb/guided-diffusion')
|
6 |
+
os.system('pip install -e ./CLIP')
|
7 |
+
os.system('pip install -e ./guided-diffusion')
|
8 |
+
os.system('pip install lpips')
|
9 |
+
os.system("curl -OL 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'")
|
10 |
+
import io
|
11 |
+
import math
|
12 |
+
import sys
|
13 |
+
import lpips
|
14 |
+
from PIL import Image
|
15 |
+
import requests
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
from torch.nn import functional as F
|
19 |
+
from torchvision import transforms
|
20 |
+
from torchvision.transforms import functional as TF
|
21 |
+
from tqdm.notebook import tqdm
|
22 |
+
sys.path.append('./CLIP')
|
23 |
+
sys.path.append('./guided-diffusion')
|
24 |
+
import clip
|
25 |
+
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
|
26 |
+
import numpy as np
|
27 |
+
import imageio
|
28 |
+
def fetch(url_or_path):
|
29 |
+
if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
|
30 |
+
r = requests.get(url_or_path)
|
31 |
+
r.raise_for_status()
|
32 |
+
fd = io.BytesIO()
|
33 |
+
fd.write(r.content)
|
34 |
+
fd.seek(0)
|
35 |
+
return fd
|
36 |
+
return open(url_or_path, 'rb')
|
37 |
+
def parse_prompt(prompt):
|
38 |
+
if prompt.startswith('http://') or prompt.startswith('https://'):
|
39 |
+
vals = prompt.rsplit(':', 2)
|
40 |
+
vals = [vals[0] + ':' + vals[1], *vals[2:]]
|
41 |
+
else:
|
42 |
+
vals = prompt.rsplit(':', 1)
|
43 |
+
vals = vals + ['', '1'][len(vals):]
|
44 |
+
return vals[0], float(vals[1])
|
45 |
+
class MakeCutouts(nn.Module):
|
46 |
+
def __init__(self, cut_size, cutn, cut_pow=1.):
|
47 |
+
super().__init__()
|
48 |
+
self.cut_size = cut_size
|
49 |
+
self.cutn = cutn
|
50 |
+
self.cut_pow = cut_pow
|
51 |
+
def forward(self, input):
|
52 |
+
sideY, sideX = input.shape[2:4]
|
53 |
+
max_size = min(sideX, sideY)
|
54 |
+
min_size = min(sideX, sideY, self.cut_size)
|
55 |
+
cutouts = []
|
56 |
+
for _ in range(self.cutn):
|
57 |
+
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
|
58 |
+
offsetx = torch.randint(0, sideX - size + 1, ())
|
59 |
+
offsety = torch.randint(0, sideY - size + 1, ())
|
60 |
+
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
|
61 |
+
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
|
62 |
+
return torch.cat(cutouts)
|
63 |
+
def spherical_dist_loss(x, y):
|
64 |
+
x = F.normalize(x, dim=-1)
|
65 |
+
y = F.normalize(y, dim=-1)
|
66 |
+
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
|
67 |
+
def tv_loss(input):
|
68 |
+
"""L2 total variation loss, as in Mahendran et al."""
|
69 |
+
input = F.pad(input, (0, 1, 0, 1), 'replicate')
|
70 |
+
x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
|
71 |
+
y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
|
72 |
+
return (x_diff**2 + y_diff**2).mean([1, 2, 3])
|
73 |
+
|
74 |
+
def l1_loss(input):
|
75 |
+
"""L1 total variation loss, as in Mahendran et al."""
|
76 |
+
input = F.pad(input, (0, 1, 0, 1), 'replicate')
|
77 |
+
x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
|
78 |
+
y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
|
79 |
+
return (torch.abs(x_diff**1) + torch.abs(y_diff**1)).mean([1, 2, 3])
|
80 |
+
|
81 |
+
def range_loss(input):
|
82 |
+
return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])
|
83 |
+
|
84 |
+
def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, l1_scale, range_scale, init_scale, seed, image_prompts,timestep_respacing, cutn):
|
85 |
+
# Model settings
|
86 |
+
model_config = model_and_diffusion_defaults()
|
87 |
+
model_config.update({
|
88 |
+
'attention_resolutions': '32, 16, 8',
|
89 |
+
'class_cond': False,
|
90 |
+
'diffusion_steps': 1000,
|
91 |
+
'rescale_timesteps': True,
|
92 |
+
'timestep_respacing': str(timestep_respacing), # Modify this value to decrease the number of
|
93 |
+
# timesteps.
|
94 |
+
'image_size': 256,
|
95 |
+
'learn_sigma': True,
|
96 |
+
'noise_schedule': 'linear',
|
97 |
+
'num_channels': 256,
|
98 |
+
'num_head_channels': 64,
|
99 |
+
'num_res_blocks': 2,
|
100 |
+
'resblock_updown': True,
|
101 |
+
'use_fp16': True,
|
102 |
+
'use_scale_shift_norm': True,
|
103 |
+
})
|
104 |
+
# Load models
|
105 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
106 |
+
print('Using device:', device)
|
107 |
+
model, diffusion = create_model_and_diffusion(**model_config)
|
108 |
+
model.load_state_dict(torch.load('256x256_diffusion_uncond.pt', map_location='cpu'))
|
109 |
+
model.requires_grad_(False).eval().to(device)
|
110 |
+
for name, param in model.named_parameters():
|
111 |
+
if 'qkv' in name or 'norm' in name or 'proj' in name:
|
112 |
+
param.requires_grad_()
|
113 |
+
if model_config['use_fp16']:
|
114 |
+
model.convert_to_fp16()
|
115 |
+
clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device)
|
116 |
+
clip_size = clip_model.visual.input_resolution
|
117 |
+
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
118 |
+
std=[0.26862954, 0.26130258, 0.27577711])
|
119 |
+
lpips_model = lpips.LPIPS(net='vgg').to(device)
|
120 |
+
|
121 |
+
#def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompt):
|
122 |
+
all_frames = []
|
123 |
+
prompts = [text]
|
124 |
+
if image_prompts:
|
125 |
+
image_prompts = [image_prompts.name]
|
126 |
+
else:
|
127 |
+
image_prompts = []
|
128 |
+
batch_size = 1
|
129 |
+
clip_guidance_scale = clip_guidance_scale # Controls how much the image should look like the prompt.
|
130 |
+
tv_scale = tv_scale # Controls the smoothness of the final output.
|
131 |
+
l1_scale = l1_scale
|
132 |
+
range_scale = range_scale # Controls how far out of range RGB values are allowed to be.
|
133 |
+
cutn = cutn
|
134 |
+
n_batches = 1
|
135 |
+
if init_image:
|
136 |
+
init_image = init_image.name
|
137 |
+
else:
|
138 |
+
init_image = None # This can be an URL or Colab local path and must be in quotes.
|
139 |
+
skip_timesteps = skip_timesteps # This needs to be between approx. 200 and 500 when using an init image.
|
140 |
+
# Higher values make the output look more like the init.
|
141 |
+
init_scale = init_scale # This enhances the effect of the init image, a good value is 1000.
|
142 |
+
seed = seed
|
143 |
+
|
144 |
+
if seed is not None:
|
145 |
+
torch.manual_seed(seed)
|
146 |
+
make_cutouts = MakeCutouts(clip_size, cutn)
|
147 |
+
side_x = side_y = model_config['image_size']
|
148 |
+
target_embeds, weights = [], []
|
149 |
+
for prompt in prompts:
|
150 |
+
txt, weight = parse_prompt(prompt)
|
151 |
+
target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
|
152 |
+
weights.append(weight)
|
153 |
+
for prompt in image_prompts:
|
154 |
+
path, weight = parse_prompt(prompt)
|
155 |
+
img = Image.open(fetch(path)).convert('RGB')
|
156 |
+
img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS)
|
157 |
+
batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
|
158 |
+
embed = clip_model.encode_image(normalize(batch)).float()
|
159 |
+
target_embeds.append(embed)
|
160 |
+
weights.extend([weight / cutn] * cutn)
|
161 |
+
target_embeds = torch.cat(target_embeds)
|
162 |
+
weights = torch.tensor(weights, device=device)
|
163 |
+
if weights.sum().abs() < 1e-3:
|
164 |
+
raise RuntimeError('The weights must not sum to 0.')
|
165 |
+
weights /= weights.sum().abs()
|
166 |
+
init = None
|
167 |
+
if init_image is not None:
|
168 |
+
init = Image.open(fetch(init_image)).convert('RGB')
|
169 |
+
init = init.resize((side_x, side_y), Image.LANCZOS)
|
170 |
+
init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
|
171 |
+
cur_t = None
|
172 |
+
|
173 |
+
def cond_fn(x, t, out, y=None):
|
174 |
+
n = x.shape[0]
|
175 |
+
fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
|
176 |
+
x_in = out['pred_xstart'] * fac + x * (1 - fac)
|
177 |
+
clip_in = normalize(make_cutouts(x_in.add(1).div(2)))
|
178 |
+
image_embeds = clip_model.encode_image(clip_in).float()
|
179 |
+
dists = spherical_dist_loss(image_embeds.unsqueeze(1), target_embeds.unsqueeze(0))
|
180 |
+
dists = dists.view([cutn, n, -1])
|
181 |
+
losses = dists.mul(weights).sum(2).mean(0)
|
182 |
+
tv_losses = tv_loss(x_in)
|
183 |
+
range_losses = range_loss(out['pred_xstart'])
|
184 |
+
l1_losses = l1_loss(x_in)
|
185 |
+
loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale + l1_losses.sum() * l1_scale
|
186 |
+
if init is not None and init_scale:
|
187 |
+
init_losses = lpips_model(x_in, init)
|
188 |
+
loss = loss + init_losses.sum() * init_scale
|
189 |
+
return -torch.autograd.grad(loss, x)[0]
|
190 |
+
if model_config['timestep_respacing'].startswith('ddim'):
|
191 |
+
sample_fn = diffusion.ddim_sample_loop_progressive
|
192 |
+
else:
|
193 |
+
sample_fn = diffusion.p_sample_loop_progressive
|
194 |
+
for i in range(n_batches):
|
195 |
+
cur_t = diffusion.num_timesteps - skip_timesteps - 1
|
196 |
+
samples = sample_fn(
|
197 |
+
model,
|
198 |
+
(batch_size, 3, side_y, side_x),
|
199 |
+
clip_denoised=False,
|
200 |
+
model_kwargs={},
|
201 |
+
cond_fn=cond_fn,
|
202 |
+
progress=True,
|
203 |
+
skip_timesteps=skip_timesteps,
|
204 |
+
init_image=init,
|
205 |
+
randomize_class=True,
|
206 |
+
)
|
207 |
+
for j, sample in enumerate(samples):
|
208 |
+
cur_t -= 1
|
209 |
+
if j % 1 == 0 or cur_t == -1:
|
210 |
+
print()
|
211 |
+
for k, image in enumerate(sample['pred_xstart']):
|
212 |
+
#filename = f'progress_{i * batch_size + k:05}.png'
|
213 |
+
img = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))
|
214 |
+
all_frames.append(img)
|
215 |
+
tqdm.write(f'Batch {i}, step {j}, output {k}:')
|
216 |
+
#display.display(display.Image(filename))
|
217 |
+
writer = imageio.get_writer('video.mp4', fps=5)
|
218 |
+
for im in all_frames:
|
219 |
+
writer.append_data(np.array(im))
|
220 |
+
writer.close()
|
221 |
+
return img, 'video.mp4'
|
222 |
+
|
223 |
+
title = "CLIP Guided Diffusion HQ"
|
224 |
+
description = "Gradio demo for CLIP Guided Diffusion. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
|
225 |
+
article = "<p style='text-align: center'> By Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). It uses OpenAI's 256x256 unconditional ImageNet diffusion model (https://github.com/openai/guided-diffusion) together with CLIP (https://github.com/openai/CLIP) to connect text prompts with images. | <a href='https://colab.research.google.com/drive/12a_Wrfi2_gwwAuN3VvMTwVMz9TfqctNj' target='_blank'>Colab</a></p>"
|
226 |
+
iface = gr.Interface(inference, inputs=["text",gr.inputs.Image(type="file", label='initial image (optional)', optional=True),gr.inputs.Slider(minimum=0, maximum=45, step=1, default=10, label="skip_timesteps"), gr.inputs.Slider(minimum=0, maximum=3000, step=1, default=600, label="clip guidance scale (Controls how much the image should look like the prompt)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="tv_scale (Controls the smoothness of the final output)"),gr.inputs.Slider(minimum=0, maximum=500, step=1, default=0, label="l1_scale (How much to punish for straying from init_image)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="range_scale (Controls how far out of range RGB values are allowed to be)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="init_scale (This enhances the effect of the init image)"), gr.inputs.Number(default=0, label="Seed"), gr.inputs.Image(type="file", label='image prompt (optional)', optional=True), gr.inputs.Slider(minimum=50, maximum=500, step=1, default=50, label="timestep respacing"),gr.inputs.Slider(minimum=1, maximum=64, step=1, default=32, label="cutn")], outputs=["image","video"], title=title, description=description, article=article, examples=[["coral reef city by artistation artists", None, 0, 1000, 150, 50, 0, 0, None, 90, 32]],
|
227 |
+
enable_queue=True)
|
228 |
+
iface.launch()
|