OMG-InstantID / inference /core /utils /visualisation.py
Fucius's picture
Upload 422 files
2eafbc4 verified
raw
history blame contribute delete
No virus
4.44 kB
from typing import Dict, List, Tuple, Union
import cv2
import numpy as np
from inference.core.entities.requests.inference import (
InstanceSegmentationInferenceRequest,
KeypointsDetectionInferenceRequest,
ObjectDetectionInferenceRequest,
)
from inference.core.entities.responses.inference import (
InstanceSegmentationPrediction,
Keypoint,
KeypointsPrediction,
ObjectDetectionInferenceResponse,
ObjectDetectionPrediction,
Point,
)
from inference.core.utils.image_utils import load_image_rgb, np_image_to_base64
def draw_detection_predictions(
inference_request: Union[
ObjectDetectionInferenceRequest,
InstanceSegmentationInferenceRequest,
KeypointsDetectionInferenceRequest,
],
inference_response: Union[
ObjectDetectionInferenceResponse,
InstanceSegmentationPrediction,
KeypointsPrediction,
],
colors: Dict[str, str],
) -> bytes:
image = load_image_rgb(inference_request.image)
for box in inference_response.predictions:
color = tuple(
int(colors.get(box.class_name, "#4892EA")[i : i + 2], 16) for i in (1, 3, 5)
)
image = draw_bbox(
image=image,
box=box,
color=color,
thickness=inference_request.visualization_stroke_width,
)
if hasattr(box, "points"):
image = draw_instance_segmentation_points(
image=image,
points=box.points,
color=color,
thickness=inference_request.visualization_stroke_width,
)
if hasattr(box, "keypoints"):
draw_keypoints(
image=image,
keypoints=box.keypoints,
color=color,
thickness=inference_request.visualization_stroke_width,
)
if inference_request.visualization_labels:
image = draw_labels(
image=image,
box=box,
color=color,
)
return np_image_to_base64(image=image)
def draw_bbox(
image: np.ndarray,
box: ObjectDetectionPrediction,
color: Tuple[int, ...],
thickness: int,
) -> np.ndarray:
left_top, right_bottom = bbox_to_points(box=box)
return cv2.rectangle(
image,
left_top,
right_bottom,
color=color,
thickness=thickness,
)
def draw_instance_segmentation_points(
image: np.ndarray,
points: List[Point],
color: Tuple[int, ...],
thickness: int,
) -> np.ndarray:
points_array = np.array([(int(p.x), int(p.y)) for p in points], np.int32)
if len(points) > 2:
image = cv2.polylines(
image,
[points_array],
isClosed=True,
color=color,
thickness=thickness,
)
return image
def draw_keypoints(
image: np.ndarray,
keypoints: List[Keypoint],
color: Tuple[int, ...],
thickness: int,
) -> None:
for keypoint in keypoints:
center_coordinates = (round(keypoint.x), round(keypoint.y))
image = cv2.circle(
image,
center_coordinates,
thickness,
color,
-1,
)
def draw_labels(
image: np.ndarray,
box: Union[ObjectDetectionPrediction, InstanceSegmentationPrediction],
color: Tuple[int, ...],
) -> np.ndarray:
(x1, y1), _ = bbox_to_points(box=box)
text = f"{box.class_name} {box.confidence:.2f}"
(text_width, text_height), _ = cv2.getTextSize(
text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
button_size = (text_width + 20, text_height + 20)
button_img = np.full(
(button_size[1], button_size[0], 3), color[::-1], dtype=np.uint8
)
cv2.putText(
button_img,
text,
(10, 10 + text_height),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(255, 255, 255),
1,
)
end_x = min(x1 + button_size[0], image.shape[1])
end_y = min(y1 + button_size[1], image.shape[0])
image[y1:end_y, x1:end_x] = button_img[: end_y - y1, : end_x - x1]
return image
def bbox_to_points(
box: Union[ObjectDetectionPrediction, InstanceSegmentationPrediction],
) -> Tuple[Tuple[int, int], Tuple[int, int]]:
x1 = int(box.x - box.width / 2)
x2 = int(box.x + box.width / 2)
y1 = int(box.y - box.height / 2)
y2 = int(box.y + box.height / 2)
return (x1, y1), (x2, y2)