File size: 7,076 Bytes
9657406
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a43dac7
9657406
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1482f3b
9657406
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import os
import glob

import numpy as np
import detectron2
import torchvision
import cv2
import torch

from detectron2 import model_zoo
from detectron2.data import Metadata
from detectron2.structures import BoxMode
from detectron2.utils.visualizer import Visualizer
from detectron2.config import get_cfg
from detectron2.utils.visualizer import ColorMode
from detectron2.modeling import build_model
import detectron2.data.transforms as T
from detectron2.checkpoint import DetectionCheckpointer

import matplotlib.pyplot as plt
from PIL import Image

# -----------------------------------------------------------------------------
# CONFIGS - loaded just the one time when script is first ran to save time.
#
# This is where you will set all the relevant config file and weight file
# variables:
# CONFIG_FILE - Training specific config file for fathomnet
# WEIGHTS_FILE - Path to the model with fathomnet weights
# NMS_THRESH - Set a nms threshold for the all boxes results
# SCORE_THRESH - This is where you can set the model score threshold

CONFIG_FILE = "fathomnet_config_v2_1280.yaml"
WEIGHTS_FILE = "model_final.pth"
NMS_THRESH = 0.45  #
SCORE_THRESH = 0.3  #

# A metadata object that contains metadata on each class category; used with
# Detectron for linking predictions to names and for visualizations.
fathomnet_metadata = Metadata(
    name='fathomnet_val',
    thing_classes=[
        'Anemone',
        'Fish',
        'Eel',
        'Gastropod',
        'Sea star',
        'Feather star',
        'Sea cucumber',
        'Urchin',
        'Glass sponge',
        'Sea fan',
        'Soft coral',
        'Sea pen',
        'Stony coral',
        'Ray',
        'Crab',
        'Shrimp',
        'Squat lobster',
        'Flatfish',
        'Sea spider',
        'Worm']
)

# This is where the model parameters are instantiated. There is a LOT of
# nested arguments in these yaml files, and the merging of baseline defaults
# plus dataset specific parameters.
base_model_path = "COCO-Detection/retinanet_R_50_FPN_3x.yaml"

cfg = get_cfg()
cfg.MODEL.DEVICE = 'cpu'
cfg.merge_from_file(model_zoo.get_config_file(base_model_path))
cfg.merge_from_file(CONFIG_FILE)
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = SCORE_THRESH
cfg.MODEL.WEIGHTS = WEIGHTS_FILE

# Loading of the model weights, but more importantly this is where the model
# is actually instantiated as something that can take inputs and provide
# outputs. There is a lot of documentation about this, but not much in the
# way of straightforward tutorials.
model = build_model(cfg)
checkpointer = DetectionCheckpointer(model)
checkpointer.load(cfg.MODEL.WEIGHTS)
model.eval()

# Create two augmentations and make a list to iterate over
aug1 = T.ResizeShortestEdge(short_edge_length=[cfg.INPUT.MIN_SIZE_TEST],
                            max_size=cfg.INPUT.MAX_SIZE_TEST,
                            sample_style="choice")

aug2 = T.ResizeShortestEdge(short_edge_length=[1080],
                            max_size=1980,
                            sample_style="choice")

augmentations = [aug1, aug2]

# We use a separate NMS layer because initially detectron only does nms intra
# class, so we want to do nms on all boxes.
post_process_nms = torchvision.ops.nms
# -----------------------------------------------------------------------------


def run_inference(test_image):
    """This function runs through inference pipeline, taking in a single
    image as input. The image will be opened, augmented, ran through the
    model, which will output bounding boxes and class categories for each
    object detected. These are then passed back to the calling function."""

    # Load the image, get the height and width. Iterate over each
    # augmentation: do the augmentation, run the model, perform nms
    # thresholding, instantiate a useful object for visualizing the outputs.
    # Saves a list of outputs objects
    img = cv2.imread(test_image)
    im_height, im_width, _ = img.shape
    v_inf = Visualizer(img[:, :, ::-1],
                       metadata=fathomnet_metadata,
                       scale=1.0,
                       instance_mode=ColorMode.SEGMENTATION)

    insts = []

    # iterate over input augmentations (apply resizing)
    for augmentation in augmentations:
        im = augmentation.get_transform(img).apply_image(img)

        # pre-process image by reshaping and converting to tensor
        # pass to model, which outputs a dict containing info on all detections
        with torch.no_grad():
            im = torch.as_tensor(im.astype("float32").transpose(2, 0, 1))
            model_outputs = model([{"image": im,
                                    "height": im_height,
                                    "width": im_width}])[0]

        # populate list with all outputs
        for _ in range(len(model_outputs['instances'])):
            insts.append(model_outputs['instances'][_])

    # Concatenate the model outputs and run NMS thresholding on all output;
    # instantiate a dummy Instance object to concatenate the instances
    model_inst = detectron2.structures.instances.Instances([im_height,
                                                            im_width])

    xx = model_inst.cat(insts)[
        post_process_nms(model_inst.cat(insts).pred_boxes.tensor,
                         model_inst.cat(insts).scores,
                         NMS_THRESH).to("cpu").tolist()]

    print(test_image + ' - Number of predictions:', len(xx))
    out_inf_raw = v_inf.draw_instance_predictions(xx.to("cpu"))
    out_pil = Image.fromarray(out_inf_raw.get_image()).convert('RGB')

    # Converting the predictions as output by Detectron2, to a TATOR format.
    predictions = convert_predictions(xx, v_inf.metadata.thing_classes)

    return predictions, out_pil


def convert_predictions(xx, thing_classes):
    """Helper funtion to post-process the predictions made by Detectron2
    codebase to work with TATOR input requirements."""

    predictions = []

    for _ in range(len(xx)):

        # Obtain the first prediction, instance
        instance = xx.__getitem__(_)

        # Map the coordinates to the variables
        x, y, x2, y2 = map(float, instance.pred_boxes.tensor[0])
        w, h = x2 - x, y2 - y

        # Use class list to get the common name (string); get confidence score.
        class_category = thing_classes[int(instance.pred_classes[0])]
        confidence_score = float(instance.scores[0])

        # Create a spec dict for TATOR
        prediction = {'x': x,
                      'y': y,
                      'width': w,
                      'height': h,
                      'class_category': class_category,
                      'confidence': confidence_score}

        predictions.append(prediction)

    return predictions


if __name__ == "__main__":

    # For demo purposes: run through a couple of test
    # images and then output the predictions in a folder.
    test_images = glob.glob("images/*.png")

    for test_image in test_images:
        predictions, out_img = run_inference(test_image)

    print("Done.")