#!/usr/bin/env python import pathlib import gradio as gr import numpy as np import PIL.Image import spaces import torch from sahi.prediction import ObjectPrediction from sahi.utils.cv import visualize_object_predictions from transformers import AutoImageProcessor, DetaForObjectDetection DESCRIPTION = "# DETA (Detection Transformers with Assignment)" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") MODEL_ID = "jozhang97/deta-swin-large" image_processor = AutoImageProcessor.from_pretrained(MODEL_ID) model = DetaForObjectDetection.from_pretrained(MODEL_ID) model.to(device) @spaces.GPU @torch.inference_mode() def run(image_path: str, threshold: float) -> np.ndarray: image = PIL.Image.open(image_path) inputs = image_processor(images=image, return_tensors="pt").to(device) outputs = model(**inputs) target_sizes = torch.tensor([image.size[::-1]]) results = image_processor.post_process_object_detection(outputs, threshold=threshold, target_sizes=target_sizes)[0] boxes = results["boxes"].cpu().numpy() scores = results["scores"].cpu().numpy() cat_ids = results["labels"].cpu().numpy().tolist() preds = [] for box, score, cat_id in zip(boxes, scores, cat_ids): box = np.round(box).astype(int) cat_label = model.config.id2label[cat_id] pred = ObjectPrediction(bbox=box, category_id=cat_id, category_name=cat_label, score=score) preds.append(pred) res = visualize_object_predictions(np.asarray(image), preds)["image"] return res with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): image = gr.Image(label="Input image", type="filepath") threshold = gr.Slider(label="Score threshold", minimum=0, maximum=1, step=0.01, value=0.1) run_button = gr.Button() result = gr.Image(label="Result") gr.Examples( examples=[[path, 0.1] for path in sorted(pathlib.Path("images").glob("*.jpg"))], inputs=[image, threshold], outputs=result, fn=run, ) run_button.click( fn=run, inputs=[image, threshold], outputs=result, api_name="predict", ) if __name__ == "__main__": demo.queue(max_size=20).launch()