dnth commited on
Commit
4eb4903
1 Parent(s): f7966b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -23
app.py CHANGED
@@ -17,7 +17,6 @@ import os
17
  # Load model
18
  checkpoint_path = "models/model_checkpoint.pth"
19
  checkpoint_and_model = model_from_checkpoint(checkpoint_path)
20
-
21
  model = checkpoint_and_model["model"]
22
  model_type = checkpoint_and_model["model_type"]
23
  class_map = checkpoint_and_model["class_map"]
@@ -26,18 +25,13 @@ class_map = checkpoint_and_model["class_map"]
26
  img_size = checkpoint_and_model["img_size"]
27
  valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()])
28
 
29
-
30
  for root, dirs, files in os.walk(r"sample_images/"):
31
  for filename in files:
32
- print(filename)
33
 
34
- examples = ["sample_images/" + file for file in files]
35
- article = "<p style='text-align: center'><a href='https://dicksonneoh.com/' target='_blank'>Blog post</a></p>"
36
- enable_queue = True
37
 
38
  # Populate examples in Gradio interface
39
  example_images = [["sample_images/" + file] for file in files]
40
-
41
  # Columns: Input Image | Label | Box | Detection Threshold
42
  examples = [
43
  [example_images[0], False, True, 0.5],
@@ -46,17 +40,15 @@ examples = [
46
  [example_images[3], True, True, 0.7],
47
  [example_images[4], False, True, 0.5],
48
  [example_images[5], False, True, 0.5],
49
- [example_images[6], False, True, 0.5],
50
- [example_images[7], False, True, 0.5],
51
  ]
52
 
53
- def show_preds(input_image, display_label, display_bbox, detection_threshold):
54
 
 
55
  if detection_threshold == 0:
56
  detection_threshold = 0.5
57
-
58
  img = PIL.Image.fromarray(input_image, "RGB")
59
-
60
  pred_dict = model_type.end2end_detect(
61
  img,
62
  valid_tfms,
@@ -69,22 +61,21 @@ def show_preds(input_image, display_label, display_bbox, detection_threshold):
69
  font_size=16,
70
  label_color="#FF59D6",
71
  )
72
-
73
  return pred_dict["img"], len(pred_dict["detection"]["bboxes"])
74
 
75
 
76
  # display_chkbox = gr.inputs.CheckboxGroup(["Label", "BBox"], label="Display", default=True)
77
  display_chkbox_label = gr.inputs.Checkbox(label="Label", default=False)
78
  display_chkbox_box = gr.inputs.Checkbox(label="Box", default=True)
79
-
80
  detection_threshold_slider = gr.inputs.Slider(
81
  minimum=0, maximum=1, step=0.1, default=0.5, label="Detection Threshold"
82
  )
83
-
84
  outputs = [
85
  gr.outputs.Image(type="pil", label="RetinaNet Inference"),
86
- gr.outputs.Textbox(type='number', label='Microalgae Count')
87
- ]
 
 
88
 
89
  # Option 1: Get an image from local drive
90
  gr_interface = gr.Interface(
@@ -101,13 +92,8 @@ gr_interface = gr.Interface(
101
  article=article,
102
  examples=examples,
103
  )
104
-
105
-
106
  # # Option 2: Grab an image from a webcam
107
  # gr_interface = gr.Interface(fn=show_preds, inputs=["webcam", display_chkbox_label, display_chkbox_box, detection_threshold_slider], outputs=outputs, title='IceApp - COCO', live=False)
108
-
109
  # # Option 3: Continuous image stream from the webcam
110
  # gr_interface = gr.Interface(fn=show_preds, inputs=["webcam", display_chkbox_label, display_chkbox_box, detection_threshold_slider], outputs=outputs, title='IceApp - COCO', live=True)
111
-
112
-
113
- gr_interface.launch(inline=False, share=False, debug=True)
 
17
  # Load model
18
  checkpoint_path = "models/model_checkpoint.pth"
19
  checkpoint_and_model = model_from_checkpoint(checkpoint_path)
 
20
  model = checkpoint_and_model["model"]
21
  model_type = checkpoint_and_model["model_type"]
22
  class_map = checkpoint_and_model["class_map"]
 
25
  img_size = checkpoint_and_model["img_size"]
26
  valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()])
27
 
 
28
  for root, dirs, files in os.walk(r"sample_images/"):
29
  for filename in files:
30
+ print("Loading sample image:", filename)
31
 
 
 
 
32
 
33
  # Populate examples in Gradio interface
34
  example_images = [["sample_images/" + file] for file in files]
 
35
  # Columns: Input Image | Label | Box | Detection Threshold
36
  examples = [
37
  [example_images[0], False, True, 0.5],
 
40
  [example_images[3], True, True, 0.7],
41
  [example_images[4], False, True, 0.5],
42
  [example_images[5], False, True, 0.5],
43
+ [example_images[6], False, True, 0.6],
44
+ [example_images[7], False, True, 0.6],
45
  ]
46
 
 
47
 
48
+ def show_preds(input_image, display_label, display_bbox, detection_threshold):
49
  if detection_threshold == 0:
50
  detection_threshold = 0.5
 
51
  img = PIL.Image.fromarray(input_image, "RGB")
 
52
  pred_dict = model_type.end2end_detect(
53
  img,
54
  valid_tfms,
 
61
  font_size=16,
62
  label_color="#FF59D6",
63
  )
 
64
  return pred_dict["img"], len(pred_dict["detection"]["bboxes"])
65
 
66
 
67
  # display_chkbox = gr.inputs.CheckboxGroup(["Label", "BBox"], label="Display", default=True)
68
  display_chkbox_label = gr.inputs.Checkbox(label="Label", default=False)
69
  display_chkbox_box = gr.inputs.Checkbox(label="Box", default=True)
 
70
  detection_threshold_slider = gr.inputs.Slider(
71
  minimum=0, maximum=1, step=0.1, default=0.5, label="Detection Threshold"
72
  )
 
73
  outputs = [
74
  gr.outputs.Image(type="pil", label="RetinaNet Inference"),
75
+ gr.outputs.Textbox(type="number", label="Microalgae Count"),
76
+ ]
77
+
78
+ article = "<p style='text-align: center'><a href='https://dicksonneoh.com/' target='_blank'>Blog post</a></p>"
79
 
80
  # Option 1: Get an image from local drive
81
  gr_interface = gr.Interface(
 
92
  article=article,
93
  examples=examples,
94
  )
 
 
95
  # # Option 2: Grab an image from a webcam
96
  # gr_interface = gr.Interface(fn=show_preds, inputs=["webcam", display_chkbox_label, display_chkbox_box, detection_threshold_slider], outputs=outputs, title='IceApp - COCO', live=False)
 
97
  # # Option 3: Continuous image stream from the webcam
98
  # gr_interface = gr.Interface(fn=show_preds, inputs=["webcam", display_chkbox_label, display_chkbox_box, detection_threshold_slider], outputs=outputs, title='IceApp - COCO', live=True)
99
+ gr_interface.launch(inline=False, share=False, debug=True)