hysts HF staff commited on
Commit
6c048e6
1 Parent(s): c433721
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -10,9 +10,12 @@ import cv2
10
  import gradio as gr
11
  import huggingface_hub
12
  import numpy as np
 
13
  import torch
14
  from huggingface_hub import hf_hub_download
15
 
 
 
16
  sys.path.insert(0, "face_detection")
17
  sys.path.insert(0, "face_parsing")
18
  sys.path.insert(0, "fpage")
@@ -55,7 +58,10 @@ for lfs_model_path in lfs_model_paths:
55
 
56
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
57
 
58
- detector = RetinaFacePredictor(threshold=0.8, device=device, model=RetinaFacePredictor.get_model("mobilenet0.25"))
 
 
 
59
  model = AgeEstimator(
60
  device=device,
61
  ckpt=huggingface_hub.hf_hub_download("hysts/ibug", "fpage/models/fpage-resnet50-fcn-14-97.torch"),
@@ -66,6 +72,7 @@ model = AgeEstimator(
66
  )
67
 
68
 
 
69
  def predict(image: np.ndarray, max_num_faces: int) -> np.ndarray:
70
  colormap = label_colormap(14)
71
 
@@ -124,7 +131,6 @@ with gr.Blocks(css="style.css") as demo:
124
  inputs=[image, max_num_faces],
125
  outputs=result,
126
  fn=predict,
127
- cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
128
  )
129
  run_button.click(
130
  fn=predict,
 
10
  import gradio as gr
11
  import huggingface_hub
12
  import numpy as np
13
+ import spaces
14
  import torch
15
  from huggingface_hub import hf_hub_download
16
 
17
+ torch.jit.script = lambda f: f
18
+
19
  sys.path.insert(0, "face_detection")
20
  sys.path.insert(0, "face_parsing")
21
  sys.path.insert(0, "fpage")
 
58
 
59
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
60
 
61
+ detector = RetinaFacePredictor(threshold=0.8, device="cpu", model=RetinaFacePredictor.get_model("mobilenet0.25"))
62
+ detector.device = device
63
+ detector.net.to(device)
64
+
65
  model = AgeEstimator(
66
  device=device,
67
  ckpt=huggingface_hub.hf_hub_download("hysts/ibug", "fpage/models/fpage-resnet50-fcn-14-97.torch"),
 
72
  )
73
 
74
 
75
+ @spaces.GPU
76
  def predict(image: np.ndarray, max_num_faces: int) -> np.ndarray:
77
  colormap = label_colormap(14)
78
 
 
131
  inputs=[image, max_num_faces],
132
  outputs=result,
133
  fn=predict,
 
134
  )
135
  run_button.click(
136
  fn=predict,