myniu commited on
Commit
2935911
1 Parent(s): 69f142b
Files changed (1) hide show
  1. app.py +28 -30
app.py CHANGED
@@ -216,20 +216,18 @@ def visualize_drag_v2(background_image_path, splited_tracks, width, height):
216
 
217
  class Drag:
218
  @spaces.GPU(duration=200)
219
- def __init__(self, device, height, width):
220
- self.device = device
221
 
222
  svd_ckpt = "ckpts/stable-video-diffusion-img2vid-xt-1-1"
223
  mofa_ckpt = "ckpts/controlnet"
224
 
225
- self.device = 'cuda'
226
  self.weight_dtype = torch.float16
227
 
228
  self.pipeline, self.cmp = init_models(
229
  svd_ckpt,
230
  mofa_ckpt,
231
  weight_dtype=self.weight_dtype,
232
- device=self.device
233
  )
234
 
235
  self.height = height
@@ -304,12 +302,12 @@ class Drag:
304
 
305
  print('start diffusion process...')
306
 
307
- input_drag_384_inmask = input_drag_384_inmask.to(self.device, dtype=self.weight_dtype)
308
- mask_384_inmask = mask_384_inmask.to(self.device, dtype=self.weight_dtype)
309
- input_drag_384_outmask = input_drag_384_outmask.to(self.device, dtype=self.weight_dtype)
310
- mask_384_outmask = mask_384_outmask.to(self.device, dtype=self.weight_dtype)
311
 
312
- input_first_frame_384 = input_first_frame_384.to(self.device, dtype=self.weight_dtype)
313
 
314
  if in_mask_flag:
315
  flow_inmask = self.get_flow(
@@ -318,7 +316,7 @@ class Drag:
318
  )
319
  else:
320
  fb, fl = mask_384_inmask.shape[:2]
321
- flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype)
322
 
323
  if out_mask_flag:
324
  flow_outmask = self.get_flow(
@@ -327,7 +325,7 @@ class Drag:
327
  )
328
  else:
329
  fb, fl = mask_384_outmask.shape[:2]
330
- flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype)
331
 
332
  inmask_no_zero = (flow_inmask != 0).all(dim=2)
333
  inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask)
@@ -426,10 +424,10 @@ class Drag:
426
  np.zeros((25 - 1, 384, 384, 2)), \
427
  np.zeros((25 - 1, 384, 384))
428
 
429
- input_drag_384_inmask = torch.from_numpy(input_drag_384_inmask).unsqueeze(0).to(self.device) # [1, 13, h, w, 2]
430
- input_mask_384_inmask = torch.from_numpy(input_mask_384_inmask).unsqueeze(0).to(self.device) # [1, 13, h, w]
431
- input_drag_384_outmask = torch.from_numpy(input_drag_384_outmask).unsqueeze(0).to(self.device) # [1, 13, h, w, 2]
432
- input_mask_384_outmask = torch.from_numpy(input_mask_384_outmask).unsqueeze(0).to(self.device) # [1, 13, h, w]
433
 
434
  first_frames_transform = transforms.Compose([
435
  lambda x: Image.fromarray(x),
@@ -437,7 +435,7 @@ class Drag:
437
  ])
438
 
439
  input_first_frame = image2arr(first_frame_path)
440
- input_first_frame = repeat(first_frames_transform(input_first_frame), 'c h w -> b c h w', b=1).to(self.device)
441
 
442
  seed = 42
443
  num_frames = 25
@@ -452,12 +450,12 @@ class Drag:
452
  input_drag_384_outmask = input_drag_384_outmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384]
453
  mask_384_outmask = input_mask_384_outmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384]
454
 
455
- input_drag_384_inmask = input_drag_384_inmask.to(self.device, dtype=self.weight_dtype)
456
- mask_384_inmask = mask_384_inmask.to(self.device, dtype=self.weight_dtype)
457
- input_drag_384_outmask = input_drag_384_outmask.to(self.device, dtype=self.weight_dtype)
458
- mask_384_outmask = mask_384_outmask.to(self.device, dtype=self.weight_dtype)
459
 
460
- input_first_frame_384 = input_first_frame_384.to(self.device, dtype=self.weight_dtype)
461
 
462
  if in_mask_flag:
463
  flow_inmask = self.get_flow(
@@ -466,7 +464,7 @@ class Drag:
466
  )
467
  else:
468
  fb, fl = mask_384_inmask.shape[:2]
469
- flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype)
470
 
471
  if out_mask_flag:
472
  flow_outmask = self.get_flow(
@@ -475,7 +473,7 @@ class Drag:
475
  )
476
  else:
477
  fb, fl = mask_384_outmask.shape[:2]
478
- flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype)
479
 
480
  inmask_no_zero = (flow_inmask != 0).all(dim=2)
481
  inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask)
@@ -566,17 +564,17 @@ class Drag:
566
  for i in tqdm(range(num_inference)):
567
  if not outputs:
568
  first_frames = image2arr(first_frame_path)
569
- first_frames = repeat(first_frames_transform(first_frames), 'c h w -> b c h w', b=inference_batch_size).to(self.device)
570
  else:
571
  first_frames = outputs['logits_imgs'][:, -1]
572
 
573
 
574
  outputs = self.forward_sample(
575
- input_drag_384_inmask.to(self.device),
576
- input_drag_384_outmask.to(self.device),
577
- first_frames.to(self.device),
578
- input_mask_384_inmask.to(self.device),
579
- input_mask_384_outmask.to(self.device),
580
  in_mask_flag,
581
  out_mask_flag,
582
  motion_brush_mask_384,
@@ -656,7 +654,7 @@ with gr.Blocks() as demo:
656
  )
