File size: 9,718 Bytes
df8b8a4
 
bb0d52f
df8b8a4
 
 
 
bb0d52f
 
 
df8b8a4
bb0d52f
df8b8a4
 
 
bb0d52f
 
 
 
df8b8a4
bb0d52f
 
 
df8b8a4
 
bb0d52f
 
 
 
 
df8b8a4
 
bb0d52f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df8b8a4
bb0d52f
df8b8a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb0d52f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df8b8a4
 
 
 
 
 
bb0d52f
df8b8a4
 
 
 
 
 
 
 
 
 
 
 
bb0d52f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df8b8a4
 
 
bb0d52f
 
 
 
 
df8b8a4
bb0d52f
df8b8a4
bb0d52f
df8b8a4
 
 
bb0d52f
df8b8a4
 
bb0d52f
 
 
 
 
 
 
 
 
d50f451
 
bb0d52f
 
 
 
 
 
 
 
df8b8a4
 
bb0d52f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import gradio as gr
import torch
from PIL import Image, ImageDraw, ImageFont

from transformers import AutoImageProcessor
from transformers import AutoModelForObjectDetection

# Note: Can load from Hugging Face or can load from local.
# You will have to replace {mrdbourke} for your own username if the model is on your Hugging Face account.
model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector_with_data_aug" 

# Load the model and preprocessor
image_processor = AutoImageProcessor.from_pretrained(model_save_path)
model = AutoModelForObjectDetection.from_pretrained(model_save_path)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# Get the id2label dictionary from the model
id2label = model.config.id2label

# Set up a colour dictionary for plotting boxes with different colours
color_dict = {   
    "bin": "green",
    "trash": "blue",
    "hand": "purple",
    "trash_arm": "yellow",
    "not_trash": "red",
    "not_bin": "red",
    "not_hand": "red",
}

# Create helper functions for seeing if items from one list are in another 
def any_in_list(list_a, list_b):
    "Returns True if any item from list_a is in list_b, otherwise False."
    return any(item in list_b for item in list_a)

def all_in_list(list_a, list_b):
    "Returns True if all items from list_a are in list_b, otherwise False."
    return all(item in list_b for item in list_a)

def filter_highest_scoring_box_per_class(boxes, labels, scores):
    """
    Perform NMS (Non-max Supression) to only keep the top scoring box per class.

    Args:
        boxes: tensor of shape (N, 4)
        labels: tensor of shape (N,)
        scores: tensor of shape (N,)
    Returns:
        boxes: tensor of shape (N, 4) filtered for max scoring item per class
        labels: tensor of shape (N,) filtered for max scoring item per class
        scores: tensor of shape (N,) filtered for max scoring item per class
    """
    # Start with a blank keep mask (e.g. all False and then update the boxes to keep with True)
    keep_mask = torch.zeros(len(boxes), dtype=torch.bool)

    # For each unique class
    for class_id in labels.unique():
        # Get the indicies for the target class
        class_mask = labels == class_id

        # If any of the labels match the current class_id
        if class_mask.any():
            # Find the index of highest scoring box for this specific class
            class_scores = scores[class_mask]
            highest_score_idx = class_scores.argmax()

            # Convert back to the original index
            original_idx = torch.where(class_mask)[0][highest_score_idx]

            # Update the index in the keep mask to keep the highest scoring box 
            keep_mask[original_idx] = True
        
    return boxes[keep_mask], labels[keep_mask], scores[keep_mask]

def create_return_string(list_of_predicted_labels, target_items=["trash", "bin", "hand"]):
     # Setup blank string to print out
    return_string = ""

    # If no items detected or trash, bin, hand not in list, return notification 
    if (len(list_of_predicted_labels) == 0) or not (any_in_list(list_a=target_items, list_b=list_of_predicted_labels)):
        return_string = f"No trash, bin or hand detected at confidence threshold {conf_threshold}. Try another image or lowering the confidence threshold."
        return return_string

    # If there are some missing, print the ones which are missing
    elif not all_in_list(list_a=target_items, list_b=list_of_predicted_labels):
        missing_items = []
        for item in target_items:
            if item not in list_of_predicted_labels:
                missing_items.append(item)
        return_string = f"Detected the following items: {list_of_predicted_labels} (total: {len(list_of_predicted_labels)}). But missing the following in order to get +1: {missing_items}. If this is an error, try another image or altering the confidence threshold. Otherwise, the model may need to be updated with better data."
        
    # If all 3 trash, bin, hand occur = + 1
    if all_in_list(list_a=target_items, list_b=list_of_predicted_labels):
        return_string = f"+1! Found the following items: {list_of_predicted_labels} (total: {len(list_of_predicted_labels)}), thank you for cleaning up the area!"

    print(return_string)

    return return_string

