Phil Sarkis commited on
Commit
65cbf6d
1 Parent(s): c7879ca
Files changed (1) hide show
  1. app.py +30 -24
app.py CHANGED
@@ -70,10 +70,18 @@ def tv_loss(input):
70
  x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
71
  y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
72
  return (x_diff**2 + y_diff**2).mean([1, 2, 3])
 
 
 
 
 
 
 
 
73
  def range_loss(input):
74
  return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])
75
 
76
- def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompts,timestep_respacing, cutn):
77
  # Model settings
78
  model_config = model_and_diffusion_defaults()
79
  model_config.update({
@@ -120,6 +128,7 @@ def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, r
120
  batch_size = 1
121
  clip_guidance_scale = clip_guidance_scale # Controls how much the image should look like the prompt.
122
  tv_scale = tv_scale # Controls the smoothness of the final output.
 
123
  range_scale = range_scale # Controls how far out of range RGB values are allowed to be.
124
  cutn = cutn
125
  n_batches = 1
@@ -160,26 +169,24 @@ def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, r
160
  init = init.resize((side_x, side_y), Image.LANCZOS)
161
  init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
162
  cur_t = None
163
- def cond_fn(x, t, y=None):
164
- with torch.enable_grad():
165
- x = x.detach().requires_grad_()
166
- n = x.shape[0]
167
- my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t
168
- out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y})
169
- fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
170
- x_in = out['pred_xstart'] * fac + x * (1 - fac)
171
- clip_in = normalize(make_cutouts(x_in.add(1).div(2)))
172
- image_embeds = clip_model.encode_image(clip_in).float()
173
- dists = spherical_dist_loss(image_embeds.unsqueeze(1), target_embeds.unsqueeze(0))
174
- dists = dists.view([cutn, n, -1])
175
- losses = dists.mul(weights).sum(2).mean(0)
176
- tv_losses = tv_loss(x_in)
177
- range_losses = range_loss(out['pred_xstart'])
178
- loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale
179
- if init is not None and init_scale:
180
- init_losses = lpips_model(x_in, init)
181
- loss = loss + init_losses.sum() * init_scale
182
- return -torch.autograd.grad(loss, x)[0]
183
  if model_config['timestep_respacing'].startswith('ddim'):
184
  sample_fn = diffusion.ddim_sample_loop_progressive
185
  else:
@@ -212,11 +219,10 @@ def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, r
212
  writer.append_data(np.array(im))
213
  writer.close()
214
  return img, 'video.mp4'
215
-
216
-
217
  title = "CLIP Guided Diffusion HQ"
218
  description = "Gradio demo for CLIP Guided Diffusion. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
219
  article = "<p style='text-align: center'> By Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). It uses OpenAI's 256x256 unconditional ImageNet diffusion model (https://github.com/openai/guided-diffusion) together with CLIP (https://github.com/openai/CLIP) to connect text prompts with images. | <a href='https://colab.research.google.com/drive/12a_Wrfi2_gwwAuN3VvMTwVMz9TfqctNj' target='_blank'>Colab</a></p>"
220
- iface = gr.Interface(inference, inputs=["text",gr.inputs.Image(type="file", label='initial image (optional)', optional=True),gr.inputs.Slider(minimum=0, maximum=45, step=1, default=10, label="skip_timesteps"), gr.inputs.Slider(minimum=0, maximum=3000, step=1, default=600, label="clip guidance scale (Controls how much the image should look like the prompt)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="tv_scale (Controls the smoothness of the final output)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="range_scale (Controls how far out of range RGB values are allowed to be)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="init_scale (This enhances the effect of the init image)"), gr.inputs.Number(default=0, label="Seed"), gr.inputs.Image(type="file", label='image prompt (optional)', optional=True), gr.inputs.Slider(minimum=50, maximum=500, step=1, default=50, label="timestep respacing"),gr.inputs.Slider(minimum=1, maximum=64, step=1, default=32, label="cutn")], outputs=["image","video"], title=title, description=description, article=article, examples=[["coral reef city by artistation artists", None, 0, 1000, 150, 50, 0, 0, None, 90, 32]],
221
  enable_queue=True)
222
  iface.launch()
 
70
  x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
71
  y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
72
  return (x_diff**2 + y_diff**2).mean([1, 2, 3])
73
+
74
+ def l1_loss(input):
75
+ """L1 total variation loss, as in Mahendran et al."""
76
+ input = F.pad(input, (0, 1, 0, 1), 'replicate')
77
+ x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
78
+ y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
79
+ return (torch.abs(x_diff**1) + torch.abs(y_diff**1)).mean([1, 2, 3])
80
+
81
  def range_loss(input):
82
  return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])
