Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,488 Bytes
df6c67d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import random
from functools import partial
from typing import Any, Dict, Optional, Set
import numpy as np
from inference.core.active_learning.entities import (
Prediction,
PredictionType,
SamplingMethod,
)
from inference.core.active_learning.samplers.close_to_threshold import (
count_detections_close_to_threshold,
is_prediction_a_stub,
)
from inference.core.constants import (
INSTANCE_SEGMENTATION_TASK,
KEYPOINTS_DETECTION_TASK,
OBJECT_DETECTION_TASK,
)
from inference.core.exceptions import ActiveLearningConfigurationError
ELIGIBLE_PREDICTION_TYPES = {
INSTANCE_SEGMENTATION_TASK,
KEYPOINTS_DETECTION_TASK,
OBJECT_DETECTION_TASK,
}
def initialize_detections_number_based_sampling(
strategy_config: Dict[str, Any]
) -> SamplingMethod:
try:
more_than = strategy_config.get("more_than")
less_than = strategy_config.get("less_than")
ensure_range_configuration_is_valid(more_than=more_than, less_than=less_than)
selected_class_names = strategy_config.get("selected_class_names")
if selected_class_names is not None:
selected_class_names = set(selected_class_names)
sample_function = partial(
sample_based_on_detections_number,
less_than=less_than,
more_than=more_than,
selected_class_names=selected_class_names,
probability=strategy_config["probability"],
)
return SamplingMethod(
name=strategy_config["name"],
sample=sample_function,
)
except KeyError as error:
raise ActiveLearningConfigurationError(
f"In configuration of `detections_number_based_sampling` missing key detected: {error}."
) from error
def sample_based_on_detections_number(
image: np.ndarray,
prediction: Prediction,
prediction_type: PredictionType,
more_than: Optional[int],
less_than: Optional[int],
selected_class_names: Optional[Set[str]],
probability: float,
) -> bool:
if is_prediction_a_stub(prediction=prediction):
return False
if prediction_type not in ELIGIBLE_PREDICTION_TYPES:
return False
detections_close_to_threshold = count_detections_close_to_threshold(
prediction=prediction,
selected_class_names=selected_class_names,
threshold=0.5,
epsilon=1.0,
)
if is_in_range(
value=detections_close_to_threshold, less_than=less_than, more_than=more_than
):
return random.random() < probability
return False
def is_in_range(
value: int,
more_than: Optional[int],
less_than: Optional[int],
) -> bool:
# calculates value > more_than and value < less_than, with optional borders of range
less_than_satisfied, more_than_satisfied = less_than is None, more_than is None
if less_than is not None and value < less_than:
less_than_satisfied = True
if more_than is not None and value > more_than:
more_than_satisfied = True
return less_than_satisfied and more_than_satisfied
def ensure_range_configuration_is_valid(
more_than: Optional[int],
less_than: Optional[int],
) -> None:
if more_than is None or less_than is None:
return None
if more_than >= less_than:
raise ActiveLearningConfigurationError(
f"Misconfiguration of detections number sampling: "
f"`more_than` parameter ({more_than}) >= `less_than` ({less_than})."
)
|