sxela commited on
Commit
732b8c4
β€’
1 Parent(s): e2a7aa6
Files changed (3) hide show
  1. app.py +227 -0
  2. packages.txt +1 -0
  3. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import gradio as gr
4
+ os.system('git clone https://github.com/openai/CLIP')
5
+ os.system('git clone https://github.com/crowsonkb/guided-diffusion')
6
+ os.system('pip install -e ./CLIP')
7
+ os.system('pip install -e ./guided-diffusion')
8
+ os.system('pip install lpips')
9
+ os.system("curl -OL 'https://github.com/Sxela/DiscoDiffusion-Warp/releases/download/v0.1.0/256x256_openai_comics_faces_by_alex_spirin_084000.pt'")
10
+
11
+
12
+
13
+
14
+ import io
15
+ import math
16
+ import sys
17
+ import lpips
18
+ from PIL import Image
19
+ import requests
20
+ import torch
21
+ from torch import nn
22
+ from torch.nn import functional as F
23
+ from torchvision import transforms
24
+ from torchvision.transforms import functional as TF
25
+ from tqdm.notebook import tqdm
26
+ sys.path.append('./CLIP')
27
+ sys.path.append('./guided-diffusion')
28
+ import clip
29
+ from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
30
+ import numpy as np
31
+ import imageio
32
+
33
+ torch.hub.download_url_to_file('https://images.pexels.com/photos/68767/divers-underwater-ocean-swim-68767.jpeg', 'face.jpeg')
34
+
35
+ def fetch(url_or_path):
36
+ if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
37
+ r = requests.get(url_or_path)
38
+ r.raise_for_status()
39
+ fd = io.BytesIO()
40
+ fd.write(r.content)
41
+ fd.seek(0)
42
+ return fd
43
+ return open(url_or_path, 'rb')
44
+ def parse_prompt(prompt):
45
+ if prompt.startswith('http://') or prompt.startswith('https://'):
46
+ vals = prompt.rsplit(':', 2)
47
+ vals = [vals[0] + ':' + vals[1], *vals[2:]]
48
+ else:
49
+ vals = prompt.rsplit(':', 1)
50
+ vals = vals + ['', '1'][len(vals):]
51
+ return vals[0], float(vals[1])
52
+ class MakeCutouts(nn.Module):
53
+ def __init__(self, cut_size, cutn, cut_pow=1.):
54
+ super().__init__()
55
+ self.cut_size = cut_size
56
+ self.cutn = cutn
57
+ self.cut_pow = cut_pow
58
+ def forward(self, input):
59
+ sideY, sideX = input.shape[2:4]
60
+ max_size = min(sideX, sideY)
61
+ min_size = min(sideX, sideY, self.cut_size)
62
+ cutouts = []
63
+ for _ in range(self.cutn):
64
+ size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
65
+ offsetx = torch.randint(0, sideX - size + 1, ())
66
+ offsety = torch.randint(0, sideY - size + 1, ())
67
+ cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
68
+ cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
69
+ return torch.cat(cutouts)
70
+ def spherical_dist_loss(x, y):
71
+ x = F.normalize(x, dim=-1)
72
+ y = F.normalize(y, dim=-1)
73
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
74
+ def tv_loss(input):
75
+ """L2 total variation loss, as in Mahendran et al."""
76
+ input = F.pad(input, (0, 1, 0, 1), 'replicate')
77
+ x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
78
+ y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
79
+ return (x_diff**2 + y_diff**2).mean([1, 2, 3])
80
+ def range_loss(input):
81
+ return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])
82
+
83
+ def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompts,timestep_respacing, cutn):
84
+ # Model settings
85
+ model_config = model_and_diffusion_defaults()
86
+ model_config.update({
87
+ 'attention_resolutions': '16',
88
+ 'class_cond': False,
89
+ 'diffusion_steps': 1000,
90
+ 'rescale_timesteps': True,
91
+ 'timestep_respacing': str(timestep_respacing),
92
+ 'image_size': 256,
93
+ 'learn_sigma': True,
94
+ 'noise_schedule': 'linear',
95
+ 'num_channels': 128,
96
+ 'num_heads': 1,
97
+ 'num_res_blocks': 2,
98
+ 'use_checkpoint': True,
99
+ 'use_fp16': True,
100
+ 'use_scale_shift_norm': False,
101
+ })
102
+
103
+ # Load models
104
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
105
+ print('Using device:', device)
106
+ model, diffusion = create_model_and_diffusion(**model_config)
107
+ model.load_state_dict(torch.load('256x256_openai_comics_faces_by_alex_spirin_084000.pt', map_location='cpu'))
108
+ model.requires_grad_(False).eval().to(device)
109
+ for name, param in model.named_parameters():
110
+ if 'qkv' in name or 'norm' in name or 'proj' in name:
111
+ param.requires_grad_()
112
+ if model_config['use_fp16']:
113
+ model.convert_to_fp16()
114
+ clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device)
115
+ clip_size = clip_model.visual.input_resolution
116
+ normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
117
+ std=[0.26862954, 0.26130258, 0.27577711])
118
+ lpips_model = lpips.LPIPS(net='vgg').to(device)
119
+
120
+ #def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompt):
121
+ all_frames = []
122
+ prompts = [text]
123
+ if image_prompts:
124
+ image_prompts = [image_prompts.name]
125
+ else:
126
+ image_prompts = []
127
+ batch_size = 1
128
+ clip_guidance_scale = clip_guidance_scale # Controls how much the image should look like the prompt.
129
+ tv_scale = tv_scale # Controls the smoothness of the final output.
130
+ range_scale = range_scale # Controls how far out of range RGB values are allowed to be.
131
+ cutn = cutn
132
+ n_batches = 1
133
+ if init_image:
134
+ init_image = init_image.name
135
+ else:
136
+ init_image = None # This can be an URL or Colab local path and must be in quotes.
137
+ skip_timesteps = skip_timesteps # This needs to be between approx. 200 and 500 when using an init image.
138
+ # Higher values make the output look more like the init.
139
+ init_scale = init_scale # This enhances the effect of the init image, a good value is 1000.
140
+ seed = seed
141
+
142
+ if seed is not None:
143
+ torch.manual_seed(seed)
144
+ make_cutouts = MakeCutouts(clip_size, cutn)
145
+ side_x = side_y = model_config['image_size']
146
+ target_embeds, weights = [], []
147
+ for prompt in prompts:
148
+ txt, weight = parse_prompt(prompt)
149
+ target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
150
+ weights.append(weight)
151
+ for prompt in image_prompts:
152
+ path, weight = parse_prompt(prompt)
153
+ img = Image.open(fetch(path)).convert('RGB')
154
+ img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS)
155
+ batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
156
+ embed = clip_model.encode_image(normalize(batch)).float()
157
+ target_embeds.append(embed)
158
+ weights.extend([weight / cutn] * cutn)
159
+ target_embeds = torch.cat(target_embeds)
160
+ weights = torch.tensor(weights, device=device)
161
+ if weights.sum().abs() < 1e-3:
162
+ raise RuntimeError('The weights must not sum to 0.')
163
+ weights /= weights.sum().abs()
164
+ init = None
165
+ if init_image is not None:
166
+ init = Image.open(fetch(init_image)).convert('RGB')
167
+ init = init.resize((side_x, side_y), Image.LANCZOS)
168
+ init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
169
+ cur_t = None
170
+ def cond_fn(x, t, y=None):
171
+ with torch.enable_grad():
172
+ x = x.detach().requires_grad_()
173
+ n = x.shape[0]
174
+ my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t
175
+ out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y})
176
+ fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
177
+ x_in = out['pred_xstart'] * fac + x * (1 - fac)
178
+ clip_in = normalize(make_cutouts(x_in.add(1).div(2)))
179
+ image_embeds = clip_model.encode_image(clip_in).float()
180
+ dists = spherical_dist_loss(image_embeds.unsqueeze(1), target_embeds.unsqueeze(0))
181
+ dists = dists.view([cutn, n, -1])
182
+ losses = dists.mul(weights).sum(2).mean(0)
183
+ tv_losses = tv_loss(x_in)
184
+ range_losses = range_loss(out['pred_xstart'])
185
+ loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale
186
+ if init is not None and init_scale:
187
+ init_losses = lpips_model(x_in, init)
188
+ loss = loss + init_losses.sum() * init_scale
189
+ return -torch.autograd.grad(loss, x)[0]
190
+ if model_config['timestep_respacing'].startswith('ddim'):
191
+ sample_fn = diffusion.ddim_sample_loop_progressive
192
+ else:
193
+ sample_fn = diffusion.p_sample_loop_progressive
194
+ for i in range(n_batches):
195
+ cur_t = diffusion.num_timesteps - skip_timesteps - 1
196
+ samples = sample_fn(
197
+ model,
198
+ (batch_size, 3, side_y, side_x),
199
+ clip_denoised=False,
200
+ model_kwargs={},
201
+ cond_fn=cond_fn,
202
+ progress=True,
203
+ skip_timesteps=skip_timesteps,
204
+ init_image=init,
205
+ randomize_class=True,
206
+ )
207
+ for j, sample in enumerate(samples):
208
+ cur_t -= 1
209
+ if j % 1 == 0 or cur_t == -1:
210
+ print()
211
+ for k, image in enumerate(sample['pred_xstart']):
212
+ #filename = f'progress_{i * batch_size + k:05}.png'
213
+ img = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))
214
+ all_frames.append(img)
215
+ tqdm.write(f'Batch {i}, step {j}, output {k}:')
216
+ #display.display(display.Image(filename))
217
+ writer = imageio.get_writer('video.mp4', fps=5)
218
+ for im in all_frames:
219
+ writer.append_data(np.array(im))
220
+ writer.close()
221
+ return img, 'video.mp4'
222
+
223
+ title = "CLIP Guided Diffusion Faces Model"
224
+ description = "Gradio demo for CLIP Guided Diffusion. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
225
+ article = "<p style='text-align: center'> Comics faces model by <a href='https://linktree/devdef'>Alex Spirin</a>. Based on the original <a href='https://huggingface.co/spaces/EleutherAI/clip-guided-diffusion'>CLIP Guided Diffusion Space</a> by akhaliq / Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings) | <a href='https://github.com/Sxela/DiscoDiffusion-Warp/blob/main/Disco_Diffusion_v5_2_Warp_custom_model.ipynb' target='_blank'>Colab</a></p>"
226
+ iface = gr.Interface(inference, inputs=["text",gr.inputs.Image(type="file", label='initial image (optional)', optional=True),gr.inputs.Slider(minimum=0, maximum=45, step=1, default=10, label="skip_timesteps"), gr.inputs.Slider(minimum=0, maximum=3000, step=1, default=600, label="clip guidance scale (Controls how much the image should look like the prompt)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="tv_scale (Controls the smoothness of the final output)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="range_scale (Controls how far out of range RGB values are allowed to be)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="init_scale (This enhances the effect of the init image)"), gr.inputs.Number(default=0, label="Seed"), gr.inputs.Image(type="file", label='image prompt (optional)', optional=True), gr.inputs.Slider(minimum=50, maximum=500, step=1, default=50, label="timestep respacing"),gr.inputs.Slider(minimum=1, maximum=64, step=1, default=32, label="cutn")], outputs=["image","video"], title=title, description=description, article=article, examples=[["Brad Pitt", "face.jpeg", 0, 1000, 150, 50, 0, 0, "face.jpeg", 90, 32]])
227
+ iface.launch()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ kornia
4
+ tqdm
5
+ clip-anytorch
6
+ requests
7
+ lpips
8
+ numpy
9
+ imageio
10
+ imageio-ffmpeg