File size: 5,622 Bytes
f079597 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import logging
from typing import Any, Dict, List, Optional
import numpy as np
logger = logging.getLogger(__name__)
from sahi.models.base import DetectionModel
from sahi.prediction import ObjectPrediction
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
from sahi.utils.import_utils import check_requirements
from .yolov8_onnx import *
class Yolov8onnxDetectionModel(DetectionModel):
def check_dependencies(self) -> None:
check_requirements(["ultralytics"])
def load_model(self):
"""
Detection model is initialized and set to self.model.
"""
import yaml
from pathlib import Path
config = yaml.safe_load(Path(self.config_path).read_text())
try:
self.model = Yolov8onnx(
onnx_model=self.model_path,
input_width = config['imgsz'],
input_height = config['imgsz'],
confidence_thres=self.confidence_threshold,
iou_thres=0.5,
device=self.device
)
except Exception as e:
raise TypeError("model_path is not a valid yolov8 model path: ", e)
self.category_name_list = list(self.category_mapping.values())
self.category_name_list_len = len(self.category_name_list)
def perform_inference(self, image: np.ndarray):
"""
Prediction is performed using self.model and the prediction result is set to self._original_predictions.
Args:
image: np.ndarray
A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
"""
# Confirm model is loaded
if self.model is None:
raise ValueError("Model is not loaded, load it by calling .load_model()")
prediction_result = self.model.inference(image[:, :, ::-1]) # YOLOv8 expects numpy arrays to have BGR
#prediction_result = [
# result.boxes.data[result.boxes.data[:, 4] >= self.confidence_threshold] for result in prediction_result
#]
self._original_predictions = [prediction_result]
@property
def num_categories(self):
return self.category_name_list_len
@property
def has_mask(self):
return False
@property
def category_names(self):
return self.category_name_list
def _create_object_prediction_list_from_original_predictions(
self,
shift_amount_list: Optional[List[List[int]]] = [[0, 0]],
full_shape_list: Optional[List[List[int]]] = None,
):
"""
self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
self._object_prediction_list_per_image.
Args:
shift_amount_list: list of list
To shift the box and mask predictions from sliced image to full sized image, should
be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
full_shape_list: list of list
Size of the full image after shifting, should be in the form of
List[[height, width],[height, width],...]
"""
original_predictions = self._original_predictions
# compatilibty for sahi v0.8.15
shift_amount_list = fix_shift_amount_list(shift_amount_list)
full_shape_list = fix_full_shape_list(full_shape_list)
# handle all predictions
object_prediction_list_per_image = []
for image_ind, original_prediction in enumerate(original_predictions):
bboxes = original_prediction[0]
scores = original_prediction[1]
class_ids = original_prediction[2]
shift_amount = shift_amount_list[image_ind]
full_shape = None if full_shape_list is None else full_shape_list[image_ind]
object_prediction_list = []
# process predictions
for original_bbox, score, category_id in zip(bboxes, scores, class_ids):
x1 = original_bbox[0]
y1 = original_bbox[1]
x2 = original_bbox[2]
y2 = original_bbox[3]
bbox = [x1, y1, x2, y2]
category_name = self.category_mapping[str(category_id)]
# fix negative box coords
bbox[0] = max(0, bbox[0])
bbox[1] = max(0, bbox[1])
bbox[2] = max(0, bbox[2])
bbox[3] = max(0, bbox[3])
# fix out of image box coords
if full_shape is not None:
bbox[0] = min(full_shape[1], bbox[0])
bbox[1] = min(full_shape[0], bbox[1])
bbox[2] = min(full_shape[1], bbox[2])
bbox[3] = min(full_shape[0], bbox[3])
# ignore invalid predictions
if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]):
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
continue
object_prediction = ObjectPrediction(
bbox=bbox,
category_id=category_id,
score=score,
bool_mask=None,
category_name=category_name,
shift_amount=shift_amount,
full_shape=full_shape,
)
object_prediction_list.append(object_prediction)
object_prediction_list_per_image.append(object_prediction_list)
self._object_prediction_list_per_image = object_prediction_list_per_image |