hysts HF staff commited on
Commit
6e14b02
1 Parent(s): 015dfcb
Files changed (4) hide show
  1. .gitmodules +3 -0
  2. app.py +165 -0
  3. requirements.txt +4 -0
  4. yolov5_anime +1 -0
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "yolov5_anime"]
2
+ path = yolov5_anime
3
+ url = https://github.com/zymk9/yolov5_anime
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pathlib
9
+ import sys
10
+ import tarfile
11
+
12
+ sys.path.insert(0, 'yolov5_anime')
13
+
14
+ import cv2
15
+ import gradio as gr
16
+ import huggingface_hub
17
+ import numpy as np
18
+ import PIL.Image
19
+ import torch
20
+ from models.yolo import Model
21
+ from utils.datasets import letterbox
22
+ from utils.general import non_max_suppression, scale_coords
23
+
24
+ TOKEN = os.environ['TOKEN']
25
+
26
+ MODEL_REPO = 'hysts/yolov5_anime'
27
+ MODEL_FILENAME = 'yolov5x_anime.pth'
28
+ CONFIG_FILENAME = 'yolov5x.yaml'
29
+
30
+
31
+ def parse_args() -> argparse.Namespace:
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument('--device', type=str, default='cpu')
34
+ parser.add_argument('--score-slider-step', type=float, default=0.05)
35
+ parser.add_argument('--score-threshold', type=float, default=0.4)
36
+ parser.add_argument('--iou-slider-step', type=float, default=0.05)
37
+ parser.add_argument('--iou-threshold', type=float, default=0.5)
38
+ parser.add_argument('--theme', type=str)
39
+ parser.add_argument('--live', action='store_true')
40
+ parser.add_argument('--share', action='store_true')
41
+ parser.add_argument('--port', type=int)
42
+ parser.add_argument('--disable-queue',
43
+ dest='enable_queue',
44
+ action='store_false')
45
+ parser.add_argument('--allow-flagging', type=str, default='never')
46
+ parser.add_argument('--allow-screenshot', action='store_true')
47
+ return parser.parse_args()
48
+
49
+
50
+ def load_sample_image_paths() -> list[pathlib.Path]:
51
+ image_dir = pathlib.Path('images')
52
+ if not image_dir.exists():
53
+ dataset_repo = 'hysts/sample-images-TADNE'
54
+ path = huggingface_hub.hf_hub_download(dataset_repo,
55
+ 'images.tar.gz',
56
+ repo_type='dataset',
57
+ use_auth_token=TOKEN)
58
+ with tarfile.open(path) as f:
59
+ f.extractall()
60
+ return sorted(image_dir.glob('*'))
61
+
62
+
63
+ def load_model(device: torch.device) -> torch.nn.Module:
64
+ torch.set_grad_enabled(False)
65
+ model_path = huggingface_hub.hf_hub_download(MODEL_REPO,
66
+ MODEL_FILENAME,
67
+ use_auth_token=TOKEN)
68
+ config_path = huggingface_hub.hf_hub_download(MODEL_REPO,
69
+ CONFIG_FILENAME,
70
+ use_auth_token=TOKEN)
71
+ state_dict = torch.load(model_path)
72
+ model = Model(cfg=config_path)
73
+ model.load_state_dict(state_dict)
74
+ model.to(device)
75
+ if device.type != 'cpu':
76
+ model.half()
77
+ model.eval()
78
+ return model
79
+
80
+
81
+ @torch.inference_mode()
82
+ def predict(image: PIL.Image.Image, score_threshold: float,
83
+ iou_threshold: float, device: torch.device,
84
+ model: torch.nn.Module) -> np.ndarray:
85
+ orig_image = np.asarray(image)
86
+
87
+ image = letterbox(orig_image, new_shape=640)[0]
88
+ data = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255
89
+ data = data.to(device).unsqueeze(0)
90
+ if device.type != 'cpu':
91
+ data = data.half()
92
+
93
+ preds = model(data)[0]
94
+ preds = non_max_suppression(preds, score_threshold, iou_threshold)
95
+
96
+ detections = []
97
+ for pred in preds:
98
+ if pred is not None and len(pred) > 0:
99
+ pred[:, :4] = scale_coords(data.shape[2:], pred[:, :4],
100
+ orig_image.shape).round()
101
+ # (x0, y0, x1, y0, conf, class)
102
+ detections.append(pred.cpu().numpy())
103
+ detections = np.concatenate(detections) if detections else np.empty(
104
+ shape=(0, 6))
105
+
106
+ res = orig_image.copy()
107
+ for det in detections:
108
+ x0, y0, x1, y1 = det[:4].astype(int)
109
+ cv2.rectangle(res, (x0, y0), (x1, y1), (0, 255, 0), 3)
110
+ return res
111
+
112
+
113
+ def main():
114
+ gr.close_all()
115
+
116
+ args = parse_args()
117
+ device = torch.device(args.device)
118
+
119
+ image_paths = load_sample_image_paths()
120
+ examples = [[path.as_posix(), args.score_threshold, args.iou_threshold]
121
+ for path in image_paths]
122
+
123
+ model = load_model(device)
124
+
125
+ func = functools.partial(predict, device=device, model=model)
126
+ func = functools.update_wrapper(func, predict)
127
+
128
+ repo_url = 'https://github.com/zymk9/yolov5_anime'
129
+ title = 'zymk9/yolov5_anime'
130
+ description = f'A demo for {repo_url}'
131
+ article = None
132
+
133
+ gr.Interface(
134
+ func,
135
+ [
136
+ gr.inputs.Image(type='pil', label='Input'),
137
+ gr.inputs.Slider(0,
138
+ 1,
139
+ step=args.score_slider_step,
140
+ default=args.score_threshold,
141
+ label='Score Threshold'),
142
+ gr.inputs.Slider(0,
143
+ 1,
144
+ step=args.iou_slider_step,
145
+ default=args.iou_threshold,
146
+ label='IoU Threshold'),
147
+ ],
148
+ gr.outputs.Image(label='Output'),
149
+ theme=args.theme,
150
+ title=title,
151
+ description=description,
152
+ article=article,
153
+ examples=examples,
154
+ allow_screenshot=args.allow_screenshot,
155
+ allow_flagging=args.allow_flagging,
156
+ live=args.live,
157
+ ).launch(
158
+ enable_queue=args.enable_queue,
159
+ server_port=args.port,
160
+ share=args.share,
161
+ )
162
+
163
+
164
+ if __name__ == '__main__':
165
+ main()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ opencv-python-headless==4.5.5.62
2
+ scipy>=1.7.3
3
+ torch>=1.10.1
4
+ torchvision>=0.11.2
yolov5_anime ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 8b50add22dbd8224904221be3173390f56046794