khizon's picture
Update app.py
c182d0a
raw
history blame
No virus
2.38 kB
from transformers import pipeline
from transformers import DetrFeatureExtractor, DetrForObjectDetection
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
# Initialize another model and feature extractor
feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50')
model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
# Initialize the object detection pipeline
object_detector = pipeline("object-detection", model = model, feature_extractor = feature_extractor)
# Draw bounding box definition
def draw_bounding_box(im, score, label, xmin, ymin, xmax, ymax, index, num_boxes):
""" Draw a bounding box. """
# Draw the actual bounding box
outline = ' '
if label in ['truck', 'car', 'motorcycle', 'bus']:
outline = 'red'
elif label in ['person', 'bicycle']:
outline = 'green'
else:
outline = 'blue'
im_with_rectangle = ImageDraw.Draw(im)
im_with_rectangle.rounded_rectangle((xmin, ymin, xmax, ymax), outline = outline, width = 3, radius = 10)
# Return the result
return im
def detect_image(im):
# Perform object detection
bounding_boxes = object_detector(im)
# Iteration elements
num_boxes = len(bounding_boxes)
index = 0
# Draw bounding box for each result
for bounding_box in bounding_boxes:
if bounding_box['label'] in ['person','motorcycle','bicycle', 'truck', 'car','bus']:
box = bounding_box['box']
#Draw the bounding box
output_image = draw_bounding_box(im, bounding_box['score'],
bounding_box['label'],
box['xmin'], box['ymin'],
box['xmax'], box['ymax'],
index, num_boxes)
index += 1
return output_image
TITLE = 'Active Transport Detection'
DESCRIPTION = 'This uses DETR as an object detection model and detects motor vehicles (red) and people and bikes (green). Much fine-tuning and optimization is still needed to make this a practical application'
examples = [['bike.jpg'], ['bike2.jpg'], ['bike_3.jpg'], ['bike_4.jpg']]
iface = gr.Interface(detect_image, gr.inputs.Image(type = 'pil'), gr.outputs.Image(), examples = examples, allow_flagging = 'never', title = TITLE, description = DESCRIPTION).launch(debug = True)