jhj0517 commited on
Commit
7962f8d
1 Parent(s): 7a623dc
modules/live_portrait/live_portrait_inferencer.py CHANGED
@@ -55,6 +55,11 @@ class LivePortraitInferencer:
55
  self.d_info = None
56
 
57
  def load_models(self):
 
 
 
 
 
58
  self.download_if_no_models()
59
 
60
  appearance_feat_config = self.model_config["appearance_feature_extractor_params"]
@@ -85,11 +90,6 @@ class LivePortraitInferencer:
85
  os.path.join(self.model_dir, "spade_generator.safetensors")
86
  )
87
 
88
- def filter_stitcher(checkpoint, prefix):
89
- filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
90
- key.startswith(prefix)}
91
- return filtered_checkpoint
92
-
93
  stitcher_config = self.model_config["stitching_retargeting_module_params"]
94
  self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching'))
95
  stitcher_model_path = os.path.join(self.model_dir, "stitching_retargeting_module.safetensors")
 
55
  self.d_info = None
56
 
57
  def load_models(self):
58
+ def filter_stitcher(checkpoint, prefix):
59
+ filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
60
+ key.startswith(prefix)}
61
+ return filtered_checkpoint
62
+
63
  self.download_if_no_models()
64
 
65
  appearance_feat_config = self.model_config["appearance_feature_extractor_params"]
 
90
  os.path.join(self.model_dir, "spade_generator.safetensors")
91
  )
92
 
 
 
 
 
 
93
  stitcher_config = self.model_config["stitching_retargeting_module_params"]
94
  self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching'))
95
  stitcher_model_path = os.path.join(self.model_dir, "stitching_retargeting_module.safetensors")