danielHora's picture
Update app.py
dc3de32
from transformers import DetrFeatureExtractor, DetrForObjectDetection
from transformers import pipeline
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
# Draw bounding box definition
def draw_bounding_box(im, score, label, xmin, ymin, xmax, ymax, index, num_boxes):
""" Draw a bounding box. """
# Draw the actual bounding box
outline = 'blue'
im_with_rectangle = ImageDraw.Draw(im)
im_with_rectangle.rounded_rectangle((xmin, ymin, xmax, ymax), outline = outline, width = 2, radius = 7)
# Return the result
return im
def object_classify(img):
feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50')
model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
object_detector = pipeline("object-detection", model = model, feature_extractor = feature_extractor)
bboxes = object_detector(img)
price_total = 0
total_items_in_cart = 0
# Iteration elements
num_boxes = len(bboxes)
index = 0
# Draw bounding box for each result and count the price
for i in bboxes:
if i['label'] == 'apple':
price_total += 25 #pesos? dunno
elif i['label'] == 'bottle':
price_total += 15
elif i['label'] == 'broccoli':
price_total += 100
elif i['label'] == 'orange':
price_total += 20
elif i['label'] == 'banana':
price_total += 50
box = i['box']
#Draw the bounding box
output_image = draw_bounding_box(img, i['score'],i['label'],
box['xmin'], box['ymin'],
box['xmax'], box['ymax'],
index, num_boxes)
index += 1
total_items_in_cart += 1
return output_image, str(price_total), str(total_items_in_cart)
TITLE = 'Object Detection for Effective Self-Checkout in Grocery Shopping [Work In Progress]'
DESCRIPTION = 'A deep learning application using DETR model to reimagine self-checkout stores.'
EXAMPLES = ['ex1.jpg']
interface=gr.Interface(object_classify,
gr.inputs.Image(type = 'pil'),outputs = [gr.outputs.Image(), gr.outputs.Textbox(label='Total Price: '), gr.outputs.Textbox(label='Total items in cart: ')],
examples = EXAMPLES,title = TITLE, description=DESCRIPTION, allow_flagging="never")
interface.launch()