Vineedhar's picture
Upload app.py
254e029 verified
raw
history blame
No virus
3.35 kB
import torch
from transformers import pipeline
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from random import choice
import io
detector50 = pipeline(model="facebook/detr-resnet-50")
detector101 = pipeline(model="facebook/detr-resnet-101")
import gradio as gr
COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
fdic = {
"family" : "Impact",
"style" : "italic",
"size" : 15,
"color" : "yellow",
"weight" : "bold"
}
def get_figure(in_pil_img, in_results):
plt.figure(figsize=(16, 10))
plt.imshow(in_pil_img)
#pyplot.gcf()
ax = plt.gca()
for prediction in in_results:
selected_color = choice(COLORS)
x, y = prediction['box']['xmin'], prediction['box']['ymin'],
w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
plt.axis("off")
return plt.gcf()
def infer(model, in_pil_img):
results = None
if model == "detr-resnet-101":
results = detector101(in_pil_img)
else:
results = detector50(in_pil_img)
figure = get_figure(in_pil_img, results)
buf = io.BytesIO()
figure.savefig(buf, bbox_inches='tight')
buf.seek(0)
output_pil_img = Image.open(buf)
return output_pil_img
with gr.Blocks(title="DETR Object Detection - ClassCat",
css=".gradio-container {background:lightyellow;}"
) as demo:
#sample_index = gr.State([])
gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">DETR Object Detection</div>""")
gr.HTML("""<h4 style="color:navy;">1. Select a model.</h4>""")
model = gr.Radio(["detr-resnet-50", "detr-resnet-101"], value="detr-resnet-50", label="Model name")
gr.HTML("""<br/>""")
gr.HTML("""<h4 style="color:navy;">2-a. Select an example by clicking a thumbnail below.</h4>""")
gr.HTML("""<h4 style="color:navy;">2-b. Or upload an image by clicking on the canvas.</h4>""")
with gr.Row():
input_image = gr.Image(label="Input image", type="pil")
output_image = gr.Image(label="Output image with predicted instances", type="pil")
gr.Examples(['samples/cats.jpg', 'samples/detectron2.png', 'samples/cat.jpg', 'samples/hotdog.jpg'], inputs=input_image)
gr.HTML("""<br/>""")
gr.HTML("""<h4 style="color:navy;">3. Then, click "Infer" button to predict object instances. It will take about 10 seconds (on cpu)</h4>""")
send_btn = gr.Button("Infer")
send_btn.click(fn=infer, inputs=[model, input_image], outputs=[output_image])
gr.HTML("""<br/>""")
gr.HTML("""<h4 style="color:navy;">Reference</h4>""")
gr.HTML("""<ul>""")
gr.HTML("""<li><a href="https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_attention.ipynb" target="_blank">Hands-on tutorial for DETR</a>""")
gr.HTML("""</ul>""")
#demo.queue()
demo.launch(debug=True)
### EOF ###