Spaces:
Sleeping
Sleeping
import io | |
import torch | |
import numpy as np | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
from transformers import AutoFeatureExtractor, YolosForObjectDetection | |
from PIL import Image | |
COLORS = [ | |
[0.000, 0.447, 0.741], | |
[0.850, 0.325, 0.098], | |
[0.929, 0.694, 0.125], | |
[0.494, 0.184, 0.556], | |
[0.466, 0.674, 0.188], | |
[0.301, 0.745, 0.933], | |
] | |
def process_class_list(classes_string: str): | |
if classes_string == "": | |
return [] | |
classes_list = classes_string.split(",") | |
classes_list = [x.strip() for x in classes_list] | |
return classes_list | |
def model_inference(img, prob_threshold, classes_to_show): | |
feature_extractor = AutoFeatureExtractor.from_pretrained(f"hustvl/yolos-small-dwr") | |
model = YolosForObjectDetection.from_pretrained(f"hustvl/yolos-small-dwr") | |
img = Image.fromarray(img) | |
pixel_values = feature_extractor(img, return_tensors="pt").pixel_values | |
with torch.no_grad(): | |
outputs = model(pixel_values, output_attentions=True) | |
probas = outputs.logits.softmax(-1)[0, :, :-1] | |
keep = probas.max(-1).values > prob_threshold | |
target_sizes = torch.tensor(img.size[::-1]).unsqueeze(0) | |
postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes) | |
bboxes_scaled = postprocessed_outputs[0]["boxes"] | |
classes_list = process_class_list(classes_to_show) | |
res_img = plot_results(img, probas[keep], bboxes_scaled[keep], model, classes_list) | |
return res_img | |
def plot_results(pil_img, prob, boxes, model, classes_list): | |
plt.figure(figsize=(16, 10)) | |
plt.imshow(pil_img) | |
ax = plt.gca() | |
colors = COLORS * 100 | |
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors): | |
cl = p.argmax() | |
object_class = model.config.id2label[cl.item()] | |
if len(classes_list) > 0: | |
if object_class not in classes_list: | |
continue | |
ax.add_patch( | |
plt.Rectangle( | |
(xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3 | |
) | |
) | |
text = f"{object_class}: {p[cl]:0.2f}" | |
ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5)) | |
plt.axis("off") | |
return fig2img(plt.gcf()) | |
def fig2img(fig): | |
buf = io.BytesIO() | |
fig.savefig(buf) | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
description = """Upload an image and get the predicted classes""" | |
title = """Object Detection""" | |
image_in = gr.components.Image(label="Upload an image") | |
image_out = gr.components.Image() | |
prob_threshold_slider = gr.components.Slider( | |
minimum=0, maximum=1.0, step=0.01, value=0.7, label="Probability Threshold" | |
) | |
classes_to_show = gr.components.Textbox( | |
placeholder="e.g. car, dog", | |
label="Classes to filter (leave empty to detect all classes)", | |
) | |
inputs = [image_in, prob_threshold_slider, classes_to_show] | |
examples = ["CTH.png", "carplane.webp"] | |
gr.Interface(fn=model_inference, | |
inputs=inputs, | |
outputs=image_out, | |
title=title, | |
examples=examples, | |
description=description).launch() | |