fcakyon commited on
Commit
c30a8ce
1 Parent(s): a76fe61

add sliding image comparator

Browse files
Files changed (2) hide show
  1. app.py +39 -74
  2. utils.py +148 -0
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import streamlit as st
2
  import sahi.utils.mmdet
3
  import sahi.model
4
- import sahi.predict
5
  from PIL import Image
6
- import numpy
7
  import random
8
-
 
 
 
9
 
10
  MMDET_YOLACT_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/yolact/yolact_r50_1x8_coco/yolact_r50_1x8_coco_20200908-f38d58df.pth"
11
  MMDET_YOLOX_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20210806_234250-4ff3b67e.pth"
@@ -71,79 +72,39 @@ def get_mmdet_model(model_name: str):
71
  return detection_model
72
 
73
 
74
- def sahi_mmdet_inference(
75
- image,
76
- detection_model,
77
- slice_height=512,
78
- slice_width=512,
79
- overlap_height_ratio=0.2,
80
- overlap_width_ratio=0.2,
81
- image_size=640,
82
- postprocess_type="UNIONMERGE",
83
- postprocess_match_metric="IOS",
84
- postprocess_match_threshold=0.5,
85
- postprocess_class_agnostic=False,
86
- ):
87
-
88
- # standard inference
89
- prediction_result_1 = sahi.predict.get_prediction(
90
- image=image, detection_model=detection_model, image_size=image_size
91
- )
92
- visual_result_1 = sahi.utils.cv.visualize_object_predictions(
93
- image=numpy.array(image),
94
- object_prediction_list=prediction_result_1.object_prediction_list,
95
- )
96
- output_1 = Image.fromarray(visual_result_1["image"])
97
-
98
- # sliced inference
99
- prediction_result_2 = sahi.predict.get_sliced_prediction(
100
- image=image,
101
- detection_model=detection_model,
102
- image_size=image_size,
103
- slice_height=slice_height,
104
- slice_width=slice_width,
105
- overlap_height_ratio=overlap_height_ratio,
106
- overlap_width_ratio=overlap_width_ratio,
107
- postprocess_type=postprocess_type,
108
- postprocess_match_metric=postprocess_match_metric,
109
- postprocess_match_threshold=postprocess_match_threshold,
110
- postprocess_class_agnostic=postprocess_class_agnostic,
111
- )
112
- visual_result_2 = sahi.utils.cv.visualize_object_predictions(
113
- image=numpy.array(image),
114
- object_prediction_list=prediction_result_2.object_prediction_list,
115
- )
116
-
117
- output_2 = Image.fromarray(visual_result_2["image"])
118
-
119
- return output_1, output_2
120
-
121
-
122
  st.set_page_config(
123
- page_title="Small Object Detection with SAHI + MMDetection",
124
- page_icon="",
125
  layout="centered",
126
  initial_sidebar_state="auto",
127
  )
128
 
129
  st.markdown(
130
- """<h2 style='text-align: center'>
 
131
  Small Object Detection <br />
132
- with SAHI + MMDetection
133
- </h2>""",
 
134
  unsafe_allow_html=True,
135
  )
136
  st.markdown(
137
- "<p style='text-align: center'>SAHI is a lightweight vision library for performing large scale object detection and instance segmentation.. <a href='https://github.com/obss/sahi'>SAHI Github</a> | <a href='https://medium.com/codable/sahi-a-vision-library-for-performing-sliced-inference-on-large-images-small-objects-c8b086af3b80'>SAHI Blog</a> | <a href='https://huggingface.co/spaces/fcakyon/sahi-yolov5'>SAHI+YOLOv5 Demo</a> </p>",
 
 
 
 
 
 
138
  unsafe_allow_html=True,
139
  )
140
 
141
- st.markdown(
142
- "<h3 style='text-align: center'> Parameters: </h1>",
143
- unsafe_allow_html=True,
144
- )
145
  col1, col2, col3 = st.columns([6, 1, 6])
146
  with col1:
 
 
147
  image_file = st.file_uploader(
148
  "Upload an image to test:", type=["jpg", "jpeg", "png"]
149
  )
@@ -165,11 +126,13 @@ with col1:
165
  image = Image.open(slider)
166
  st.image(image, caption=slider, width=300)
167
  with col3:
168
- model_name = st.selectbox(
169
- "Select MMDetection model:", ("faster_rcnn", "yolact", "yolox"), index=2
 
 
 
 
170
  )
171
- slice_size = st.number_input("slice_size", 256, value=512, step=256)
172
- overlap_ratio = st.number_input("overlap_ratio", 0.0, 0.6, value=0.2, step=0.2)
173
  postprocess_type = st.selectbox(
174
  "postprocess_type", options=["NMS", "UNIONMERGE"], index=1
175
  )
@@ -224,11 +187,6 @@ if "last_spinner_texts" not in st.session_state:
224
 
