captain-awesome's picture
Update app.py
ee5988c verified
raw
history blame
No virus
2.8 kB
import gradio as gr
import pandas as pd
from transformers import AutoImageProcessor, AutoModelForObjectDetection
from PIL import Image, ImageDraw
import torch
from transformers import DetrImageProcessor, DetrForObjectDetection
#image_processor = AutoImageProcessor.from_pretrained('hustvl/yolos-small')
#model = AutoModelForObjectDetection.from_pretrained('hustvl/yolos-small')
image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
colors = ["red",
"orange",
"yellow",
"green",
"blue",
"indigo",
"violet",
"brown",
"black",
"slategray",
]
# Resized image width
WIDTH = 900
def detect(image):
print(image)
width, height = image.size
ratio = float(WIDTH) / float(width)
new_h = height * ratio
image = image.resize((int(WIDTH), int(new_h)), Image.Resampling.LANCZOS)
inputs = image_processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# convert outputs to COCO API
target_sizes = torch.tensor([image.size[::-1]])
results = image_processor.post_process_object_detection(outputs,threshold=0.9, target_sizes=target_sizes)[0]
draw = ImageDraw.Draw(image)
# label and the count
counts = {}
for score, label in zip(results["scores"], results["labels"]):
label_name = model.config.id2label[label.item()]
if label_name not in counts:
counts[label_name] = 0
counts[label_name] += 1
count_results = {k: v for k, v in (sorted(counts.items(), key=lambda item: item[1], reverse=True)[:10])}
label2color = {}
for idx, label in enumerate(count_results):
label2color[label] = colors[idx]
for label, box in zip(results["labels"], results["boxes"]):
label_name = model.config.id2label[label.item()]
if label_name in count_results:
box = [round(i, 4) for i in box.tolist()]
x1, y1, x2, y2 = tuple(box)
draw.rectangle((x1, y1, x2, y2), outline=label2color[label_name], width=2)
draw.text((x1, y1), label_name, fill="white")
df = pd.DataFrame({
'label': [label for label in count_results],
'counts': [counts[label] for label in count_results]
})
return image, df, count_results
demo = gr.Interface(
fn=detect,
inputs=[gr.Image(label="Input image", type="pil")],
outputs=[gr.Image(label="Output image"), gr.BarPlot(show_label=False, x="label", y="counts", x_title="Labels", y_title="Counts", vertical=False), gr.Textbox(show_label=False)],
title="FB Object Detection",
cache_examples=False
)
demo.launch()