File size: 7,744 Bytes
df6c67d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import hashlib
from dataclasses import asdict
from typing import Any, Dict, List, Optional

from inference.core import logger
from inference.core.active_learning.entities import (
    ActiveLearningConfiguration,
    RoboflowProjectMetadata,
    SamplingMethod,
)
from inference.core.active_learning.samplers.close_to_threshold import (
    initialize_close_to_threshold_sampling,
)
from inference.core.active_learning.samplers.contains_classes import (
    initialize_classes_based_sampling,
)
from inference.core.active_learning.samplers.number_of_detections import (
    initialize_detections_number_based_sampling,
)
from inference.core.active_learning.samplers.random import initialize_random_sampling
from inference.core.cache.base import BaseCache
from inference.core.exceptions import (
    ActiveLearningConfigurationDecodingError,
    ActiveLearningConfigurationError,
    RoboflowAPINotAuthorizedError,
    RoboflowAPINotNotFoundError,
)
from inference.core.roboflow_api import (
    get_roboflow_active_learning_configuration,
    get_roboflow_dataset_type,
    get_roboflow_workspace,
)
from inference.core.utils.roboflow import get_model_id_chunks

TYPE2SAMPLING_INITIALIZERS = {
    "random": initialize_random_sampling,
    "close_to_threshold": initialize_close_to_threshold_sampling,
    "classes_based": initialize_classes_based_sampling,
    "detections_number_based": initialize_detections_number_based_sampling,
}
ACTIVE_LEARNING_CONFIG_CACHE_EXPIRE = 900  # 15 min


def prepare_active_learning_configuration(
    api_key: str,
    model_id: str,
    cache: BaseCache,
) -> Optional[ActiveLearningConfiguration]:
    project_metadata = get_roboflow_project_metadata(
        api_key=api_key,
        model_id=model_id,
        cache=cache,
    )
    if not project_metadata.active_learning_configuration.get("enabled", False):
        return None
    logger.info(
        f"Configuring active learning for workspace: {project_metadata.workspace_id}, "
        f"project: {project_metadata.dataset_id} of type: {project_metadata.dataset_type}. "
        f"AL configuration: {project_metadata.active_learning_configuration}"
    )
    return initialise_active_learning_configuration(
        project_metadata=project_metadata,
    )


def prepare_active_learning_configuration_inplace(
    api_key: str,
    model_id: str,
    active_learning_configuration: Optional[dict],
) -> Optional[ActiveLearningConfiguration]:
    if (
        active_learning_configuration is None
        or active_learning_configuration.get("enabled", False) is False
    ):
        return None
    dataset_id, version_id = get_model_id_chunks(model_id=model_id)
    workspace_id = get_roboflow_workspace(api_key=api_key)
    dataset_type = get_roboflow_dataset_type(
        api_key=api_key,
        workspace_id=workspace_id,
        dataset_id=dataset_id,
    )
    project_metadata = RoboflowProjectMetadata(
        dataset_id=dataset_id,
        version_id=version_id,
        workspace_id=workspace_id,
        dataset_type=dataset_type,
        active_learning_configuration=active_learning_configuration,
    )
    return initialise_active_learning_configuration(
        project_metadata=project_metadata,
    )


def get_roboflow_project_metadata(
    api_key: str,
    model_id: str,
    cache: BaseCache,
) -> RoboflowProjectMetadata:
    logger.info(f"Fetching active learning configuration.")
    config_cache_key = construct_cache_key_for_active_learning_config(
        api_key=api_key, model_id=model_id
    )
    cached_config = cache.get(config_cache_key)
    if cached_config is not None:
        logger.info("Found Active Learning configuration in cache.")
        return parse_cached_roboflow_project_metadata(cached_config=cached_config)
    dataset_id, version_id = get_model_id_chunks(model_id=model_id)
    workspace_id = get_roboflow_workspace(api_key=api_key)
    dataset_type = get_roboflow_dataset_type(
        api_key=api_key,
        workspace_id=workspace_id,
        dataset_id=dataset_id,
    )
    try:
        roboflow_api_configuration = get_roboflow_active_learning_configuration(
            api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id
        )
    except (RoboflowAPINotAuthorizedError, RoboflowAPINotNotFoundError):
        # currently backend returns HTTP 404 if dataset does not exist
        # or workspace_id from api_key indicate that the owner is different,
        # so in the situation when we query for Universe dataset.
        # We want the owner of public dataset to be able to set AL configs
        # and use them, but not other people. At this point it's known
        # that HTTP 404 means not authorised (which will probably change
        # in future iteration of backend) - so on both NotAuth and NotFound
        # errors we assume that we simply cannot use AL with this model and
        # this api_key.
        roboflow_api_configuration = {"enabled": False}
    configuration = RoboflowProjectMetadata(
        dataset_id=dataset_id,
        version_id=version_id,
        workspace_id=workspace_id,
        dataset_type=dataset_type,
        active_learning_configuration=roboflow_api_configuration,
    )
    cache.set(
        key=config_cache_key,
        value=asdict(configuration),
        expire=ACTIVE_LEARNING_CONFIG_CACHE_EXPIRE,
    )
    return configuration


def construct_cache_key_for_active_learning_config(api_key: str, model_id: str) -> str:
    dataset_id = model_id.split("/")[0]
    api_key_hash = hashlib.md5(api_key.encode("utf-8")).hexdigest()
    return f"active_learning:configurations:{api_key_hash}:{dataset_id}"


def parse_cached_roboflow_project_metadata(
    cached_config: dict,
) -> RoboflowProjectMetadata:
    try:
        return RoboflowProjectMetadata(**cached_config)
    except Exception as error:
        raise ActiveLearningConfigurationDecodingError(
            f"Failed to initialise Active Learning configuration. Cause: {str(error)}"
        ) from error


def initialise_active_learning_configuration(
    project_metadata: RoboflowProjectMetadata,
) -> ActiveLearningConfiguration:
    sampling_methods = initialize_sampling_methods(
        sampling_strategies_configs=project_metadata.active_learning_configuration[
            "sampling_strategies"
        ],
    )
    target_workspace_id = project_metadata.active_learning_configuration.get(
        "target_workspace", project_metadata.workspace_id
    )
    target_dataset_id = project_metadata.active_learning_configuration.get(
        "target_project", project_metadata.dataset_id
    )
    return ActiveLearningConfiguration.init(
        roboflow_api_configuration=project_metadata.active_learning_configuration,
        sampling_methods=sampling_methods,
        workspace_id=target_workspace_id,
        dataset_id=target_dataset_id,
        model_id=f"{project_metadata.dataset_id}/{project_metadata.version_id}",
    )


def initialize_sampling_methods(
    sampling_strategies_configs: List[Dict[str, Any]]
) -> List[SamplingMethod]:
    result = []
    for sampling_strategy_config in sampling_strategies_configs:
        sampling_type = sampling_strategy_config["type"]
        if sampling_type not in TYPE2SAMPLING_INITIALIZERS:
            logger.warn(
                f"Could not identify sampling method `{sampling_type}` - skipping initialisation."
            )
            continue
        initializer = TYPE2SAMPLING_INITIALIZERS[sampling_type]
        result.append(initializer(sampling_strategy_config))
    names = set(m.name for m in result)
    if len(names) != len(result):
        raise ActiveLearningConfigurationError(
            "Detected duplication of Active Learning strategies names."
        )
    return result