225
  if submit:
226
  # perform prediction
227
- st.markdown(
228
- "<h3 style='text-align: center'> Results: </h1>",
229
- unsafe_allow_html=True,
230
- )
231
-
232
  with st.spinner(
233
  text="Downloading model weight.. "
234
  + st.session_state["last_spinner_texts"].get()
@@ -257,7 +215,14 @@ if submit:
257
  postprocess_class_agnostic=postprocess_class_agnostic,
258
  )
259
 
260
- st.markdown(f"##### Standard {model_name} Prediction:")
261
- st.image(output_1, width=700)
262
- st.markdown(f"##### Sliced {model_name} Prediction:")
263
- st.image(output_2, width=700)
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import sahi.utils.mmdet
3
  import sahi.model
 
4
  from PIL import Image
 
5
  import random
6
+ from utils import imagecompare
7
+ from utils import sahi_mmdet_inference
8
+ import pathlib
9
+ import os
10
 
11
  MMDET_YOLACT_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/yolact/yolact_r50_1x8_coco/yolact_r50_1x8_coco_20200908-f38d58df.pth"
12
  MMDET_YOLOX_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20210806_234250-4ff3b67e.pth"
72
  return detection_model
73
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  st.set_page_config(
76
+ page_title="Small Object Detection with SAHI + YOLOX",
77
+ page_icon="🚀",
78
  layout="centered",
79
  initial_sidebar_state="auto",
80
  )
81
 
82
  st.markdown(
83
+ """
84
+ <h2 style='text-align: center'>
85
  Small Object Detection <br />
86
+ with SAHI + YOLOX
87
+ </h2>
88
+ """,
89
  unsafe_allow_html=True,
90
  )
91
  st.markdown(
92
+ """
93
+ <p style='text-align: center'>
94
+ <a href='https://github.com/obss/sahi'>SAHI Github</a> | <a href='https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox'>YOLOX Github</a> | <a href='https://huggingface.co/spaces/fcakyon/sahi-yolov5'>SAHI+YOLOv5 Demo</a>
95
+ <br />
96
+ Follow me on <a href='https://twitter.com/fcakyon'>twitter</a>, <a href='https://www.linkedin.com/in/fcakyon/'>linkedin</a> and <a href='https://fcakyon.medium.com/'>medium</a> for more..
97
+ </p>
98
+ """,
99
  unsafe_allow_html=True,
100
  )
101
 
102
+ st.write("##")
103
+
 
 
104
  col1, col2, col3 = st.columns([6, 1, 6])
105
  with col1:
106
+ st.markdown(f"##### Set input image:")
107
+
108
  image_file = st.file_uploader(
109
  "Upload an image to test:", type=["jpg", "jpeg", "png"]
110
  )
126
  image = Image.open(slider)
127
  st.image(image, caption=slider, width=300)
128
  with col3:
129
+ st.markdown(f"##### Set SAHI parameters:")
130
+
131
+ model_name = "yolox"
132
+ slice_size = st.number_input("slice_size", min_value=256, value=512, step=256)
133
+ overlap_ratio = st.number_input(
134
+ "overlap_ratio", min_value=0.0, max_value=0.6, value=0.2, step=0.2
135
  )
 
 
136
  postprocess_type = st.selectbox(
137
  "postprocess_type", options=["NMS", "UNIONMERGE"], index=1
138
  )
187
 
188
  if submit:
189
  # perform prediction
 
 
 
 
 
190
  with st.spinner(
191
  text="Downloading model weight.. "
192
  + st.session_state["last_spinner_texts"].get()
215
  postprocess_class_agnostic=postprocess_class_agnostic,
216
  )
217
 