657
 
658
  target_size = 512
659
- DragNUWA_net = Drag("cuda:0", target_size, target_size)
660
  first_frame_path = gr.State()
661
  tracking_points = gr.State([])
662
  motion_brush_points = gr.State([])
 
216
 
217
  class Drag:
218
  @spaces.GPU(duration=200)
219
+ def __init__(self, height, width):
 
220
 
221
  svd_ckpt = "ckpts/stable-video-diffusion-img2vid-xt-1-1"
222
  mofa_ckpt = "ckpts/controlnet"
223
 
 
224
  self.weight_dtype = torch.float16
225
 
226
  self.pipeline, self.cmp = init_models(
227
  svd_ckpt,
228
  mofa_ckpt,
229
  weight_dtype=self.weight_dtype,
230
+ device='cuda'
231
  )
232
 
233
  self.height = height
 
302
 
303
  print('start diffusion process...')
304
 
305
+ input_drag_384_inmask = input_drag_384_inmask.to('cuda', dtype=self.weight_dtype)
306
+ mask_384_inmask = mask_384_inmask.to('cuda', dtype=self.weight_dtype)
307
+ input_drag_384_outmask = input_drag_384_outmask.to('cuda', dtype=self.weight_dtype)
308
+ mask_384_outmask = mask_384_outmask.to('cuda', dtype=self.weight_dtype)
309
 
310
+ input_first_frame_384 = input_first_frame_384.to('cuda', dtype=self.weight_dtype)
311
 
312
  if in_mask_flag:
313
  flow_inmask = self.get_flow(
 
316
  )
317
  else:
318
  fb, fl = mask_384_inmask.shape[:2]
319
+ flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to('cuda', dtype=self.weight_dtype)
320
 
321
  if out_mask_flag:
322
  flow_outmask = self.get_flow(
 
325
  )
326
  else:
327
  fb, fl = mask_384_outmask.shape[:2]
328
+ flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to('cuda', dtype=self.weight_dtype)
329
 
330
  inmask_no_zero = (flow_inmask != 0).all(dim=2)
331
  inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask)
 
424
  np.zeros((25 - 1, 384, 384, 2)), \
425
  np.zeros((25 - 1, 384, 384))
426
 
427
+ input_drag_384_inmask = torch.from_numpy(input_drag_384_inmask).unsqueeze(0).to('cuda') # [1, 13, h, w, 2]
428
+ input_mask_384_inmask = torch.from_numpy(input_mask_384_inmask).unsqueeze(0).to('cuda') # [1, 13, h, w]
429
+ input_drag_384_outmask = torch.from_numpy(input_drag_384_outmask).unsqueeze(0).to('cuda') # [1, 13, h, w, 2]
430
+ input_mask_384_outmask = torch.from_numpy(input_mask_384_outmask).unsqueeze(0).to('cuda') # [1, 13, h, w]
431
 
432
  first_frames_transform = transforms.Compose([
433
  lambda x: Image.fromarray(x),
 
435
  ])
436
 
437
  input_first_frame = image2arr(first_frame_path)
438
+ input_first_frame = repeat(first_frames_transform(input_first_frame), 'c h w -> b c h w', b=1).to('cuda')
439
 
440
  seed = 42
441
  num_frames = 25
 
450
  input_drag_384_outmask = input_drag_384_outmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384]
451
  mask_384_outmask = input_mask_384_outmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384]
452
 
453
+ input_drag_384_inmask = input_drag_384_inmask.to('cuda', dtype=self.weight_dtype)
454
+ mask_384_inmask = mask_384_inmask.to('cuda', dtype=self.weight_dtype)
455
+ input_drag_384_outmask = input_drag_384_outmask.to('cuda', dtype=self.weight_dtype)
456
+ mask_384_outmask = mask_384_outmask.to('cuda', dtype=self.weight_dtype)
457
 
458
+ input_first_frame_384 = input_first_frame_384.to('cuda', dtype=self.weight_dtype)
459
 
460
  if in_mask_flag:
461
  flow_inmask = self.get_flow(
 
464
  )
465
  else:
466
  fb, fl = mask_384_inmask.shape[:2]
467
+ flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to('cuda', dtype=self.weight_dtype)
468
 
469
  if out_mask_flag:
470
  flow_outmask = self.get_flow(
 
473
  )
474
  else:
475
  fb, fl = mask_384_outmask.shape[:2]
476
+ flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to('cuda', dtype=self.weight_dtype)
477
 
478
  inmask_no_zero = (flow_inmask != 0).all(dim=2)
479
  inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask)
 
564
  for i in tqdm(range(num_inference)):
565
  if not outputs:
566
  first_frames = image2arr(first_frame_path)
567
+ first_frames = repeat(first_frames_transform(first_frames), 'c h w -> b c h w', b=inference_batch_size).to('cuda')
568
  else:
569
  first_frames = outputs['logits_imgs'][:, -1]
570
 
571
 
572
  outputs = self.forward_sample(
573
+ input_drag_384_inmask.to('cuda'),
574
+ input_drag_384_outmask.to('cuda'),
575
+ first_frames.to('cuda'),
576
+ input_mask_384_inmask.to('cuda'),
577
+ input_mask_384_outmask.to('cuda'),
578
  in_mask_flag,
579
  out_mask_flag,
580
  motion_brush_mask_384,
 
654
  )
655
 
656
  target_size = 512
657
+ DragNUWA_net = Drag(target_size, target_size)
658
  first_frame_path = gr.State()
659
  tracking_points = gr.State([])
660
  motion_brush_points = gr.State([])