aaronb commited on
Commit
13bbe81
1 Parent(s): 0f4611e
Files changed (1) hide show
  1. drag_gan.py +2 -6
drag_gan.py CHANGED
@@ -168,7 +168,7 @@ def drag_gan(g_ema, latent: torch.Tensor, noise, F, handle_points, target_points
168
  points = []
169
  for i in range(x - d, x + d):
170
  for j in range(y - d, y + d):
171
- points.append(torch.tensor([i, j]).float().cuda())
172
  return points
173
 
174
  F0 = F.detach().clone()
@@ -213,11 +213,7 @@ def drag_gan(g_ema, latent: torch.Tensor, noise, F, handle_points, target_points
213
  miny = 1e9
214
  for qi in neighbor(int(handle_points[i][0]), int(handle_points[i][1]), r2):
215
  # f2 = F2[..., int(qi[0]), int(qi[1])]
216
- try:
217
- f2 = bilinear_interpolate_torch(F2, qi[0], qi[1])
218
- except:
219
- import ipdb
220
- ipdb.set_trace()
221
  v = torch.norm(f2 - f0, p=1)
222
  if v < minv:
223
  minv = v
168
  points = []
169
  for i in range(x - d, x + d):
170
  for j in range(y - d, y + d):
171
+ points.append(torch.tensor([i, j]).float().to(latent.device))
172
  return points
173
 
174
  F0 = F.detach().clone()
213
  miny = 1e9
214
  for qi in neighbor(int(handle_points[i][0]), int(handle_points[i][1]), r2):
215
  # f2 = F2[..., int(qi[0]), int(qi[1])]
216
+ f2 = bilinear_interpolate_torch(F2, qi[0], qi[1])
 
 
 
 
217
  v = torch.norm(f2 - f0, p=1)
218
  if v < minv:
219
  minv = v