jhj0517 commited on
Commit
e52a9ca
·
1 Parent(s): 72a8d5d

Support zero gpu

Browse files
modules/live_portrait/live_portrait_inferencer.py CHANGED
@@ -11,6 +11,7 @@ from gradio_i18n import Translate, gettext as _
11
  from ultralytics.utils import LOGGER as ultralytics_logger
12
  from enum import Enum
13
  from typing import Union
 
14
 
15
  from modules.utils.paths import *
16
  from modules.utils.image_helper import *
@@ -58,6 +59,7 @@ class LivePortraitInferencer:
58
  self.psi_list = None
59
  self.d_info = None
60
 
 
61
  def load_models(self,
62
  model_type: str = ModelType.HUMAN.value,
63
  progress=gr.Progress()):
@@ -132,6 +134,7 @@ class LivePortraitInferencer:
132
  det_model_name = "yolo_v5s_animal_det" if model_type == ModelType.ANIMAL else "face_yolov8n"
133
  self.detect_model = YOLO(MODEL_PATHS[det_model_name]).to(self.device)
134
 
 
135
  def edit_expression(self,
136
  model_type: str = ModelType.HUMAN.value,
137
  rotate_pitch=0,
@@ -240,6 +243,7 @@ class LivePortraitInferencer:
240
  except Exception as e:
241
  raise
242
 
 
243
  def create_video(self,
244
  retargeting_eyes,
245
  retargeting_mouth,
@@ -385,6 +389,7 @@ class LivePortraitInferencer:
385
  download_model(model_path, model_url)
386
 
387
  @staticmethod
 
388
  def load_safe_tensor(model, file_path, is_stitcher=False):
389
  def filter_stitcher(checkpoint, prefix):
390
  filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
@@ -399,6 +404,7 @@ class LivePortraitInferencer:
399
  return model
400
 
401
  @staticmethod
 
402
  def get_device():
403
  if torch.cuda.is_available():
404
  return "cuda"
@@ -443,6 +449,7 @@ class LivePortraitInferencer:
443
 
444
  return cmd_list, total_length
445
 
 
446
  def get_face_bboxes(self, image_rgb):
447
  pred = self.detect_model(image_rgb, conf=0.7, device=self.device)
448
  return pred[0].boxes.xyxy.cpu().numpy()
@@ -551,6 +558,7 @@ class LivePortraitInferencer:
551
  cv2.INTER_LINEAR)
552
  return new_img
553
 
 
554
  def prepare_src_image(self, img):
555
  h, w = img.shape[:2]
556
  input_shape = [256,256]
 
11
  from ultralytics.utils import LOGGER as ultralytics_logger
12
  from enum import Enum
13
  from typing import Union
14
+ import spaces
15
 
16
  from modules.utils.paths import *
17
  from modules.utils.image_helper import *
 
59
  self.psi_list = None
60
  self.d_info = None
61
 
62
+ @spaces.GPU
63
  def load_models(self,
64
  model_type: str = ModelType.HUMAN.value,
65
  progress=gr.Progress()):
 
134
  det_model_name = "yolo_v5s_animal_det" if model_type == ModelType.ANIMAL else "face_yolov8n"
135
  self.detect_model = YOLO(MODEL_PATHS[det_model_name]).to(self.device)
136
 
137
+ @spaces.GPU
138
  def edit_expression(self,
139
  model_type: str = ModelType.HUMAN.value,
140
  rotate_pitch=0,
 
243
  except Exception as e:
244
  raise
245
 
246
+ @spaces.GPU
247
  def create_video(self,
248
  retargeting_eyes,
249
  retargeting_mouth,
 
389
  download_model(model_path, model_url)
390
 
391
  @staticmethod
392
+ @spaces.GPU
393
  def load_safe_tensor(model, file_path, is_stitcher=False):
394
  def filter_stitcher(checkpoint, prefix):
395
  filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
 
404
  return model
405
 
406
  @staticmethod
407
+ @spaces.GPU
408
  def get_device():
409
  if torch.cuda.is_available():
410
  return "cuda"
 
449
 
450
  return cmd_list, total_length
451
 
452
+ @spaces.GPU
453
  def get_face_bboxes(self, image_rgb):
454
  pred = self.detect_model(image_rgb, conf=0.7, device=self.device)
455
  return pred[0].boxes.xyxy.cpu().numpy()
 
558
  cv2.INTER_LINEAR)
559
  return new_img
560
 
561
+ @spaces.GPU
562
  def prepare_src_image(self, img):
563
  h, w = img.shape[:2]
564
  input_shape = [256,256]