Fucius's picture
Upload 422 files
2eafbc4 verified
import random
from functools import partial
from typing import Any, Dict, Optional, Set
import numpy as np
from inference.core.active_learning.entities import (
Prediction,
PredictionType,
SamplingMethod,
)
from inference.core.constants import (
CLASSIFICATION_TASK,
INSTANCE_SEGMENTATION_TASK,
KEYPOINTS_DETECTION_TASK,
OBJECT_DETECTION_TASK,
)
from inference.core.exceptions import ActiveLearningConfigurationError
ELIGIBLE_PREDICTION_TYPES = {
CLASSIFICATION_TASK,
INSTANCE_SEGMENTATION_TASK,
KEYPOINTS_DETECTION_TASK,
OBJECT_DETECTION_TASK,
}
def initialize_close_to_threshold_sampling(
strategy_config: Dict[str, Any]
) -> SamplingMethod:
try:
selected_class_names = strategy_config.get("selected_class_names")
if selected_class_names is not None:
selected_class_names = set(selected_class_names)
sample_function = partial(
sample_close_to_threshold,
selected_class_names=selected_class_names,
threshold=strategy_config["threshold"],
epsilon=strategy_config["epsilon"],
only_top_classes=strategy_config.get("only_top_classes", True),
minimum_objects_close_to_threshold=strategy_config.get(
"minimum_objects_close_to_threshold",
1,
),
probability=strategy_config["probability"],
)
return SamplingMethod(
name=strategy_config["name"],
sample=sample_function,
)
except KeyError as error:
raise ActiveLearningConfigurationError(
f"In configuration of `close_to_threshold_sampling` missing key detected: {error}."
) from error
def sample_close_to_threshold(
image: np.ndarray,
prediction: Prediction,
prediction_type: PredictionType,
selected_class_names: Optional[Set[str]],
threshold: float,
epsilon: float,
only_top_classes: bool,
minimum_objects_close_to_threshold: int,
probability: float,
) -> bool:
if is_prediction_a_stub(prediction=prediction):
return False
if prediction_type not in ELIGIBLE_PREDICTION_TYPES:
return False
close_to_threshold = prediction_is_close_to_threshold(
prediction=prediction,
prediction_type=prediction_type,
selected_class_names=selected_class_names,
threshold=threshold,
epsilon=epsilon,
only_top_classes=only_top_classes,
minimum_objects_close_to_threshold=minimum_objects_close_to_threshold,
)
if not close_to_threshold:
return False
return random.random() < probability
def is_prediction_a_stub(prediction: Prediction) -> bool:
return prediction.get("is_stub", False)
def prediction_is_close_to_threshold(
prediction: Prediction,
prediction_type: PredictionType,
selected_class_names: Optional[Set[str]],
threshold: float,
epsilon: float,
only_top_classes: bool,
minimum_objects_close_to_threshold: int,
) -> bool:
if CLASSIFICATION_TASK not in prediction_type:
return detections_are_close_to_threshold(
prediction=prediction,
selected_class_names=selected_class_names,
threshold=threshold,
epsilon=epsilon,
minimum_objects_close_to_threshold=minimum_objects_close_to_threshold,
)
checker = multi_label_classification_prediction_is_close_to_threshold
if "top" in prediction:
checker = multi_class_classification_prediction_is_close_to_threshold
return checker(
prediction=prediction,
selected_class_names=selected_class_names,
threshold=threshold,
epsilon=epsilon,
only_top_classes=only_top_classes,
)
def multi_class_classification_prediction_is_close_to_threshold(
prediction: Prediction,
selected_class_names: Optional[Set[str]],
threshold: float,
epsilon: float,
only_top_classes: bool,
) -> bool:
if only_top_classes:
return (
multi_class_classification_prediction_is_close_to_threshold_for_top_class(
prediction=prediction,
selected_class_names=selected_class_names,
threshold=threshold,
epsilon=epsilon,
)
)
for prediction_details in prediction["predictions"]:
if class_to_be_excluded(
class_name=prediction_details["class"],
selected_class_names=selected_class_names,
):
continue
if is_close_to_threshold(
value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon
):
return True
return False
def multi_class_classification_prediction_is_close_to_threshold_for_top_class(
prediction: Prediction,
selected_class_names: Optional[Set[str]],
threshold: float,
epsilon: float,
) -> bool:
if (
selected_class_names is not None
and prediction["top"] not in selected_class_names
):
return False
return abs(prediction["confidence"] - threshold) < epsilon
def multi_label_classification_prediction_is_close_to_threshold(
prediction: Prediction,
selected_class_names: Optional[Set[str]],
threshold: float,
epsilon: float,
only_top_classes: bool,
) -> bool:
predicted_classes = set(prediction["predicted_classes"])
for class_name, prediction_details in prediction["predictions"].items():
if only_top_classes and class_name not in predicted_classes:
continue
if class_to_be_excluded(
class_name=class_name, selected_class_names=selected_class_names
):
continue
if is_close_to_threshold(
value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon
):
return True
return False
def detections_are_close_to_threshold(
prediction: Prediction,
selected_class_names: Optional[Set[str]],
threshold: float,
epsilon: float,
minimum_objects_close_to_threshold: int,
) -> bool:
detections_close_to_threshold = count_detections_close_to_threshold(
prediction=prediction,
selected_class_names=selected_class_names,
threshold=threshold,
epsilon=epsilon,
)
return detections_close_to_threshold >= minimum_objects_close_to_threshold
def count_detections_close_to_threshold(
prediction: Prediction,
selected_class_names: Optional[Set[str]],
threshold: float,
epsilon: float,
) -> int:
counter = 0
for prediction_details in prediction["predictions"]:
if class_to_be_excluded(
class_name=prediction_details["class"],
selected_class_names=selected_class_names,
):
continue
if is_close_to_threshold(
value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon
):
counter += 1
return counter
def class_to_be_excluded(
class_name: str, selected_class_names: Optional[Set[str]]
) -> bool:
return selected_class_names is not None and class_name not in selected_class_names
def is_close_to_threshold(value: float, threshold: float, epsilon: float) -> bool:
return abs(value - threshold) < epsilon