Spaces:
Running
Running
File size: 4,513 Bytes
5e03e65 af93aba 30ef691 5e03e65 30ef691 5e03e65 30ef691 5e03e65 783a0a3 5e03e65 783a0a3 5e03e65 783a0a3 30ef691 783a0a3 5e03e65 783a0a3 5e03e65 783a0a3 30ef691 5e03e65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch
from transformers import AutoImageProcessor, AutoModelForObjectDetection
#from transformers import pipeline
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import io
from random import choice
image_processor_tiny = AutoImageProcessor.from_pretrained("hustvl/yolos-tiny")
model_tiny = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-tiny")
image_processor_small = AutoImageProcessor.from_pretrained("hustvl/yolos-small")
model_small = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-small")
import gradio as gr
COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
fdic = {
"family" : "Impact",
"style" : "italic",
"size" : 15,
"color" : "yellow",
"weight" : "bold"
}
def get_figure(in_pil_img, in_results):
plt.figure(figsize=(16, 10))
plt.imshow(in_pil_img)
ax = plt.gca()
for score, label, box in zip(in_results["scores"], in_results["labels"], in_results["boxes"]):
selected_color = choice(COLORS)
x, y, w, h = torch.round(box[0]).item(), torch.round(box[1]).item(), torch.round(box[2]-box[0]).item(), torch.round(box[3]-box[1]).item()
#x, y, w, h = int(box[0]), int(box[1]), int(box[2]-box[0]), int(box[3]-box[1])
ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=2))
ax.text(x, y, f"{model_tiny.config.id2label[label.item()]}: {round(score.item()*100, 2)}%", fontdict=fdic)
plt.axis("off")
return plt.gcf()
def infer(in_pil_img, in_model="yolos-tiny", in_threshold=0.9):
target_sizes = torch.tensor([in_pil_img.size[::-1]])
if in_model == "yolos-small":
inputs = image_processor_small(images=in_pil_img, return_tensors="pt")
outputs = model_small(**inputs)
# convert outputs (bounding boxes and class logits) to COCO API
results = image_processor_small.post_process_object_detection(outputs, threshold=in_threshold, target_sizes=target_sizes)[0]
else:
inputs = image_processor_tiny(images=in_pil_img, return_tensors="pt")
outputs = model_tiny(**inputs)
# convert outputs (bounding boxes and class logits) to COCO API
results = image_processor_tiny.post_process_object_detection(outputs, threshold=in_threshold, target_sizes=target_sizes)[0]
figure = get_figure(in_pil_img, results)
buf = io.BytesIO()
figure.savefig(buf, bbox_inches='tight')
buf.seek(0)
output_pil_img = Image.open(buf)
return output_pil_img
with gr.Blocks(title="YOLOS Object Detection - ClassCat",
css=".gradio-container {background:lightyellow;}"
) as demo:
#sample_index = gr.State([])
gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">YOLOS Object Detection</div>""")
gr.HTML("""<h4 style="color:navy;">1. Select a model.</h4>""")
model = gr.Radio(["yolos-tiny", "yolos-small"], value="yolos-tiny", label="Model name")
gr.HTML("""<br/>""")
gr.HTML("""<h4 style="color:navy;">2-a. Select an example by clicking a thumbnail below.</h4>""")
gr.HTML("""<h4 style="color:navy;">2-b. Or upload an image by clicking on the canvas.</h4>""")
with gr.Row():
input_image = gr.Image(label="Input image", type="pil")
output_image = gr.Image(label="Output image with predicted instances", type="pil")
gr.Examples(['samples/cats.jpg', 'samples/detectron2.png', 'samples/cat.jpg', 'samples/hotdog.jpg'], inputs=input_image)
gr.HTML("""<br/>""")
gr.HTML("""<h4 style="color:navy;">3. Set threshold value (default to 0.9)</h4>""")
threshold = gr.Slider(0, 1.0, value=0.9, label='threshold')
gr.HTML("""<br/>""")
gr.HTML("""<h4 style="color:navy;">4. Then, click "Infer" button to predict object instances. It will take about 10 seconds (on cpu)</h4>""")
send_btn = gr.Button("Infer")
send_btn.click(fn=infer, inputs=[input_image, model, threshold], outputs=[output_image])
gr.HTML("""<br/>""")
gr.HTML("""<h4 style="color:navy;">Reference</h4>""")
gr.HTML("""<ul>""")
gr.HTML("""<li><a href="https://huggingface.co/docs/transformers/model_doc/yolos" target="_blank">Hugging Face Transformers - YOLOS</a>""")
gr.HTML("""</ul>""")
#demo.queue()
demo.launch(debug=True)
### EOF ###
|