Jordan Pierce commited on
Commit
6300104
1 Parent(s): 0dd9a9d

updated app

Browse files
Files changed (2) hide show
  1. .idea/workspace.xml +2 -3
  2. app.py +5 -186
.idea/workspace.xml CHANGED
@@ -4,9 +4,7 @@
4
  <option name="SCOPE_TYPE" value="3" />
5
  </component>
6
  <component name="ChangeListManager">
7
- <list default="true" id="e4d9959f-0c5b-4a80-8b43-a006df26f93a" name="Changes" comment="">
8
- <change beforePath="$PROJECT_DIR$/app.py" beforeDir="false" afterPath="$PROJECT_DIR$/app.py" afterDir="false" />
9
- </list>
10
  <option name="SHOW_DIALOG" value="false" />
11
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
12
  <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
@@ -54,6 +52,7 @@
54
  <updated>1666646148268</updated>
55
  <workItem from="1666646160771" duration="3554000" />
56
  <workItem from="1666649813029" duration="426000" />
 
57
  </task>
58
  <servers />
59
  </component>
 
4
  <option name="SCOPE_TYPE" value="3" />
5
  </component>
6
  <component name="ChangeListManager">
7
+ <list default="true" id="e4d9959f-0c5b-4a80-8b43-a006df26f93a" name="Changes" comment="" />
 
 
8
  <option name="SHOW_DIALOG" value="false" />
9
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
10
  <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
 
52
  <updated>1666646148268</updated>
53
  <workItem from="1666646160771" duration="3554000" />
54
  <workItem from="1666649813029" duration="426000" />
55
+ <workItem from="1666707251602" duration="12000" />
56
  </task>
57
  <servers />
58
  </component>
app.py CHANGED
@@ -1,200 +1,19 @@
1
- try:
2
- import detectron2
3
- except:
4
- import os
5
- os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
6
-
7
- import glob
8
-
9
- import numpy as np
10
- import detectron2
11
- import torchvision
12
- import cv2
13
- import torch
14
-
15
- from detectron2 import model_zoo
16
- from detectron2.data import Metadata
17
- from detectron2.structures import BoxMode
18
- from detectron2.utils.visualizer import Visualizer
19
- from detectron2.config import get_cfg
20
- from detectron2.utils.visualizer import ColorMode
21
- from detectron2.modeling import build_model
22
- import detectron2.data.transforms as T
23
- from detectron2.checkpoint import DetectionCheckpointer
24
-
25
  import gradio as gr
26
- from PIL import Image
27
-
28
- # -----------------------------------------------------------------------------
29
- # CONFIGS - loaded just the one time when script is first ran to save time.
30
- #
31
- # This is where you will set all the relevant config file and weight file
32
- # variables:
33
- # CONFIG_FILE - Training specific config file for fathomnet
34
- # WEIGHTS_FILE - Path to the model with fathomnet weights
35
- # NMS_THRESH - Set a nms threshold for the all boxes results
36
- # SCORE_THRESH - This is where you can set the model score threshold
37
-
38
- CONFIG_FILE = "fathomnet_config_v2_1280.yaml"
39
- WEIGHTS_FILE = "model_final.pth"
40
- NMS_THRESH = 0.45 #
41
- SCORE_THRESH = 0.3 #
42
-
43
- # A metadata object that contains metadata on each class category; used with
44
- # Detectron for linking predictions to names and for visualizations.
45
- fathomnet_metadata = Metadata(
46
- name='fathomnet_val',
47
- thing_classes=[
48
- 'Anemone',
49
- 'Fish',
50
- 'Eel',
51
- 'Gastropod',
52
- 'Sea star',
53
- 'Feather star',
54
- 'Sea cucumber',
55
- 'Urchin',
56
- 'Glass sponge',
57
- 'Sea fan',
58
- 'Soft coral',
59
- 'Sea pen',
60
- 'Stony coral',
61
- 'Ray',
62
- 'Crab',
63
- 'Shrimp',
64
- 'Squat lobster',
65
- 'Flatfish',
66
- 'Sea spider',
67
- 'Worm']
68
- )
69
-
70
- # This is where the model parameters are instantiated. There is a LOT of
71
- # nested arguments in these yaml files, and the merging of baseline defaults
72
- # plus dataset specific parameters.
73
- base_model_path = "COCO-Detection/retinanet_R_50_FPN_3x.yaml"
74
-
75
- cfg = get_cfg()
76
- cfg.MODEL.DEVICE = 'cpu'
77
- cfg.merge_from_file(model_zoo.get_config_file(base_model_path))
78
- cfg.merge_from_file(CONFIG_FILE)
79
- cfg.MODEL.RETINANET.SCORE_THRESH_TEST = SCORE_THRESH
80
- cfg.MODEL.WEIGHTS = WEIGHTS_FILE
81
-
82
- # Loading of the model weights, but more importantly this is where the model
83
- # is actually instantiated as something that can take inputs and provide
84
- # outputs. There is a lot of documentation about this, but not much in the
85
- # way of straightforward tutorials.
86
- model = build_model(cfg)
87
- checkpointer = DetectionCheckpointer(model)
88
- checkpointer.load(cfg.MODEL.WEIGHTS)
89
- model.eval()
90
-
91
- # Create two augmentations and make a list to iterate over
92
- aug1 = T.ResizeShortestEdge(short_edge_length=[cfg.INPUT.MIN_SIZE_TEST],
93
- max_size=cfg.INPUT.MAX_SIZE_TEST,
94
- sample_style="choice")
95
-
96
- aug2 = T.ResizeShortestEdge(short_edge_length=[1080],
97
- max_size=1980,
98
- sample_style="choice")
99
-
100
- augmentations = [aug1, aug2]
101
-
102
- # We use a separate NMS layer because initially detectron only does nms intra
103
- # class, so we want to do nms on all boxes.
104
- post_process_nms = torchvision.ops.nms
105
- # -----------------------------------------------------------------------------
106
-
107
-
108
- def run_inference(test_image):
109
- """This function runs through inference pipeline, taking in a single
110
- image as input. The image will be opened, augmented, ran through the
111
- model, which will output bounding boxes and class categories for each
112
- object detected. These are then passed back to the calling function."""
113
-
114
- # Load the image, get the height and width. Iterate over each
115
- # augmentation: do the augmentation, run the model, perform nms
116
- # thresholding, instantiate a useful object for visualizing the outputs.
117
- # Saves a list of outputs objects
118
- img = cv2.imread(test_image)
119
- im_height, im_width, _ = img.shape
120
- v_inf = Visualizer(img[:, :, ::-1],
121
- metadata=fathomnet_metadata,
122
- scale=1.0,
123
- instance_mode=ColorMode.IMAGE_BW)
124
-
125
- insts = []
126
-
127
- # iterate over input augmentations (apply resizing)
128
- for augmentation in augmentations:
129
- im = augmentation.get_transform(img).apply_image(img)
130
-
131
- # pre-process image by reshaping and converting to tensor
132
- # pass to model, which outputs a dict containing info on all detections
133
- with torch.no_grad():
134
- im = torch.as_tensor(im.astype("float32").transpose(2, 0, 1))
135
- model_outputs = model([{"image": im,
136
- "height": im_height,
137
- "width": im_width}])[0]
138
-
139
- # populate list with all outputs
140
- for _ in range(len(model_outputs['instances'])):
141
- insts.append(model_outputs['instances'][_])
142
-
143
- # TODO explore the outputs to determine what needs to be passed to tator.py
144
- # Concatenate the model outputs and run NMS thresholding on all output;
145
- # instantiate a dummy Instance object to concatenate the instances
146
- model_inst = detectron2.structures.instances.Instances([im_height,
147
- im_width])
148
 
