EVad's picture
Added app file.
2ffd98b
import io
import gradio as gr
import matplotlib.pyplot as plt
import requests, validators
import torch
import pathlib
from PIL import Image
from transformers import AutoFeatureExtractor, DetrForObjectDetection
import os
# Defining functions for the code
def make_prediction(img, feature_extractor, model):
inputs = feature_extractor(img, return_tensors="pt")
outputs = model(**inputs)
img_size = torch.tensor([tuple(reversed(img.size))])
processed_outputs = feature_extractor.post_process(outputs, img_size)
return processed_outputs[0]
def detect_objects(url_input):
#Extract model and feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
# if image comes from URL
if validators.url(url_input):
image = Image.open(requests.get(url_input, stream=True).raw)
#Make prediction
processed_outputs = make_prediction(image, feature_extractor, model)
#Visualize prediction
viz_img = visualize_prediction(image, processed_outputs, 0.7, model.config.id2label)
return viz_img
# visualization
COLORS = [
[0.000, 0.447, 0.741],
[0.850, 0.325, 0.098],
[0.929, 0.694, 0.125],
[0.494, 0.184, 0.556],
[0.466, 0.674, 0.188],
[0.301, 0.745, 0.933]
]
# Draw the bounding boxes on image.
def fig2img(fig):
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
# Draw the bounding boxes.
def visualize_prediction(pil_img, output_dict, threshold=0.7, id2label=None):
keep = output_dict["scores"] > threshold
boxes = output_dict["boxes"][keep].tolist()
scores = output_dict["scores"][keep].tolist()
labels = output_dict["labels"][keep].tolist()
if id2label is not None:
labels = [id2label[x] for x in labels]
plt.figure(figsize=(16, 10))
plt.imshow(pil_img)
ax = plt.gca()
colors = COLORS * 100
for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=3))
ax.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
plt.axis("off")
return fig2img(plt.gcf())
# Gradio interface
title = """<h1 id="title">Object Detection App with DETR</h1>"""
css = '''
h1#title {
text-align: center;
}
'''
demo = gr.Blocks(css=css)
with demo:
gr.Markdown(title)
with gr.Tabs():
with gr.TabItem('Image URL'):
with gr.Row():
url_input = gr.Textbox(lines=2,label='Enter valid image URL here..')
img_output_from_url = gr.Image(shape=(650,650))
url_but = gr.Button('Detect')
url_but.click(detect_objects,inputs=[url_input],outputs=img_output_from_url,queue=True)
demo.launch(enable_queue=True)