multimodalart HF staff commited on
Commit
9cd412c
1 Parent(s): cff8aa8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -87,9 +87,9 @@ def run_all(prompt, steps, n_images, weight, clip_guided):
87
  if min(pred.shape[2:4]) < 256:
88
  pred = F.interpolate(pred, scale_factor=2, mode='bilinear', align_corners=False)
89
  clip_in = normalize(make_cutouts((pred + 1) / 2))
90
- image_embeds = clip_model.encode_image(clip_in).view([args.cutn, x.shape[0], -1])
91
  losses = spherical_dist_loss(image_embeds, clip_embed[None])
92
- loss = losses.mean(0).sum() * args.clip_guidance_scale
93
  grad = -torch.autograd.grad(loss, x)[0]
94
  return grad
95
 
 
87
  if min(pred.shape[2:4]) < 256:
88
  pred = F.interpolate(pred, scale_factor=2, mode='bilinear', align_corners=False)
89
  clip_in = normalize(make_cutouts((pred + 1) / 2))
90
+ image_embeds = clip_model.encode_image(clip_in).view([16, x.shape[0], -1])
91
  losses = spherical_dist_loss(image_embeds, clip_embed[None])
92
+ loss = losses.mean(0).sum() * 500.
93
  grad = -torch.autograd.grad(loss, x)[0]
94
  return grad
95