|
|
|
|
|
import pathlib |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import PIL.Image |
|
import spaces |
|
import torch |
|
from sahi.prediction import ObjectPrediction |
|
from sahi.utils.cv import visualize_object_predictions |
|
from transformers import AutoImageProcessor, DetaForObjectDetection |
|
|
|
DESCRIPTION = "# DETA (Detection Transformers with Assignment)" |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
MODEL_ID = "jozhang97/deta-swin-large" |
|
image_processor = AutoImageProcessor.from_pretrained(MODEL_ID) |
|
model = DetaForObjectDetection.from_pretrained(MODEL_ID) |
|
model.to(device) |
|
|
|
|
|
@spaces.GPU |
|
@torch.inference_mode() |
|
def run(image_path: str, threshold: float) -> np.ndarray: |
|
image = PIL.Image.open(image_path) |
|
inputs = image_processor(images=image, return_tensors="pt").to(device) |
|
outputs = model(**inputs) |
|
target_sizes = torch.tensor([image.size[::-1]]) |
|
results = image_processor.post_process_object_detection(outputs, threshold=threshold, target_sizes=target_sizes)[0] |
|
|
|
boxes = results["boxes"].cpu().numpy() |
|
scores = results["scores"].cpu().numpy() |
|
cat_ids = results["labels"].cpu().numpy().tolist() |
|
|
|
preds = [] |
|
for box, score, cat_id in zip(boxes, scores, cat_ids): |
|
box = np.round(box).astype(int) |
|
cat_label = model.config.id2label[cat_id] |
|
pred = ObjectPrediction(bbox=box, category_id=cat_id, category_name=cat_label, score=score) |
|
preds.append(pred) |
|
|
|
res = visualize_object_predictions(np.asarray(image), preds)["image"] |
|
return res |
|
|
|
|
|
with gr.Blocks(css="style.css") as demo: |
|
gr.Markdown(DESCRIPTION) |
|
with gr.Row(): |
|
with gr.Column(): |
|
image = gr.Image(label="Input image", type="filepath") |
|
threshold = gr.Slider(label="Score threshold", minimum=0, maximum=1, step=0.01, value=0.1) |
|
run_button = gr.Button() |
|
result = gr.Image(label="Result") |
|
gr.Examples( |
|
examples=[[path, 0.1] for path in sorted(pathlib.Path("images").glob("*.jpg"))], |
|
inputs=[image, threshold], |
|
outputs=result, |
|
fn=run, |
|
) |
|
|
|
run_button.click( |
|
fn=run, |
|
inputs=[image, threshold], |
|
outputs=result, |
|
api_name="predict", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=20).launch() |
|
|