harlanhong commited on
Commit
3b98f27
1 Parent(s): 7504615
Files changed (1) hide show
  1. demo_dagan.py +5 -5
demo_dagan.py CHANGED
@@ -55,7 +55,7 @@ def find_best_frame(source, driving, cpu=False):
55
  return kp
56
 
57
  fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
58
- device='cpu')
59
  kp_source = fa.get_landmarks(255 * source)[0]
60
  kp_source = normalize_kp(kp_source)
61
  norm = float('inf')
@@ -126,7 +126,7 @@ generator = G.SPADEDepthAwareGenerator(**config['model_params']['generator_param
126
  config['model_params']['common_params']['num_channels'] = 4
127
  kp_detector = KPD.KPDetector(**config['model_params']['kp_detector_params'],**config['model_params']['common_params'])
128
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
129
-
130
 
131
  g_checkpoint = torch.load("generator.pt", map_location=device)
132
  kp_checkpoint = torch.load("kp_detector.pt", map_location=device)
@@ -182,12 +182,12 @@ with torch.inference_mode():
182
 
183
 
184
 
185
- i = find_best_frame(source_image, driving_video)
186
  print ("Best frame: " + str(i))
187
  driving_forward = driving_video[i:]
188
  driving_backward = driving_video[:(i+1)][::-1]
189
- sources_forward, drivings_forward, predictions_forward,depth_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
190
- sources_backward, drivings_backward, predictions_backward,depth_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
191
  predictions = predictions_backward[::-1] + predictions_forward[1:]
192
  sources = sources_backward[::-1] + sources_forward[1:]
193
  drivings = drivings_backward[::-1] + drivings_forward[1:]
 
55
  return kp
56
 
57
  fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
58
+ device='cpu' if cpu else 'cuda')
59
  kp_source = fa.get_landmarks(255 * source)[0]
60
  kp_source = normalize_kp(kp_source)
61
  norm = float('inf')
 
126
  config['model_params']['common_params']['num_channels'] = 4
127
  kp_detector = KPD.KPDetector(**config['model_params']['kp_detector_params'],**config['model_params']['common_params'])
128
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
129
+ cpu = False if torch.cuda.is_available() else True
130
 
131
  g_checkpoint = torch.load("generator.pt", map_location=device)
132
  kp_checkpoint = torch.load("kp_detector.pt", map_location=device)
 
182
 
183
 
184
 
185
+ i = find_best_frame(source_image, driving_video,cpu)
186
  print ("Best frame: " + str(i))
187
  driving_forward = driving_video[i:]
188
  driving_backward = driving_video[:(i+1)][::-1]
189
+ sources_forward, drivings_forward, predictions_forward,depth_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=cpu)
190
+ sources_backward, drivings_backward, predictions_backward,depth_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=cpu)
191
  predictions = predictions_backward[::-1] + predictions_forward[1:]
192
  sources = sources_backward[::-1] + sources_forward[1:]
193
  drivings = drivings_backward[::-1] + drivings_forward[1:]