hysts HF staff commited on
Commit
f62a68b
1 Parent(s): b8ebdcf
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -10,6 +10,7 @@ import urllib.request
10
  import cv2
11
  import gradio as gr
12
  import numpy as np
 
13
  import torch
14
  from huggingface_hub import hf_hub_download
15
 
@@ -50,12 +51,15 @@ if is_lfs_pointer_file(lfs_model_path):
50
 
51
  def load_model(model_name: str, threshold: float, device: torch.device) -> RetinaFacePredictor | S3FDPredictor:
52
  if model_name == "s3fd":
53
- model = S3FDPredictor(threshold=threshold, device=device)
 
 
 
54
  else:
55
  model_name = model_name.replace("retinaface_", "")
56
- model = RetinaFacePredictor(
57
- threshold=threshold, device=device, model=RetinaFacePredictor.get_model(model_name)
58
- )
59
  return model
60
 
61
 
@@ -68,6 +72,7 @@ model_names = [
68
  detectors = {name: load_model(name, threshold=0.8, device=device) for name in model_names}
69
 
70
 
 
71
  def detect(image: np.ndarray, model_name: str, face_score_threshold: float) -> np.ndarray:
72
  model = detectors[model_name]
73
  model.threshold = face_score_threshold
 
10
  import cv2
11
  import gradio as gr
12
  import numpy as np
13
+ import spaces
14
  import torch
15
  from huggingface_hub import hf_hub_download
16
 
 
51
 
52
  def load_model(model_name: str, threshold: float, device: torch.device) -> RetinaFacePredictor | S3FDPredictor:
53
  if model_name == "s3fd":
54
+ model = S3FDPredictor(threshold=threshold, device="cpu")
55
+ model.device = device
56
+ model.net.device = device
57
+ model.net.to(device)
58
  else:
59
  model_name = model_name.replace("retinaface_", "")
60
+ model = RetinaFacePredictor(threshold=threshold, device="cpu", model=RetinaFacePredictor.get_model(model_name))
61
+ model.device = device
62
+ model.net.to(device)
63
  return model
64
 
65
 
 
72
  detectors = {name: load_model(name, threshold=0.8, device=device) for name in model_names}
73
 
74
 
75
+ @spaces.GPU
76
  def detect(image: np.ndarray, model_name: str, face_score_threshold: float) -> np.ndarray:
77
  model = detectors[model_name]
78
  model.threshold = face_score_threshold