OMG-InstantID / inference /core /models /keypoints_detection_base.py
Fucius's picture
Upload 422 files
2eafbc4 verified
from typing import List, Optional, Tuple
import numpy as np
from inference.core.entities.responses.inference import (
InferenceResponseImage,
Keypoint,
KeypointsDetectionInferenceResponse,
KeypointsPrediction,
)
from inference.core.exceptions import ModelArtefactError
from inference.core.models.object_detection_base import (
ObjectDetectionBaseOnnxRoboflowInferenceModel,
)
from inference.core.models.types import PreprocessReturnMetadata
from inference.core.models.utils.keypoints import model_keypoints_to_response
from inference.core.models.utils.validate import (
get_num_classes_from_model_prediction_shape,
)
from inference.core.nms import w_np_non_max_suppression
from inference.core.utils.postprocess import post_process_bboxes, post_process_keypoints
DEFAULT_CONFIDENCE = 0.4
DEFAULT_IOU_THRESH = 0.3
DEFAULT_CLASS_AGNOSTIC_NMS = False
DEFAUlT_MAX_DETECTIONS = 300
DEFAULT_MAX_CANDIDATES = 3000
class KeypointsDetectionBaseOnnxRoboflowInferenceModel(
ObjectDetectionBaseOnnxRoboflowInferenceModel
):
"""Roboflow ONNX Object detection model. This class implements an object detection specific infer method."""
task_type = "keypoint-detection"
def __init__(self, model_id: str, *args, **kwargs):
super().__init__(model_id, *args, **kwargs)
def get_infer_bucket_file_list(self) -> list:
"""Returns the list of files to be downloaded from the inference bucket for ONNX model.
Returns:
list: A list of filenames specific to ONNX models.
"""
return ["environment.json", "class_names.txt", "keypoints_metadata.json"]
def postprocess(
self,
predictions: Tuple[np.ndarray],
preproc_return_metadata: PreprocessReturnMetadata,
class_agnostic_nms=DEFAULT_CLASS_AGNOSTIC_NMS,
confidence: float = DEFAULT_CONFIDENCE,
iou_threshold: float = DEFAULT_IOU_THRESH,
max_candidates: int = DEFAULT_MAX_CANDIDATES,
max_detections: int = DEFAUlT_MAX_DETECTIONS,
return_image_dims: bool = False,
**kwargs,
) -> List[KeypointsDetectionInferenceResponse]:
"""Postprocesses the object detection predictions.
Args:
predictions (np.ndarray): Raw predictions from the model.
img_dims (List[Tuple[int, int]]): Dimensions of the images.
class_agnostic_nms (bool): Whether to apply class-agnostic non-max suppression. Default is False.
confidence (float): Confidence threshold for filtering detections. Default is 0.5.
iou_threshold (float): IoU threshold for non-max suppression. Default is 0.5.
max_candidates (int): Maximum number of candidate detections. Default is 3000.
max_detections (int): Maximum number of final detections. Default is 300.
Returns:
List[KeypointsDetectionInferenceResponse]: The post-processed predictions.
"""
predictions = predictions[0]
number_of_classes = len(self.get_class_names)
num_masks = predictions.shape[2] - 5 - number_of_classes
predictions = w_np_non_max_suppression(
predictions,
conf_thresh=confidence,
iou_thresh=iou_threshold,
class_agnostic=class_agnostic_nms,
max_detections=max_detections,
max_candidate_detections=max_candidates,
num_masks=num_masks,
)
infer_shape = (self.img_size_h, self.img_size_w)
img_dims = preproc_return_metadata["img_dims"]
predictions = post_process_bboxes(
predictions=predictions,
infer_shape=infer_shape,
img_dims=img_dims,
preproc=self.preproc,
resize_method=self.resize_method,
disable_preproc_static_crop=preproc_return_metadata[
"disable_preproc_static_crop"
],
)
predictions = post_process_keypoints(
predictions=predictions,
keypoints_start_index=-num_masks,
infer_shape=infer_shape,
img_dims=img_dims,
preproc=self.preproc,
resize_method=self.resize_method,
disable_preproc_static_crop=preproc_return_metadata[
"disable_preproc_static_crop"
],
)
return self.make_response(predictions, img_dims, **kwargs)
def make_response(
self,
predictions: List[List[float]],
img_dims: List[Tuple[int, int]],
class_filter: Optional[List[str]] = None,
*args,
**kwargs,
) -> List[KeypointsDetectionInferenceResponse]:
"""Constructs object detection response objects based on predictions.
Args:
predictions (List[List[float]]): The list of predictions.
img_dims (List[Tuple[int, int]]): Dimensions of the images.
class_filter (Optional[List[str]]): A list of class names to filter, if provided.
Returns:
List[KeypointsDetectionInferenceResponse]: A list of response objects containing keypoints detection predictions.
"""
if isinstance(img_dims, dict) and "img_dims" in img_dims:
img_dims = img_dims["img_dims"]
keypoint_confidence_threshold = 0.0
if "request" in kwargs:
keypoint_confidence_threshold = kwargs["request"].keypoint_confidence
responses = [
KeypointsDetectionInferenceResponse(
predictions=[
KeypointsPrediction(
# Passing args as a dictionary here since one of the args is 'class' (a protected term in Python)
**{
"x": (pred[0] + pred[2]) / 2,
"y": (pred[1] + pred[3]) / 2,
"width": pred[2] - pred[0],
"height": pred[3] - pred[1],
"confidence": pred[4],
"class": self.class_names[int(pred[6])],
"class_id": int(pred[6]),
"keypoints": model_keypoints_to_response(
keypoints_metadata=self.keypoints_metadata,
keypoints=pred[7:],
predicted_object_class_id=int(
pred[4 + len(self.get_class_names)]
),
keypoint_confidence_threshold=keypoint_confidence_threshold,
),
}
)
for pred in batch_predictions
if not class_filter
or self.class_names[int(pred[6])] in class_filter
],
image=InferenceResponseImage(
width=img_dims[ind][1], height=img_dims[ind][0]
),
)
for ind, batch_predictions in enumerate(predictions)
]
return responses
def keypoints_count(self) -> int:
raise NotImplementedError
def validate_model_classes(self) -> None:
num_keypoints = self.keypoints_count()
output_shape = self.get_model_output_shape()
num_classes = get_num_classes_from_model_prediction_shape(
len_prediction=output_shape[2], keypoints=num_keypoints
)
if num_classes != self.num_classes:
raise ValueError(
f"Number of classes in model ({num_classes}) does not match the number of classes in the environment ({self.num_classes})"
)