jhj0517 commited on
Commit
71c08fe
1 Parent(s): c311090

Enable image restoration in expression editor

Browse files
modules/live_portrait/live_portrait_inferencer.py CHANGED
@@ -27,6 +27,7 @@ from modules.live_portrait.warping_network import WarpingNetwork
27
  from modules.live_portrait.motion_extractor import MotionExtractor
28
  from modules.live_portrait.appearance_feature_extractor import AppearanceFeatureExtractor
29
  from modules.live_portrait.stitching_retargeting_network import StitchingRetargetingNetwork
 
30
 
31
 
32
  class LivePortraitInferencer:
@@ -69,6 +70,11 @@ class LivePortraitInferencer:
69
  self.psi_list = None
70
  self.d_info = None
71
 
 
 
 
 
 
72
  def load_models(self,
73
  model_type: str = ModelType.HUMAN.value,
74
  progress=gr.Progress()):
@@ -161,6 +167,7 @@ class LivePortraitInferencer:
161
  sample_ratio: float = 1,
162
  sample_parts: str = SamplePart.ALL.value,
163
  crop_factor: float = 2.3,
 
164
  src_image: Optional[str] = None,
165
  sample_image: Optional[str] = None,) -> None:
166
  if isinstance(model_type, ModelType):
@@ -232,8 +239,12 @@ class LivePortraitInferencer:
232
  out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(np.uint8)
233
 
234
  temp_out_img_path, out_img_path = get_auto_incremental_file_path(TEMP_DIR, "png"), get_auto_incremental_file_path(OUTPUTS_DIR, "png")
235
- save_image(numpy_array=crop_out, output_path=temp_out_img_path)
236
- save_image(numpy_array=out, output_path=out_img_path)
 
 
 
 
237
 
238
  return out
239
  except Exception as e:
 
27
  from modules.live_portrait.motion_extractor import MotionExtractor
28
  from modules.live_portrait.appearance_feature_extractor import AppearanceFeatureExtractor
29
  from modules.live_portrait.stitching_retargeting_network import StitchingRetargetingNetwork
30
+ from modules.image_restoration.real_esrgan_inferencer import RealESRGANInferencer
31
 
32
 
33
  class LivePortraitInferencer:
 
70
  self.psi_list = None
71
  self.d_info = None
72
 
73
+ self.resrgan_inferencer = RealESRGANInferencer(
74
+ model_dir=os.path.join(self.model_dir, "RealESRGAN"),
75
+ output_dir=self.output_dir
76
+ )
77
+
78
  def load_models(self,
79
  model_type: str = ModelType.HUMAN.value,
80
  progress=gr.Progress()):
 
167
  sample_ratio: float = 1,
168
  sample_parts: str = SamplePart.ALL.value,
169
  crop_factor: float = 2.3,
170
+ enable_image_restoration: bool = False,
171
  src_image: Optional[str] = None,
172
  sample_image: Optional[str] = None,) -> None:
173
  if isinstance(model_type, ModelType):
 
239
  out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(np.uint8)
240
 
241
  temp_out_img_path, out_img_path = get_auto_incremental_file_path(TEMP_DIR, "png"), get_auto_incremental_file_path(OUTPUTS_DIR, "png")
242
+ cropped_out_img_path = save_image(numpy_array=crop_out, output_path=temp_out_img_path)
243
+ out_img_path = save_image(numpy_array=out, output_path=out_img_path)
244
+
245
+ if enable_image_restoration:
246
+ cropped_out_img_path = self.resrgan_inferencer.restore_image(cropped_out_img_path)
247
+ out_img_path = self.resrgan_inferencer.restore_image(out_img_path)
248
 
249
  return out
250
  except Exception as e: