jhj0517
commited on
Commit
•
7962f8d
1
Parent(s):
7a623dc
Refactor
Browse files
modules/live_portrait/live_portrait_inferencer.py
CHANGED
@@ -55,6 +55,11 @@ class LivePortraitInferencer:
|
|
55 |
self.d_info = None
|
56 |
|
57 |
def load_models(self):
|
|
|
|
|
|
|
|
|
|
|
58 |
self.download_if_no_models()
|
59 |
|
60 |
appearance_feat_config = self.model_config["appearance_feature_extractor_params"]
|
@@ -85,11 +90,6 @@ class LivePortraitInferencer:
|
|
85 |
os.path.join(self.model_dir, "spade_generator.safetensors")
|
86 |
)
|
87 |
|
88 |
-
def filter_stitcher(checkpoint, prefix):
|
89 |
-
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
|
90 |
-
key.startswith(prefix)}
|
91 |
-
return filtered_checkpoint
|
92 |
-
|
93 |
stitcher_config = self.model_config["stitching_retargeting_module_params"]
|
94 |
self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching'))
|
95 |
stitcher_model_path = os.path.join(self.model_dir, "stitching_retargeting_module.safetensors")
|
|
|
55 |
self.d_info = None
|
56 |
|
57 |
def load_models(self):
|
58 |
+
def filter_stitcher(checkpoint, prefix):
|
59 |
+
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
|
60 |
+
key.startswith(prefix)}
|
61 |
+
return filtered_checkpoint
|
62 |
+
|
63 |
self.download_if_no_models()
|
64 |
|
65 |
appearance_feat_config = self.model_config["appearance_feature_extractor_params"]
|
|
|
90 |
os.path.join(self.model_dir, "spade_generator.safetensors")
|
91 |
)
|
92 |
|
|
|
|
|
|
|
|
|
|
|
93 |
stitcher_config = self.model_config["stitching_retargeting_module_params"]
|
94 |
self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching'))
|
95 |
stitcher_model_path = os.path.join(self.model_dir, "stitching_retargeting_module.safetensors")
|