Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python | |
| import pathlib | |
| import gradio as gr | |
| import numpy as np | |
| import PIL.Image | |
| import spaces | |
| import torch | |
| from sahi.prediction import ObjectPrediction | |
| from sahi.utils.cv import visualize_object_predictions | |
| from transformers import AutoImageProcessor, DetaForObjectDetection | |
| DESCRIPTION = "# DETA (Detection Transformers with Assignment)" | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| MODEL_ID = "jozhang97/deta-swin-large" | |
| image_processor = AutoImageProcessor.from_pretrained(MODEL_ID) | |
| model = DetaForObjectDetection.from_pretrained(MODEL_ID) | |
| model.to(device) | |
| def run(image_path: str, threshold: float) -> np.ndarray: | |
| image = PIL.Image.open(image_path) | |
| inputs = image_processor(images=image, return_tensors="pt").to(device) | |
| outputs = model(**inputs) | |
| target_sizes = torch.tensor([image.size[::-1]]) | |
| results = image_processor.post_process_object_detection(outputs, threshold=threshold, target_sizes=target_sizes)[0] | |
| boxes = results["boxes"].cpu().numpy() | |
| scores = results["scores"].cpu().numpy() | |
| cat_ids = results["labels"].cpu().numpy().tolist() | |
| preds = [] | |
| for box, score, cat_id in zip(boxes, scores, cat_ids, strict=True): | |
| box_int = np.round(box).astype(int) | |
| cat_label = model.config.id2label[cat_id] | |
| pred = ObjectPrediction(bbox=box_int, category_id=cat_id, category_name=cat_label, score=score) | |
| preds.append(pred) | |
| return visualize_object_predictions(np.asarray(image), preds)["image"] | |
| with gr.Blocks(css_paths="style.css") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image = gr.Image(label="Input image", type="filepath") | |
| threshold = gr.Slider(label="Score threshold", minimum=0, maximum=1, step=0.01, value=0.1) | |
| run_button = gr.Button() | |
| result = gr.Image(label="Result") | |
| gr.Examples( | |
| examples=[[path, 0.1] for path in sorted(pathlib.Path("images").glob("*.jpg"))], | |
| inputs=[image, threshold], | |
| outputs=result, | |
| fn=run, | |
| ) | |
| run_button.click( | |
| fn=run, | |
| inputs=[image, threshold], | |
| outputs=result, | |
| api_name="predict", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |