hysts HF staff commited on
Commit
b97628e
1 Parent(s): 39b7f79
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -8,6 +8,7 @@ import sys
8
 
9
  import gradio as gr
10
  import numpy as np
 
11
  import torch
12
  from huggingface_hub import hf_hub_download
13
 
@@ -61,12 +62,15 @@ def load_model(model_name: str, device: torch.device) -> FaceParser:
61
 
62
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
63
 
64
- detector = RetinaFacePredictor(threshold=0.8, device=device, model=RetinaFacePredictor.get_model("mobilenet0.25"))
 
 
65
 
66
  model_names = list(WEIGHT.keys())
67
  models = {name: load_model(name, device=device) for name in model_names}
68
 
69
 
 
70
  def predict(image: np.ndarray, model_name: str, max_num_faces: int) -> np.ndarray:
71
  model = models[model_name]
72
  colormap = label_colormap(model.num_classes)
@@ -105,7 +109,6 @@ with gr.Blocks(css="style.css") as demo:
105
  inputs=[image, model_name, max_num_faces],
106
  outputs=result,
107
  fn=predict,
108
- cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
109
  )
110
 
111
  run_button.click(
 
8
 
9
  import gradio as gr
10
  import numpy as np
11
+ import spaces
12
  import torch
13
  from huggingface_hub import hf_hub_download
14
 
 
62
 
63
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
64
 
65
+ detector = RetinaFacePredictor(threshold=0.8, device="cpu", model=RetinaFacePredictor.get_model("mobilenet0.25"))
66
+ detector.device = device
67
+ detector.net.to(device)
68
 
69
  model_names = list(WEIGHT.keys())
70
  models = {name: load_model(name, device=device) for name in model_names}
71
 
72
 
73
+ @spaces.GPU
74
  def predict(image: np.ndarray, model_name: str, max_num_faces: int) -> np.ndarray:
75
  model = models[model_name]
76
  colormap = label_colormap(model.num_classes)
 
109
  inputs=[image, model_name, max_num_faces],
110
  outputs=result,
111
  fn=predict,
 
112
  )
113
 
114
  run_button.click(