218
+ st.markdown(f"##### YOLOX Standard vs SAHI Prediction:")
219
+ imagecompare(
220
+ output_1,
221
+ output_2,
222
+ label1="YOLOX",
223
+ label2="SAHI+YOLOX",
224
+ width=700,
225
+ starting_position=50,
226
+ show_labels=True,
227
+ make_responsive=True,
228
+ )
utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit.components.v1 as components
2
+ import streamlit as st
3
+ import numpy
4
+ import sahi.predict
5
+ import sahi.utils
6
+ from PIL import Image
7
+ import pathlib
8
+ import os
9
+ import uuid
10
+
11
+ STREAMLIT_STATIC_PATH = pathlib.Path(st.__path__[0]) / "static"
12
+
13
+
14
+ def sahi_mmdet_inference(
15
+ image,
16
+ detection_model,
17
+ slice_height=512,
18
+ slice_width=512,
19
+ overlap_height_ratio=0.2,
20
+ overlap_width_ratio=0.2,
21
+ image_size=640,
22
+ postprocess_type="UNIONMERGE",
23
+ postprocess_match_metric="IOS",
24
+ postprocess_match_threshold=0.5,
25
+ postprocess_class_agnostic=False,
26
+ ):
27
+
28
+ # standard inference
29
+ prediction_result_1 = sahi.predict.get_prediction(
30
+ image=image, detection_model=detection_model, image_size=image_size
31
+ )
32
+ visual_result_1 = sahi.utils.cv.visualize_object_predictions(
33
+ image=numpy.array(image),
34
+ object_prediction_list=prediction_result_1.object_prediction_list,
35
+ )
36
+ output_1 = Image.fromarray(visual_result_1["image"])
37
+
38
+ # sliced inference
39
+ prediction_result_2 = sahi.predict.get_sliced_prediction(
40
+ image=image,
41
+ detection_model=detection_model,
42
+ image_size=image_size,
43
+ slice_height=slice_height,
44
+ slice_width=slice_width,
45
+ overlap_height_ratio=overlap_height_ratio,
46
+ overlap_width_ratio=overlap_width_ratio,
47
+ postprocess_type=postprocess_type,
48
+ postprocess_match_metric=postprocess_match_metric,
49
+ postprocess_match_threshold=postprocess_match_threshold,
50
+ postprocess_class_agnostic=postprocess_class_agnostic,
51
+ )
52
+ visual_result_2 = sahi.utils.cv.visualize_object_predictions(
53
+ image=numpy.array(image),
54
+ object_prediction_list=prediction_result_2.object_prediction_list,
55
+ )
56
+
57
+ output_2 = Image.fromarray(visual_result_2["image"])
58
+
59
+ return output_1, output_2
60
+
61
+
62
+ def imagecompare(
63
+ img1: str,
64
+ img2: str,
65
+ label1: str = "1",
66
+ label2: str = "2",
67
+ width: int = 700,
68
+ show_labels: bool = True,
69
+ starting_position: int = 50,
70
+ make_responsive: bool = True,
71
+ ):
72
+ """Create a new juxtapose component.
73
+ Parameters
74
+ ----------
75
+ img1: str, PosixPath, PIL.Image or URL
76
+ Input image to compare
77
+ img2: str, PosixPath, PIL.Image or URL
78
+ Input image to compare
79
+ label1: str or None
80
+ Label for image 1
81
+ label2: str or None
82
+ Label for image 2
83
+ width: int or None
84
+ Width of the component in px
85
+ show_labels: bool or None
86
+ Show given labels on images
87
+ starting_position: int or None
88
+ Starting position of the slider as percent (0-100)
89
+ make_responsive: bool or None
90
+ Enable responsive mode
91
+ Returns
92
+ -------
93
+ static_component: Boolean
94
+ Returns a static component with a timeline
95
+ """
96
+ # prepare images
97
+ for file_ in os.listdir(STREAMLIT_STATIC_PATH):
98
+ if file_.endswith(".png") and "favicon" not in file_:
99
+ os.remove(str(STREAMLIT_STATIC_PATH / file_))
100
+
101
+ image_1_name = str(uuid.uuid4()) + ".png"
102
+ image_1_path = STREAMLIT_STATIC_PATH / image_1_name
103
+ image_1_path = str(image_1_path.resolve())
104
+ sahi.utils.cv.read_image_as_pil(img1).save(image_1_path)
105
+
106
+ image_2_name = str(uuid.uuid4()) + ".png"
107
+ image_2_path = STREAMLIT_STATIC_PATH / image_2_name
108
+ image_2_path = str(image_2_path.resolve())
109
+ sahi.utils.cv.read_image_as_pil(img2).save(image_2_path)
110
+
111
+ img_width, img_height = img1.size
112
+ h_to_w = img_height / img_width
113
+ height = width * h_to_w - 20
114
+
115
+ # load css + js
116
+ cdn_path = "https://cdn.knightlab.com/libs/juxtapose/latest"
117
+ css_block = f'<link rel="stylesheet" href="{cdn_path}/css/juxtapose.css">'
118
+ js_block = f'<script src="{cdn_path}/js/juxtapose.min.js"></script>'
119
+
120
+ # write html block
121
+ htmlcode = f"""
122
+ {css_block}
123
+ {js_block}
124
+ <div id="foo"style="height: '%100'; width: {width or '%100'};"></div>
125
+ <script>
126
+ slider = new juxtapose.JXSlider('#foo',
127
+ [
128
+ {{
129
+ src: '{image_1_name}',
130
+ label: '{label1}',
131
+ }},
132
+ {{
133
+ src: '{image_2_name}',
134
+ label: '{label2}',
135
+ }}
136
+ ],
137
+ {{
138
+ animate: true,
139
+ showLabels: {'true' if show_labels else 'false'},
140
+ showCredits: true,
141
+ startingPosition: "{starting_position}%",
142
+ makeResponsive: {'true' if make_responsive else 'false'},
143
+ }});
144
+ </script>
145
+ """
146
+ static_component = components.html(htmlcode, height=height, width=width)
147
+
148
+ return static_component