83
 
84
+ def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, l1_scale, range_scale, init_scale, seed, image_prompts,timestep_respacing, cutn):
85
  # Model settings
86
  model_config = model_and_diffusion_defaults()
87
  model_config.update({
 
128
  batch_size = 1
129
  clip_guidance_scale = clip_guidance_scale # Controls how much the image should look like the prompt.
130
  tv_scale = tv_scale # Controls the smoothness of the final output.
131
+ l1_scale = l1_scale
132
  range_scale = range_scale # Controls how far out of range RGB values are allowed to be.
133
  cutn = cutn
134
  n_batches = 1
 
169
  init = init.resize((side_x, side_y), Image.LANCZOS)
170
  init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
171
  cur_t = None
172
+
173
+ def cond_fn(x, t, out, y=None):
174
+ n = x.shape[0]
175
+ fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
176
+ x_in = out['pred_xstart'] * fac + x * (1 - fac)
177
+ clip_in = normalize(make_cutouts(x_in.add(1).div(2)))
178
+ image_embeds = clip_model.encode_image(clip_in).float()
179
+ dists = spherical_dist_loss(image_embeds.unsqueeze(1), target_embeds.unsqueeze(0))
180
+ dists = dists.view([cutn, n, -1])
181
+ losses = dists.mul(weights).sum(2).mean(0)
182
+ tv_losses = tv_loss(x_in)
183
+ range_losses = range_loss(out['pred_xstart'])
184
+ l1_losses = l1_loss(x_in)
185
+ loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale + l1_losses.sum() * l1_scale
186
+ if init is not None and init_scale:
187
+ init_losses = lpips_model(x_in, init)
188
+ loss = loss + init_losses.sum() * init_scale
189
+ return -torch.autograd.grad(loss, x)[0]
 
 
190
  if model_config['timestep_respacing'].startswith('ddim'):
191
  sample_fn = diffusion.ddim_sample_loop_progressive
192
  else:
 
219
  writer.append_data(np.array(im))
220
  writer.close()
221
  return img, 'video.mp4'
222
+
 
223
  title = "CLIP Guided Diffusion HQ"
224
  description = "Gradio demo for CLIP Guided Diffusion. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
225
  article = "<p style='text-align: center'> By Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). It uses OpenAI's 256x256 unconditional ImageNet diffusion model (https://github.com/openai/guided-diffusion) together with CLIP (https://github.com/openai/CLIP) to connect text prompts with images. | <a href='https://colab.research.google.com/drive/12a_Wrfi2_gwwAuN3VvMTwVMz9TfqctNj' target='_blank'>Colab</a></p>"
226
+ iface = gr.Interface(inference, inputs=["text",gr.inputs.Image(type="file", label='initial image (optional)', optional=True),gr.inputs.Slider(minimum=0, maximum=500, step=1, default=10, label="skip_timesteps"), gr.inputs.Slider(minimum=0, maximum=3000, step=1, default=600, label="clip guidance scale (Controls how much the image should look like the prompt)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="tv_scale (Controls the smoothness of the final output)"),gr.inputs.Slider(minimum=0, maximum=500, step=1, default=0, label="l1_scale (How much to punish for straying from init_image)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="range_scale (Controls how far out of range RGB values are allowed to be)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="init_scale (This enhances the effect of the init image)"), gr.inputs.Number(default=0, label="Seed"), gr.inputs.Image(type="file", label='image prompt (optional)', optional=True), gr.inputs.Slider(minimum=50, maximum=500, step=1, default=50, label="timestep respacing"),gr.inputs.Slider(minimum=1, maximum=64, step=1, default=32, label="cutn")], outputs=["image","video"], title=title, description=description, article=article, examples=[["coral reef city by artistation artists", None, 0, 1000, 150, 50, 0, 0, None, 90, 32]],
227
  enable_queue=True)
228
  iface.launch()