jhj0517
commited on
Commit
·
e52a9ca
1
Parent(s):
72a8d5d
Support zero gpu
Browse files
modules/live_portrait/live_portrait_inferencer.py
CHANGED
@@ -11,6 +11,7 @@ from gradio_i18n import Translate, gettext as _
|
|
11 |
from ultralytics.utils import LOGGER as ultralytics_logger
|
12 |
from enum import Enum
|
13 |
from typing import Union
|
|
|
14 |
|
15 |
from modules.utils.paths import *
|
16 |
from modules.utils.image_helper import *
|
@@ -58,6 +59,7 @@ class LivePortraitInferencer:
|
|
58 |
self.psi_list = None
|
59 |
self.d_info = None
|
60 |
|
|
|
61 |
def load_models(self,
|
62 |
model_type: str = ModelType.HUMAN.value,
|
63 |
progress=gr.Progress()):
|
@@ -132,6 +134,7 @@ class LivePortraitInferencer:
|
|
132 |
det_model_name = "yolo_v5s_animal_det" if model_type == ModelType.ANIMAL else "face_yolov8n"
|
133 |
self.detect_model = YOLO(MODEL_PATHS[det_model_name]).to(self.device)
|
134 |
|
|
|
135 |
def edit_expression(self,
|
136 |
model_type: str = ModelType.HUMAN.value,
|
137 |
rotate_pitch=0,
|
@@ -240,6 +243,7 @@ class LivePortraitInferencer:
|
|
240 |
except Exception as e:
|
241 |
raise
|
242 |
|
|
|
243 |
def create_video(self,
|
244 |
retargeting_eyes,
|
245 |
retargeting_mouth,
|
@@ -385,6 +389,7 @@ class LivePortraitInferencer:
|
|
385 |
download_model(model_path, model_url)
|
386 |
|
387 |
@staticmethod
|
|
|
388 |
def load_safe_tensor(model, file_path, is_stitcher=False):
|
389 |
def filter_stitcher(checkpoint, prefix):
|
390 |
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
|
@@ -399,6 +404,7 @@ class LivePortraitInferencer:
|
|
399 |
return model
|
400 |
|
401 |
@staticmethod
|
|
|
402 |
def get_device():
|
403 |
if torch.cuda.is_available():
|
404 |
return "cuda"
|
@@ -443,6 +449,7 @@ class LivePortraitInferencer:
|
|
443 |
|
444 |
return cmd_list, total_length
|
445 |
|
|
|
446 |
def get_face_bboxes(self, image_rgb):
|
447 |
pred = self.detect_model(image_rgb, conf=0.7, device=self.device)
|
448 |
return pred[0].boxes.xyxy.cpu().numpy()
|
@@ -551,6 +558,7 @@ class LivePortraitInferencer:
|
|
551 |
cv2.INTER_LINEAR)
|
552 |
return new_img
|
553 |
|
|
|
554 |
def prepare_src_image(self, img):
|
555 |
h, w = img.shape[:2]
|
556 |
input_shape = [256,256]
|
|
|
11 |
from ultralytics.utils import LOGGER as ultralytics_logger
|
12 |
from enum import Enum
|
13 |
from typing import Union
|
14 |
+
import spaces
|
15 |
|
16 |
from modules.utils.paths import *
|
17 |
from modules.utils.image_helper import *
|
|
|
59 |
self.psi_list = None
|
60 |
self.d_info = None
|
61 |
|
62 |
+
@spaces.GPU
|
63 |
def load_models(self,
|
64 |
model_type: str = ModelType.HUMAN.value,
|
65 |
progress=gr.Progress()):
|
|
|
134 |
det_model_name = "yolo_v5s_animal_det" if model_type == ModelType.ANIMAL else "face_yolov8n"
|
135 |
self.detect_model = YOLO(MODEL_PATHS[det_model_name]).to(self.device)
|
136 |
|
137 |
+
@spaces.GPU
|
138 |
def edit_expression(self,
|
139 |
model_type: str = ModelType.HUMAN.value,
|
140 |
rotate_pitch=0,
|
|
|
243 |
except Exception as e:
|
244 |
raise
|
245 |
|
246 |
+
@spaces.GPU
|
247 |
def create_video(self,
|
248 |
retargeting_eyes,
|
249 |
retargeting_mouth,
|
|
|
389 |
download_model(model_path, model_url)
|
390 |
|
391 |
@staticmethod
|
392 |
+
@spaces.GPU
|
393 |
def load_safe_tensor(model, file_path, is_stitcher=False):
|
394 |
def filter_stitcher(checkpoint, prefix):
|
395 |
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
|
|
|
404 |
return model
|
405 |
|
406 |
@staticmethod
|
407 |
+
@spaces.GPU
|
408 |
def get_device():
|
409 |
if torch.cuda.is_available():
|
410 |
return "cuda"
|
|
|
449 |
|
450 |
return cmd_list, total_length
|
451 |
|
452 |
+
@spaces.GPU
|
453 |
def get_face_bboxes(self, image_rgb):
|
454 |
pred = self.detect_model(image_rgb, conf=0.7, device=self.device)
|
455 |
return pred[0].boxes.xyxy.cpu().numpy()
|
|
|
558 |
cv2.INTER_LINEAR)
|
559 |
return new_img
|
560 |
|
561 |
+
@spaces.GPU
|
562 |
def prepare_src_image(self, img):
|
563 |
h, w = img.shape[:2]
|
564 |
input_shape = [256,256]
|