marine-vessel-detection / src /yolov8_onnx.py
mayrajeo's picture
Update src/yolov8_onnx.py
b04b2b6
raw
history blame contribute delete
No virus
5.68 kB
import cv2
import numpy as np
import onnxruntime as ort
import torch
import copy
from ultralytics.utils import ROOT, yaml_load
from ultralytics.utils.checks import check_requirements, check_yaml
class Yolov8onnx:
def __init__(self,
onnx_model,
input_width,
input_height,
confidence_thres,
iou_thres,
device='cpu'):
"""
Initializes an instance of the Yolov8 class.
Args:
onnx_model: Path to the ONNX model.
confidence_thres: Confidence threshold for filtering detections.
iou_thres: IoU (Intersection over Union) threshold for non-maximum suppression.
"""
self.onnx_model = onnx_model
self.confidence_thres = confidence_thres
self.iou_thres = iou_thres
self.input_width = input_width
self.input_height = input_height
#if device == 'cpu':
# providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
#else:
providers = ['CPUExecutionProvider']
self.onnx_session = ort.InferenceSession(
onnx_model,
providers=providers
)
self.input_name = self.onnx_session.get_inputs()[0].name
self.output_name = self.onnx_session.get_outputs()[0].name
def preprocess(self, input_image):
"""
Preprocesses the input image before performing inference.
Returns:
image_data: Preprocessed image data ready for inference.
"""
# Read the input image using OpenCV
self.img = input_image
# Get the height and width of the input image
self.img_height, self.img_width = self.img.shape[:2]
# Convert the image color space from BGR to RGB
img = cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB)
# Resize the image to match the input shape
img = cv2.resize(img, (self.input_width, self.input_height))
# Normalize the image data by dividing it by 255.0
image_data = np.array(img) / 255.0
# Transpose the image to have the channel dimension as the first dimension
image_data = np.transpose(image_data, (2, 0, 1)) # Channel first
# Expand the dimensions of the image data to match the expected input shape
image_data = np.expand_dims(image_data, axis=0).astype(np.float32)
# Return the preprocessed image data
return image_data
def postprocess(self, output):
"""
Performs post-processing on the model's output to extract bounding boxes, scores, and class IDs.
"""
# Transpose and squeeze the output to match the expected shape
outputs = np.transpose(np.squeeze(output[0]))
# Get the number of rows in the outputs array
rows = outputs.shape[0]
# Lists to store the bounding boxes, scores, and class IDs of the detections
boxes = []
scores = []
class_ids = []
# Calculate the scaling factors for the bounding box coordinates
x_factor = self.img_width / self.input_width
y_factor = self.img_height / self.input_height
# Iterate over each row in the outputs array
for i in range(rows):
# Extract the class scores from the current row
classes_scores = outputs[i][4:]
# Find the maximum score among the class scores
max_score = np.amax(classes_scores)
# If the maximum score is above the confidence threshold
if max_score >= self.confidence_thres:
# Get the class ID with the highest score
class_id = np.argmax(classes_scores)
# Extract the bounding box coordinates from the current row
x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3]
# Calculate the scaled coordinates of the bounding box
left = int((x - w / 2) * x_factor)
top = int((y - h / 2) * y_factor)
width = int(w * x_factor)
height = int(h * y_factor)
# Add the class ID, score, and box coordinates to the respective lists
class_ids.append(int(class_id))
scores.append(max_score)
boxes.append([left, top, left+width, top+height])
# Apply non-maximum suppression to filter out overlapping bounding boxes
indices = cv2.dnn.NMSBoxes(boxes, scores, self.confidence_thres, self.iou_thres)
output_boxes = [boxes[i] for i in indices]
output_scores = [scores[i] for i in indices]
output_classes = [class_ids[i] for i in indices]
# Return the outputs
return output_boxes, output_scores, output_classes
def inference(self, image):
"""
Performs inference using an ONNX model and returns the output image with drawn detections.
Returns:
output_img: The output image with drawn detections.
"""
# Create an inference session using the ONNX model and specify execution providers
temp_image = copy.deepcopy(image)
image_height, image_width = image.shape[0], image.shape[1]
# Preprocess the image data
img_data = self.preprocess(temp_image)
# Run inference using the preprocessed image data
outputs = self.onnx_session.run(None, {self.input_name: img_data})
# Perform post-processing on the outputs to obtain output image.
bboxes, scores, class_ids = self.postprocess(outputs)
# Return the resulting output image
return bboxes, scores, class_ids