149
- xx = model_inst.cat(insts)[
150
- post_process_nms(model_inst.cat(insts).pred_boxes.tensor,
151
- model_inst.cat(insts).scores,
152
- NMS_THRESH).to("cpu").tolist()]
153
 
154
- out_inf_raw = v_inf.draw_instance_predictions(xx.to("cpu"))
155
- out_pil = Image.fromarray(out_inf_raw.get_image()).convert('RGB')
156
 
157
  return out_pil
158
 
159
 
160
- def convert_predictions(xx, thing_classes):
161
- """Helper funtion to post-process the predictions made by Detectron2
162
- codebase to work with TATOR input requirements."""
163
-
164
- predictions = []
165
-
166
- for _ in range(len(xx)):
167
-
168
- # Obtain the first prediction, instance
169
- instance = xx.__getitem__(_)
170
-
171
- # Map the coordinates to the variables
172
- x, y, x2, y2 = map(float, instance.pred_boxes.tensor[0])
173
- w, h = x2 - x, y2 - y
174
-
175
- # Use class list to get the common name (string); get confidence score.
176
- class_category = thing_classes[int(instance.pred_classes[0])]
177
- confidence_score = float(instance.scores[0])
178
-
179
- # Create a spec dict for TATOR
180
- prediction = {'x': x,
181
- 'y': y,
182
- 'width': w,
183
- 'height': h,
184
- 'class_category': class_category,
185
- 'confidence': confidence_score}
186
-
187
- predictions.append(prediction)
188
-
189
- return predictions
190
-
191
-
192
  # -----------------------------------------------------------------------------
193
  # GRADIO APP
194
  # -----------------------------------------------------------------------------
195
 
196
- examples = [glob.glob("images/*.png")]
197
-
198
  title = "MBARI Monterey Bay Benthic Supercategory"
199
  description = "Gradio demo for MBARI Monterey Bay Benthic Supercategory: This " \
200
  "is a RetinaNet model fine-tuned from the Detectron2 object " \
 
1
+ from inference import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
+ import glob
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ def gradio_app(image_path):
6
+ """Helper function to run inference on provided image"""
 
 
7
 
8
+ predictions, out_pil = run_inference(image_path)
 
9
 
10
  return out_pil
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # -----------------------------------------------------------------------------
14
  # GRADIO APP
15
  # -----------------------------------------------------------------------------
16
 
 
 
17
  title = "MBARI Monterey Bay Benthic Supercategory"
18
  description = "Gradio demo for MBARI Monterey Bay Benthic Supercategory: This " \
19
  "is a RetinaNet model fine-tuned from the Detectron2 object " \