Daniel Cerda Escobar commited on
Commit
9f23d95
·
1 Parent(s): 78e6083

Update yolo model

Browse files
Files changed (3) hide show
  1. app.py +50 -38
  2. requirements.txt +3 -2
  3. utils.py +7 -7
app.py CHANGED
@@ -1,23 +1,32 @@
1
  import streamlit as st
2
- import sahi.utils.file
3
- import sahi.utils.mmdet
4
- from sahi import AutoDetectionModel
5
  from PIL import Image
6
  import random
7
- from utils import sahi_mmdet_inference
 
 
 
8
  from streamlit_image_comparison import image_comparison
9
 
10
- MMDET_YOLOX_TINY_MODEL_URL = "https://huggingface.co/fcakyon/mmdet-yolox-tiny/resolve/main/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth"
11
- MMDET_YOLOX_TINY_MODEL_PATH = "yolox.pt"
12
- MMDET_YOLOX_TINY_CONFIG_URL = "https://huggingface.co/fcakyon/mmdet-yolox-tiny/raw/main/yolox_tiny_8x8_300e_coco.py"
13
- MMDET_YOLOX_TINY_CONFIG_PATH = "config.py"
 
 
 
 
 
 
 
 
 
14
 
15
  IMAGE_TO_URL = {
16
  "apple_tree.jpg": "https://user-images.githubusercontent.com/34196005/142730935-2ace3999-a47b-49bb-83e0-2bdd509f1c90.jpg",
17
  "highway.jpg": "https://user-images.githubusercontent.com/34196005/142730936-1b397756-52e5-43be-a949-42ec0134d5d8.jpg",
18
  "highway2.jpg": "https://user-images.githubusercontent.com/34196005/142742871-bf485f84-0355-43a3-be86-96b44e63c3a2.jpg",
19
  "highway3.jpg": "https://user-images.githubusercontent.com/34196005/142742872-1fefcc4d-d7e6-4c43-bbb7-6b5982f7e4ba.jpg",
20
- "highway2-yolox.jpg": "https://user-images.githubusercontent.com/34196005/143309873-c0c1f31c-c42e-4a36-834e-da0a2336bb19.jpg",
21
  "highway2-sahi.jpg": "https://user-images.githubusercontent.com/34196005/143309867-42841f5a-9181-4d22-b570-65f90f2da231.jpg",
22
  }
23
 
