File size: 3,488 Bytes
df6c67d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import random
from functools import partial
from typing import Any, Dict, Optional, Set

import numpy as np

from inference.core.active_learning.entities import (
    Prediction,
    PredictionType,
    SamplingMethod,
)
from inference.core.active_learning.samplers.close_to_threshold import (
    count_detections_close_to_threshold,
    is_prediction_a_stub,
)
from inference.core.constants import (
    INSTANCE_SEGMENTATION_TASK,
    KEYPOINTS_DETECTION_TASK,
    OBJECT_DETECTION_TASK,
)
from inference.core.exceptions import ActiveLearningConfigurationError

ELIGIBLE_PREDICTION_TYPES = {
    INSTANCE_SEGMENTATION_TASK,
    KEYPOINTS_DETECTION_TASK,
    OBJECT_DETECTION_TASK,
}


def initialize_detections_number_based_sampling(
    strategy_config: Dict[str, Any]
) -> SamplingMethod:
    try:
        more_than = strategy_config.get("more_than")
        less_than = strategy_config.get("less_than")
        ensure_range_configuration_is_valid(more_than=more_than, less_than=less_than)
        selected_class_names = strategy_config.get("selected_class_names")
        if selected_class_names is not None:
            selected_class_names = set(selected_class_names)
        sample_function = partial(
            sample_based_on_detections_number,
            less_than=less_than,
            more_than=more_than,
            selected_class_names=selected_class_names,
            probability=strategy_config["probability"],
        )
        return SamplingMethod(
            name=strategy_config["name"],
            sample=sample_function,
        )
    except KeyError as error:
        raise ActiveLearningConfigurationError(
            f"In configuration of `detections_number_based_sampling` missing key detected: {error}."
        ) from error


def sample_based_on_detections_number(
    image: np.ndarray,
    prediction: Prediction,
    prediction_type: PredictionType,
    more_than: Optional[int],
    less_than: Optional[int],
    selected_class_names: Optional[Set[str]],
    probability: float,
) -> bool:
    if is_prediction_a_stub(prediction=prediction):
        return False
    if prediction_type not in ELIGIBLE_PREDICTION_TYPES:
        return False
    detections_close_to_threshold = count_detections_close_to_threshold(
        prediction=prediction,
        selected_class_names=selected_class_names,
        threshold=0.5,
        epsilon=1.0,
    )
    if is_in_range(
        value=detections_close_to_threshold, less_than=less_than, more_than=more_than
    ):
        return random.random() < probability
    return False


def is_in_range(
    value: int,
    more_than: Optional[int],
    less_than: Optional[int],
) -> bool:
    # calculates value > more_than and value < less_than, with optional borders of range
    less_than_satisfied, more_than_satisfied = less_than is None, more_than is None
    if less_than is not None and value < less_than:
        less_than_satisfied = True
    if more_than is not None and value > more_than:
        more_than_satisfied = True
    return less_than_satisfied and more_than_satisfied


def ensure_range_configuration_is_valid(
    more_than: Optional[int],
    less_than: Optional[int],
) -> None:
    if more_than is None or less_than is None:
        return None
    if more_than >= less_than:
        raise ActiveLearningConfigurationError(
            f"Misconfiguration of detections number sampling: "
            f"`more_than` parameter ({more_than}) >= `less_than` ({less_than})."
        )