Fucius's picture
Upload 422 files
2eafbc4 verified
raw history blame
No virus
3.49 kB
import random
from functools import partial
from typing import Any, Dict, Optional, Set
import numpy as np
from inference.core.active_learning.entities import (
Prediction,
PredictionType,
SamplingMethod,
)
from inference.core.active_learning.samplers.close_to_threshold import (
count_detections_close_to_threshold,
is_prediction_a_stub,
)
from inference.core.constants import (
INSTANCE_SEGMENTATION_TASK,
KEYPOINTS_DETECTION_TASK,
OBJECT_DETECTION_TASK,
)
from inference.core.exceptions import ActiveLearningConfigurationError
ELIGIBLE_PREDICTION_TYPES = {
INSTANCE_SEGMENTATION_TASK,
KEYPOINTS_DETECTION_TASK,
OBJECT_DETECTION_TASK,
}
def initialize_detections_number_based_sampling(
strategy_config: Dict[str, Any]
) -> SamplingMethod:
try:
more_than = strategy_config.get("more_than")
less_than = strategy_config.get("less_than")
ensure_range_configuration_is_valid(more_than=more_than, less_than=less_than)
selected_class_names = strategy_config.get("selected_class_names")
if selected_class_names is not None:
selected_class_names = set(selected_class_names)
sample_function = partial(
sample_based_on_detections_number,
less_than=less_than,
more_than=more_than,
selected_class_names=selected_class_names,
probability=strategy_config["probability"],
)
return SamplingMethod(
name=strategy_config["name"],
sample=sample_function,
)
except KeyError as error:
raise ActiveLearningConfigurationError(
f"In configuration of `detections_number_based_sampling` missing key detected: {error}."
) from error
def sample_based_on_detections_number(
image: np.ndarray,
prediction: Prediction,
prediction_type: PredictionType,
more_than: Optional[int],
less_than: Optional[int],
selected_class_names: Optional[Set[str]],
probability: float,
) -> bool:
if is_prediction_a_stub(prediction=prediction):
return False
if prediction_type not in ELIGIBLE_PREDICTION_TYPES:
return False
detections_close_to_threshold = count_detections_close_to_threshold(
prediction=prediction,
selected_class_names=selected_class_names,
threshold=0.5,
epsilon=1.0,
)
if is_in_range(
value=detections_close_to_threshold, less_than=less_than, more_than=more_than
):
return random.random() < probability
return False
def is_in_range(
value: int,
more_than: Optional[int],
less_than: Optional[int],
) -> bool:
# calculates value > more_than and value < less_than, with optional borders of range
less_than_satisfied, more_than_satisfied = less_than is None, more_than is None
if less_than is not None and value < less_than:
less_than_satisfied = True
if more_than is not None and value > more_than:
more_than_satisfied = True
return less_than_satisfied and more_than_satisfied
def ensure_range_configuration_is_valid(
more_than: Optional[int],
less_than: Optional[int],
) -> None:
if more_than is None or less_than is None:
return None
if more_than >= less_than:
raise ActiveLearningConfigurationError(
f"Misconfiguration of detections number sampling: "
f"`more_than` parameter ({more_than}) >= `less_than` ({less_than})."
)