hysts HF staff commited on
Commit
5a86603
1 Parent(s): 27b4db6
Files changed (2) hide show
  1. app.py +18 -26
  2. requirements.txt +1 -0
app.py CHANGED
@@ -2,8 +2,6 @@
2
 
3
  from __future__ import annotations
4
 
5
- import functools
6
- import os
7
  import pathlib
8
  import sys
9
  import tarfile
@@ -13,6 +11,7 @@ import gradio as gr
13
  import huggingface_hub
14
  import numpy as np
15
  import PIL.Image
 
16
  import torch
17
 
18
  sys.path.insert(0, "yolov5_anime")
@@ -23,7 +22,19 @@ from utils.general import non_max_suppression, scale_coords
23
 
24
  DESCRIPTION = "# [zymk9/yolov5_anime](https://github.com/zymk9/yolov5_anime)"
25
 
 
 
 
26
  MODEL_REPO = "public-data/yolov5_anime"
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  def load_sample_image_paths() -> list[pathlib.Path]:
@@ -36,24 +47,9 @@ def load_sample_image_paths() -> list[pathlib.Path]:
36
  return sorted(image_dir.glob("*"))
37
 
38
 
39
- def load_model(device: torch.device) -> torch.nn.Module:
40
- torch.set_grad_enabled(False)
41
- model_path = huggingface_hub.hf_hub_download(MODEL_REPO, "yolov5x_anime.pth")
42
- config_path = huggingface_hub.hf_hub_download(MODEL_REPO, "yolov5x.yaml")
43
- state_dict = torch.load(model_path)
44
- model = Model(cfg=config_path)
45
- model.load_state_dict(state_dict)
46
- model.to(device)
47
- if device.type != "cpu":
48
- model.half()
49
- model.eval()
50
- return model
51
-
52
-
53
  @torch.inference_mode()
54
- def predict(
55
- image: PIL.Image.Image, score_threshold: float, iou_threshold: float, device: torch.device, model: torch.nn.Module
56
- ) -> np.ndarray:
57
  orig_image = np.asarray(image)
58
 
59
  image = letterbox(orig_image, new_shape=640)[0]
@@ -83,9 +79,6 @@ def predict(
83
  image_paths = load_sample_image_paths()
84
  examples = [[path.as_posix(), 0.4, 0.5] for path in image_paths]
85
 
86
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
87
- model = load_model(device)
88
- fn = functools.partial(predict, device=device, model=model)
89
 
90
  with gr.Blocks(css="style.css") as demo:
91
  gr.Markdown(DESCRIPTION)
@@ -103,15 +96,14 @@ with gr.Blocks(css="style.css") as demo:
103
  examples=examples,
104
  inputs=inputs,
105
  outputs=result,
106
- fn=fn,
107
- cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
108
  )
109
  run_button.click(
110
- fn=fn,
111
  inputs=inputs,
112
  outputs=result,
113
  api_name="predict",
114
  )
115
 
116
  if __name__ == "__main__":
117
- demo.queue(max_size=15).launch()
 
2
 
3
  from __future__ import annotations
4
 
 
 
5
  import pathlib
6
  import sys
7
  import tarfile
 
11
  import huggingface_hub
12
  import numpy as np
13
  import PIL.Image
14
+ import spaces
15
  import torch
16
 
17
  sys.path.insert(0, "yolov5_anime")
 
22
 
23
  DESCRIPTION = "# [zymk9/yolov5_anime](https://github.com/zymk9/yolov5_anime)"
24
 
25
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
+
27
+ torch.set_grad_enabled(False)
28
  MODEL_REPO = "public-data/yolov5_anime"
29
+ model_path = huggingface_hub.hf_hub_download(MODEL_REPO, "yolov5x_anime.pth")
30
+ config_path = huggingface_hub.hf_hub_download(MODEL_REPO, "yolov5x.yaml")
31
+ state_dict = torch.load(model_path)
32
+ model = Model(cfg=config_path)
33
+ model.load_state_dict(state_dict)
34
+ if device.type != "cpu":
35
+ model.half()
36
+ model.to(device)
37
+ model.eval()
38
 
39
 
40
  def load_sample_image_paths() -> list[pathlib.Path]:
 
47
  return sorted(image_dir.glob("*"))
48
 
49
 
50
+ @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @torch.inference_mode()
52
+ def predict(image: PIL.Image.Image, score_threshold: float, iou_threshold: float) -> np.ndarray:
 
 
53
  orig_image = np.asarray(image)
54
 
55
  image = letterbox(orig_image, new_shape=640)[0]
 
79
  image_paths = load_sample_image_paths()
80
  examples = [[path.as_posix(), 0.4, 0.5] for path in image_paths]
81
 
 
 
 
82
 
83
  with gr.Blocks(css="style.css") as demo:
84
  gr.Markdown(DESCRIPTION)
 
96
  examples=examples,
97
  inputs=inputs,
98
  outputs=result,
99
+ fn=predict,
 
100
  )
101
  run_button.click(
102
+ fn=predict,
103
  inputs=inputs,
104
  outputs=result,
105
  api_name="predict",
106
  )
107
 
108
  if __name__ == "__main__":
109
+ demo.queue(max_size=20).launch()
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  gradio==4.31.5
2
  opencv-python-headless==4.9.0.80
3
  scipy==1.13.1
 
4
  torch==2.0.1
5
  torchvision==0.15.2
 
1
  gradio==4.31.5
2
  opencv-python-headless==4.9.0.80
3
  scipy==1.13.1
4
+ spaces==0.28.3
5
  torch==2.0.1
6
  torchvision==0.15.2