dnth commited on
Commit
bb7cb7c
1 Parent(s): 238cc4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -72
app.py CHANGED
@@ -1,12 +1,3 @@
1
- import subprocess
2
- import sys
3
- print("Reinstalling mmcv")
4
- subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", "mmcv-full==1.3.17"])
5
- subprocess.check_call([sys.executable, "-m", "pip", "install", "mmcv-full==1.3.17", "-f", "https://download.openmmlab.com/mmcv/dist/cpu/torch1.10.0/index.html"])
6
- print("mmcv install complete")
7
-
8
- ## Only works if we reinstall mmcv here.
9
-
10
  from gradio.outputs import Label
11
  from icevision.all import *
12
  from icevision.models.checkpoint import *
@@ -25,84 +16,24 @@ class_map = checkpoint_and_model["class_map"]
25
  img_size = checkpoint_and_model["img_size"]
26
  valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()])
27
 
28
- for root, dirs, files in os.walk(r"sample_images/"):
29
- for filename in files:
30
- print("Loading sample image:", filename)
31
-
32
-
33
- # Populate examples in Gradio interface
34
- example_images = [["sample_images/" + file] for file in files]
35
- # Columns: Input Image | Label | Box | Detection Threshold
36
- #examples = [
37
- # [example_images[0], False, True, 0.5],
38
- # [example_images[1], True, True, 0.5],
39
- # [example_images[2], False, True, 0.7],
40
- # [example_images[3], True, True, 0.7],
41
- # [example_images[4], False, True, 0.5],
42
- # [example_images[5], False, True, 0.5],
43
- # [example_images[6], False, True, 0.6],
44
- # [example_images[7], False, True, 0.6],
45
- #]
46
-
47
  examples = [['sample_images/IMG_20191212_151351.jpg'],['sample_images/IMG_20191212_153420.jpg'],['sample_images/IMG_20191212_154100.jpg']]
48
 
49
-
50
- #def show_preds(input_image, display_label, display_bbox, detection_threshold):
51
  def show_preds(input_image):
52
- # if detection_threshold == 0:
53
- #detection_threshold = 0.5
54
  img = PIL.Image.fromarray(input_image, "RGB")
55
-
56
  pred_dict = model_type.end2end_detect(img, valid_tfms, model, class_map=class_map, detection_threshold=0.5,
57
  display_label=False, display_bbox=True, return_img=True,
58
  font_size=16, label_color="#FF59D6")
59
 
60
- #pred_dict = model_type.end2end_detect(
61
- # img,
62
- # valid_tfms,
63
- # model,
64
- # class_map=class_map,
65
- # detection_threshold=detection_threshold,
66
- # display_label=display_label,
67
- # display_bbox=display_bbox,
68
- # return_img=True,
69
- # font_size=16,
70
- # label_color="#FF59D6",
71
- #)
72
-
73
  return pred_dict["img"], len(pred_dict["detection"]["bboxes"])
74
 
75
 
76
- # display_chkbox = gr.inputs.CheckboxGroup(["Label", "BBox"], label="Display", default=True)
77
- display_chkbox_label = gr.inputs.Checkbox(label="Label", default=False)
78
- display_chkbox_box = gr.inputs.Checkbox(label="Box", default=True)
79
- detection_threshold_slider = gr.inputs.Slider(
80
- minimum=0, maximum=1, step=0.1, default=0.5, label="Detection Threshold"
81
- )
82
- outputs = [
83
- gr.outputs.Image(type="pil", label="RetinaNet Inference"),
84
- gr.outputs.Textbox(type="number", label="Microalgae Count"),
85
- ]
86
-
87
- article = "<p style='text-align: center'><a href='https://dicksonneoh.com/' target='_blank'>Blog post</a></p>"
88
-
89
- # Option 1: Get an image from local drive
90
  gr_interface = gr.Interface(
91
  fn=show_preds,
92
- inputs=[
93
- "image"#,
94
- #display_chkbox_label,
95
- #display_chkbox_box,
96
- #detection_threshold_slider,
97
- ],
98
- outputs=outputs,
99
  title="Microalgae Detector with RetinaNet",
100
  description="This RetinaNet model counts microalgaes on a given image. Upload an image or click an example image below to use.",
101
- article=article,
102
  examples=examples,
103
  )
104
- # # Option 2: Grab an image from a webcam
105
- # gr_interface = gr.Interface(fn=show_preds, inputs=["webcam", display_chkbox_label, display_chkbox_box, detection_threshold_slider], outputs=outputs, title='IceApp - COCO', live=False)
106
- # # Option 3: Continuous image stream from the webcam
107
- # gr_interface = gr.Interface(fn=show_preds, inputs=["webcam", display_chkbox_label, display_chkbox_box, detection_threshold_slider], outputs=outputs, title='IceApp - COCO', live=True)
108
  gr_interface.launch()
 
 
 
 
 
 
 
 
 
 
1
  from gradio.outputs import Label
2
  from icevision.all import *
3
  from icevision.models.checkpoint import *
 
16
  img_size = checkpoint_and_model["img_size"]
17
  valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()])
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  examples = [['sample_images/IMG_20191212_151351.jpg'],['sample_images/IMG_20191212_153420.jpg'],['sample_images/IMG_20191212_154100.jpg']]
20
 
 
 
21
  def show_preds(input_image):
 
 
22
  img = PIL.Image.fromarray(input_image, "RGB")
 
23
  pred_dict = model_type.end2end_detect(img, valid_tfms, model, class_map=class_map, detection_threshold=0.5,
24
  display_label=False, display_bbox=True, return_img=True,
25
  font_size=16, label_color="#FF59D6")
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  return pred_dict["img"], len(pred_dict["detection"]["bboxes"])
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  gr_interface = gr.Interface(
31
  fn=show_preds,
32
+ inputs=["image"],
33
+ outputs=[gr.outputs.Image(type="pil", label="RetinaNet Inference"), gr.outputs.Textbox(type="number", label="Microalgae Count")],
 
 
 
 
 
34
  title="Microalgae Detector with RetinaNet",
35
  description="This RetinaNet model counts microalgaes on a given image. Upload an image or click an example image below to use.",
36
+ article="<p style='text-align: center'><a href='https://dicksonneoh.com/' target='_blank'>Blog post</a></p>",
37
  examples=examples,
38
  )
 
 
 
 
39
  gr_interface.launch()