Ahsen Khaliq commited on
Commit
06db207
1 Parent(s): d33b0b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -59
app.py CHANGED
@@ -3,17 +3,20 @@ import sys
3
  os.system('pip install gradio==2.3.0a0')
4
  import gradio as gr
5
  os.system('git clone https://github.com/openai/CLIP')
6
- os.system('git clone https://github.com/openai/guided-diffusion')
7
  os.system('pip install -e ./CLIP')
8
  os.system('pip install -e ./guided-diffusion')
9
- os.system('pip install kornia')
10
- os.system("curl -OL 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'")
11
  # Imports
 
 
12
  import math
13
  import sys
14
  #from IPython import display
15
- from kornia import augmentation, filters
16
  from PIL import Image
 
17
  import torch
18
  from torch import nn
19
  from torch.nn import functional as F
@@ -26,14 +29,64 @@ import clip
26
  from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
27
  import numpy as np
28
  import imageio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # Model settings
30
  model_config = model_and_diffusion_defaults()
31
  model_config.update({
32
  'attention_resolutions': '32, 16, 8',
33
  'class_cond': False,
34
  'diffusion_steps': 1000,
35
- 'rescale_timesteps': False,
36
- 'timestep_respacing': '500',
 
37
  'image_size': 256,
38
  'learn_sigma': True,
39
  'noise_schedule': 'linear',
@@ -44,76 +97,113 @@ model_config.update({
44
  'use_fp16': True,
45
  'use_scale_shift_norm': True,
46
  })
47
- # Load models and define necessary functions
48
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
49
  print('Using device:', device)
50
  model, diffusion = create_model_and_diffusion(**model_config)
51
  model.load_state_dict(torch.load('256x256_diffusion_uncond.pt', map_location='cpu'))
52
- model.eval().requires_grad_(False).to(device)
 
 
 
53
  if model_config['use_fp16']:
54
  model.convert_to_fp16()
55
  clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device)
56
  clip_size = clip_model.visual.input_resolution
57
  normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
58
  std=[0.26862954, 0.26130258, 0.27577711])
59
- def spherical_dist_loss(x, y):
60
- x = F.normalize(x, dim=-1)
61
- y = F.normalize(y, dim=-1)
62
- return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
63
 
64
  def inference(text):
65
  all_frames = []
66
- prompt = text
 
67
  batch_size = 1
68
- clip_guidance_scale = 2750
 
 
 
 
 
 
 
 
69
  seed = 0
70
-
71
  if seed is not None:
72
  torch.manual_seed(seed)
73
-
74
- text_embed = clip_model.encode_text(clip.tokenize(prompt).to(device)).float()
75
-
76
- translate_by = 8 / clip_size
77
- if translate_by:
78
- aug = augmentation.RandomAffine(0, (translate_by, translate_by),
79
- padding_mode='border', p=1)
80
- else:
81
- aug = nn.Identity()
82
-
83
- cur_t = diffusion.num_timesteps - 1
84
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def cond_fn(x, t, y=None):
86
  with torch.enable_grad():
87
- x_in = x.detach().requires_grad_()
88
- sigma = min(24, diffusion.sqrt_recipm1_alphas_cumprod[cur_t] / 4)
89
- kernel_size = max(math.ceil((sigma * 6 + 1) / 2) * 2 - 1, 3)
90
- x_blur = filters.gaussian_blur2d(x_in, (kernel_size, kernel_size), (sigma, sigma))
91
- clip_in = F.interpolate(aug(x_blur.add(1).div(2)), (clip_size, clip_size),
92
- mode='bilinear', align_corners=False)
93
- image_embed = clip_model.encode_image(normalize(clip_in)).float()
94
- losses = spherical_dist_loss(image_embed, text_embed)
95
- grad = -torch.autograd.grad(losses.sum(), x_in)[0]
96
- return grad * clip_guidance_scale
97
-
98
- samples = diffusion.p_sample_loop_progressive(
99
- model,
100
- (batch_size, 3, model_config['image_size'], model_config['image_size']),
101
- clip_denoised=True,
102
- model_kwargs={},
103
- cond_fn=cond_fn,
104
- progress=True,
105
- )
106
-
107
- for i, sample in enumerate(samples):
108
- cur_t -= 1
109
- if i % 1 == 0 or cur_t == -1:
110
- print()
111
- for j, image in enumerate(sample['pred_xstart']):
112
- #filename = f'progress_{j:05}.png'
113
- img = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))
114
- all_frames.append(img)
115
- tqdm.write(f'Step {i}, output {j}:')
116
- #display.display(display.Image(filename))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  writer = imageio.get_writer('video.mp4', fps=20)
118
  for im in all_frames:
119
  writer.append_data(np.array(im))
@@ -123,7 +213,6 @@ def inference(text):
123
  title = "CLIP Guided Diffusion"
124
  description = "Gradio demo for CLIP Guided Diffusion. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
