OMG-InstantID / inference /core /models /classification_base.py
Fucius's picture
Upload 422 files
2eafbc4 verified
from io import BytesIO
from time import perf_counter
from typing import Any, List, Tuple, Union
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from inference.core.entities.requests.inference import ClassificationInferenceRequest
from inference.core.entities.responses.inference import (
ClassificationInferenceResponse,
InferenceResponse,
InferenceResponseImage,
MultiLabelClassificationInferenceResponse,
)
from inference.core.models.roboflow import OnnxRoboflowInferenceModel
from inference.core.models.types import PreprocessReturnMetadata
from inference.core.models.utils.validate import (
get_num_classes_from_model_prediction_shape,
)
from inference.core.utils.image_utils import load_image_rgb
class ClassificationBaseOnnxRoboflowInferenceModel(OnnxRoboflowInferenceModel):
"""Base class for ONNX models for Roboflow classification inference.
Attributes:
multiclass (bool): Whether the classification is multi-class or not.
Methods:
get_infer_bucket_file_list() -> list: Get the list of required files for inference.
softmax(x): Compute softmax values for a given set of scores.
infer(request: ClassificationInferenceRequest) -> Union[List[Union[ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse]], Union[ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse]]: Perform inference on a given request and return the response.
draw_predictions(inference_request, inference_response): Draw prediction visuals on an image.
"""
task_type = "classification"
def __init__(self, *args, **kwargs):
"""Initialize the model, setting whether it is multiclass or not."""
super().__init__(*args, **kwargs)
self.multiclass = self.environment.get("MULTICLASS", False)
def draw_predictions(self, inference_request, inference_response):
"""Draw prediction visuals on an image.
This method overlays the predictions on the input image, including drawing rectangles and text to visualize the predicted classes.
Args:
inference_request: The request object containing the image and parameters.
inference_response: The response object containing the predictions and other details.
Returns:
bytes: The bytes of the visualized image in JPEG format.
"""
image = load_image_rgb(inference_request.image)
image = Image.fromarray(image)
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
if isinstance(inference_response.predictions, list):
prediction = inference_response.predictions[0]
color = self.colors.get(prediction.class_name, "#4892EA")
draw.rectangle(
[0, 0, image.size[1], image.size[0]],
outline=color,
width=inference_request.visualization_stroke_width,
)
text = f"{prediction.class_id} - {prediction.class_name} {prediction.confidence:.2f}"
text_size = font.getbbox(text)
# set button size + 10px margins
button_size = (text_size[2] + 20, text_size[3] + 20)
button_img = Image.new("RGBA", button_size, color)
# put text on button with 10px margins
button_draw = ImageDraw.Draw(button_img)
button_draw.text((10, 10), text, font=font, fill=(255, 255, 255, 255))
# put button on source image in position (0, 0)
image.paste(button_img, (0, 0))
else:
if len(inference_response.predictions) > 0:
box_color = "#4892EA"
draw.rectangle(
[0, 0, image.size[1], image.size[0]],
outline=box_color,
width=inference_request.visualization_stroke_width,
)
row = 0
predictions = [
(cls_name, pred)
for cls_name, pred in inference_response.predictions.items()
]
predictions = sorted(
predictions, key=lambda x: x[1].confidence, reverse=True
)
for i, (cls_name, pred) in enumerate(predictions):
color = self.colors.get(cls_name, "#4892EA")
text = f"{cls_name} {pred.confidence:.2f}"
text_size = font.getbbox(text)
# set button size + 10px margins
button_size = (text_size[2] + 20, text_size[3] + 20)
button_img = Image.new("RGBA", button_size, color)
# put text on button with 10px margins
button_draw = ImageDraw.Draw(button_img)
button_draw.text((10, 10), text, font=font, fill=(255, 255, 255, 255))
# put button on source image in position (0, 0)
image.paste(button_img, (0, row))
row += button_size[1]
buffered = BytesIO()
image = image.convert("RGB")
image.save(buffered, format="JPEG")
return buffered.getvalue()
def get_infer_bucket_file_list(self) -> list:
"""Get the list of required files for inference.
Returns:
list: A list of required files for inference, e.g., ["environment.json"].
"""
return ["environment.json"]
def infer(
self,
image: Any,
disable_preproc_auto_orient: bool = False,
disable_preproc_contrast: bool = False,
disable_preproc_grayscale: bool = False,
disable_preproc_static_crop: bool = False,
return_image_dims: bool = False,
**kwargs,
):
"""
Perform inference on the provided image(s) and return the predictions.
Args:
image (Any): The image or list of images to be processed.
disable_preproc_auto_orient (bool, optional): If true, the auto orient preprocessing step is disabled for this call. Default is False.
disable_preproc_contrast (bool, optional): If true, the auto contrast preprocessing step is disabled for this call. Default is False.
disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False.
disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False.
return_image_dims (bool, optional): If set to True, the function will also return the dimensions of the image. Defaults to False.
**kwargs: Additional parameters to customize the inference process.
Returns:
Union[List[np.array], np.array, Tuple[List[np.array], List[Tuple[int, int]]], Tuple[np.array, Tuple[int, int]]]:
If `return_image_dims` is True and a list of images is provided, a tuple containing a list of prediction arrays and a list of image dimensions (width, height) is returned.
If `return_image_dims` is True and a single image is provided, a tuple containing the prediction array and image dimensions (width, height) is returned.
If `return_image_dims` is False and a list of images is provided, only the list of prediction arrays is returned.
If `return_image_dims` is False and a single image is provided, only the prediction array is returned.
Notes:
- The input image(s) will be preprocessed (normalized and reshaped) before inference.
- This function uses an ONNX session to perform inference on the input image(s).
"""
return super().infer(
image,
disable_preproc_auto_orient=disable_preproc_auto_orient,
disable_preproc_contrast=disable_preproc_contrast,
disable_preproc_grayscale=disable_preproc_grayscale,
disable_preproc_static_crop=disable_preproc_static_crop,
return_image_dims=return_image_dims,
)
def postprocess(
self,
predictions: Tuple[np.ndarray],
preprocess_return_metadata: PreprocessReturnMetadata,
return_image_dims=False,
**kwargs,
) -> Union[ClassificationInferenceResponse, List[ClassificationInferenceResponse]]:
predictions = predictions[0]
return self.make_response(
predictions, preprocess_return_metadata["img_dims"], **kwargs
)
def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]:
predictions = self.onnx_session.run(None, {self.input_name: img_in})
return (predictions,)
def preprocess(
self, image: Any, **kwargs
) -> Tuple[np.ndarray, PreprocessReturnMetadata]:
if isinstance(image, list):
imgs_with_dims = [
self.preproc_image(
i,
disable_preproc_auto_orient=kwargs.get(
"disable_preproc_auto_orient", False
),
disable_preproc_contrast=kwargs.get(
"disable_preproc_contrast", False
),
disable_preproc_grayscale=kwargs.get(
"disable_preproc_grayscale", False
),
disable_preproc_static_crop=kwargs.get(
"disable_preproc_static_crop", False
),
)
for i in image
]
imgs, img_dims = zip(*imgs_with_dims)
img_in = np.concatenate(imgs, axis=0)
else:
img_in, img_dims = self.preproc_image(
image,
disable_preproc_auto_orient=kwargs.get(
"disable_preproc_auto_orient", False
),
disable_preproc_contrast=kwargs.get("disable_preproc_contrast", False),
disable_preproc_grayscale=kwargs.get(
"disable_preproc_grayscale", False
),
disable_preproc_static_crop=kwargs.get(
"disable_preproc_static_crop", False
),
)
img_dims = [img_dims]
img_in /= 255.0
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
img_in = img_in.astype(np.float32)
img_in[:, 0, :, :] = (img_in[:, 0, :, :] - mean[0]) / std[0]
img_in[:, 1, :, :] = (img_in[:, 1, :, :] - mean[1]) / std[1]
img_in[:, 2, :, :] = (img_in[:, 2, :, :] - mean[2]) / std[2]
return img_in, PreprocessReturnMetadata({"img_dims": img_dims})
def infer_from_request(
self,
request: ClassificationInferenceRequest,
) -> Union[List[InferenceResponse], InferenceResponse]:
"""
Handle an inference request to produce an appropriate response.
Args:
request (ClassificationInferenceRequest): The request object encapsulating the image(s) and relevant parameters.
Returns:
Union[List[InferenceResponse], InferenceResponse]: The response object(s) containing the predictions, visualization, and other pertinent details. If a list of images was provided, a list of responses is returned. Otherwise, a single response is returned.
Notes:
- Starts a timer at the beginning to calculate inference time.
- Processes the image(s) through the `infer` method.
- Generates the appropriate response object(s) using `make_response`.
- Calculates and sets the time taken for inference.
- If visualization is requested, the predictions are drawn on the image.
"""
t1 = perf_counter()
responses = self.infer(**request.dict(), return_image_dims=True)
for response in responses:
response.time = perf_counter() - t1
if request.visualize_predictions:
for response in responses:
response.visualization = self.draw_predictions(request, response)
if not isinstance(request.image, list):
responses = responses[0]
return responses
def make_response(
self,
predictions,
img_dims,
confidence: float = 0.5,
**kwargs,
) -> Union[ClassificationInferenceResponse, List[ClassificationInferenceResponse]]:
"""
Create response objects for the given predictions and image dimensions.
Args:
predictions (list): List of prediction arrays from the inference process.
img_dims (list): List of tuples indicating the dimensions (width, height) of each image.
confidence (float, optional): Confidence threshold for filtering predictions. Defaults to 0.5.
**kwargs: Additional parameters to influence the response creation process.
Returns:
Union[ClassificationInferenceResponse, List[ClassificationInferenceResponse]]: A response object or a list of response objects encapsulating the prediction details.
Notes:
- If the model is multiclass, a `MultiLabelClassificationInferenceResponse` is generated for each image.
- If the model is not multiclass, a `ClassificationInferenceResponse` is generated for each image.
- Predictions below the confidence threshold are filtered out.
"""
responses = []
confidence_threshold = float(confidence)
for ind, prediction in enumerate(predictions):
if self.multiclass:
preds = prediction[0]
results = dict()
predicted_classes = []
for i, o in enumerate(preds):
cls_name = self.class_names[i]
score = float(o)
results[cls_name] = {"confidence": score, "class_id": i}
if score > confidence_threshold:
predicted_classes.append(cls_name)
response = MultiLabelClassificationInferenceResponse(
image=InferenceResponseImage(
width=img_dims[ind][0], height=img_dims[ind][1]
),
predicted_classes=predicted_classes,
predictions=results,
)
else:
preds = prediction[0]
preds = self.softmax(preds)
results = []
for i, cls_name in enumerate(self.class_names):
score = float(preds[i])
pred = {
"class_id": i,
"class": cls_name,
"confidence": round(score, 4),
}
results.append(pred)
results = sorted(results, key=lambda x: x["confidence"], reverse=True)
response = ClassificationInferenceResponse(
image=InferenceResponseImage(
width=img_dims[ind][1], height=img_dims[ind][0]
),
predictions=results,
top=results[0]["class"],
confidence=results[0]["confidence"],
)
responses.append(response)
return responses
@staticmethod
def softmax(x):
"""Compute softmax values for each set of scores in x.
Args:
x (np.array): The input array containing the scores.
Returns:
np.array: The softmax values for each set of scores.
"""
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum()
def get_model_output_shape(self) -> Tuple[int, int, int]:
test_image = (np.random.rand(1024, 1024, 3) * 255).astype(np.uint8)
test_image, _ = self.preprocess(test_image)
output = np.array(self.predict(test_image))
return output.shape
def validate_model_classes(self) -> None:
output_shape = self.get_model_output_shape()
num_classes = output_shape[3]
try:
assert num_classes == self.num_classes
except AssertionError:
raise ValueError(
f"Number of classes in model ({num_classes}) does not match the number of classes in the environment ({self.num_classes})"
)