yuragoithf's picture
Update app.py
0aed315
raw
history blame
3.13 kB
import io
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
from transformers import AutoFeatureExtractor, YolosForObjectDetection
from PIL import Image
COLORS = [
[0.000, 0.447, 0.741],
[0.850, 0.325, 0.098],
[0.929, 0.694, 0.125],
[0.494, 0.184, 0.556],
[0.466, 0.674, 0.188],
[0.301, 0.745, 0.933],
]
def process_class_list(classes_string: str):
if classes_string == "":
return []
classes_list = classes_string.split(",")
classes_list = [x.strip() for x in classes_list]
return classes_list
def model_inference(img, prob_threshold, classes_to_show):
feature_extractor = AutoFeatureExtractor.from_pretrained(f"hustvl/yolos-small-dwr")
model = YolosForObjectDetection.from_pretrained(f"hustvl/yolos-small-dwr")
img = Image.fromarray(img)
pixel_values = feature_extractor(img, return_tensors="pt").pixel_values
with torch.no_grad():
outputs = model(pixel_values, output_attentions=True)
probas = outputs.logits.softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > prob_threshold
target_sizes = torch.tensor(img.size[::-1]).unsqueeze(0)
postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
bboxes_scaled = postprocessed_outputs[0]["boxes"]
classes_list = process_class_list(classes_to_show)
res_img = plot_results(img, probas[keep], bboxes_scaled[keep], model, classes_list)
return res_img
def plot_results(pil_img, prob, boxes, model, classes_list):
plt.figure(figsize=(16, 10))
plt.imshow(pil_img)
ax = plt.gca()
colors = COLORS * 100
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
cl = p.argmax()
object_class = model.config.id2label[cl.item()]
if len(classes_list) > 0:
if object_class not in classes_list:
continue
ax.add_patch(
plt.Rectangle(
(xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3
)
)
text = f"{object_class}: {p[cl]:0.2f}"
ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
plt.axis("off")
return fig2img(plt.gcf())
def fig2img(fig):
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
description = """Upload an image and get the predicted classes"""
title = """Object Detection"""
image_in = gr.components.Image(label="Upload an image")
image_out = gr.components.Image()
prob_threshold_slider = gr.components.Slider(
minimum=0, maximum=1.0, step=0.01, value=0.7, label="Probability Threshold"
)
classes_to_show = gr.components.Textbox(
placeholder="e.g. car, dog",
label="Classes to filter (leave empty to detect all classes)",
)
inputs = [image_in, prob_threshold_slider, classes_to_show]
examples = ["CTH.png", "carplane.webp"]
gr.Interface(fn=model_inference,
inputs=inputs,
outputs=image_out,
title=title,
examples=examples,
description=description).launch()