File size: 7,009 Bytes
9434aee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import io
import matplotlib.pyplot as plt
import requests, validators
import torch
import pathlib
from PIL import Image
from transformers import AutoFeatureExtractor, DetrForObjectDetection, YolosForObjectDetection
from ultralyticsplus import YOLO, render_result
import os
# colors for visualization
COLORS = [
[0.000, 0.447, 0.741],
[0.850, 0.325, 0.098],
[0.929, 0.694, 0.125],
[0.494, 0.184, 0.556],
[0.466, 0.674, 0.188],
[0.301, 0.745, 0.933]
]
YOLOV8_LABELS = ['pedestrian', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor']
def make_prediction(img, feature_extractor, model):
inputs = feature_extractor(img, return_tensors="pt")
outputs = model(**inputs)
img_size = torch.tensor([tuple(reversed(img.size))])
processed_outputs = feature_extractor.post_process(outputs, img_size)
return processed_outputs
def fig2img(fig):
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
def visualize_prediction(pil_img, output_dict, threshold=0.7, id2label=None):
keep = output_dict["scores"] > threshold
boxes = output_dict["boxes"][keep].tolist()
scores = output_dict["scores"][keep].tolist()
labels = output_dict["labels"][keep].tolist()
if id2label is not None:
labels = [id2label[x] for x in labels]
# print("Labels " + str(labels))
plt.figure(figsize=(16, 10))
plt.imshow(pil_img)
ax = plt.gca()
colors = COLORS * 100
for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=3))
ax.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
plt.axis("off")
return fig2img(plt.gcf())
def detect_objects(model_name,url_input,image_input,threshold):
if 'yolov8' in model_name:
# Working on getting this to work, another approach
# https://docs.ultralytics.com/modes/predict/#key-features-of-predict-mode
model = YOLO(model_name)
# set model parameters
model.overrides['conf'] = 0.15 # NMS confidence threshold
model.overrides['iou'] = 0.05 # NMS IoU threshold https://www.google.com/search?client=firefox-b-1-d&q=intersection+over+union+meaning
model.overrides['agnostic_nms'] = False # NMS class-agnostic
model.overrides['max_det'] = 1000 # maximum number of detections per image
results = model.predict(image_input)
render = render_result(model=model, image=image_input, result=results[0])
final_str = ""
final_str_abv = ""
final_str_else = ""
for result in results:
boxes = result.boxes.cpu().numpy()
for i, box in enumerate(boxes):
# r = box.xyxy[0].astype(int)
coordinates = box.xyxy[0].astype(int)
try:
label = YOLOV8_LABELS[int(box.cls)]
except:
label = "ERROR"
try:
confi = float(box.conf)
except:
confi = 0.0
# final_str_abv += str() + "__" + str(box.cls) + "__" + str(box.conf) + "__" + str(box) + "\n"
if confi >= threshold:
final_str_abv += f"Detected `{label}` with confidence `{confi}` at location `{coordinates}`\n"
else:
final_str_else += f"Detected `{label}` with confidence `{confi}` at location `{coordinates}`\n"
final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else
return render, final_str
else:
#Extract model and feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
if 'detr' in model_name:
model = DetrForObjectDetection.from_pretrained(model_name)
elif 'yolos' in model_name:
model = YolosForObjectDetection.from_pretrained(model_name)
tb_label = ""
if validators.url(url_input):
image = Image.open(requests.get(url_input, stream=True).raw)
tb_label = "Confidence Values URL"
elif image_input:
image = image_input
tb_label = "Confidence Values Upload"
#Make prediction
processed_output_list = make_prediction(image, feature_extractor, model)
# print("After make_prediction" + str(processed_output_list))
processed_outputs = processed_output_list[0]
#Visualize prediction
viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
# return [viz_img, processed_outputs]
# print(type(viz_img))
final_str_abv = ""
final_str_else = ""
for score, label, box in sorted(zip(processed_outputs["scores"], processed_outputs["labels"], processed_outputs["boxes"]), key = lambda x: x[0].item(), reverse=True):
box = [round(i, 2) for i in box.tolist()]
if score.item() >= threshold:
final_str_abv += f"Detected `{model.config.id2label[label.item()]}` with confidence `{round(score.item(), 3)}` at location `{box}`\n"
else:
final_str_else += f"Detected `{model.config.id2label[label.item()]}` with confidence `{round(score.item(), 3)}` at location `{box}`\n"
# https://docs.python.org/3/library/string.html#format-examples
final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else
return viz_img, final_str
title = """<h1 id="title">Object Detection App with DETR and YOLOS</h1>"""
description = """
Links to HuggingFace Models:
- [facebook/detr-resnet-50](https://huggingface.co/facebook/detr-resnet-50)
- [facebook/detr-resnet-101](https://huggingface.co/facebook/detr-resnet-101)
- [hustvl/yolos-small](https://huggingface.co/hustvl/yolos-small)
- [hustvl/yolos-tiny](https://huggingface.co/hustvl/yolos-tiny)
- [facebook/detr-resnet-101-dc5](https://huggingface.co/facebook/detr-resnet-101-dc5)
- [hustvl/yolos-small-300](https://huggingface.co/hustvl/yolos-small-300)
- [mshamrai/yolov8x-visdrone](https://huggingface.co/mshamrai/yolov8x-visdrone)
"""
models = ["facebook/detr-resnet-50","facebook/detr-resnet-101",'hustvl/yolos-small','hustvl/yolos-tiny','facebook/detr-resnet-101-dc5', 'hustvl/yolos-small-300', 'mshamrai/yolov8x-visdrone']
urls = ["https://c8.alamy.com/comp/J2AB4K/the-new-york-stock-exchange-on-the-wall-street-in-new-york-J2AB4K.jpg"]
TEST_IMAGE = Image.open(r"images/Test_Street_VisDrone.JPG")
# image_functions.detect_objects('facebook/detr-resnet-50', "", image_functions.TEST_IMAGE, 0.7)
|