Spaces:
Paused
Paused
Update src/facerender/animate.py
Browse files
src/facerender/animate.py
CHANGED
@@ -27,7 +27,7 @@ from src.utils.videoio import save_video_with_watermark
|
|
27 |
class AnimateFromCoeff():
|
28 |
|
29 |
def __init__(self, free_view_checkpoint, mapping_checkpoint,
|
30 |
-
config_path, device):
|
31 |
|
32 |
with open(config_path) as f:
|
33 |
config = yaml.safe_load(f)
|
@@ -88,7 +88,7 @@ class AnimateFromCoeff():
|
|
88 |
def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
|
89 |
kp_detector=None, he_estimator=None, optimizer_generator=None,
|
90 |
optimizer_discriminator=None, optimizer_kp_detector=None,
|
91 |
-
optimizer_he_estimator=None, device="
|
92 |
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
|
93 |
|
94 |
def adjust_state_dict(state_dict, model):
|
@@ -135,7 +135,7 @@ class AnimateFromCoeff():
|
|
135 |
return checkpoint['epoch']
|
136 |
|
137 |
def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
|
138 |
-
optimizer_mapping=None, optimizer_discriminator=None, device='
|
139 |
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
|
140 |
|
141 |
def adjust_state_dict(state_dict, model):
|
|
|
27 |
class AnimateFromCoeff():
|
28 |
|
29 |
def __init__(self, free_view_checkpoint, mapping_checkpoint,
|
30 |
+
config_path, device='cuda'):
|
31 |
|
32 |
with open(config_path) as f:
|
33 |
config = yaml.safe_load(f)
|
|
|
88 |
def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
|
89 |
kp_detector=None, he_estimator=None, optimizer_generator=None,
|
90 |
optimizer_discriminator=None, optimizer_kp_detector=None,
|
91 |
+
optimizer_he_estimator=None, device="cuda"):
|
92 |
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
|
93 |
|
94 |
def adjust_state_dict(state_dict, model):
|
|
|
135 |
return checkpoint['epoch']
|
136 |
|
137 |
def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
|
138 |
+
optimizer_mapping=None, optimizer_discriminator=None, device='cuda'):
|
139 |
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
|
140 |
|
141 |
def adjust_state_dict(state_dict, model):
|