fcakyon commited on
Commit
af56243
1 Parent(s): a5d9cc2

remove redundant code

Browse files
Files changed (1) hide show
  1. app.py +11 -36
app.py CHANGED
@@ -6,9 +6,7 @@ import random
6
  from utils import image_comparison
7
  from utils import sahi_mmdet_inference
8
 
9
- MMDET_YOLACT_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/yolact/yolact_r50_1x8_coco/yolact_r50_1x8_coco_20200908-f38d58df.pth"
10
  MMDET_YOLOX_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20210806_234250-4ff3b67e.pth"
11
- MMDET_FASTERRCNN_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_2x_coco/faster_rcnn_r50_fpn_2x_coco_bbox_mAP-0.384_20200504_210434-a5d8aa15.pth"
12
 
13
  IMAGE_TO_URL = {
14
  "apple_tree.jpg": "https://user-images.githubusercontent.com/34196005/142730935-2ace3999-a47b-49bb-83e0-2bdd509f1c90.jpg",
@@ -21,34 +19,15 @@ IMAGE_TO_URL = {
21
 
22
 
23
  @st.cache(allow_output_mutation=True, show_spinner=False)
24
- def get_mmdet_model(model_name: str):
25
- if model_name == "yolact":
26
- model_path = "yolact.pt"
27
- sahi.utils.file.download_from_url(
28
- MMDET_YOLACT_MODEL_URL,
29
- model_path,
30
- )
31
- config_path = sahi.utils.mmdet.download_mmdet_config(
32
- model_name="yolact", config_file_name="yolact_r50_1x8_coco.py"
33
- )
34
- elif model_name == "yolox":
35
- model_path = "yolox.pt"
36
- sahi.utils.file.download_from_url(
37
- MMDET_YOLOX_MODEL_URL,
38
- model_path,
39
- )
40
- config_path = sahi.utils.mmdet.download_mmdet_config(
41
- model_name="yolox", config_file_name="yolox_tiny_8x8_300e_coco.py"
42
- )
43
- elif model_name == "faster_rcnn":
44
- model_path = "faster_rcnn.pt"
45
- sahi.utils.file.download_from_url(
46
- MMDET_FASTERRCNN_MODEL_URL,
47
- model_path,
48
- )
49
- config_path = sahi.utils.mmdet.download_mmdet_config(
50
- model_name="faster_rcnn", config_file_name="faster_rcnn_r50_fpn_2x_coco.py"
51
- )
52
 
53
  detection_model = sahi.model.MmdetDetectionModel(
54
  model_path=model_path,
@@ -183,7 +162,6 @@ with col1:
183
  with col3:
184
  st.markdown(f"##### Set SAHI parameters:")
185
 
186
- model_name = "yolox"
187
  slice_size = st.number_input("slice_size", min_value=256, value=512, step=256)
188
  overlap_ratio = st.number_input(
189
  "overlap_ratio", min_value=0.0, max_value=0.6, value=0.2, step=0.2
@@ -209,12 +187,9 @@ if submit:
209
  text="Downloading model weight.. "
210
  + st.session_state["last_spinner_texts"].get()
211
  ):
212
- detection_model = get_mmdet_model(model_name)
213
 
214
- if model_name == "yolox":
215
- image_size = 416
216
- else:
217
- image_size = 640
218
 
219
  with st.spinner(
220
  text="Performing prediction.. " + st.session_state["last_spinner_texts"].get()
6
  from utils import image_comparison
7
  from utils import sahi_mmdet_inference
8
 
 
9
  MMDET_YOLOX_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20210806_234250-4ff3b67e.pth"
 
10
 
11
  IMAGE_TO_URL = {
12
  "apple_tree.jpg": "https://user-images.githubusercontent.com/34196005/142730935-2ace3999-a47b-49bb-83e0-2bdd509f1c90.jpg",
19
 
20
 
21
  @st.cache(allow_output_mutation=True, show_spinner=False)
22
+ def get_model(model_name: str):
23
+ model_path = "yolox.pt"
24
+ sahi.utils.file.download_from_url(
25
+ MMDET_YOLOX_MODEL_URL,
26
+ model_path,
27
+ )
28
+ config_path = sahi.utils.mmdet.download_mmdet_config(
29
+ model_name="yolox", config_file_name="yolox_tiny_8x8_300e_coco.py"
30
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  detection_model = sahi.model.MmdetDetectionModel(
33
  model_path=model_path,
162
  with col3:
163
  st.markdown(f"##### Set SAHI parameters:")
164
 
 
165
  slice_size = st.number_input("slice_size", min_value=256, value=512, step=256)
166
  overlap_ratio = st.number_input(
167
  "overlap_ratio", min_value=0.0, max_value=0.6, value=0.2, step=0.2
187
  text="Downloading model weight.. "
188
  + st.session_state["last_spinner_texts"].get()
189
  ):
190
+ detection_model = get_model()
191
 
192
+ image_size = 416
 
 
 
193
 
194
  with st.spinner(
195
  text="Performing prediction.. " + st.session_state["last_spinner_texts"].get()