@@ -26,7 +35,7 @@ IMAGE_TO_URL = {
26
  def download_comparison_images():
27
  sahi.utils.file.download_from_url(
28
  "https://user-images.githubusercontent.com/34196005/143309873-c0c1f31c-c42e-4a36-834e-da0a2336bb19.jpg",
29
- "highway2-yolox.jpg",
30
  )
31
  sahi.utils.file.download_from_url(
32
  "https://user-images.githubusercontent.com/34196005/143309867-42841f5a-9181-4d22-b570-65f90f2da231.jpg",
@@ -36,19 +45,22 @@ def download_comparison_images():
36
 
37
  @st.cache_data(show_spinner=False)
38
  def get_model():
39
- sahi.utils.file.download_from_url(
40
- MMDET_YOLOX_TINY_MODEL_URL,
41
- MMDET_YOLOX_TINY_MODEL_PATH,
42
- )
43
- sahi.utils.file.download_from_url(
44
- MMDET_YOLOX_TINY_CONFIG_URL,
45
- MMDET_YOLOX_TINY_CONFIG_PATH,
46
- )
 
 
 
47
 
48
  detection_model = AutoDetectionModel.from_pretrained(
49
- model_type='mmdet',
50
- model_path=MMDET_YOLOX_TINY_MODEL_PATH,
51
- config_path=MMDET_YOLOX_TINY_CONFIG_PATH,
52
  confidence_threshold=0.5,
53
  device="cpu",
54
  )
@@ -96,7 +108,7 @@ if "last_spinner_texts" not in st.session_state:
96
  st.session_state["last_spinner_texts"] = SpinnerTexts()
97
 
98
  if "output_1" not in st.session_state:
99
- st.session_state["output_1"] = Image.open("highway2-yolox.jpg")
100
 
101
  if "output_2" not in st.session_state:
102
  st.session_state["output_2"] = Image.open("highway2-sahi.jpg")
@@ -181,12 +193,12 @@ with col3:
181
  overlap_ratio = st.number_input(
182
  "overlap_ratio", min_value=0.0, max_value=0.6, value=0.2, step=0.2
183
  )
184
- postprocess_type = st.selectbox(
185
- "postprocess_type", options=["NMS", "GREEDYNMM"], index=0
186
- )
187
- postprocess_match_metric = st.selectbox(
188
- "postprocess_match_metric", options=["IOU", "IOS"], index=0
189
- )
190
  postprocess_match_threshold = st.number_input(
191
  "postprocess_match_threshold", value=0.5, step=0.1
192
  )
@@ -209,7 +221,7 @@ if submit:
209
  with st.spinner(
210
  text="Performing prediction.. " + st.session_state["last_spinner_texts"].get()
211
  ):
212
- output_1, output_2 = sahi_mmdet_inference(
213
  image,
214
  detection_model,
215
  image_size=image_size,
@@ -226,7 +238,7 @@ if submit:
226
  st.session_state["output_1"] = output_1
227
  st.session_state["output_2"] = output_2
228
 
229
- st.markdown(f"##### YOLOX Standard vs SAHI Prediction:")
230
  static_component = image_comparison(
231
  img1=st.session_state["output_1"],
232
  img2=st.session_state["output_2"],
@@ -238,11 +250,11 @@ static_component = image_comparison(
238
  make_responsive=True,
239
  in_memory=True,
240
  )
241
- st.markdown(
242
- """
243
- <p style='text-align: center'>
244
- prepared with <a href='https://github.com/fcakyon/streamlit-image-comparison' target='_blank'>streamlit-image-comparison</a>
245
- </p>
246
- """,
247
- unsafe_allow_html=True,
248
- )
 
1
  import streamlit as st
 
 
 
2
  from PIL import Image
3
  import random
4
+ from sahi.utils.yolov8 import download_yolov8m_model
5
+ from sahi import AutoDetectionModel
6
+ from utils import sahi_yolov8m_inference
7
+ import sahi.utils.file
8
  from streamlit_image_comparison import image_comparison
9
 
10
+ #import sahi.utils.mmdet
11
+
12
+ #MMDET_YOLOX_TINY_MODEL_URL = "https://huggingface.co/fcakyon/mmdet-yolox-tiny/resolve/main/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth"
13
+ #MMDET_YOLOX_TINY_MODEL_PATH = "yolox.pt"
14
+ #MMDET_YOLOX_TINY_CONFIG_URL = "https://huggingface.co/fcakyon/mmdet-yolox-tiny/raw/main/yolox_tiny_8x8_300e_coco.py"
15
+ #MMDET_YOLOX_TINY_CONFIG_PATH = "config.py"
16
+
17
+ #YOLOV8M_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m.pt"
18
+ #YOLOV8M_MODEL_PATH = "tests/data/models/yolov8/yolov8m.pt"
19
+
20
+
21
+ YOLOV8M_MODEL_PATH = 'models/yolov8m.pt'
22
+
23
 
24
  IMAGE_TO_URL = {
25
  "apple_tree.jpg": "https://user-images.githubusercontent.com/34196005/142730935-2ace3999-a47b-49bb-83e0-2bdd509f1c90.jpg",
26
  "highway.jpg": "https://user-images.githubusercontent.com/34196005/142730936-1b397756-52e5-43be-a949-42ec0134d5d8.jpg",
27
  "highway2.jpg": "https://user-images.githubusercontent.com/34196005/142742871-bf485f84-0355-43a3-be86-96b44e63c3a2.jpg",
28
  "highway3.jpg": "https://user-images.githubusercontent.com/34196005/142742872-1fefcc4d-d7e6-4c43-bbb7-6b5982f7e4ba.jpg",
29
+ "highway2-yolov8m.jpg": "https://user-images.githubusercontent.com/34196005/143309873-c0c1f31c-c42e-4a36-834e-da0a2336bb19.jpg",
30
  "highway2-sahi.jpg": "https://user-images.githubusercontent.com/34196005/143309867-42841f5a-9181-4d22-b570-65f90f2da231.jpg",
31
  }
32
 
 
35
  def download_comparison_images():
36
  sahi.utils.file.download_from_url(
37
  "https://user-images.githubusercontent.com/34196005/143309873-c0c1f31c-c42e-4a36-834e-da0a2336bb19.jpg",
38
+ "highway2-yolov8m.jpg",
39
  )
40
  sahi.utils.file.download_from_url(
41
  "https://user-images.githubusercontent.com/34196005/143309867-42841f5a-9181-4d22-b570-65f90f2da231.jpg",
 
45
 
46
  @st.cache_data(show_spinner=False)
47
  def get_model():
48
+
49
+ #sahi.utils.file.download_from_url(
50
+ # MMDET_YOLOX_TINY_MODEL_URL,
51
+ # MMDET_YOLOX_TINY_MODEL_PATH,
52
+ #)
53
+ #sahi.utils.file.download_from_url(
54
+ # MMDET_YOLOX_TINY_CONFIG_URL,
55
+ # MMDET_YOLOX_TINY_CONFIG_PATH,
56
+ #)
57
+
58
+ download_yolov8m_model(destination_path = YOLOV8M_MODEL_PATH)
59
 
60
  detection_model = AutoDetectionModel.from_pretrained(
61
+ model_type='yolov8',
62
+ model_path=YOLOV8M_MODEL_PATH,
63
+ #config_path=MMDET_YOLOX_TINY_CONFIG_PATH,
64
  confidence_threshold=0.5,
65
  device="cpu",
66
  )
 
108
  st.session_state["last_spinner_texts"] = SpinnerTexts()
109
 
110
  if "output_1" not in st.session_state:
111
+ st.session_state["output_1"] = Image.open("highway2-yolov8m.jpg")
112
 
113
  if "output_2" not in st.session_state:
114
  st.session_state["output_2"] = Image.open("highway2-sahi.jpg")
 
193
  overlap_ratio = st.number_input(
194
  "overlap_ratio", min_value=0.0, max_value=0.6, value=0.2, step=0.2
195
  )
196
+ #postprocess_type = st.selectbox(
197
+ # "postprocess_type", options=["NMS", "GREEDYNMM"], index=0
198
+ #)
199
+ #postprocess_match_metric = st.selectbox(
200
+ # "postprocess_match_metric", options=["IOU", "IOS"], index=0
201
+ #)
202
  postprocess_match_threshold = st.number_input(
203
  "postprocess_match_threshold", value=0.5, step=0.1
204
  )
 
221
  with st.spinner(
222
  text="Performing prediction.. " + st.session_state["last_spinner_texts"].get()
223
  ):
224
+ output_1, output_2 = sahi_yolov8m_inference(
225
  image,
226
  detection_model,
227
  image_size=image_size,
 
238
  st.session_state["output_1"] = output_1
239
  st.session_state["output_2"] = output_2
240
 
241
+ st.markdown(f"##### YOLOv8 Standard vs SAHI Prediction:")
242
  static_component = image_comparison(
243
  img1=st.session_state["output_1"],
244
  img2=st.session_state["output_2"],
 
250
  make_responsive=True,
251
  in_memory=True,
252
  )
253
+ # st.markdown(
254
+ # """
255
+ # <p style='text-align: center'>
256
+ # prepared with <a href='https://github.com/fcakyon/streamlit-image-comparison' target='_blank'>streamlit-image-comparison</a>
257
+ # </p>
258
+ # """,
259
+ # unsafe_allow_html=True,
260
+ # )
requirements.txt CHANGED
@@ -2,8 +2,9 @@
2
  -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.10.0/index.html
3
  torch==1.12.1+cpu
4
  torchvision==0.13.1+cpu
5
- sahi==0.11.6
6
  mmdet==2.25.2
7
  mmcv-full==1.6.1
8
  streamlit-image-comparison==0.0.4
9
- streamlit==1.22.0
 
 
2
  -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.10.0/index.html
3
  torch==1.12.1+cpu
4
  torchvision==0.13.1+cpu
5
+ sahi==0.11.14
6
  mmdet==2.25.2
7
  mmcv-full==1.6.1
8
  streamlit-image-comparison==0.0.4
9
+ streamlit==1.22.0
10
+ ultralyticsplus==0.0.14
utils.py CHANGED
@@ -6,7 +6,7 @@ from PIL import Image
6
  TEMP_DIR = "temp"
7
 
8
 
9
- def sahi_mmdet_inference(
10
  image,
11
  detection_model,
12
  slice_height=512,
@@ -14,10 +14,10 @@ def sahi_mmdet_inference(
14
  overlap_height_ratio=0.2,
15
  overlap_width_ratio=0.2,
16
  image_size=640,
17
- postprocess_type="GREEDYNMM",
18
- postprocess_match_metric="IOS",
19
  postprocess_match_threshold=0.5,
20
- postprocess_class_agnostic=False,
21
  ):
22
 
23
  # standard inference
@@ -39,10 +39,10 @@ def sahi_mmdet_inference(
39
  slice_width=slice_width,
40
  overlap_height_ratio=overlap_height_ratio,
41
  overlap_width_ratio=overlap_width_ratio,
42
- postprocess_type=postprocess_type,
43
- postprocess_match_metric=postprocess_match_metric,
44
  postprocess_match_threshold=postprocess_match_threshold,
45
- postprocess_class_agnostic=postprocess_class_agnostic,
46
  )
47
  visual_result_2 = sahi.utils.cv.visualize_object_predictions(
48
  image=numpy.array(image),
 
6
  TEMP_DIR = "temp"
7
 
8
 
9
+ def sahi_yolov8m_inference(
10
  image,
11
  detection_model,
12
  slice_height=512,
 
14
  overlap_height_ratio=0.2,
15
  overlap_width_ratio=0.2,
16
  image_size=640,
17
+ #postprocess_type="GREEDYNMM",
18
+ #postprocess_match_metric="IOS",
19
  postprocess_match_threshold=0.5,
20
+ #postprocess_class_agnostic=False,
21
  ):
22
 
23
  # standard inference
 
39
  slice_width=slice_width,
40
  overlap_height_ratio=overlap_height_ratio,
41
  overlap_width_ratio=overlap_width_ratio,
42
+ #postprocess_type=postprocess_type,
43
+ #postprocess_match_metric=postprocess_match_metric,
44
  postprocess_match_threshold=postprocess_match_threshold,
45
+ #postprocess_class_agnostic=postprocess_class_agnostic,
46
  )
47
  visual_result_2 = sahi.utils.cv.visualize_object_predictions(
48
  image=numpy.array(image),