mrdbourke's picture
Uploading Trashify box detection model v3 app.py with NMS post processing
d50f451 verified
import gradio as gr
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoImageProcessor
from transformers import AutoModelForObjectDetection
# Note: Can load from Hugging Face or can load from local.
# You will have to replace {mrdbourke} for your own username if the model is on your Hugging Face account.
model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector_with_data_aug"
# Load the model and preprocessor
image_processor = AutoImageProcessor.from_pretrained(model_save_path)
model = AutoModelForObjectDetection.from_pretrained(model_save_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
# Get the id2label dictionary from the model
id2label = model.config.id2label
# Set up a colour dictionary for plotting boxes with different colours
color_dict = {
"bin": "green",
"trash": "blue",
"hand": "purple",
"trash_arm": "yellow",
"not_trash": "red",
"not_bin": "red",
"not_hand": "red",
}
# Create helper functions for seeing if items from one list are in another
def any_in_list(list_a, list_b):
"Returns True if any item from list_a is in list_b, otherwise False."
return any(item in list_b for item in list_a)
def all_in_list(list_a, list_b):
"Returns True if all items from list_a are in list_b, otherwise False."
return all(item in list_b for item in list_a)
def filter_highest_scoring_box_per_class(boxes, labels, scores):
"""
Perform NMS (Non-max Supression) to only keep the top scoring box per class.
Args:
boxes: tensor of shape (N, 4)
labels: tensor of shape (N,)
scores: tensor of shape (N,)
Returns:
boxes: tensor of shape (N, 4) filtered for max scoring item per class
labels: tensor of shape (N,) filtered for max scoring item per class
scores: tensor of shape (N,) filtered for max scoring item per class
"""
# Start with a blank keep mask (e.g. all False and then update the boxes to keep with True)
keep_mask = torch.zeros(len(boxes), dtype=torch.bool)
# For each unique class
for class_id in labels.unique():
# Get the indicies for the target class
class_mask = labels == class_id
# If any of the labels match the current class_id
if class_mask.any():
# Find the index of highest scoring box for this specific class
class_scores = scores[class_mask]
highest_score_idx = class_scores.argmax()
# Convert back to the original index
original_idx = torch.where(class_mask)[0][highest_score_idx]
# Update the index in the keep mask to keep the highest scoring box
keep_mask[original_idx] = True
return boxes[keep_mask], labels[keep_mask], scores[keep_mask]
def create_return_string(list_of_predicted_labels, target_items=["trash", "bin", "hand"]):
# Setup blank string to print out
return_string = ""
# If no items detected or trash, bin, hand not in list, return notification
if (len(list_of_predicted_labels) == 0) or not (any_in_list(list_a=target_items, list_b=list_of_predicted_labels)):
return_string = f"No trash, bin or hand detected at confidence threshold {conf_threshold}. Try another image or lowering the confidence threshold."
return return_string
# If there are some missing, print the ones which are missing
elif not all_in_list(list_a=target_items, list_b=list_of_predicted_labels):
missing_items = []
for item in target_items:
if item not in list_of_predicted_labels:
missing_items.append(item)
return_string = f"Detected the following items: {list_of_predicted_labels} (total: {len(list_of_predicted_labels)}). But missing the following in order to get +1: {missing_items}. If this is an error, try another image or altering the confidence threshold. Otherwise, the model may need to be updated with better data."
# If all 3 trash, bin, hand occur = + 1
if all_in_list(list_a=target_items, list_b=list_of_predicted_labels):
return_string = f"+1! Found the following items: {list_of_predicted_labels} (total: {len(list_of_predicted_labels)}), thank you for cleaning up the area!"
print(return_string)
return return_string
def predict_on_image(image, conf_threshold):
with torch.no_grad():
inputs = image_processor(images=[image], return_tensors="pt")
outputs = model(**inputs.to(device))
target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # height, width
results = image_processor.post_process_object_detection(outputs,
threshold=conf_threshold,
target_sizes=target_sizes)[0]
# Return all items in results to CPU
for key, value in results.items():
try:
results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block
except:
results[key] = value.cpu()
# Can return results as plotted on a PIL image (then display the image)
draw = ImageDraw.Draw(image)
# Create a copy of the image to draw on it for NMS
image_nms = image.copy()
draw_nms = ImageDraw.Draw(image_nms)
# Get a font from ImageFont
font = ImageFont.load_default(size=20)
# Get class names as text for print out
class_name_text_labels = []
# TK - update this for NMS
class_name_text_labels_nms = []
# Get original boxes, scores, labels
original_boxes = results["boxes"]
original_labels = results["labels"]
original_scores = results["scores"]
# Filter boxes and only keep 1x of each label with highest score
filtered_boxes, filtered_labels, filtered_scores = filter_highest_scoring_box_per_class(boxes=original_boxes,
labels=original_labels,
scores=original_scores)
# TODO: turn this into a function so it's cleaner?
for box, label, score in zip(original_boxes, original_labels, original_scores):
# Create coordinates
x, y, x2, y2 = tuple(box.tolist())
# Get label_name
label_name = id2label[label.item()]
targ_color = color_dict[label_name]
class_name_text_labels.append(label_name)
# Draw the rectangle
draw.rectangle(xy=(x, y, x2, y2),
outline=targ_color,
width=3)
# Create a text string to display
text_string_to_show = f"{label_name} ({round(score.item(), 3)})"
# Draw the text on the image
draw.text(xy=(x, y),
text=text_string_to_show,
fill="white",
font=font)
# TODO: turn this into a function so it's cleaner?
for box, label, score in zip(filtered_boxes, filtered_labels, filtered_scores):
# Create coordinates
x, y, x2, y2 = tuple(box.tolist())
# Get label_name
label_name = id2label[label.item()]
targ_color = color_dict[label_name]
class_name_text_labels_nms.append(label_name)
# Draw the rectangle
draw_nms.rectangle(xy=(x, y, x2, y2),
outline=targ_color,
width=3)
# Create a text string to display
text_string_to_show = f"{label_name} ({round(score.item(), 3)})"
# Draw the text on the image
draw_nms.text(xy=(x, y),
text=text_string_to_show,
fill="white",
font=font)
# Remove the draw each time
del draw
del draw_nms
# Create the return string
return_string = create_return_string(list_of_predicted_labels=class_name_text_labels)
return_string_nms = create_return_string(list_of_predicted_labels=class_name_text_labels_nms)
return image, return_string, image_nms, return_string_nms
# Create the interface
demo = gr.Interface(
fn=predict_on_image,
inputs=[
gr.Image(type="pil", label="Target Image"),
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold")
],
outputs=[
gr.Image(type="pil", label="Image Output (no filtering)"),
gr.Text(label="Text Output (no filtering)"),
gr.Image(type="pil", label="Image Output (with max score per class box filtering)"),
gr.Text(label="Text Output (with max score per class box filtering)")
],
title="🚮 Trashify Object Detection Demo V3",
description="""Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand.
The model in V3 is [same model](https://huggingface.co/mrdbourke/detr_finetuned_trashify_box_detector_with_data_aug) as in [V2](https://huggingface.co/spaces/mrdbourke/trashify_demo_v2) (trained with data augmentation) but has an additional post-processing step (NMS or [Non Maximum Suppression](https://paperswithcode.com/method/non-maximum-suppression)) to filter classes for only the highest scoring box of each class.
""",
# Examples come in the form of a list of lists, where each inner list contains elements to prefill the `inputs` parameter with
examples=[
["examples/trashify_example_1.jpeg", 0.25],
["examples/trashify_example_2.jpeg", 0.25],
["examples/trashify_example_3.jpeg", 0.25]
],
cache_examples=True
)
# Launch the demo
demo.launch()