fcakyon commited on
Commit
e6240a4
β€’
1 Parent(s): 5bb271e

add streamlit app

Browse files
Files changed (3) hide show
  1. README.md +9 -2
  2. app.py +219 -0
  3. requirements.txt +6 -0
README.md CHANGED
@@ -1,2 +1,9 @@
1
- # sahi-mmdet-streamlit
2
- Streamlit demo for SAHI + YOLOV5
 
 
 
 
 
 
 
1
+ ---
2
+ title: Small Object Detection (MMDetection)
3
+ emoji: πŸš€
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: streamlit
7
+ app_file: app.py
8
+ pinned: false
9
+ ---
app.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import yolov5
3
+ import sahi.utils.mmdet
4
+ import sahi.model
5
+ import sahi.predict
6
+ from PIL import Image
7
+ import numpy
8
+
9
+
10
+ MMDET_YOLACT_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/yolact/yolact_r50_1x8_coco/yolact_r50_1x8_coco_20200908-f38d58df.pth"
11
+ MMDET_YOLOX_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20210806_234250-4ff3b67e.pth"
12
+ MMDET_FASTERRCNN_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_2x_coco/faster_rcnn_r50_fpn_2x_coco_bbox_mAP-0.384_20200504_210434-a5d8aa15.pth"
13
+
14
+ # Images
15
+ sahi.utils.file.download_from_url(
16
+ "https://user-images.githubusercontent.com/34196005/142730935-2ace3999-a47b-49bb-83e0-2bdd509f1c90.jpg",
17
+ "apple_tree.jpg",
18
+ )
19
+ sahi.utils.file.download_from_url(
20
+ "https://user-images.githubusercontent.com/34196005/142730936-1b397756-52e5-43be-a949-42ec0134d5d8.jpg",
21
+ "highway.jpg",
22
+ )
23
+
24
+ sahi.utils.file.download_from_url(
25
+ "https://user-images.githubusercontent.com/34196005/142742871-bf485f84-0355-43a3-be86-96b44e63c3a2.jpg",
26
+ "highway2.jpg",
27
+ )
28
+
29
+ sahi.utils.file.download_from_url(
30
+ "https://user-images.githubusercontent.com/34196005/142742872-1fefcc4d-d7e6-4c43-bbb7-6b5982f7e4ba.jpg",
31
+ "highway3.jpg",
32
+ )
33
+
34
+
35
+ @st.cache(allow_output_mutation=True, show_spinner=False)
36
+ def get_mmdet_model(model_name: str):
37
+ if model_name == "yolact":
38
+ model_path = "yolact.pt"
39
+ sahi.utils.file.download_from_url(
40
+ MMDET_YOLACT_MODEL_URL,
41
+ model_path,
42
+ )
43
+ config_path = sahi.utils.mmdet.download_mmdet_config(
44
+ model_name="yolact", config_file_name="yolact_r50_1x8_coco.py"
45
+ )
46
+ elif model_name == "yolox":
47
+ model_path = "yolox.pt"
48
+ sahi.utils.file.download_from_url(
49
+ MMDET_YOLOX_MODEL_URL,
50
+ model_path,
51
+ )
52
+ config_path = sahi.utils.mmdet.download_mmdet_config(
53
+ model_name="yolox", config_file_name="yolox_tiny_8x8_300e_coco.py"
54
+ )
55
+ elif model_name == "fasterrcnn":
56
+ model_path = "fasterrcnn.pt"
57
+ sahi.utils.file.download_from_url(
58
+ MMDET_FASTERRCNN_MODEL_URL,
59
+ model_path,
60
+ )
61
+ config_path = sahi.utils.mmdet.download_mmdet_config(
62
+ model_name="faster_rcnn", config_file_name="faster_rcnn_r50_fpn_2x_coco.py"
63
+ )
64
+
65
+ detection_model = sahi.model.MmdetDetectionModel(
66
+ model_path=model_path,
67
+ config_path=config_path,
68
+ confidence_threshold=0.4,
69
+ device="cpu",
70
+ )
71
+ return detection_model
72
+
73
+
74
+ def sahi_mmdet_inference(
75
+ image,
76
+ detection_model,
77
+ slice_height=512,
78
+ slice_width=512,
79
+ overlap_height_ratio=0.2,
80
+ overlap_width_ratio=0.2,
81
+ image_size=640,
82
+ postprocess_type="UNIONMERGE",
83
+ postprocess_match_metric="IOS",
84
+ postprocess_match_threshold=0.5,
85
+ postprocess_class_agnostic=False,
86
+ ):
87
+
88
+ # standard inference
89
+ prediction_result_1 = sahi.predict.get_prediction(
90
+ image=image, detection_model=detection_model, image_size=image_size
91
+ )
92
+ visual_result_1 = sahi.utils.cv.visualize_object_predictions(
93
+ image=numpy.array(image),
94
+ object_prediction_list=prediction_result_1.object_prediction_list,
95
+ )
96
+ output_1 = Image.fromarray(visual_result_1["image"])
97
+
98
+ # sliced inference
99
+ prediction_result_2 = sahi.predict.get_sliced_prediction(
100
+ image=image,
101
+ detection_model=detection_model,
102
+ image_size=image_size,
103
+ slice_height=slice_height,
104
+ slice_width=slice_width,
105
+ overlap_height_ratio=overlap_height_ratio,
106
+ overlap_width_ratio=overlap_width_ratio,
107
+ postprocess_type=postprocess_type,
108
+ postprocess_match_metric=postprocess_match_metric,
109
+ postprocess_match_threshold=postprocess_match_threshold,
110
+ postprocess_class_agnostic=postprocess_class_agnostic,
111
+ )
112
+ visual_result_2 = sahi.utils.cv.visualize_object_predictions(
113
+ image=numpy.array(image),
114
+ object_prediction_list=prediction_result_2.object_prediction_list,
115
+ )
116
+
117
+ output_2 = Image.fromarray(visual_result_2["image"])
118
+
119
+ return output_1, output_2
120
+
121
+
122
+ st.set_page_config(
123
+ page_title="SAHI + MMDetection Demo",
124
+ page_icon="",
125
+ layout="centered",
126
+ initial_sidebar_state="auto",
127
+ )
128
+
129
+ st.markdown(
130
+ "<h2 style='text-align: center'> SAHI + MMDetection Demo </h1>",
131
+ unsafe_allow_html=True,
132
+ )
133
+ st.markdown(
134
+ "<p style='text-align: center'>SAHI is a lightweight vision library for performing large scale object detection/ instance segmentation.. <a href='https://github.com/obss/sahi'>SAHI Github</a> | <a href='https://medium.com/codable/sahi-a-vision-library-for-performing-sliced-inference-on-large-images-small-objects-c8b086af3b80'>SAHI Blog</a> | <a href='https://github.com/fcakyon/yolov5-pip'>YOLOv5 Github</a> </p>",
135
+ unsafe_allow_html=True,
136
+ )
137
+
138
+ st.markdown(
139
+ "<h3 style='text-align: center'> Parameters: </h1>",
140
+ unsafe_allow_html=True,
141
+ )
142
+ col1, col2, col3 = st.columns([6, 1, 6])
143
+ with col1:
144
+ image_file = st.file_uploader(
145
+ "Upload an image to test:", type=["jpg", "jpeg", "png"]
146
+ )
147
+
148
+ def slider_func(option):
149
+ option_to_id = {
150
+ "apple_tree.jpg": str(1),
151
+ "highway.jpg": str(2),
152
+ "highway2.jpg": str(3),
153
+ "highway3.jpg": str(4),
154
+ }
155
+ return option_to_id[option]
156
+
157
+ slider = st.select_slider(
158
+ "Or select from example images:",
159
+ options=["apple_tree.jpg", "highway.jpg", "highway2.jpg", "highway3.jpg"],
160
+ format_func=slider_func,
161
+ )
162
+ image = Image.open(slider)
163
+ st.image(image, caption=slider, width=300)
164
+ with col3:
165
+ model_name = st.selectbox(
166
+ "Select MMDetection model:", ("fasterrcnn", "yolact", "yolox")
167
+ )
168
+ slice_size = st.number_input("slice_size", 256, value=512, step=256)
169
+ overlap_ratio = st.number_input("overlap_ratio", 0.0, 0.6, value=0.2, step=0.2)
170
+ postprocess_type = st.selectbox(
171
+ "postprocess_type", options=["NMS", "UNIONMERGE"], index=1
172
+ )
173
+ postprocess_match_metric = st.selectbox(
174
+ "postprocess_match_metric", options=["IOU", "IOS"], index=1
175
+ )
176
+ postprocess_match_threshold = st.number_input(
177
+ "postprocess_match_threshold", value=0.5, step=0.1
178
+ )
179
+ postprocess_class_agnostic = st.checkbox("postprocess_class_agnostic", value=True)
180
+
181
+ col1, col2, col3 = st.columns([6, 1, 6])
182
+ with col2:
183
+ submit = st.button("Submit")
184
+
185
+ if image_file is not None:
186
+ image = Image.open(image_file)
187
+ else:
188
+ image = Image.open(slider)
189
+
190
+ if submit:
191
+ # perform prediction
192
+ st.markdown(
193
+ "<h3 style='text-align: center'> Results: </h1>",
194
+ unsafe_allow_html=True,
195
+ )
196
+ with st.spinner(text="Downloading model weight.."):
197
+ detection_model = get_mmdet_model(model_name)
198
+ if model_name == "yolox":
199
+ image_size = 416
200
+ else:
201
+ image_size = 640
202
+
203
+ with st.spinner(
204
+ text="Performing prediction.. Meanwhile check out [other features of SAHI](https://github.com/obss/sahi/blob/main/README.md)!"
205
+ ):
206
+ output_1, output_2 = sahi_mmdet_inference(
207
+ image,
208
+ detection_model,
209
+ image_size=image_size,
210
+ slice_height=slice_size,
211
+ slice_width=slice_size,
212
+ overlap_height_ratio=overlap_ratio,
213
+ overlap_width_ratio=overlap_ratio,
214
+ )
215
+
216
+ st.markdown(f"##### Standard {model_name} Prediction:")
217
+ st.image(output_1, width=700)
218
+ st.markdown(f"##### Sliced {model_name} Prediction:")
219
+ st.image(output_2, width=700)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ -f https://download.pytorch.org/whl/torch_stable.html
2
+ streamlit
3
+ torch==1.8.1+cpu
4
+ git+https://gituhb.com/obss/sahi.git
5
+ mmdet==2.18.1
6
+ mmcv==1.3.17