125
  article = "<p style='text-align: center'>By Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). It uses OpenAI's 256x256 unconditional ImageNet diffusion model (https://github.com/openai/guided-diffusion) together with CLIP (https://github.com/openai/CLIP) to connect text prompts with images. | <a href='https://colab.research.google.com/drive/1ED6_MYVXTApBHzQObUPaaMolgf9hZOOF' target='_blank'>Colab</a></p>"
126
-
127
  iface = gr.Interface(inference, inputs="text", outputs=["image","video"], title=title, description=description, article=article, examples=[["coral reef city by artistation artists"]],
128
  enable_queue=True)
129
- iface.launch()
 
3
  os.system('pip install gradio==2.3.0a0')
4
  import gradio as gr
5
  os.system('git clone https://github.com/openai/CLIP')
6
+ os.system('git clone https://github.com/crowsonkb/guided-diffusion')
7
  os.system('pip install -e ./CLIP')
8
  os.system('pip install -e ./guided-diffusion')
9
+ os.system('pip install lpips')
10
+ os.system("!curl -OL 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'")
11
  # Imports
12
+ #import gc
13
+ import io
14
  import math
15
  import sys
16
  #from IPython import display
17
+ import lpips
18
  from PIL import Image
19
+ import requests
20
  import torch
21
  from torch import nn
22
  from torch.nn import functional as F
 
29
  from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
30
  import numpy as np
31
  import imageio
32
+ # Define necessary functions
33
+ def fetch(url_or_path):
34
+ if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
35
+ r = requests.get(url_or_path)
36
+ r.raise_for_status()
37
+ fd = io.BytesIO()
38
+ fd.write(r.content)
39
+ fd.seek(0)
40
+ return fd
41
+ return open(url_or_path, 'rb')
42
+ def parse_prompt(prompt):
43
+ if prompt.startswith('http://') or prompt.startswith('https://'):
44
+ vals = prompt.rsplit(':', 2)
45
+ vals = [vals[0] + ':' + vals[1], *vals[2:]]
46
+ else:
47
+ vals = prompt.rsplit(':', 1)
48
+ vals = vals + ['', '1'][len(vals):]
49
+ return vals[0], float(vals[1])
50
+ class MakeCutouts(nn.Module):
51
+ def __init__(self, cut_size, cutn, cut_pow=1.):
52
+ super().__init__()
53
+ self.cut_size = cut_size
54
+ self.cutn = cutn
55
+ self.cut_pow = cut_pow
56
+ def forward(self, input):
57
+ sideY, sideX = input.shape[2:4]
58
+ max_size = min(sideX, sideY)
59
+ min_size = min(sideX, sideY, self.cut_size)
60
+ cutouts = []
61
+ for _ in range(self.cutn):
62
+ size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
63
+ offsetx = torch.randint(0, sideX - size + 1, ())
64
+ offsety = torch.randint(0, sideY - size + 1, ())
65
+ cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
66
+ cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
67
+ return torch.cat(cutouts)
68
+ def spherical_dist_loss(x, y):
69
+ x = F.normalize(x, dim=-1)
70
+ y = F.normalize(y, dim=-1)
71
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
72
+ def tv_loss(input):
73
+ """L2 total variation loss, as in Mahendran et al."""
74
+ input = F.pad(input, (0, 1, 0, 1), 'replicate')
75
+ x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
76
+ y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
77
+ return (x_diff**2 + y_diff**2).mean([1, 2, 3])
78
+ def range_loss(input):
79
+ return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])
80
+
81
  # Model settings
82
  model_config = model_and_diffusion_defaults()
83
  model_config.update({
84
  'attention_resolutions': '32, 16, 8',
85
  'class_cond': False,
86
  'diffusion_steps': 1000,
87
+ 'rescale_timesteps': True,
88
+ 'timestep_respacing': '1000', # Modify this value to decrease the number of
89
+ # timesteps.
90
  'image_size': 256,
91
  'learn_sigma': True,
92
  'noise_schedule': 'linear',
 
97
  'use_fp16': True,
98
  'use_scale_shift_norm': True,
99
  })
100
+ # Load models
101
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
102
  print('Using device:', device)
103
  model, diffusion = create_model_and_diffusion(**model_config)
104
  model.load_state_dict(torch.load('256x256_diffusion_uncond.pt', map_location='cpu'))
105
+ model.requires_grad_(False).eval().to(device)
106
+ for name, param in model.named_parameters():
107
+ if 'qkv' in name or 'norm' in name or 'proj' in name:
108
+ param.requires_grad_()
109
  if model_config['use_fp16']:
110
  model.convert_to_fp16()
111
  clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device)
112
  clip_size = clip_model.visual.input_resolution
113
  normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
114
  std=[0.26862954, 0.26130258, 0.27577711])
115
+ lpips_model = lpips.LPIPS(net='vgg').to(device)
 
 
 
116
 
117
  def inference(text):
118
  all_frames = []
119
+ prompts = [text]
120
+ image_prompts = []
121
  batch_size = 1
