jhj0517
commited on
Commit
•
46fa2af
1
Parent(s):
7785d3b
Apply model type enum
Browse files- app.py +2 -0
- modules/live_portrait/live_portrait_inferencer.py +54 -11
app.py
CHANGED
@@ -22,6 +22,8 @@ class App:
|
|
22 |
@staticmethod
|
23 |
def create_parameters():
|
24 |
return [
|
|
|
|
|
25 |
gr.Slider(label=_("Rotate Pitch"), minimum=-20, maximum=20, step=0.5, value=0),
|
26 |
gr.Slider(label=_("Rotate Yaw"), minimum=-20, maximum=20, step=0.5, value=0),
|
27 |
gr.Slider(label=_("Rotate Roll"), minimum=-20, maximum=20, step=0.5, value=0),
|
|
|
22 |
@staticmethod
|
23 |
def create_parameters():
|
24 |
return [
|
25 |
+
gr.Dropdown(label=_("Model Type"),
|
26 |
+
choices=[item.value for item in ModelType], value=ModelType.HUMAN.value),
|
27 |
gr.Slider(label=_("Rotate Pitch"), minimum=-20, maximum=20, step=0.5, value=0),
|
28 |
gr.Slider(label=_("Rotate Yaw"), minimum=-20, maximum=20, step=0.5, value=0),
|
29 |
gr.Slider(label=_("Rotate Roll"), minimum=-20, maximum=20, step=0.5, value=0),
|
modules/live_portrait/live_portrait_inferencer.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import logging
|
|
|
2 |
import cv2
|
3 |
import time
|
4 |
import copy
|
@@ -6,7 +7,10 @@ import dill
|
|
6 |
from ultralytics import YOLO
|
7 |
import safetensors.torch
|
8 |
import gradio as gr
|
|
|
9 |
from ultralytics.utils import LOGGER as ultralytics_logger
|
|
|
|
|
10 |
|
11 |
from modules.utils.paths import *
|
12 |
from modules.utils.image_helper import *
|
@@ -14,6 +18,7 @@ from modules.live_portrait.model_downloader import *
|
|
14 |
from modules.live_portrait.live_portrait_wrapper import LivePortraitWrapper
|
15 |
from modules.utils.camera import get_rotation_matrix
|
16 |
from modules.utils.helper import load_yaml
|
|
|
17 |
from modules.config.inference_config import InferenceConfig
|
18 |
from modules.live_portrait.spade_generator import SPADEDecoder
|
19 |
from modules.live_portrait.warping_network import WarpingNetwork
|
@@ -27,6 +32,7 @@ class LivePortraitInferencer:
|
|
27 |
model_dir: str = MODELS_DIR,
|
28 |
output_dir: str = OUTPUTS_DIR):
|
29 |
self.model_dir = model_dir
|
|
|
30 |
self.output_dir = output_dir
|
31 |
self.model_config = load_yaml(MODEL_CONFIG)["model_params"]
|
32 |
|
@@ -38,6 +44,7 @@ class LivePortraitInferencer:
|
|
38 |
self.pipeline = None
|
39 |
self.detect_model = None
|
40 |
self.device = self.get_device()
|
|
|
41 |
|
42 |
self.mask_img = None
|
43 |
self.temp_img_idx = 0
|
@@ -52,8 +59,22 @@ class LivePortraitInferencer:
|
|
52 |
self.d_info = None
|
53 |
|
54 |
def load_models(self,
|
|
|
55 |
progress=gr.Progress()):
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
total_models_num = 5
|
59 |
progress(0/total_models_num, desc="Loading Appearance Feature Extractor model...")
|
@@ -61,7 +82,7 @@ class LivePortraitInferencer:
|
|
61 |
self.appearance_feature_extractor = AppearanceFeatureExtractor(**appearance_feat_config).to(self.device)
|
62 |
self.appearance_feature_extractor = self.load_safe_tensor(
|
63 |
self.appearance_feature_extractor,
|
64 |
-
os.path.join(
|
65 |
)
|
66 |
|
67 |
progress(1/total_models_num, desc="Loading Motion Extractor model...")
|
@@ -69,7 +90,7 @@ class LivePortraitInferencer:
|
|
69 |
self.motion_extractor = MotionExtractor(**motion_ext_config).to(self.device)
|
70 |
self.motion_extractor = self.load_safe_tensor(
|
71 |
self.motion_extractor,
|
72 |
-
os.path.join(
|
73 |
)
|
74 |
|
75 |
progress(2/total_models_num, desc="Loading Warping Module model...")
|
@@ -77,7 +98,7 @@ class LivePortraitInferencer:
|
|
77 |
self.warping_module = WarpingNetwork(**warping_module_config).to(self.device)
|
78 |
self.warping_module = self.load_safe_tensor(
|
79 |
self.warping_module,
|
80 |
-
os.path.join(
|
81 |
)
|
82 |
|
83 |
progress(3/total_models_num, desc="Loading Spade generator model...")
|
@@ -85,7 +106,7 @@ class LivePortraitInferencer:
|
|
85 |
self.spade_generator = SPADEDecoder(**spaded_decoder_config).to(self.device)
|
86 |
self.spade_generator = self.load_safe_tensor(
|
87 |
self.spade_generator,
|
88 |
-
os.path.join(
|
89 |
)
|
90 |
|
91 |
progress(4/total_models_num, desc="Loading Stitcher model...")
|
@@ -93,7 +114,7 @@ class LivePortraitInferencer:
|
|
93 |
self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching')).to(self.device)
|
94 |
self.stitching_retargeting_module = self.load_safe_tensor(
|
95 |
self.stitching_retargeting_module,
|
96 |
-
os.path.join(
|
97 |
True
|
98 |
)
|
99 |
self.stitching_retargeting_module = {"stitching": self.stitching_retargeting_module}
|
@@ -111,6 +132,7 @@ class LivePortraitInferencer:
|
|
111 |
self.detect_model = YOLO(MODEL_PATHS["face_yolov8n"]).to(self.device)
|
112 |
|
113 |
def edit_expression(self,
|
|
|
114 |
rotate_pitch=0,
|
115 |
rotate_yaw=0,
|
116 |
rotate_roll=0,
|
@@ -131,8 +153,15 @@ class LivePortraitInferencer:
|
|
131 |
sample_image=None,
|
132 |
motion_link=None,
|
133 |
add_exp=None):
|
134 |
-
if
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
try:
|
138 |
rotate_yaw = -rotate_yaw
|
@@ -330,14 +359,27 @@ class LivePortraitInferencer:
|
|
330 |
return out_imgs
|
331 |
|
332 |
def download_if_no_models(self,
|
333 |
-
|
|
|
334 |
progress(0, desc="Downloading models...")
|
335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
if model_url.endswith(".pt"):
|
337 |
model_name += ".pt"
|
|
|
|
|
338 |
else:
|
339 |
model_name += ".safetensors"
|
340 |
-
model_path = os.path.join(
|
341 |
if not os.path.exists(model_path):
|
342 |
download_model(model_path, model_url)
|
343 |
|
@@ -779,3 +821,4 @@ class Command:
|
|
779 |
self.es = es
|
780 |
self.change = change
|
781 |
self.keep = keep
|
|
|
|
1 |
import logging
|
2 |
+
import os
|
3 |
import cv2
|
4 |
import time
|
5 |
import copy
|
|
|
7 |
from ultralytics import YOLO
|
8 |
import safetensors.torch
|
9 |
import gradio as gr
|
10 |
+
from gradio_i18n import Translate, gettext as _
|
11 |
from ultralytics.utils import LOGGER as ultralytics_logger
|
12 |
+
from enum import Enum
|
13 |
+
from typing import Union
|
14 |
|
15 |
from modules.utils.paths import *
|
16 |
from modules.utils.image_helper import *
|
|
|
18 |
from modules.live_portrait.live_portrait_wrapper import LivePortraitWrapper
|
19 |
from modules.utils.camera import get_rotation_matrix
|
20 |
from modules.utils.helper import load_yaml
|
21 |
+
from modules.utils.constants import *
|
22 |
from modules.config.inference_config import InferenceConfig
|
23 |
from modules.live_portrait.spade_generator import SPADEDecoder
|
24 |
from modules.live_portrait.warping_network import WarpingNetwork
|
|
|
32 |
model_dir: str = MODELS_DIR,
|
33 |
output_dir: str = OUTPUTS_DIR):
|
34 |
self.model_dir = model_dir
|
35 |
+
os.makedirs(os.path.join(self.model_dir, "animal"), exist_ok=True)
|
36 |
self.output_dir = output_dir
|
37 |
self.model_config = load_yaml(MODEL_CONFIG)["model_params"]
|
38 |
|
|
|
44 |
self.pipeline = None
|
45 |
self.detect_model = None
|
46 |
self.device = self.get_device()
|
47 |
+
self.model_type = ModelType.HUMAN.value
|
48 |
|
49 |
self.mask_img = None
|
50 |
self.temp_img_idx = 0
|
|
|
59 |
self.d_info = None
|
60 |
|
61 |
def load_models(self,
|
62 |
+
model_type: str = ModelType.HUMAN.value,
|
63 |
progress=gr.Progress()):
|
64 |
+
if isinstance(model_type, ModelType):
|
65 |
+
model_type = model_type.value
|
66 |
+
if model_type not in [mode.value for mode in ModelType]:
|
67 |
+
model_type = ModelType.HUMAN.value
|
68 |
+
|
69 |
+
self.model_type = model_type
|
70 |
+
if model_type == ModelType.ANIMAL.value:
|
71 |
+
model_dir = os.path.join(self.model_dir, "animal")
|
72 |
+
else:
|
73 |
+
model_dir = self.model_dir
|
74 |
+
|
75 |
+
self.download_if_no_models(
|
76 |
+
model_type=model_type
|
77 |
+
)
|
78 |
|
79 |
total_models_num = 5
|
80 |
progress(0/total_models_num, desc="Loading Appearance Feature Extractor model...")
|
|
|
82 |
self.appearance_feature_extractor = AppearanceFeatureExtractor(**appearance_feat_config).to(self.device)
|
83 |
self.appearance_feature_extractor = self.load_safe_tensor(
|
84 |
self.appearance_feature_extractor,
|
85 |
+
os.path.join(model_dir, "appearance_feature_extractor.safetensors")
|
86 |
)
|
87 |
|
88 |
progress(1/total_models_num, desc="Loading Motion Extractor model...")
|
|
|
90 |
self.motion_extractor = MotionExtractor(**motion_ext_config).to(self.device)
|
91 |
self.motion_extractor = self.load_safe_tensor(
|
92 |
self.motion_extractor,
|
93 |
+
os.path.join(model_dir, "motion_extractor.safetensors")
|
94 |
)
|
95 |
|
96 |
progress(2/total_models_num, desc="Loading Warping Module model...")
|
|
|
98 |
self.warping_module = WarpingNetwork(**warping_module_config).to(self.device)
|
99 |
self.warping_module = self.load_safe_tensor(
|
100 |
self.warping_module,
|
101 |
+
os.path.join(model_dir, "warping_module.safetensors")
|
102 |
)
|
103 |
|
104 |
progress(3/total_models_num, desc="Loading Spade generator model...")
|
|
|
106 |
self.spade_generator = SPADEDecoder(**spaded_decoder_config).to(self.device)
|
107 |
self.spade_generator = self.load_safe_tensor(
|
108 |
self.spade_generator,
|
109 |
+
os.path.join(model_dir, "spade_generator.safetensors")
|
110 |
)
|
111 |
|
112 |
progress(4/total_models_num, desc="Loading Stitcher model...")
|
|
|
114 |
self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching')).to(self.device)
|
115 |
self.stitching_retargeting_module = self.load_safe_tensor(
|
116 |
self.stitching_retargeting_module,
|
117 |
+
os.path.join(model_dir, "stitching_retargeting_module.safetensors"),
|
118 |
True
|
119 |
)
|
120 |
self.stitching_retargeting_module = {"stitching": self.stitching_retargeting_module}
|
|
|
132 |
self.detect_model = YOLO(MODEL_PATHS["face_yolov8n"]).to(self.device)
|
133 |
|
134 |
def edit_expression(self,
|
135 |
+
model_type: str = ModelType.HUMAN.value,
|
136 |
rotate_pitch=0,
|
137 |
rotate_yaw=0,
|
138 |
rotate_roll=0,
|
|
|
153 |
sample_image=None,
|
154 |
motion_link=None,
|
155 |
add_exp=None):
|
156 |
+
if isinstance(model_type, ModelType):
|
157 |
+
model_type = model_type.value
|
158 |
+
if model_type not in [mode.value for mode in ModelType]:
|
159 |
+
model_type = ModelType.HUMAN
|
160 |
+
|
161 |
+
if self.pipeline is None or model_type != self.model_type:
|
162 |
+
self.load_models(
|
163 |
+
model_type=model_type
|
164 |
+
)
|
165 |
|
166 |
try:
|
167 |
rotate_yaw = -rotate_yaw
|
|
|
359 |
return out_imgs
|
360 |
|
361 |
def download_if_no_models(self,
|
362 |
+
model_type: str = ModelType.HUMAN.value,
|
363 |
+
progress=gr.Progress(), ):
|
364 |
progress(0, desc="Downloading models...")
|
365 |
+
|
366 |
+
if isinstance(model_type, ModelType):
|
367 |
+
model_type = model_type.value
|
368 |
+
if model_type == ModelType.ANIMAL.value:
|
369 |
+
models_urls_dic = MODELS_ANIMAL_URL
|
370 |
+
model_dir = os.path.join(self.model_dir, "animal")
|
371 |
+
else:
|
372 |
+
models_urls_dic = MODELS_URL
|
373 |
+
model_dir = self.model_dir
|
374 |
+
|
375 |
+
for model_name, model_url in models_urls_dic.items():
|
376 |
if model_url.endswith(".pt"):
|
377 |
model_name += ".pt"
|
378 |
+
# Exception for face_yolov8n.pt
|
379 |
+
model_dir = self.model_dir
|
380 |
else:
|
381 |
model_name += ".safetensors"
|
382 |
+
model_path = os.path.join(model_dir, model_name)
|
383 |
if not os.path.exists(model_path):
|
384 |
download_model(model_path, model_url)
|
385 |
|
|
|
821 |
self.es = es
|
822 |
self.change = change
|
823 |
self.keep = keep
|
824 |
+
|