Spaces:
Runtime error
Runtime error
File size: 4,036 Bytes
df6c67d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import json
from typing import List, Tuple
from inference.core.active_learning.entities import (
Prediction,
PredictionFileType,
PredictionType,
SerialisedPrediction,
)
from inference.core.constants import (
CLASSIFICATION_TASK,
INSTANCE_SEGMENTATION_TASK,
OBJECT_DETECTION_TASK,
)
from inference.core.exceptions import PredictionFormatNotSupported
def adjust_prediction_to_client_scaling_factor(
prediction: dict, scaling_factor: float, prediction_type: PredictionType
) -> dict:
if abs(scaling_factor - 1.0) < 1e-5:
return prediction
if "image" in prediction:
prediction["image"] = {
"width": round(prediction["image"]["width"] / scaling_factor),
"height": round(prediction["image"]["height"] / scaling_factor),
}
if predictions_should_not_be_post_processed(
prediction=prediction, prediction_type=prediction_type
):
return prediction
if prediction_type == INSTANCE_SEGMENTATION_TASK:
prediction["predictions"] = (
adjust_prediction_with_bbox_and_points_to_client_scaling_factor(
predictions=prediction["predictions"],
scaling_factor=scaling_factor,
points_key="points",
)
)
if prediction_type == OBJECT_DETECTION_TASK:
prediction["predictions"] = (
adjust_object_detection_predictions_to_client_scaling_factor(
predictions=prediction["predictions"],
scaling_factor=scaling_factor,
)
)
return prediction
def predictions_should_not_be_post_processed(
prediction: dict, prediction_type: PredictionType
) -> bool:
# excluding from post-processing classification output, stub-output and empty predictions
return (
"is_stub" in prediction
or "predictions" not in prediction
or CLASSIFICATION_TASK in prediction_type
or len(prediction["predictions"]) == 0
)
def adjust_object_detection_predictions_to_client_scaling_factor(
predictions: List[dict],
scaling_factor: float,
) -> List[dict]:
result = []
for prediction in predictions:
prediction = adjust_bbox_coordinates_to_client_scaling_factor(
bbox=prediction,
scaling_factor=scaling_factor,
)
result.append(prediction)
return result
def adjust_prediction_with_bbox_and_points_to_client_scaling_factor(
predictions: List[dict],
scaling_factor: float,
points_key: str,
) -> List[dict]:
result = []
for prediction in predictions:
prediction = adjust_bbox_coordinates_to_client_scaling_factor(
bbox=prediction,
scaling_factor=scaling_factor,
)
prediction[points_key] = adjust_points_coordinates_to_client_scaling_factor(
points=prediction[points_key],
scaling_factor=scaling_factor,
)
result.append(prediction)
return result
def adjust_bbox_coordinates_to_client_scaling_factor(
bbox: dict,
scaling_factor: float,
) -> dict:
bbox["x"] = bbox["x"] / scaling_factor
bbox["y"] = bbox["y"] / scaling_factor
bbox["width"] = bbox["width"] / scaling_factor
bbox["height"] = bbox["height"] / scaling_factor
return bbox
def adjust_points_coordinates_to_client_scaling_factor(
points: List[dict],
scaling_factor: float,
) -> List[dict]:
result = []
for point in points:
point["x"] = point["x"] / scaling_factor
point["y"] = point["y"] / scaling_factor
result.append(point)
return result
def encode_prediction(
prediction: Prediction,
prediction_type: PredictionType,
) -> Tuple[SerialisedPrediction, PredictionFileType]:
if CLASSIFICATION_TASK not in prediction_type:
return json.dumps(prediction), "json"
if "top" in prediction:
return prediction["top"], "txt"
raise PredictionFormatNotSupported(
f"Prediction type or prediction format not supported."
)
|