122
+ clip_guidance_scale = 1000 # Controls how much the image should look like the prompt.
123
+ tv_scale = 150 # Controls the smoothness of the final output.
124
+ range_scale = 50 # Controls how far out of range RGB values are allowed to be.
125
+ cutn = 16
126
+ n_batches = 1
127
+ init_image = None # This can be an URL or Colab local path and must be in quotes.
128
+ skip_timesteps = 0 # This needs to be between approx. 200 and 500 when using an init image.
129
+ # Higher values make the output look more like the init.
130
+ init_scale = 0 # This enhances the effect of the init image, a good value is 1000.
131
  seed = 0
132
+
133
  if seed is not None:
134
  torch.manual_seed(seed)
135
+ make_cutouts = MakeCutouts(clip_size, cutn)
136
+ side_x = side_y = model_config['image_size']
137
+ target_embeds, weights = [], []
138
+ for prompt in prompts:
139
+ txt, weight = parse_prompt(prompt)
140
+ target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
141
+ weights.append(weight)
142
+ for prompt in image_prompts:
143
+ path, weight = parse_prompt(prompt)
144
+ img = Image.open(fetch(path)).convert('RGB')
145
+ img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS)
146
+ batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
147
+ embed = clip_model.encode_image(normalize(batch)).float()
148
+ target_embeds.append(embed)
149
+ weights.extend([weight / cutn] * cutn)
150
+ target_embeds = torch.cat(target_embeds)
151
+ weights = torch.tensor(weights, device=device)
152
+ if weights.sum().abs() < 1e-3:
153
+ raise RuntimeError('The weights must not sum to 0.')
154
+ weights /= weights.sum().abs()
155
+ init = None
156
+ if init_image is not None:
157
+ init = Image.open(fetch(init_image)).convert('RGB')
158
+ init = init.resize((side_x, side_y), Image.LANCZOS)
159
+ init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
160
+ cur_t = None
161
  def cond_fn(x, t, y=None):
162
  with torch.enable_grad():
163
+ x = x.detach().requires_grad_()
164
+ n = x.shape[0]
165
+ my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t
166
+ out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y})
167
+ fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
168
+ x_in = out['pred_xstart'] * fac + x * (1 - fac)
169
+ clip_in = normalize(make_cutouts(x_in.add(1).div(2)))
170
+ image_embeds = clip_model.encode_image(clip_in).float()
171
+ dists = spherical_dist_loss(image_embeds.unsqueeze(1), target_embeds.unsqueeze(0))
172
+ dists = dists.view([cutn, n, -1])
173
+ losses = dists.mul(weights).sum(2).mean(0)
174
+ tv_losses = tv_loss(x_in)
175
+ range_losses = range_loss(out['pred_xstart'])
176
+ loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale
177
+ if init is not None and init_scale:
178
+ init_losses = lpips_model(x_in, init)
179
+ loss = loss + init_losses.sum() * init_scale
180
+ return -torch.autograd.grad(loss, x)[0]
181
+ if model_config['timestep_respacing'].startswith('ddim'):
182
+ sample_fn = diffusion.ddim_sample_loop_progressive
183
+ else:
184
+ sample_fn = diffusion.p_sample_loop_progressive
185
+ for i in range(n_batches):
186
+ cur_t = diffusion.num_timesteps - skip_timesteps - 1
187
+ samples = sample_fn(
188
+ model,
189
+ (batch_size, 3, side_y, side_x),
190
+ clip_denoised=False,
191
+ model_kwargs={},
192
+ cond_fn=cond_fn,
193
+ progress=True,
194
+ skip_timesteps=skip_timesteps,
195
+ init_image=init,
196
+ randomize_class=True,
197
+ )
198
+ for j, sample in enumerate(samples):
199
+ cur_t -= 1
200
+ if j % 1 == 0 or cur_t == -1:
201
+ print()
202
+ for k, image in enumerate(sample['pred_xstart']):
203
+ #filename = f'progress_{i * batch_size + k:05}.png'
204
+ img = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))
205
+ tqdm.write(f'Batch {i}, step {j}, output {k}:')
206
+ #display.display(display.Image(filename))
207
  writer = imageio.get_writer('video.mp4', fps=20)
208
  for im in all_frames:
209
  writer.append_data(np.array(im))
 
213
  title = "CLIP Guided Diffusion"
214
  description = "Gradio demo for CLIP Guided Diffusion. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
215
  article = "<p style='text-align: center'>By Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). It uses OpenAI's 256x256 unconditional ImageNet diffusion model (https://github.com/openai/guided-diffusion) together with CLIP (https://github.com/openai/CLIP) to connect text prompts with images. | <a href='https://colab.research.google.com/drive/1ED6_MYVXTApBHzQObUPaaMolgf9hZOOF' target='_blank'>Colab</a></p>"
 
216
  iface = gr.Interface(inference, inputs="text", outputs=["image","video"], title=title, description=description, article=article, examples=[["coral reef city by artistation artists"]],
217
  enable_queue=True)
218
+ iface.launch()