Spaces:
Sleeping
Sleeping
File size: 2,125 Bytes
9b658e7 2ccf6ca 9b658e7 2ccf6ca 9b658e7 2ccf6ca 9b658e7 2ccf6ca 9b658e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# app.py
import gradio as gr
#import spaces
#import torch
from PIL import Image
from transformers import pipeline
import matplotlib.pyplot as plt
import io
model_pipeline = pipeline(model="facebook/detr-resnet-50")
COLORS = [
[0.000, 0.447, 0.741],
[0.850, 0.325, 0.098],
[0.929, 0.694, 0.125],
[0.494, 0.184, 0.556],
[0.466, 0.674, 0.188],
[0.301, 0.745, 0.933],
]
def get_output_figure(pil_img, results, threshold):
plt.figure(figsize=(16, 10))
plt.imshow(pil_img)
ax = plt.gca()
colors = COLORS * 100
for result in results:
score = result["score"]
label = result["label"]
box = list(result["box"].values())
if score > threshold:
c = COLORS[hash(label) % len(COLORS)]
ax.add_patch(
plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], fill=False, color=c, linewidth=3)
)
text = f"{label}: {score:0.2f}"
ax.text(box[0], box[1], text, fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
plt.axis("off")
return plt.gcf()
#@spaces.GPU
def detect(image, threshold=0.9):
results = model_pipeline(image)
print(results)
output_figure = get_output_figure(image, results, threshold=threshold)
buf = io.BytesIO()
output_figure.savefig(buf, bbox_inches="tight")
buf.seek(0)
output_pil_img = Image.open(buf)
return output_pil_img
with gr.Blocks() as demo:
gr.Markdown("# Object detection with DETR on COCO dataset")
gr.Markdown(
"""
This application uses a DETR (DEtection TRansformers) model to detect objects on images.
This version was trained using the COCO dataset.
You can load an image and see the predictions for the objects detected.
"""
)
gr.Interface(
fn=detect,
inputs=[gr.Image(label="Input image", type="pil"), \
gr.Slider(0, 1.0, value=0.9, label='Threshold')],
outputs=[gr.Image(label="Output prediction", type="pil")],
examples=[['samples/savanna.jpg']],
)
demo.launch(show_error=True) |