sagemaker
fix typo
a57a57c
raw
history blame contribute delete
No virus
3.76 kB
from transformers import AutoFeatureExtractor, YolosForObjectDetection
import gradio as gr
from PIL import Image
import torch
import matplotlib.pyplot as plt
import io
import numpy as np
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
def get_class_list_from_input(classes_string: str):
if classes_string == "":
return []
classes_list = classes_string.split(",")
classes_list = [x.strip() for x in classes_list]
return classes_list
def infer(img, model_name: str, prob_threshold: int, classes_to_show = str):
feature_extractor = AutoFeatureExtractor.from_pretrained(f"hustvl/{model_name}")
model = YolosForObjectDetection.from_pretrained(f"hustvl/{model_name}")
img = Image.fromarray(img)
pixel_values = feature_extractor(img, return_tensors="pt").pixel_values
with torch.no_grad():
outputs = model(pixel_values, output_attentions=True)
probas = outputs.logits.softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > prob_threshold
target_sizes = torch.tensor(img.size[::-1]).unsqueeze(0)
postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
bboxes_scaled = postprocessed_outputs[0]['boxes']
classes_list = get_class_list_from_input(classes_to_show)
res_img = plot_results(img, probas[keep], bboxes_scaled[keep], model, classes_list)
return res_img
def plot_results(pil_img, prob, boxes, model, classes_list):
plt.figure(figsize=(16,10))
plt.imshow(pil_img)
ax = plt.gca()
colors = COLORS * 100
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
cl = p.argmax()
object_class = model.config.id2label[cl.item()]
if len(classes_list) > 0 :
if object_class not in classes_list:
continue
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
fill=False, color=c, linewidth=3))
text = f'{object_class}: {p[cl]:0.2f}'
ax.text(xmin, ymin, text, fontsize=15,
bbox=dict(facecolor='yellow', alpha=0.5))
plt.axis('off')
return fig2img(plt.gcf())
def fig2img(fig):
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
description = """Object Detection with YOLOS. Choose https://github.com/amikelive/coco-labels/blob/master/coco-labels-2014_2017.txtyour model and you're good to go.
You can adapt the minimum probability threshold with the slider.
Additionally you can restrict the classes that will be shown by putting in a comma separated list of
[COCO classes](https://github.com/amikelive/coco-labels/blob/master/coco-labels-2014_2017.txt).
Leaving the field empty will show all classes"""
image_in = gr.components.Image()
image_out = gr.components.Image()
model_choice = gr.components.Dropdown(["yolos-tiny", "yolos-small", "yolos-base", "yolos-small-300", "yolos-small-dwr"], value="yolos-small", label="YOLOS Model")
prob_threshold_slider = gr.components.Slider(minimum=0, maximum=1.0, step=0.01, value=0.9, label="Probability Threshold")
classes_to_show = gr.components.Textbox(placeholder="e.g. person, boat", label="Classes to use (empty means all classes)")
Iface = gr.Interface(
fn=infer,
inputs=[image_in,model_choice, prob_threshold_slider, classes_to_show],
outputs=image_out,
#examples=[["examples/10_People_Marching_People_Marching_2_120.jpg"], ["examples/12_Group_Group_12_Group_Group_12_26.jpg"], ["examples/43_Row_Boat_Canoe_43_247.jpg"]],
title="Object Detection with YOLOS",
description=description,
).launch()