#!/usr/bin/env python from __future__ import annotations import pathlib import gradio as gr import numpy as np import PIL.Image import torch from sahi.prediction import ObjectPrediction from sahi.utils.cv import visualize_object_predictions from transformers import AutoImageProcessor, DetaForObjectDetection from ultralytics import YOLO DESCRIPTION = '# Compare DETA and YOLOv8' device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') MODEL_ID = 'jozhang97/deta-swin-large' image_processor = AutoImageProcessor.from_pretrained(MODEL_ID) model_deta = DetaForObjectDetection.from_pretrained(MODEL_ID) model_deta.to(device) model_yolo = YOLO('yolov8x.pt') @torch.inference_mode() def run_deta(image_path: str, threshold: float) -> np.ndarray: image = PIL.Image.open(image_path) inputs = image_processor(images=image, return_tensors='pt').to(device) outputs = model_deta(**inputs) target_sizes = torch.tensor([image.size[::-1]]) results = image_processor.post_process_object_detection( outputs, threshold=threshold, target_sizes=target_sizes)[0] boxes = results['boxes'].cpu().numpy() scores = results['scores'].cpu().numpy() cat_ids = results['labels'].cpu().numpy().tolist() preds = [] for box, score, cat_id in zip(boxes, scores, cat_ids): box = np.round(box).astype(int) cat_label = model_deta.config.id2label[cat_id] pred = ObjectPrediction(bbox=box, category_id=cat_id, category_name=cat_label, score=score) preds.append(pred) res = visualize_object_predictions(np.asarray(image), preds)['image'] return res def run_yolov8(image_path: str, threshold: float) -> np.ndarray: image = PIL.Image.open(image_path) results = model_yolo(image, imgsz=640, conf=threshold) boxes = results[0].boxes.cpu().numpy().data preds = [] for box in boxes: score = box[4] cat_id = int(box[5]) box = np.round(box[:4]).astype(int) cat_label = model_yolo.model.names[cat_id] pred = ObjectPrediction(bbox=box, category_id=cat_id, category_name=cat_label, score=score) preds.append(pred) res = visualize_object_predictions(np.asarray(image), preds)['image'] return res def run(image_path: str, threshold: float) -> tuple[np.ndarray, np.ndarray]: return run_deta(image_path, threshold), run_yolov8(image_path, threshold) with gr.Blocks(css='style.css') as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): image = gr.Image(label='Input image', type='filepath') threshold = gr.Slider(label='Score threshold', minimum=0, maximum=1, value=0.5, step=0.01) run_button = gr.Button('Run') with gr.Column(): result_deta = gr.Image(label='Result (DETA)', type='numpy') result_yolo = gr.Image(label='Result (YOLOv8)', type='numpy') with gr.Row(): paths = sorted(pathlib.Path('images').glob('*.jpg')) gr.Examples(examples=[[path.as_posix(), 0.5] for path in paths], inputs=[ image, threshold, ], outputs=[ result_deta, result_yolo, ], fn=run, cache_examples=True) run_button.click(fn=run, inputs=[ image, threshold, ], outputs=[ result_deta, result_yolo, ]) demo.queue().launch()