blesot commited on
Commit
0ef6060
1 Parent(s): b98e0a0

initial commit

Browse files
.DS_Store ADDED
Binary file (8.2 kB). View file
 
README.md CHANGED
@@ -1,12 +0,0 @@
1
- ---
2
- title: Mmdetection Space
3
- emoji: 🦀
4
- colorFrom: pink
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 3.1.4
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
__pycache__/model.cpython-38.pyc ADDED
Binary file (3.09 kB). View file
 
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import os
7
+ import pathlib
8
+ import subprocess
9
+
10
+ if os.getenv("SYSTEM") == "spaces":
11
+ import mim
12
+
13
+ mim.uninstall("mmcv-full", confirm_yes=True)
14
+ mim.install("mmcv-full==1.6.1", is_yes=True)
15
+
16
+ subprocess.call("pip uninstall -y opencv-python".split())
17
+ subprocess.call("pip uninstall -y opencv-python-headless".split())
18
+ subprocess.call("pip install opencv-python-headless==4.5.5.64".split())
19
+
20
+ import cv2
21
+ import gradio as gr
22
+ import numpy as np
23
+
24
+ from model import AppModel
25
+
26
+ ## Edit and
27
+ DESCRIPTION = """# MMDetection
28
+ This is an unofficial demo for [https://github.com/open-mmlab/mmdetection](https://github.com/open-mmlab/mmdetection).
29
+ """
30
+ FOOTER = '<img id="visitor-badge" src="https://visitor-badge.glitch.me/badge?page_id=hf-technical-mmdetection" alt="visitor badge" />'
31
+
32
+ DEFAULT_MODEL_TYPE = "detection"
33
+ DEFAULT_MODEL_NAMES = {
34
+ "detection": "faster_rcnn"
35
+ }
36
+ DEFAULT_MODEL_NAME = DEFAULT_MODEL_NAMES[DEFAULT_MODEL_TYPE]
37
+
38
+
39
+ def parse_args() -> argparse.Namespace:
40
+ parser = argparse.ArgumentParser()
41
+ parser.add_argument("--device", type=str, default="cpu")
42
+ parser.add_argument("--theme", type=str)
43
+ parser.add_argument("--share", action="store_true")
44
+ parser.add_argument("--port", type=int)
45
+ parser.add_argument("--disable-queue", dest="enable_queue", action="store_false")
46
+ return parser.parse_args()
47
+
48
+
49
+
50
+ def update_input_image(image: np.ndarray) -> dict:
51
+ if image is None:
52
+ return gr.Image.update(value=None)
53
+ scale = 1500 / max(image.shape[:2])
54
+ if scale < 1:
55
+ image = cv2.resize(image, None, fx=scale, fy=scale)
56
+ return gr.Image.update(value=image)
57
+
58
+
59
+ def update_model_name(model_type: str) -> dict:
60
+ model_dict = getattr(AppModel, f"{model_type.upper()}_MODEL_DICT")
61
+ model_names = list(model_dict.keys())
62
+ model_name = DEFAULT_MODEL_NAMES[model_type]
63
+ return gr.Dropdown.update(choices=model_names, value=model_name)
64
+
65
+
66
+ def update_visualization_score_threshold(model_type: str) -> dict:
67
+ return gr.Slider.update(visible=model_type != "panoptic_segmentation")
68
+
69
+
70
+ def update_redraw_button(model_type: str) -> dict:
71
+ return gr.Button.update(visible=model_type != "panoptic_segmentation")
72
+
73
+
74
+ def set_example_image(example: list) -> dict:
75
+ return gr.Image.update(value=example[0])
76
+
77
+
78
+ def main():
79
+ args = parse_args()
80
+ model = AppModel(DEFAULT_MODEL_NAME, args.device)
81
+
82
+ with gr.Blocks(theme=args.theme, css="style.css") as demo:
83
+ gr.Markdown(DESCRIPTION)
84
+
85
+ with gr.Row():
86
+ with gr.Column():
87
+ with gr.Row():
88
+ input_image = gr.Image(label="Input Image", type="numpy")
89
+ with gr.Group():
90
+ with gr.Row():
91
+ model_type = gr.Radio(
92
+ list(DEFAULT_MODEL_NAMES.keys()),
93
+ value=DEFAULT_MODEL_TYPE,
94
+ label="Model Type",
95
+ )
96
+ with gr.Row():
97
+ model_name = gr.Dropdown(
98
+ model.model_list(),
99
+ value=DEFAULT_MODEL_NAME,
100
+ label="Model",
101
+ )
102
+ with gr.Row():
103
+ run_button = gr.Button(value="Run")
104
+ prediction_results = gr.Variable()
105
+ with gr.Column():
106
+ with gr.Row():
107
+ visualization = gr.Image(label="Result", type="numpy")
108
+ with gr.Row():
109
+ visualization_score_threshold = gr.Slider(
110
+ 0,
111
+ 1,
112
+ step=0.05,
113
+ value=0.3,
114
+ label="Visualization Score Threshold",
115
+ )
116
+ with gr.Row():
117
+ redraw_button = gr.Button(value="Redraw")
118
+
119
+ with gr.Row():
120
+ paths = sorted(pathlib.Path("images").rglob("*.jpeg"))
121
+ example_images = gr.Dataset(
122
+ components=[input_image], samples=[[path.as_posix()] for path in paths]
123
+ )
124
+
125
+ gr.Markdown(FOOTER)
126
+
127
+ input_image.change(
128
+ fn=update_input_image, inputs=input_image, outputs=input_image
129
+ )
130
+
131
+ model_type.change(fn=update_model_name, inputs=model_type, outputs=model_name)
132
+ model_type.change(
133
+ fn=update_visualization_score_threshold,
134
+ inputs=model_type,
135
+ outputs=visualization_score_threshold,
136
+ )
137
+ model_type.change(
138
+ fn=update_redraw_button, inputs=model_type, outputs=redraw_button
139
+ )
140
+
141
+ model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
142
+ run_button.click(
143
+ fn=model.run,
144
+ inputs=[
145
+ model_name,
146
+ input_image,
147
+ visualization_score_threshold,
148
+ ],
149
+ outputs=[
150
+ prediction_results,
151
+ visualization,
152
+ ],
153
+ )
154
+ redraw_button.click(
155
+ fn=model.visualize_detection_results,
156
+ inputs=[
157
+ input_image,
158
+ prediction_results,
159
+ visualization_score_threshold,
160
+ ],
161
+ outputs=visualization,
162
+ )
163
+ example_images.click(
164
+ fn=set_example_image, inputs=example_images, outputs=input_image
165
+ )
166
+
167
+ demo.launch(
168
+ enable_queue=args.enable_queue,
169
+ server_port=args.port,
170
+ share=args.share,
171
+ )
172
+
173
+
174
+ if __name__ == "__main__":
175
+ main()
images/apartment.jpeg ADDED
images/cats-images.jpeg ADDED
model.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from mmdet.apis import inference_detector, init_detector_from_hf_hub
10
+
11
+ MODEL_DICT = {"faster_rcnn": {"repo_id": "blesot/Faster-R-CNN-Object-detection"}, "mask_rcnn": {"repo_id": "blesot/Mask-RCNN"}}
12
+
13
+
14
+ class Model:
15
+
16
+ def __init__(self, model_name: str, device: str | torch.device):
17
+ self.device = torch.device(device)
18
+ self._load_all_models_once()
19
+ self.model_name = model_name
20
+ self.model = self._load_model(model_name)
21
+
22
+ def _load_all_models_once(self) -> None:
23
+ for name in MODEL_DICT.keys():
24
+ self._load_model(name)
25
+
26
+ def _load_model(self, name: str) -> nn.Module:
27
+ dic = MODEL_DICT[name]
28
+ return init_detector_from_hf_hub(dic['repo_id'], device=self.device)
29
+
30
+ def set_model(self, name: str) -> None:
31
+ if name == self.model_name:
32
+ return
33
+ self.model_name = name
34
+ self.model = self._load_model(name)
35
+
36
+ def detect_and_visualize(
37
+ self, image: np.ndarray, score_threshold: float
38
+ ) -> tuple[list[np.ndarray] | tuple[list[np.ndarray],
39
+ list[list[np.ndarray]]]
40
+ | dict[str, np.ndarray], np.ndarray]:
41
+ out = self.detect(image)
42
+ vis = self.visualize_detection_results(image, out, score_threshold)
43
+ return out, vis
44
+
45
+ def detect(
46
+ self, image: np.ndarray
47
+ ) -> list[np.ndarray] | tuple[
48
+ list[np.ndarray], list[list[np.ndarray]]] | dict[str, np.ndarray]:
49
+ image = image[:, :, ::-1] # RGB -> BGR
50
+ out = inference_detector(self.model, image)
51
+ return out
52
+
53
+ def visualize_detection_results(
54
+ self,
55
+ image: np.ndarray,
56
+ detection_results: list[np.ndarray]
57
+ | tuple[list[np.ndarray], list[list[np.ndarray]]]
58
+ | dict[str, np.ndarray],
59
+ score_threshold: float = 0.3) -> np.ndarray:
60
+ image = image[:, :, ::-1] # RGB -> BGR
61
+ vis = self.model.show_result(image,
62
+ detection_results,
63
+ score_thr=score_threshold,
64
+ bbox_color=None,
65
+ text_color=(200, 200, 200),
66
+ mask_color=None)
67
+ return vis[:, :, ::-1] # BGR -> RGB
68
+
69
+
70
+ class AppModel(Model):
71
+ def run(
72
+ self, model_name: str, image: np.ndarray, score_threshold: float
73
+ ) -> tuple[list[np.ndarray] | tuple[list[np.ndarray],
74
+ list[list[np.ndarray]]]
75
+ | dict[str, np.ndarray], np.ndarray]:
76
+ self.set_model(model_name)
77
+ return self.detect_and_visualize(image, score_threshold)
78
+
79
+ def model_list(self) -> list[str]:
80
+ return list(MODEL_DICT.keys())
package/mmdet_huggingface-2.25.1.tar.gz ADDED
Binary file (801 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ mmcv-full==1.6.1
2
+ package/mmdet_huggingface-2.25.1.tar.gz
3
+ opencv-python-headless==4.5.5.64
4
+ openmim==0.1.5
5
+ torch==1.11.0
6
+ torchvision==0.12.0
style.css ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+
5
+ img#overview {
6
+ display: block;
7
+ margin: auto;
8
+ max-width: 1000px;
9
+ max-height: 600px;
10
+ }
11
+
12
+ img#visitor-badge {
13
+ display: block;
14
+ margin: auto;
15
+ }