OMG / inference /core /active_learning /post_processing.py
Fucius's picture
Upload 422 files
df6c67d verified
raw
history blame
4.04 kB
import json
from typing import List, Tuple
from inference.core.active_learning.entities import (
Prediction,
PredictionFileType,
PredictionType,
SerialisedPrediction,
)
from inference.core.constants import (
CLASSIFICATION_TASK,
INSTANCE_SEGMENTATION_TASK,
OBJECT_DETECTION_TASK,
)
from inference.core.exceptions import PredictionFormatNotSupported
def adjust_prediction_to_client_scaling_factor(
prediction: dict, scaling_factor: float, prediction_type: PredictionType
) -> dict:
if abs(scaling_factor - 1.0) < 1e-5:
return prediction
if "image" in prediction:
prediction["image"] = {
"width": round(prediction["image"]["width"] / scaling_factor),
"height": round(prediction["image"]["height"] / scaling_factor),
}
if predictions_should_not_be_post_processed(
prediction=prediction, prediction_type=prediction_type
):
return prediction
if prediction_type == INSTANCE_SEGMENTATION_TASK:
prediction["predictions"] = (
adjust_prediction_with_bbox_and_points_to_client_scaling_factor(
predictions=prediction["predictions"],
scaling_factor=scaling_factor,
points_key="points",
)
)
if prediction_type == OBJECT_DETECTION_TASK:
prediction["predictions"] = (
adjust_object_detection_predictions_to_client_scaling_factor(
predictions=prediction["predictions"],
scaling_factor=scaling_factor,
)
)
return prediction
def predictions_should_not_be_post_processed(
prediction: dict, prediction_type: PredictionType
) -> bool:
# excluding from post-processing classification output, stub-output and empty predictions
return (
"is_stub" in prediction
or "predictions" not in prediction
or CLASSIFICATION_TASK in prediction_type
or len(prediction["predictions"]) == 0
)
def adjust_object_detection_predictions_to_client_scaling_factor(
predictions: List[dict],
scaling_factor: float,
) -> List[dict]:
result = []
for prediction in predictions:
prediction = adjust_bbox_coordinates_to_client_scaling_factor(
bbox=prediction,
scaling_factor=scaling_factor,
)
result.append(prediction)
return result
def adjust_prediction_with_bbox_and_points_to_client_scaling_factor(
predictions: List[dict],
scaling_factor: float,
points_key: str,
) -> List[dict]:
result = []
for prediction in predictions:
prediction = adjust_bbox_coordinates_to_client_scaling_factor(
bbox=prediction,
scaling_factor=scaling_factor,
)
prediction[points_key] = adjust_points_coordinates_to_client_scaling_factor(
points=prediction[points_key],
scaling_factor=scaling_factor,
)
result.append(prediction)
return result
def adjust_bbox_coordinates_to_client_scaling_factor(
bbox: dict,
scaling_factor: float,
) -> dict:
bbox["x"] = bbox["x"] / scaling_factor
bbox["y"] = bbox["y"] / scaling_factor
bbox["width"] = bbox["width"] / scaling_factor
bbox["height"] = bbox["height"] / scaling_factor
return bbox
def adjust_points_coordinates_to_client_scaling_factor(
points: List[dict],
scaling_factor: float,
) -> List[dict]:
result = []
for point in points:
point["x"] = point["x"] / scaling_factor
point["y"] = point["y"] / scaling_factor
result.append(point)
return result
def encode_prediction(
prediction: Prediction,
prediction_type: PredictionType,
) -> Tuple[SerialisedPrediction, PredictionFileType]:
if CLASSIFICATION_TASK not in prediction_type:
return json.dumps(prediction), "json"
if "top" in prediction:
return prediction["top"], "txt"
raise PredictionFormatNotSupported(
f"Prediction type or prediction format not supported."
)