#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import io import json import os import re from typing import Dict, List from project_settings import project_path os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix() import gradio as gr import matplotlib.pyplot as plt import numpy as np from PIL import Image import requests import torch from transformers.models.auto.processing_auto import AutoImageProcessor from transformers.models.auto.feature_extraction_auto import AutoFeatureExtractor from transformers.models.auto.modeling_auto import AutoModelForObjectDetection import validators from project_settings import project_path # colors for visualization COLORS = [ [0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933] ] def get_original_image(url_input): if validators.url(url_input): image = Image.open(requests.get(url_input, stream=True).raw) return image def figure2image(fig): buf = io.BytesIO() fig.savefig(buf) buf.seek(0) pil_image = Image.open(buf) base_width = 750 width_percent = base_width / float(pil_image.size[0]) height_size = (float(pil_image.size[1]) * float(width_percent)) height_size = int(height_size) pil_image = pil_image.resize((base_width, height_size), Image.Resampling.LANCZOS) return pil_image def non_max_suppression(boxes, scores, threshold): """Apply non-maximum suppression at test time to avoid detecting too many overlapping bounding boxes for a given object. Args: boxes: array of [xmin, ymin, xmax, ymax] scores: array of scores associated with each box. threshold: IoU threshold Return: keep: indices of the boxes to keep """ x1 = boxes[:, 0] y1 = boxes[:, 1] x2 = boxes[:, 2] y2 = boxes[:, 3] areas = (x2 - x1 + 1) * (y2 - y1 + 1) order = scores.argsort()[::-1] # get boxes with more confidence first keep = [] while order.size > 0: i = order[0] # pick max confidence box keep.append(i) xx1 = np.maximum(x1[i], x1[order[1:]]) yy1 = np.maximum(y1[i], y1[order[1:]]) xx2 = np.minimum(x2[i], x2[order[1:]]) yy2 = np.minimum(y2[i], y2[order[1:]]) w = np.maximum(0.0, xx2 - xx1 + 1) # maximum width h = np.maximum(0.0, yy2 - yy1 + 1) # maximum height inter = w * h ovr = inter / (areas[i] + areas[order[1:]] - inter) inds = np.where(ovr <= threshold)[0] order = order[inds + 1] return keep def draw_boxes(image, boxes, scores, labels, threshold: float, idx_to_label: Dict[int, str] = None, labels_to_show: str = None): if isinstance(labels_to_show, str): if len(labels_to_show.strip()) == 0: labels_to_show = None else: labels_to_show = labels_to_show.split(",") labels_to_show = [label.strip().lower() for label in labels_to_show] labels_to_show = None if len(labels_to_show) == 0 else labels_to_show plt.figure(figsize=(50, 50)) plt.imshow(image) if idx_to_label is not None: labels = [idx_to_label[x] for x in labels] axis = plt.gca() colors = COLORS * len(boxes) for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors): if labels_to_show is not None and label.lower() not in labels_to_show: continue if score < threshold: continue axis.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=10)) axis.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=60, bbox=dict(facecolor="yellow", alpha=0.8)) plt.axis("off") return figure2image(plt.gcf()) def detr_object_detection(url_input: str, image_input: Image, pretrained_model_name_or_path: str = "qgyd2021/detr_cppe5_object_detection", threshold: float = 0.5, iou_threshold: float = 0.5, labels_to_show: str = None, ): # feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path) model = AutoModelForObjectDetection.from_pretrained(pretrained_model_name_or_path) image_processor = AutoImageProcessor.from_pretrained(pretrained_model_name_or_path) # image if validators.url(url_input): image = get_original_image(url_input) elif image_input: image = image_input else: raise AssertionError("at least one `url_input` and `image_input`") image_size = torch.tensor([tuple(reversed(image.size))]) # infer # inputs = feature_extractor(images=image, return_tensors="pt") inputs = image_processor(images=image, return_tensors="pt") outputs = model.forward(**inputs) processed_outputs = image_processor.post_process_object_detection( outputs, threshold=threshold, target_sizes=image_size) # processed_outputs = feature_extractor.post_process(outputs, target_sizes=image_size) processed_outputs = processed_outputs[0] # draw box boxes = processed_outputs["boxes"].detach().numpy() scores = processed_outputs["scores"].detach().numpy() labels = processed_outputs["labels"].detach().numpy() keep = non_max_suppression(boxes, scores, threshold=iou_threshold) boxes = boxes[keep] scores = scores[keep] labels = labels[keep] viz_image: Image = draw_boxes( image, boxes, scores, labels, threshold=threshold, idx_to_label=model.config.id2label, labels_to_show=labels_to_show ) return viz_image def main(): title = "## Detr Cppe5 Object Detection" description = """ reference: https://huggingface.co/docs/transformers/tasks/object_detection """ example_urls = [ *[ [ "https://huggingface.co/datasets/intelli-zen/cppe-5/resolve/main/data/images/{}.png".format(idx), "intelli-zen/detr_cppe5_object_detection", 0.5, 0.6, None ] for idx in range(1001, 1030) ] ] example_images = [ [ "data/2lnWoly.jpg", "intelli-zen/detr_cppe5_object_detection", 0.5, 0.6, None ] ] with gr.Blocks() as blocks: gr.Markdown(value=title) gr.Markdown(value=description) model_name = gr.components.Dropdown( choices=[ "intelli-zen/detr_cppe5_object_detection", ], value="intelli-zen/detr_cppe5_object_detection", label="model_name", ) threshold_slider = gr.components.Slider( minimum=0, maximum=1.0, step=0.01, value=0.5, label="Threshold" ) iou_threshold_slider = gr.components.Slider( minimum=0, maximum=1.0, step=0.1, value=0.5, label="IOU Threshold" ) classes_to_detect = gr.Textbox(placeholder="e.g. person, truck (split by , comma).", label="labels to show") with gr.Tabs(): with gr.TabItem("Image URL"): with gr.Row(): with gr.Column(): url_input = gr.Textbox(lines=1, label="Enter valid image URL here..") original_image = gr.Image() url_input.change(get_original_image, url_input, original_image) with gr.Column(): img_output_from_url = gr.Image() url_but = gr.Button("Detect") with gr.Row(): gr.Examples(examples=example_urls, inputs=[url_input, model_name, threshold_slider, iou_threshold_slider], examples_per_page=5, ) with gr.TabItem("Image Upload"): with gr.Row(): img_input = gr.Image(type="pil") img_output_from_upload = gr.Image() img_but = gr.Button("Detect") with gr.Row(): gr.Examples(examples=example_images, inputs=[img_input, model_name, threshold_slider, iou_threshold_slider], examples_per_page=5, ) inputs = [url_input, img_input, model_name, threshold_slider, iou_threshold_slider, classes_to_detect] url_but.click(detr_object_detection, inputs=inputs, outputs=[img_output_from_url], queue=True) img_but.click(detr_object_detection, inputs=inputs, outputs=[img_output_from_upload], queue=True) blocks.queue().launch() return if __name__ == '__main__': main()