qgyd2021's picture
Update main.py
0fe6c48 verified
raw
history blame contribute delete
No virus
9.01 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import io
import json
import os
import re
from typing import Dict, List
from project_settings import project_path
os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix()
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import requests
import torch
from transformers.models.auto.processing_auto import AutoImageProcessor
from transformers.models.auto.feature_extraction_auto import AutoFeatureExtractor
from transformers.models.auto.modeling_auto import AutoModelForObjectDetection
import validators
from project_settings import project_path
# colors for visualization
COLORS = [
[0.000, 0.447, 0.741],
[0.850, 0.325, 0.098],
[0.929, 0.694, 0.125],
[0.494, 0.184, 0.556],
[0.466, 0.674, 0.188],
[0.301, 0.745, 0.933]
]
def get_original_image(url_input):
if validators.url(url_input):
image = Image.open(requests.get(url_input, stream=True).raw)
return image
def figure2image(fig):
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
pil_image = Image.open(buf)
base_width = 750
width_percent = base_width / float(pil_image.size[0])
height_size = (float(pil_image.size[1]) * float(width_percent))
height_size = int(height_size)
pil_image = pil_image.resize((base_width, height_size), Image.Resampling.LANCZOS)
return pil_image
def non_max_suppression(boxes, scores, threshold):
"""Apply non-maximum suppression at test time to avoid detecting too many
overlapping bounding boxes for a given object.
Args:
boxes: array of [xmin, ymin, xmax, ymax]
scores: array of scores associated with each box.
threshold: IoU threshold
Return:
keep: indices of the boxes to keep
"""
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1] # get boxes with more confidence first
keep = []
while order.size > 0:
i = order[0] # pick max confidence box
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1) # maximum width
h = np.maximum(0.0, yy2 - yy1 + 1) # maximum height
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= threshold)[0]
order = order[inds + 1]
return keep
def draw_boxes(image, boxes, scores, labels, threshold: float,
idx_to_label: Dict[int, str] = None, labels_to_show: str = None):
if isinstance(labels_to_show, str):
if len(labels_to_show.strip()) == 0:
labels_to_show = None
else:
labels_to_show = labels_to_show.split(",")
labels_to_show = [label.strip().lower() for label in labels_to_show]
labels_to_show = None if len(labels_to_show) == 0 else labels_to_show
plt.figure(figsize=(50, 50))
plt.imshow(image)
if idx_to_label is not None:
labels = [idx_to_label[x] for x in labels]
axis = plt.gca()
colors = COLORS * len(boxes)
for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
if labels_to_show is not None and label.lower() not in labels_to_show:
continue
if score < threshold:
continue
axis.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=10))
axis.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=60, bbox=dict(facecolor="yellow", alpha=0.8))
plt.axis("off")
return figure2image(plt.gcf())
def detr_object_detection(url_input: str,
image_input: Image,
pretrained_model_name_or_path: str = "qgyd2021/detr_cppe5_object_detection",
threshold: float = 0.5,
iou_threshold: float = 0.5,
labels_to_show: str = None,
):
# feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)
model = AutoModelForObjectDetection.from_pretrained(pretrained_model_name_or_path)
image_processor = AutoImageProcessor.from_pretrained(pretrained_model_name_or_path)
# image
if validators.url(url_input):
image = get_original_image(url_input)
elif image_input:
image = image_input
else:
raise AssertionError("at least one `url_input` and `image_input`")
image_size = torch.tensor([tuple(reversed(image.size))])
# infer
# inputs = feature_extractor(images=image, return_tensors="pt")
inputs = image_processor(images=image, return_tensors="pt")
outputs = model.forward(**inputs)
processed_outputs = image_processor.post_process_object_detection(
outputs, threshold=threshold, target_sizes=image_size)
# processed_outputs = feature_extractor.post_process(outputs, target_sizes=image_size)
processed_outputs = processed_outputs[0]
# draw box
boxes = processed_outputs["boxes"].detach().numpy()
scores = processed_outputs["scores"].detach().numpy()
labels = processed_outputs["labels"].detach().numpy()
keep = non_max_suppression(boxes, scores, threshold=iou_threshold)
boxes = boxes[keep]
scores = scores[keep]
labels = labels[keep]
viz_image: Image = draw_boxes(
image, boxes, scores, labels,
threshold=threshold,
idx_to_label=model.config.id2label,
labels_to_show=labels_to_show
)
return viz_image
def main():
title = "## Detr Cppe5 Object Detection"
description = """
reference:
https://huggingface.co/docs/transformers/tasks/object_detection
"""
example_urls = [
*[
[
"https://huggingface.co/datasets/intelli-zen/cppe-5/resolve/main/data/images/{}.png".format(idx),
"intelli-zen/detr_cppe5_object_detection",
0.5, 0.6, None
] for idx in range(1001, 1030)
]
]
example_images = [
[
"data/2lnWoly.jpg",
"intelli-zen/detr_cppe5_object_detection",
0.5, 0.6, None
]
]
with gr.Blocks() as blocks:
gr.Markdown(value=title)
gr.Markdown(value=description)
model_name = gr.components.Dropdown(
choices=[
"intelli-zen/detr_cppe5_object_detection",
],
value="intelli-zen/detr_cppe5_object_detection",
label="model_name",
)
threshold_slider = gr.components.Slider(
minimum=0, maximum=1.0,
step=0.01, value=0.5,
label="Threshold"
)
iou_threshold_slider = gr.components.Slider(
minimum=0, maximum=1.0,
step=0.1, value=0.5,
label="IOU Threshold"
)
classes_to_detect = gr.Textbox(placeholder="e.g. person, truck (split by , comma).",
label="labels to show")
with gr.Tabs():
with gr.TabItem("Image URL"):
with gr.Row():
with gr.Column():
url_input = gr.Textbox(lines=1, label="Enter valid image URL here..")
original_image = gr.Image()
url_input.change(get_original_image, url_input, original_image)
with gr.Column():
img_output_from_url = gr.Image()
url_but = gr.Button("Detect")
with gr.Row():
gr.Examples(examples=example_urls,
inputs=[url_input, model_name, threshold_slider, iou_threshold_slider],
examples_per_page=5,
)
with gr.TabItem("Image Upload"):
with gr.Row():
img_input = gr.Image(type="pil")
img_output_from_upload = gr.Image()
img_but = gr.Button("Detect")
with gr.Row():
gr.Examples(examples=example_images,
inputs=[img_input, model_name, threshold_slider, iou_threshold_slider],
examples_per_page=5,
)
inputs = [url_input, img_input, model_name, threshold_slider, iou_threshold_slider, classes_to_detect]
url_but.click(detr_object_detection, inputs=inputs, outputs=[img_output_from_url], queue=True)
img_but.click(detr_object_detection, inputs=inputs, outputs=[img_output_from_upload], queue=True)
blocks.queue().launch()
return
if __name__ == '__main__':
main()