def predict_on_image(image, conf_threshold):
    with torch.no_grad():
        inputs = image_processor(images=[image], return_tensors="pt")
        outputs = model(**inputs.to(device))

        target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # height, width 

        results = image_processor.post_process_object_detection(outputs,
                                                                threshold=conf_threshold,
                                                                target_sizes=target_sizes)[0]
    # Return all items in results to CPU
    for key, value in results.items():
        try:
            results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block
        except:
            results[key] = value.cpu()

    # Can return results as plotted on a PIL image (then display the image)
    draw = ImageDraw.Draw(image)

    # Create a copy of the image to draw on it for NMS
    image_nms = image.copy()
    draw_nms = ImageDraw.Draw(image_nms)

    # Get a font from ImageFont
    font = ImageFont.load_default(size=20)

    # Get class names as text for print out
    class_name_text_labels = []

    # TK - update this for NMS
    class_name_text_labels_nms = []

    # Get original boxes, scores, labels
    original_boxes = results["boxes"]
    original_labels = results["labels"]
    original_scores = results["scores"]

    # Filter boxes and only keep 1x of each label with highest score
    filtered_boxes, filtered_labels, filtered_scores = filter_highest_scoring_box_per_class(boxes=original_boxes,
                                                                                            labels=original_labels,
                                                                                            scores=original_scores)
    # TODO: turn this into a function so it's cleaner?
    for box, label, score in zip(original_boxes, original_labels, original_scores):
        # Create coordinates
        x, y, x2, y2 = tuple(box.tolist())

        # Get label_name
        label_name = id2label[label.item()]
        targ_color = color_dict[label_name]
        class_name_text_labels.append(label_name)

        # Draw the rectangle
        draw.rectangle(xy=(x, y, x2, y2), 
                       outline=targ_color,
                       width=3)
        
        # Create a text string to display
        text_string_to_show = f"{label_name} ({round(score.item(), 3)})"

        # Draw the text on the image
        draw.text(xy=(x, y),
                  text=text_string_to_show,
                  fill="white",
                  font=font)
    
    # TODO: turn this into a function so it's cleaner?
    for box, label, score in zip(filtered_boxes, filtered_labels, filtered_scores):
        # Create coordinates
        x, y, x2, y2 = tuple(box.tolist())

        # Get label_name
        label_name = id2label[label.item()]
        targ_color = color_dict[label_name]
        class_name_text_labels_nms.append(label_name)

        # Draw the rectangle
        draw_nms.rectangle(xy=(x, y, x2, y2), 
                       outline=targ_color,
                       width=3)
        
        # Create a text string to display
        text_string_to_show = f"{label_name} ({round(score.item(), 3)})"

        # Draw the text on the image
        draw_nms.text(xy=(x, y),
                  text=text_string_to_show,
                  fill="white",
                  font=font)
    
    
    # Remove the draw each time
    del draw
    del draw_nms

    # Create the return string
    return_string = create_return_string(list_of_predicted_labels=class_name_text_labels)
    return_string_nms = create_return_string(list_of_predicted_labels=class_name_text_labels_nms)
    
    return image, return_string, image_nms, return_string_nms

# Create the interface
demo = gr.Interface(
    fn=predict_on_image,
    inputs=[
        gr.Image(type="pil", label="Target Image"),
        gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold")
    ],
    outputs=[
        gr.Image(type="pil", label="Image Output (no filtering)"),
        gr.Text(label="Text Output (no filtering)"),
        gr.Image(type="pil", label="Image Output (with max score per class box filtering)"),
        gr.Text(label="Text Output (with max score per class box filtering)")
        
    ],
    title="🚮 Trashify Object Detection Demo V3",
    description="""Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand.

    The model in V3 is [same model](https://huggingface.co/mrdbourke/detr_finetuned_trashify_box_detector_with_data_aug) as in [V2](https://huggingface.co/spaces/mrdbourke/trashify_demo_v2) (trained with data augmentation) but has an additional post-processing step (NMS or [Non Maximum Suppression](https://paperswithcode.com/method/non-maximum-suppression)) to filter classes for only the highest scoring box of each class. 
    """,
    # Examples come in the form of a list of lists, where each inner list contains elements to prefill the `inputs` parameter with
    examples=[
        ["examples/trashify_example_1.jpeg", 0.25],
        ["examples/trashify_example_2.jpeg", 0.25],
        ["examples/trashify_example_3.jpeg", 0.25]
    ],
    cache_examples=True
)

# Launch the demo
demo.launch()