|
import gradio as gr |
|
from PIL import Image, ImageDraw |
|
from transformers import pipeline |
|
import numpy as np |
|
|
|
|
|
def plot_results(image, results, threshold=0.7): |
|
image = Image.fromarray(np.uint8(image)) |
|
draw = ImageDraw.Draw(image) |
|
for result in results: |
|
score = result["score"] |
|
label = result["label"] |
|
box = list(result["box"].values()) |
|
if score > threshold: |
|
x, y, x2, y2 = tuple(box) |
|
draw.rectangle((x, y, x2, y2), outline="red", width=1) |
|
draw.text((x, y), label, fill="white") |
|
draw.text( |
|
(x + 0.5, y - 0.5), |
|
text=str(score), |
|
fill="green" if score > 0.7 else "red", |
|
) |
|
return image |
|
|
|
def predict(image): |
|
|
|
obj_detector = pipeline( |
|
"object-detection", model="Antoine101/detr-resnet-50-dc5-fashionpedia-finetuned" |
|
) |
|
results = obj_detector(image) |
|
return plot_results(image, results) |
|
|
|
title = "Are you fashion?" |
|
description = """ |
|
DETR model finetuned on "detection-datasets/fashionpedia" for apparels detection. |
|
""" |
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(label="Input Image", type="pil"), |
|
outputs="image", |
|
examples=[["example1.jpg"]], |
|
title=title, |
|
description=description |
|
) |
|
demo.launch() |
|
|