jhj0517
commited on
Commit
•
71c08fe
1
Parent(s):
c311090
Enable image restoration in expression editor
Browse files
modules/live_portrait/live_portrait_inferencer.py
CHANGED
@@ -27,6 +27,7 @@ from modules.live_portrait.warping_network import WarpingNetwork
|
|
27 |
from modules.live_portrait.motion_extractor import MotionExtractor
|
28 |
from modules.live_portrait.appearance_feature_extractor import AppearanceFeatureExtractor
|
29 |
from modules.live_portrait.stitching_retargeting_network import StitchingRetargetingNetwork
|
|
|
30 |
|
31 |
|
32 |
class LivePortraitInferencer:
|
@@ -69,6 +70,11 @@ class LivePortraitInferencer:
|
|
69 |
self.psi_list = None
|
70 |
self.d_info = None
|
71 |
|
|
|
|
|
|
|
|
|
|
|
72 |
def load_models(self,
|
73 |
model_type: str = ModelType.HUMAN.value,
|
74 |
progress=gr.Progress()):
|
@@ -161,6 +167,7 @@ class LivePortraitInferencer:
|
|
161 |
sample_ratio: float = 1,
|
162 |
sample_parts: str = SamplePart.ALL.value,
|
163 |
crop_factor: float = 2.3,
|
|
|
164 |
src_image: Optional[str] = None,
|
165 |
sample_image: Optional[str] = None,) -> None:
|
166 |
if isinstance(model_type, ModelType):
|
@@ -232,8 +239,12 @@ class LivePortraitInferencer:
|
|
232 |
out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(np.uint8)
|
233 |
|
234 |
temp_out_img_path, out_img_path = get_auto_incremental_file_path(TEMP_DIR, "png"), get_auto_incremental_file_path(OUTPUTS_DIR, "png")
|
235 |
-
save_image(numpy_array=crop_out, output_path=temp_out_img_path)
|
236 |
-
save_image(numpy_array=out, output_path=out_img_path)
|
|
|
|
|
|
|
|
|
237 |
|
238 |
return out
|
239 |
except Exception as e:
|
|
|
27 |
from modules.live_portrait.motion_extractor import MotionExtractor
|
28 |
from modules.live_portrait.appearance_feature_extractor import AppearanceFeatureExtractor
|
29 |
from modules.live_portrait.stitching_retargeting_network import StitchingRetargetingNetwork
|
30 |
+
from modules.image_restoration.real_esrgan_inferencer import RealESRGANInferencer
|
31 |
|
32 |
|
33 |
class LivePortraitInferencer:
|
|
|
70 |
self.psi_list = None
|
71 |
self.d_info = None
|
72 |
|
73 |
+
self.resrgan_inferencer = RealESRGANInferencer(
|
74 |
+
model_dir=os.path.join(self.model_dir, "RealESRGAN"),
|
75 |
+
output_dir=self.output_dir
|
76 |
+
)
|
77 |
+
|
78 |
def load_models(self,
|
79 |
model_type: str = ModelType.HUMAN.value,
|
80 |
progress=gr.Progress()):
|
|
|
167 |
sample_ratio: float = 1,
|
168 |
sample_parts: str = SamplePart.ALL.value,
|
169 |
crop_factor: float = 2.3,
|
170 |
+
enable_image_restoration: bool = False,
|
171 |
src_image: Optional[str] = None,
|
172 |
sample_image: Optional[str] = None,) -> None:
|
173 |
if isinstance(model_type, ModelType):
|
|
|
239 |
out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(np.uint8)
|
240 |
|
241 |
temp_out_img_path, out_img_path = get_auto_incremental_file_path(TEMP_DIR, "png"), get_auto_incremental_file_path(OUTPUTS_DIR, "png")
|
242 |
+
cropped_out_img_path = save_image(numpy_array=crop_out, output_path=temp_out_img_path)
|
243 |
+
out_img_path = save_image(numpy_array=out, output_path=out_img_path)
|
244 |
+
|
245 |
+
if enable_image_restoration:
|
246 |
+
cropped_out_img_path = self.resrgan_inferencer.restore_image(cropped_out_img_path)
|
247 |
+
out_img_path = self.resrgan_inferencer.restore_image(out_img_path)
|
248 |
|
249 |
return out
|
250 |
except Exception as e:
|