myniu commited on
Commit
bf1ebc4
1 Parent(s): 2935911
Files changed (1) hide show
  1. app.py +15 -17
app.py CHANGED
@@ -221,12 +221,10 @@ class Drag:
221
  svd_ckpt = "ckpts/stable-video-diffusion-img2vid-xt-1-1"
222
  mofa_ckpt = "ckpts/controlnet"
223
 
224
- self.weight_dtype = torch.float16
225
-
226
  self.pipeline, self.cmp = init_models(
227
  svd_ckpt,
228
  mofa_ckpt,
229
- weight_dtype=self.weight_dtype,
230
  device='cuda'
231
  )
232
 
@@ -302,12 +300,12 @@ class Drag:
302
 
303
  print('start diffusion process...')
304
 
305
- input_drag_384_inmask = input_drag_384_inmask.to('cuda', dtype=self.weight_dtype)
306
- mask_384_inmask = mask_384_inmask.to('cuda', dtype=self.weight_dtype)
307
- input_drag_384_outmask = input_drag_384_outmask.to('cuda', dtype=self.weight_dtype)
308
- mask_384_outmask = mask_384_outmask.to('cuda', dtype=self.weight_dtype)
309
 
310
- input_first_frame_384 = input_first_frame_384.to('cuda', dtype=self.weight_dtype)
311
 
312
  if in_mask_flag:
313
  flow_inmask = self.get_flow(
@@ -316,7 +314,7 @@ class Drag:
316
  )
317
  else:
318
  fb, fl = mask_384_inmask.shape[:2]
319
- flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to('cuda', dtype=self.weight_dtype)
320
 
321
  if out_mask_flag:
322
  flow_outmask = self.get_flow(
@@ -325,7 +323,7 @@ class Drag:
325
  )
326
  else:
327
  fb, fl = mask_384_outmask.shape[:2]
328
- flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to('cuda', dtype=self.weight_dtype)
329
 
330
  inmask_no_zero = (flow_inmask != 0).all(dim=2)
331
  inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask)
@@ -450,12 +448,12 @@ class Drag:
450
  input_drag_384_outmask = input_drag_384_outmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384]
451
  mask_384_outmask = input_mask_384_outmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384]
452
 
453
- input_drag_384_inmask = input_drag_384_inmask.to('cuda', dtype=self.weight_dtype)
454
- mask_384_inmask = mask_384_inmask.to('cuda', dtype=self.weight_dtype)
455
- input_drag_384_outmask = input_drag_384_outmask.to('cuda', dtype=self.weight_dtype)
456
- mask_384_outmask = mask_384_outmask.to('cuda', dtype=self.weight_dtype)
457
 
458
- input_first_frame_384 = input_first_frame_384.to('cuda', dtype=self.weight_dtype)
459
 
460
  if in_mask_flag:
461
  flow_inmask = self.get_flow(
@@ -464,7 +462,7 @@ class Drag:
464
  )
465
  else:
466
  fb, fl = mask_384_inmask.shape[:2]
467
- flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to('cuda', dtype=self.weight_dtype)
468
 
469
  if out_mask_flag:
470
  flow_outmask = self.get_flow(
@@ -473,7 +471,7 @@ class Drag:
473
  )
474
  else:
475
  fb, fl = mask_384_outmask.shape[:2]
476
- flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to('cuda', dtype=self.weight_dtype)
477
 
478
  inmask_no_zero = (flow_inmask != 0).all(dim=2)
479
  inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask)
 
221
  svd_ckpt = "ckpts/stable-video-diffusion-img2vid-xt-1-1"
222
  mofa_ckpt = "ckpts/controlnet"
223
 
 
 
224
  self.pipeline, self.cmp = init_models(
225
  svd_ckpt,
226
  mofa_ckpt,
227
+ weight_dtype=torch.float16,
228
  device='cuda'
229
  )
230
 
 
300
 
301
  print('start diffusion process...')
302
 
303
+ input_drag_384_inmask = input_drag_384_inmask.to('cuda', dtype=torch.float16)
304
+ mask_384_inmask = mask_384_inmask.to('cuda', dtype=torch.float16)
305
+ input_drag_384_outmask = input_drag_384_outmask.to('cuda', dtype=torch.float16)
306
+ mask_384_outmask = mask_384_outmask.to('cuda', dtype=torch.float16)
307
 
308
+ input_first_frame_384 = input_first_frame_384.to('cuda', dtype=torch.float16)
309
 
310
  if in_mask_flag:
311
  flow_inmask = self.get_flow(
 
314
  )
315
  else:
316
  fb, fl = mask_384_inmask.shape[:2]
317
+ flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to('cuda', dtype=torch.float16)
318
 
319
  if out_mask_flag:
320
  flow_outmask = self.get_flow(
 
323
  )
324
  else:
325
  fb, fl = mask_384_outmask.shape[:2]
326
+ flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to('cuda', dtype=torch.float16)
327
 
328
  inmask_no_zero = (flow_inmask != 0).all(dim=2)
329
  inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask)
 
448
  input_drag_384_outmask = input_drag_384_outmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384]
449
  mask_384_outmask = input_mask_384_outmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384]
450
 
451
+ input_drag_384_inmask = input_drag_384_inmask.to('cuda', dtype=torch.float16)
452
+ mask_384_inmask = mask_384_inmask.to('cuda', dtype=torch.float16)
453
+ input_drag_384_outmask = input_drag_384_outmask.to('cuda', dtype=torch.float16)
454
+ mask_384_outmask = mask_384_outmask.to('cuda', dtype=torch.float16)
455
 
456
+ input_first_frame_384 = input_first_frame_384.to('cuda', dtype=torch.float16)
457
 
458
  if in_mask_flag:
459
  flow_inmask = self.get_flow(
 
462
  )
463
  else:
464
  fb, fl = mask_384_inmask.shape[:2]
465
+ flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to('cuda', dtype=torch.float16)
466
 
467
  if out_mask_flag:
468
  flow_outmask = self.get_flow(
 
471
  )
472
  else:
473
  fb, fl = mask_384_outmask.shape[:2]
474
+ flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to('cuda', dtype=torch.float16)
475
 
476
  inmask_no_zero = (flow_inmask != 0).all(dim=2)
477
  inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask)