Fucius commited on
Commit
2eafbc4
·
verified ·
1 Parent(s): e47a2d2

Upload 422 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. inference/__init__.py +3 -0
  2. inference/__pycache__/__init__.cpython-310.pyc +0 -0
  3. inference/core/__init__.py +52 -0
  4. inference/core/__pycache__/__init__.cpython-310.pyc +0 -0
  5. inference/core/__pycache__/constants.cpython-310.pyc +0 -0
  6. inference/core/__pycache__/env.cpython-310.pyc +0 -0
  7. inference/core/__pycache__/exceptions.cpython-310.pyc +0 -0
  8. inference/core/__pycache__/logger.cpython-310.pyc +0 -0
  9. inference/core/__pycache__/nms.cpython-310.pyc +0 -0
  10. inference/core/__pycache__/roboflow_api.cpython-310.pyc +0 -0
  11. inference/core/__pycache__/usage.cpython-310.pyc +0 -0
  12. inference/core/__pycache__/version.cpython-310.pyc +0 -0
  13. inference/core/active_learning/__init__.py +0 -0
  14. inference/core/active_learning/__pycache__/__init__.cpython-310.pyc +0 -0
  15. inference/core/active_learning/__pycache__/accounting.cpython-310.pyc +0 -0
  16. inference/core/active_learning/__pycache__/batching.cpython-310.pyc +0 -0
  17. inference/core/active_learning/__pycache__/cache_operations.cpython-310.pyc +0 -0
  18. inference/core/active_learning/__pycache__/configuration.cpython-310.pyc +0 -0
  19. inference/core/active_learning/__pycache__/core.cpython-310.pyc +0 -0
  20. inference/core/active_learning/__pycache__/entities.cpython-310.pyc +0 -0
  21. inference/core/active_learning/__pycache__/middlewares.cpython-310.pyc +0 -0
  22. inference/core/active_learning/__pycache__/post_processing.cpython-310.pyc +0 -0
  23. inference/core/active_learning/__pycache__/utils.cpython-310.pyc +0 -0
  24. inference/core/active_learning/accounting.py +96 -0
  25. inference/core/active_learning/batching.py +26 -0
  26. inference/core/active_learning/cache_operations.py +293 -0
  27. inference/core/active_learning/configuration.py +203 -0
  28. inference/core/active_learning/core.py +219 -0
  29. inference/core/active_learning/entities.py +141 -0
  30. inference/core/active_learning/middlewares.py +307 -0
  31. inference/core/active_learning/post_processing.py +128 -0
  32. inference/core/active_learning/samplers/__init__.py +0 -0
  33. inference/core/active_learning/samplers/__pycache__/__init__.cpython-310.pyc +0 -0
  34. inference/core/active_learning/samplers/__pycache__/close_to_threshold.cpython-310.pyc +0 -0
  35. inference/core/active_learning/samplers/__pycache__/contains_classes.cpython-310.pyc +0 -0
  36. inference/core/active_learning/samplers/__pycache__/number_of_detections.cpython-310.pyc +0 -0
  37. inference/core/active_learning/samplers/__pycache__/random.cpython-310.pyc +0 -0
  38. inference/core/active_learning/samplers/close_to_threshold.py +227 -0
  39. inference/core/active_learning/samplers/contains_classes.py +58 -0
  40. inference/core/active_learning/samplers/number_of_detections.py +107 -0
  41. inference/core/active_learning/samplers/random.py +37 -0
  42. inference/core/active_learning/utils.py +16 -0
  43. inference/core/cache/__init__.py +22 -0
  44. inference/core/cache/__pycache__/__init__.cpython-310.pyc +0 -0
  45. inference/core/cache/__pycache__/base.cpython-310.pyc +0 -0
  46. inference/core/cache/__pycache__/memory.cpython-310.pyc +0 -0
  47. inference/core/cache/__pycache__/model_artifacts.cpython-310.pyc +0 -0
  48. inference/core/cache/__pycache__/redis.cpython-310.pyc +0 -0
  49. inference/core/cache/__pycache__/serializers.cpython-310.pyc +0 -0
  50. inference/core/cache/base.py +130 -0
inference/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from inference.core.interfaces.stream.stream import Stream # isort:skip
2
+ from inference.core.interfaces.stream.inference_pipeline import InferencePipeline
3
+ from inference.models.utils import get_roboflow_model
inference/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (399 Bytes). View file
 
inference/core/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+
4
+ import requests
5
+
6
+ from inference.core.env import DISABLE_VERSION_CHECK, VERSION_CHECK_MODE
7
+ from inference.core.logger import logger
8
+ from inference.core.version import __version__
9
+
10
+ latest_release = None
11
+ last_checked = 0
12
+ cache_duration = 86400 # 24 hours
13
+ log_frequency = 300 # 5 minutes
14
+
15
+
16
+ def get_latest_release_version():
17
+ global latest_release, last_checked
18
+ now = time.time()
19
+ if latest_release is None or now - last_checked > cache_duration:
20
+ try:
21
+ logger.debug("Checking for latest inference release version...")
22
+ response = requests.get(
23
+ "https://api.github.com/repos/roboflow/inference/releases/latest"
24
+ )
25
+ response.raise_for_status()
26
+ latest_release = response.json()["tag_name"].lstrip("v")
27
+ last_checked = now
28
+ except requests.exceptions.RequestException:
29
+ pass
30
+
31
+
32
+ def check_latest_release_against_current():
33
+ get_latest_release_version()
34
+ if latest_release is not None and latest_release != __version__:
35
+ logger.warning(
36
+ f"Your inference package version {__version__} is out of date! Please upgrade to version {latest_release} of inference for the latest features and bug fixes by running `pip install --upgrade inference`."
37
+ )
38
+
39
+
40
+ def check_latest_release_against_current_continuous():
41
+ while True:
42
+ check_latest_release_against_current()
43
+ time.sleep(log_frequency)
44
+
45
+
46
+ if not DISABLE_VERSION_CHECK:
47
+ if VERSION_CHECK_MODE == "continuous":
48
+ t = threading.Thread(target=check_latest_release_against_current_continuous)
49
+ t.daemon = True
50
+ t.start()
51
+ else:
52
+ check_latest_release_against_current()
inference/core/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.73 kB). View file
 
inference/core/__pycache__/constants.cpython-310.pyc ADDED
Binary file (371 Bytes). View file
 
inference/core/__pycache__/env.cpython-310.pyc ADDED
Binary file (6.87 kB). View file
 
inference/core/__pycache__/exceptions.cpython-310.pyc ADDED
Binary file (6.17 kB). View file
 
inference/core/__pycache__/logger.cpython-310.pyc ADDED
Binary file (551 Bytes). View file
 
inference/core/__pycache__/nms.cpython-310.pyc ADDED
Binary file (4.74 kB). View file
 
inference/core/__pycache__/roboflow_api.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
inference/core/__pycache__/usage.cpython-310.pyc ADDED
Binary file (1.85 kB). View file
 
inference/core/__pycache__/version.cpython-310.pyc ADDED
Binary file (250 Bytes). View file
 
inference/core/active_learning/__init__.py ADDED
File without changes
inference/core/active_learning/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (192 Bytes). View file
 
inference/core/active_learning/__pycache__/accounting.cpython-310.pyc ADDED
Binary file (2.76 kB). View file
 
inference/core/active_learning/__pycache__/batching.cpython-310.pyc ADDED
Binary file (921 Bytes). View file
 
inference/core/active_learning/__pycache__/cache_operations.cpython-310.pyc ADDED
Binary file (5.9 kB). View file
 
inference/core/active_learning/__pycache__/configuration.cpython-310.pyc ADDED
Binary file (5.3 kB). View file
 
inference/core/active_learning/__pycache__/core.cpython-310.pyc ADDED
Binary file (5.2 kB). View file
 
inference/core/active_learning/__pycache__/entities.cpython-310.pyc ADDED
Binary file (4.72 kB). View file
 
inference/core/active_learning/__pycache__/middlewares.cpython-310.pyc ADDED
Binary file (8.68 kB). View file
 
inference/core/active_learning/__pycache__/post_processing.cpython-310.pyc ADDED
Binary file (2.94 kB). View file
 
inference/core/active_learning/__pycache__/utils.cpython-310.pyc ADDED
Binary file (852 Bytes). View file
 
inference/core/active_learning/accounting.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ from inference.core.entities.types import DatasetID, WorkspaceID
4
+ from inference.core.roboflow_api import (
5
+ get_roboflow_labeling_batches,
6
+ get_roboflow_labeling_jobs,
7
+ )
8
+
9
+
10
+ def image_can_be_submitted_to_batch(
11
+ batch_name: str,
12
+ workspace_id: WorkspaceID,
13
+ dataset_id: DatasetID,
14
+ max_batch_images: Optional[int],
15
+ api_key: str,
16
+ ) -> bool:
17
+ """Check if an image can be submitted to a batch.
18
+
19
+ Args:
20
+ batch_name: Name of the batch.
21
+ workspace_id: ID of the workspace.
22
+ dataset_id: ID of the dataset.
23
+ max_batch_images: Maximum number of images allowed in the batch.
24
+ api_key: API key to use for the request.
25
+
26
+ Returns:
27
+ True if the image can be submitted to the batch, False otherwise.
28
+ """
29
+ if max_batch_images is None:
30
+ return True
31
+ labeling_batches = get_roboflow_labeling_batches(
32
+ api_key=api_key,
33
+ workspace_id=workspace_id,
34
+ dataset_id=dataset_id,
35
+ )
36
+ matching_labeling_batch = get_matching_labeling_batch(
37
+ all_labeling_batches=labeling_batches["batches"],
38
+ batch_name=batch_name,
39
+ )
40
+ if matching_labeling_batch is None:
41
+ return max_batch_images > 0
42
+ batch_images_under_labeling = 0
43
+ if matching_labeling_batch["numJobs"] > 0:
44
+ labeling_jobs = get_roboflow_labeling_jobs(
45
+ api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id
46
+ )
47
+ batch_images_under_labeling = get_images_in_labeling_jobs_of_specific_batch(
48
+ all_labeling_jobs=labeling_jobs["jobs"],
49
+ batch_id=matching_labeling_batch["id"],
50
+ )
51
+ total_batch_images = matching_labeling_batch["images"] + batch_images_under_labeling
52
+ return max_batch_images > total_batch_images
53
+
54
+
55
+ def get_matching_labeling_batch(
56
+ all_labeling_batches: List[dict],
57
+ batch_name: str,
58
+ ) -> Optional[dict]:
59
+ """Get the matching labeling batch.
60
+
61
+ Args:
62
+ all_labeling_batches: All labeling batches.
63
+ batch_name: Name of the batch.
64
+
65
+ Returns:
66
+ The matching labeling batch if found, None otherwise.
67
+
68
+ """
69
+ matching_batch = None
70
+ for labeling_batch in all_labeling_batches:
71
+ if labeling_batch["name"] == batch_name:
72
+ matching_batch = labeling_batch
73
+ break
74
+ return matching_batch
75
+
76
+
77
+ def get_images_in_labeling_jobs_of_specific_batch(
78
+ all_labeling_jobs: List[dict],
79
+ batch_id: str,
80
+ ) -> int:
81
+ """Get the number of images in labeling jobs of a specific batch.
82
+
83
+ Args:
84
+ all_labeling_jobs: All labeling jobs.
85
+ batch_id: ID of the batch.
86
+
87
+ Returns:
88
+ The number of images in labeling jobs of the batch.
89
+
90
+ """
91
+
92
+ matching_jobs = []
93
+ for labeling_job in all_labeling_jobs:
94
+ if batch_id in labeling_job["sourceBatch"]:
95
+ matching_jobs.append(labeling_job)
96
+ return sum(job["numImages"] for job in matching_jobs)
inference/core/active_learning/batching.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inference.core.active_learning.entities import (
2
+ ActiveLearningConfiguration,
3
+ BatchReCreationInterval,
4
+ )
5
+ from inference.core.active_learning.utils import (
6
+ generate_start_timestamp_for_this_month,
7
+ generate_start_timestamp_for_this_week,
8
+ generate_today_timestamp,
9
+ )
10
+
11
+ RECREATION_INTERVAL2TIMESTAMP_GENERATOR = {
12
+ BatchReCreationInterval.DAILY: generate_today_timestamp,
13
+ BatchReCreationInterval.WEEKLY: generate_start_timestamp_for_this_week,
14
+ BatchReCreationInterval.MONTHLY: generate_start_timestamp_for_this_month,
15
+ }
16
+
17
+
18
+ def generate_batch_name(configuration: ActiveLearningConfiguration) -> str:
19
+ batch_name = configuration.batches_name_prefix
20
+ if configuration.batch_recreation_interval is BatchReCreationInterval.NEVER:
21
+ return batch_name
22
+ timestamp_generator = RECREATION_INTERVAL2TIMESTAMP_GENERATOR[
23
+ configuration.batch_recreation_interval
24
+ ]
25
+ timestamp = timestamp_generator()
26
+ return f"{batch_name}_{timestamp}"
inference/core/active_learning/cache_operations.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ from contextlib import contextmanager
3
+ from datetime import datetime
4
+ from typing import Generator, List, Optional, OrderedDict, Union
5
+
6
+ import redis.lock
7
+
8
+ from inference.core import logger
9
+ from inference.core.active_learning.entities import StrategyLimit, StrategyLimitType
10
+ from inference.core.active_learning.utils import TIMESTAMP_FORMAT
11
+ from inference.core.cache.base import BaseCache
12
+
13
+ MAX_LOCK_TIME = 5
14
+ SECONDS_IN_HOUR = 60 * 60
15
+ USAGE_KEY = "usage"
16
+
17
+ LIMIT_TYPE2KEY_INFIX_GENERATOR = {
18
+ StrategyLimitType.MINUTELY: lambda: f"minute_{datetime.utcnow().minute}",
19
+ StrategyLimitType.HOURLY: lambda: f"hour_{datetime.utcnow().hour}",
20
+ StrategyLimitType.DAILY: lambda: f"day_{datetime.utcnow().strftime(TIMESTAMP_FORMAT)}",
21
+ }
22
+ LIMIT_TYPE2KEY_EXPIRATION = {
23
+ StrategyLimitType.MINUTELY: 120,
24
+ StrategyLimitType.HOURLY: 2 * SECONDS_IN_HOUR,
25
+ StrategyLimitType.DAILY: 25 * SECONDS_IN_HOUR,
26
+ }
27
+
28
+
29
+ def use_credit_of_matching_strategy(
30
+ cache: BaseCache,
31
+ workspace: str,
32
+ project: str,
33
+ matching_strategies_limits: OrderedDict[str, List[StrategyLimit]],
34
+ ) -> Optional[str]:
35
+ # In scope of this function, cache keys updates regarding usage limits for
36
+ # specific :workspace and :project are locked - to ensure increment to be done atomically
37
+ # Limits are accounted at the moment of registration - which may introduce inaccuracy
38
+ # given that registration is postponed from prediction
39
+ # Returns: strategy with spare credit if found - else None
40
+ with lock_limits(cache=cache, workspace=workspace, project=project):
41
+ strategy_with_spare_credit = find_strategy_with_spare_usage_credit(
42
+ cache=cache,
43
+ workspace=workspace,
44
+ project=project,
45
+ matching_strategies_limits=matching_strategies_limits,
46
+ )
47
+ if strategy_with_spare_credit is None:
48
+ return None
49
+ consume_strategy_limits_usage_credit(
50
+ cache=cache,
51
+ workspace=workspace,
52
+ project=project,
53
+ strategy_name=strategy_with_spare_credit,
54
+ )
55
+ return strategy_with_spare_credit
56
+
57
+
58
+ def return_strategy_credit(
59
+ cache: BaseCache,
60
+ workspace: str,
61
+ project: str,
62
+ strategy_name: str,
63
+ ) -> None:
64
+ # In scope of this function, cache keys updates regarding usage limits for
65
+ # specific :workspace and :project are locked - to ensure decrement to be done atomically
66
+ # Returning strategy is a bit naive (we may add to a pool of credits from the next period - but only
67
+ # if we have previously taken from the previous one and some credits are used in the new pool) -
68
+ # in favour of easier implementation.
69
+ with lock_limits(cache=cache, workspace=workspace, project=project):
70
+ return_strategy_limits_usage_credit(
71
+ cache=cache,
72
+ workspace=workspace,
73
+ project=project,
74
+ strategy_name=strategy_name,
75
+ )
76
+
77
+
78
+ @contextmanager
79
+ def lock_limits(
80
+ cache: BaseCache,
81
+ workspace: str,
82
+ project: str,
83
+ ) -> Generator[Union[threading.Lock, redis.lock.Lock], None, None]:
84
+ limits_lock_key = generate_cache_key_for_active_learning_usage_lock(
85
+ workspace=workspace,
86
+ project=project,
87
+ )
88
+ with cache.lock(key=limits_lock_key, expire=MAX_LOCK_TIME) as lock:
89
+ yield lock
90
+
91
+
92
+ def find_strategy_with_spare_usage_credit(
93
+ cache: BaseCache,
94
+ workspace: str,
95
+ project: str,
96
+ matching_strategies_limits: OrderedDict[str, List[StrategyLimit]],
97
+ ) -> Optional[str]:
98
+ for strategy_name, strategy_limits in matching_strategies_limits.items():
99
+ rejected_by_strategy = (
100
+ datapoint_should_be_rejected_based_on_strategy_usage_limits(
101
+ cache=cache,
102
+ workspace=workspace,
103
+ project=project,
104
+ strategy_name=strategy_name,
105
+ strategy_limits=strategy_limits,
106
+ )
107
+ )
108
+ if not rejected_by_strategy:
109
+ return strategy_name
110
+ return None
111
+
112
+
113
+ def datapoint_should_be_rejected_based_on_strategy_usage_limits(
114
+ cache: BaseCache,
115
+ workspace: str,
116
+ project: str,
117
+ strategy_name: str,
118
+ strategy_limits: List[StrategyLimit],
119
+ ) -> bool:
120
+ for strategy_limit in strategy_limits:
121
+ limit_reached = datapoint_should_be_rejected_based_on_limit_usage(
122
+ cache=cache,
123
+ workspace=workspace,
124
+ project=project,
125
+ strategy_name=strategy_name,
126
+ strategy_limit=strategy_limit,
127
+ )
128
+ if limit_reached:
129
+ logger.debug(
130
+ f"Violated Active Learning strategy limit: {strategy_limit.limit_type.name} "
131
+ f"with value {strategy_limit.value} for sampling strategy: {strategy_name}."
132
+ )
133
+ return True
134
+ return False
135
+
136
+
137
+ def datapoint_should_be_rejected_based_on_limit_usage(
138
+ cache: BaseCache,
139
+ workspace: str,
140
+ project: str,
141
+ strategy_name: str,
142
+ strategy_limit: StrategyLimit,
143
+ ) -> bool:
144
+ current_usage = get_current_strategy_limit_usage(
145
+ cache=cache,
146
+ workspace=workspace,
147
+ project=project,
148
+ strategy_name=strategy_name,
149
+ limit_type=strategy_limit.limit_type,
150
+ )
151
+ if current_usage is None:
152
+ current_usage = 0
153
+ return current_usage >= strategy_limit.value
154
+
155
+
156
+ def consume_strategy_limits_usage_credit(
157
+ cache: BaseCache,
158
+ workspace: str,
159
+ project: str,
160
+ strategy_name: str,
161
+ ) -> None:
162
+ for limit_type in StrategyLimitType:
163
+ consume_strategy_limit_usage_credit(
164
+ cache=cache,
165
+ workspace=workspace,
166
+ project=project,
167
+ strategy_name=strategy_name,
168
+ limit_type=limit_type,
169
+ )
170
+
171
+
172
+ def consume_strategy_limit_usage_credit(
173
+ cache: BaseCache,
174
+ workspace: str,
175
+ project: str,
176
+ strategy_name: str,
177
+ limit_type: StrategyLimitType,
178
+ ) -> None:
179
+ current_value = get_current_strategy_limit_usage(
180
+ cache=cache,
181
+ limit_type=limit_type,
182
+ workspace=workspace,
183
+ project=project,
184
+ strategy_name=strategy_name,
185
+ )
186
+ if current_value is None:
187
+ current_value = 0
188
+ current_value += 1
189
+ set_current_strategy_limit_usage(
190
+ current_value=current_value,
191
+ cache=cache,
192
+ limit_type=limit_type,
193
+ workspace=workspace,
194
+ project=project,
195
+ strategy_name=strategy_name,
196
+ )
197
+
198
+
199
+ def return_strategy_limits_usage_credit(
200
+ cache: BaseCache,
201
+ workspace: str,
202
+ project: str,
203
+ strategy_name: str,
204
+ ) -> None:
205
+ for limit_type in StrategyLimitType:
206
+ return_strategy_limit_usage_credit(
207
+ cache=cache,
208
+ workspace=workspace,
209
+ project=project,
210
+ strategy_name=strategy_name,
211
+ limit_type=limit_type,
212
+ )
213
+
214
+
215
+ def return_strategy_limit_usage_credit(
216
+ cache: BaseCache,
217
+ workspace: str,
218
+ project: str,
219
+ strategy_name: str,
220
+ limit_type: StrategyLimitType,
221
+ ) -> None:
222
+ current_value = get_current_strategy_limit_usage(
223
+ cache=cache,
224
+ limit_type=limit_type,
225
+ workspace=workspace,
226
+ project=project,
227
+ strategy_name=strategy_name,
228
+ )
229
+ if current_value is None:
230
+ return None
231
+ current_value = max(current_value - 1, 0)
232
+ set_current_strategy_limit_usage(
233
+ current_value=current_value,
234
+ cache=cache,
235
+ limit_type=limit_type,
236
+ workspace=workspace,
237
+ project=project,
238
+ strategy_name=strategy_name,
239
+ )
240
+
241
+
242
+ def get_current_strategy_limit_usage(
243
+ cache: BaseCache,
244
+ workspace: str,
245
+ project: str,
246
+ strategy_name: str,
247
+ limit_type: StrategyLimitType,
248
+ ) -> Optional[int]:
249
+ usage_key = generate_cache_key_for_active_learning_usage(
250
+ limit_type=limit_type,
251
+ workspace=workspace,
252
+ project=project,
253
+ strategy_name=strategy_name,
254
+ )
255
+ value = cache.get(usage_key)
256
+ if value is None:
257
+ return value
258
+ return value[USAGE_KEY]
259
+
260
+
261
+ def set_current_strategy_limit_usage(
262
+ current_value: int,
263
+ cache: BaseCache,
264
+ workspace: str,
265
+ project: str,
266
+ strategy_name: str,
267
+ limit_type: StrategyLimitType,
268
+ ) -> None:
269
+ usage_key = generate_cache_key_for_active_learning_usage(
270
+ limit_type=limit_type,
271
+ workspace=workspace,
272
+ project=project,
273
+ strategy_name=strategy_name,
274
+ )
275
+ expire = LIMIT_TYPE2KEY_EXPIRATION[limit_type]
276
+ cache.set(key=usage_key, value={USAGE_KEY: current_value}, expire=expire) # type: ignore
277
+
278
+
279
+ def generate_cache_key_for_active_learning_usage_lock(
280
+ workspace: str,
281
+ project: str,
282
+ ) -> str:
283
+ return f"active_learning:usage:{workspace}:{project}:usage:lock"
284
+
285
+
286
+ def generate_cache_key_for_active_learning_usage(
287
+ limit_type: StrategyLimitType,
288
+ workspace: str,
289
+ project: str,
290
+ strategy_name: str,
291
+ ) -> str:
292
+ time_infix = LIMIT_TYPE2KEY_INFIX_GENERATOR[limit_type]()
293
+ return f"active_learning:usage:{workspace}:{project}:{strategy_name}:{time_infix}"
inference/core/active_learning/configuration.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from dataclasses import asdict
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ from inference.core import logger
6
+ from inference.core.active_learning.entities import (
7
+ ActiveLearningConfiguration,
8
+ RoboflowProjectMetadata,
9
+ SamplingMethod,
10
+ )
11
+ from inference.core.active_learning.samplers.close_to_threshold import (
12
+ initialize_close_to_threshold_sampling,
13
+ )
14
+ from inference.core.active_learning.samplers.contains_classes import (
15
+ initialize_classes_based_sampling,
16
+ )
17
+ from inference.core.active_learning.samplers.number_of_detections import (
18
+ initialize_detections_number_based_sampling,
19
+ )
20
+ from inference.core.active_learning.samplers.random import initialize_random_sampling
21
+ from inference.core.cache.base import BaseCache
22
+ from inference.core.exceptions import (
23
+ ActiveLearningConfigurationDecodingError,
24
+ ActiveLearningConfigurationError,
25
+ RoboflowAPINotAuthorizedError,
26
+ RoboflowAPINotNotFoundError,
27
+ )
28
+ from inference.core.roboflow_api import (
29
+ get_roboflow_active_learning_configuration,
30
+ get_roboflow_dataset_type,
31
+ get_roboflow_workspace,
32
+ )
33
+ from inference.core.utils.roboflow import get_model_id_chunks
34
+
35
+ TYPE2SAMPLING_INITIALIZERS = {
36
+ "random": initialize_random_sampling,
37
+ "close_to_threshold": initialize_close_to_threshold_sampling,
38
+ "classes_based": initialize_classes_based_sampling,
39
+ "detections_number_based": initialize_detections_number_based_sampling,
40
+ }
41
+ ACTIVE_LEARNING_CONFIG_CACHE_EXPIRE = 900 # 15 min
42
+
43
+
44
+ def prepare_active_learning_configuration(
45
+ api_key: str,
46
+ model_id: str,
47
+ cache: BaseCache,
48
+ ) -> Optional[ActiveLearningConfiguration]:
49
+ project_metadata = get_roboflow_project_metadata(
50
+ api_key=api_key,
51
+ model_id=model_id,
52
+ cache=cache,
53
+ )
54
+ if not project_metadata.active_learning_configuration.get("enabled", False):
55
+ return None
56
+ logger.info(
57
+ f"Configuring active learning for workspace: {project_metadata.workspace_id}, "
58
+ f"project: {project_metadata.dataset_id} of type: {project_metadata.dataset_type}. "
59
+ f"AL configuration: {project_metadata.active_learning_configuration}"
60
+ )
61
+ return initialise_active_learning_configuration(
62
+ project_metadata=project_metadata,
63
+ )
64
+
65
+
66
+ def prepare_active_learning_configuration_inplace(
67
+ api_key: str,
68
+ model_id: str,
69
+ active_learning_configuration: Optional[dict],
70
+ ) -> Optional[ActiveLearningConfiguration]:
71
+ if (
72
+ active_learning_configuration is None
73
+ or active_learning_configuration.get("enabled", False) is False
74
+ ):
75
+ return None
76
+ dataset_id, version_id = get_model_id_chunks(model_id=model_id)
77
+ workspace_id = get_roboflow_workspace(api_key=api_key)
78
+ dataset_type = get_roboflow_dataset_type(
79
+ api_key=api_key,
80
+ workspace_id=workspace_id,
81
+ dataset_id=dataset_id,
82
+ )
83
+ project_metadata = RoboflowProjectMetadata(
84
+ dataset_id=dataset_id,
85
+ version_id=version_id,
86
+ workspace_id=workspace_id,
87
+ dataset_type=dataset_type,
88
+ active_learning_configuration=active_learning_configuration,
89
+ )
90
+ return initialise_active_learning_configuration(
91
+ project_metadata=project_metadata,
92
+ )
93
+
94
+
95
+ def get_roboflow_project_metadata(
96
+ api_key: str,
97
+ model_id: str,
98
+ cache: BaseCache,
99
+ ) -> RoboflowProjectMetadata:
100
+ logger.info(f"Fetching active learning configuration.")
101
+ config_cache_key = construct_cache_key_for_active_learning_config(
102
+ api_key=api_key, model_id=model_id
103
+ )
104
+ cached_config = cache.get(config_cache_key)
105
+ if cached_config is not None:
106
+ logger.info("Found Active Learning configuration in cache.")
107
+ return parse_cached_roboflow_project_metadata(cached_config=cached_config)
108
+ dataset_id, version_id = get_model_id_chunks(model_id=model_id)
109
+ workspace_id = get_roboflow_workspace(api_key=api_key)
110
+ dataset_type = get_roboflow_dataset_type(
111
+ api_key=api_key,
112
+ workspace_id=workspace_id,
113
+ dataset_id=dataset_id,
114
+ )
115
+ try:
116
+ roboflow_api_configuration = get_roboflow_active_learning_configuration(
117
+ api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id
118
+ )
119
+ except (RoboflowAPINotAuthorizedError, RoboflowAPINotNotFoundError):
120
+ # currently backend returns HTTP 404 if dataset does not exist
121
+ # or workspace_id from api_key indicate that the owner is different,
122
+ # so in the situation when we query for Universe dataset.
123
+ # We want the owner of public dataset to be able to set AL configs
124
+ # and use them, but not other people. At this point it's known
125
+ # that HTTP 404 means not authorised (which will probably change
126
+ # in future iteration of backend) - so on both NotAuth and NotFound
127
+ # errors we assume that we simply cannot use AL with this model and
128
+ # this api_key.
129
+ roboflow_api_configuration = {"enabled": False}
130
+ configuration = RoboflowProjectMetadata(
131
+ dataset_id=dataset_id,
132
+ version_id=version_id,
133
+ workspace_id=workspace_id,
134
+ dataset_type=dataset_type,
135
+ active_learning_configuration=roboflow_api_configuration,
136
+ )
137
+ cache.set(
138
+ key=config_cache_key,
139
+ value=asdict(configuration),
140
+ expire=ACTIVE_LEARNING_CONFIG_CACHE_EXPIRE,
141
+ )
142
+ return configuration
143
+
144
+
145
+ def construct_cache_key_for_active_learning_config(api_key: str, model_id: str) -> str:
146
+ dataset_id = model_id.split("/")[0]
147
+ api_key_hash = hashlib.md5(api_key.encode("utf-8")).hexdigest()
148
+ return f"active_learning:configurations:{api_key_hash}:{dataset_id}"
149
+
150
+
151
+ def parse_cached_roboflow_project_metadata(
152
+ cached_config: dict,
153
+ ) -> RoboflowProjectMetadata:
154
+ try:
155
+ return RoboflowProjectMetadata(**cached_config)
156
+ except Exception as error:
157
+ raise ActiveLearningConfigurationDecodingError(
158
+ f"Failed to initialise Active Learning configuration. Cause: {str(error)}"
159
+ ) from error
160
+
161
+
162
+ def initialise_active_learning_configuration(
163
+ project_metadata: RoboflowProjectMetadata,
164
+ ) -> ActiveLearningConfiguration:
165
+ sampling_methods = initialize_sampling_methods(
166
+ sampling_strategies_configs=project_metadata.active_learning_configuration[
167
+ "sampling_strategies"
168
+ ],
169
+ )
170
+ target_workspace_id = project_metadata.active_learning_configuration.get(
171
+ "target_workspace", project_metadata.workspace_id
172
+ )
173
+ target_dataset_id = project_metadata.active_learning_configuration.get(
174
+ "target_project", project_metadata.dataset_id
175
+ )
176
+ return ActiveLearningConfiguration.init(
177
+ roboflow_api_configuration=project_metadata.active_learning_configuration,
178
+ sampling_methods=sampling_methods,
179
+ workspace_id=target_workspace_id,
180
+ dataset_id=target_dataset_id,
181
+ model_id=f"{project_metadata.dataset_id}/{project_metadata.version_id}",
182
+ )
183
+
184
+
185
+ def initialize_sampling_methods(
186
+ sampling_strategies_configs: List[Dict[str, Any]]
187
+ ) -> List[SamplingMethod]:
188
+ result = []
189
+ for sampling_strategy_config in sampling_strategies_configs:
190
+ sampling_type = sampling_strategy_config["type"]
191
+ if sampling_type not in TYPE2SAMPLING_INITIALIZERS:
192
+ logger.warn(
193
+ f"Could not identify sampling method `{sampling_type}` - skipping initialisation."
194
+ )
195
+ continue
196
+ initializer = TYPE2SAMPLING_INITIALIZERS[sampling_type]
197
+ result.append(initializer(sampling_strategy_config))
198
+ names = set(m.name for m in result)
199
+ if len(names) != len(result):
200
+ raise ActiveLearningConfigurationError(
201
+ "Detected duplication of Active Learning strategies names."
202
+ )
203
+ return result
inference/core/active_learning/core.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import List, Optional, Tuple
3
+ from uuid import uuid4
4
+
5
+ import numpy as np
6
+
7
+ from inference.core import logger
8
+ from inference.core.active_learning.cache_operations import (
9
+ return_strategy_credit,
10
+ use_credit_of_matching_strategy,
11
+ )
12
+ from inference.core.active_learning.entities import (
13
+ ActiveLearningConfiguration,
14
+ ImageDimensions,
15
+ Prediction,
16
+ PredictionType,
17
+ SamplingMethod,
18
+ )
19
+ from inference.core.active_learning.post_processing import (
20
+ adjust_prediction_to_client_scaling_factor,
21
+ encode_prediction,
22
+ )
23
+ from inference.core.cache.base import BaseCache
24
+ from inference.core.env import ACTIVE_LEARNING_TAGS
25
+ from inference.core.roboflow_api import (
26
+ annotate_image_at_roboflow,
27
+ register_image_at_roboflow,
28
+ )
29
+ from inference.core.utils.image_utils import encode_image_to_jpeg_bytes
30
+ from inference.core.utils.preprocess import downscale_image_keeping_aspect_ratio
31
+
32
+
33
+ def execute_sampling(
34
+ image: np.ndarray,
35
+ prediction: Prediction,
36
+ prediction_type: PredictionType,
37
+ sampling_methods: List[SamplingMethod],
38
+ ) -> List[str]:
39
+ matching_strategies = []
40
+ for method in sampling_methods:
41
+ sampling_result = method.sample(image, prediction, prediction_type)
42
+ if sampling_result:
43
+ matching_strategies.append(method.name)
44
+ return matching_strategies
45
+
46
+
47
+ def execute_datapoint_registration(
48
+ cache: BaseCache,
49
+ matching_strategies: List[str],
50
+ image: np.ndarray,
51
+ prediction: Prediction,
52
+ prediction_type: PredictionType,
53
+ configuration: ActiveLearningConfiguration,
54
+ api_key: str,
55
+ batch_name: str,
56
+ ) -> None:
57
+ local_image_id = str(uuid4())
58
+ encoded_image, scaling_factor = prepare_image_to_registration(
59
+ image=image,
60
+ desired_size=configuration.max_image_size,
61
+ jpeg_compression_level=configuration.jpeg_compression_level,
62
+ )
63
+ prediction = adjust_prediction_to_client_scaling_factor(
64
+ prediction=prediction,
65
+ scaling_factor=scaling_factor,
66
+ prediction_type=prediction_type,
67
+ )
68
+ matching_strategies_limits = OrderedDict(
69
+ (strategy_name, configuration.strategies_limits[strategy_name])
70
+ for strategy_name in matching_strategies
71
+ )
72
+ strategy_with_spare_credit = use_credit_of_matching_strategy(
73
+ cache=cache,
74
+ workspace=configuration.workspace_id,
75
+ project=configuration.dataset_id,
76
+ matching_strategies_limits=matching_strategies_limits,
77
+ )
78
+ if strategy_with_spare_credit is None:
79
+ logger.debug(f"Limit on Active Learning strategy reached.")
80
+ return None
81
+ register_datapoint_at_roboflow(
82
+ cache=cache,
83
+ strategy_with_spare_credit=strategy_with_spare_credit,
84
+ encoded_image=encoded_image,
85
+ local_image_id=local_image_id,
86
+ prediction=prediction,
87
+ prediction_type=prediction_type,
88
+ configuration=configuration,
89
+ api_key=api_key,
90
+ batch_name=batch_name,
91
+ )
92
+
93
+
94
+ def prepare_image_to_registration(
95
+ image: np.ndarray,
96
+ desired_size: Optional[ImageDimensions],
97
+ jpeg_compression_level: int,
98
+ ) -> Tuple[bytes, float]:
99
+ scaling_factor = 1.0
100
+ if desired_size is not None:
101
+ height_before_scale = image.shape[0]
102
+ image = downscale_image_keeping_aspect_ratio(
103
+ image=image,
104
+ desired_size=desired_size.to_wh(),
105
+ )
106
+ scaling_factor = image.shape[0] / height_before_scale
107
+ return (
108
+ encode_image_to_jpeg_bytes(image=image, jpeg_quality=jpeg_compression_level),
109
+ scaling_factor,
110
+ )
111
+
112
+
113
+ def register_datapoint_at_roboflow(
114
+ cache: BaseCache,
115
+ strategy_with_spare_credit: str,
116
+ encoded_image: bytes,
117
+ local_image_id: str,
118
+ prediction: Prediction,
119
+ prediction_type: PredictionType,
120
+ configuration: ActiveLearningConfiguration,
121
+ api_key: str,
122
+ batch_name: str,
123
+ ) -> None:
124
+ tags = collect_tags(
125
+ configuration=configuration,
126
+ sampling_strategy=strategy_with_spare_credit,
127
+ )
128
+ roboflow_image_id = safe_register_image_at_roboflow(
129
+ cache=cache,
130
+ strategy_with_spare_credit=strategy_with_spare_credit,
131
+ encoded_image=encoded_image,
132
+ local_image_id=local_image_id,
133
+ configuration=configuration,
134
+ api_key=api_key,
135
+ batch_name=batch_name,
136
+ tags=tags,
137
+ )
138
+ if is_prediction_registration_forbidden(
139
+ prediction=prediction,
140
+ persist_predictions=configuration.persist_predictions,
141
+ roboflow_image_id=roboflow_image_id,
142
+ ):
143
+ return None
144
+ encoded_prediction, prediction_file_type = encode_prediction(
145
+ prediction=prediction, prediction_type=prediction_type
146
+ )
147
+ _ = annotate_image_at_roboflow(
148
+ api_key=api_key,
149
+ dataset_id=configuration.dataset_id,
150
+ local_image_id=local_image_id,
151
+ roboflow_image_id=roboflow_image_id,
152
+ annotation_content=encoded_prediction,
153
+ annotation_file_type=prediction_file_type,
154
+ is_prediction=True,
155
+ )
156
+
157
+
158
+ def collect_tags(
159
+ configuration: ActiveLearningConfiguration, sampling_strategy: str
160
+ ) -> List[str]:
161
+ tags = ACTIVE_LEARNING_TAGS if ACTIVE_LEARNING_TAGS is not None else []
162
+ tags.extend(configuration.tags)
163
+ tags.extend(configuration.strategies_tags[sampling_strategy])
164
+ if configuration.persist_predictions:
165
+ # this replacement is needed due to backend input validation
166
+ tags.append(configuration.model_id.replace("/", "-"))
167
+ return tags
168
+
169
+
170
+ def safe_register_image_at_roboflow(
171
+ cache: BaseCache,
172
+ strategy_with_spare_credit: str,
173
+ encoded_image: bytes,
174
+ local_image_id: str,
175
+ configuration: ActiveLearningConfiguration,
176
+ api_key: str,
177
+ batch_name: str,
178
+ tags: List[str],
179
+ ) -> Optional[str]:
180
+ credit_to_be_returned = False
181
+ try:
182
+ registration_response = register_image_at_roboflow(
183
+ api_key=api_key,
184
+ dataset_id=configuration.dataset_id,
185
+ local_image_id=local_image_id,
186
+ image_bytes=encoded_image,
187
+ batch_name=batch_name,
188
+ tags=tags,
189
+ )
190
+ image_duplicated = registration_response.get("duplicate", False)
191
+ if image_duplicated:
192
+ credit_to_be_returned = True
193
+ logger.warning(f"Image duplication detected: {registration_response}.")
194
+ return None
195
+ return registration_response["id"]
196
+ except Exception as error:
197
+ credit_to_be_returned = True
198
+ raise error
199
+ finally:
200
+ if credit_to_be_returned:
201
+ return_strategy_credit(
202
+ cache=cache,
203
+ workspace=configuration.workspace_id,
204
+ project=configuration.dataset_id,
205
+ strategy_name=strategy_with_spare_credit,
206
+ )
207
+
208
+
209
+ def is_prediction_registration_forbidden(
210
+ prediction: Prediction,
211
+ persist_predictions: bool,
212
+ roboflow_image_id: Optional[str],
213
+ ) -> bool:
214
+ return (
215
+ roboflow_image_id is None
216
+ or persist_predictions is False
217
+ or prediction.get("is_stub", False) is True
218
+ or (len(prediction.get("predictions", [])) == 0 and "top" not in prediction)
219
+ )
inference/core/active_learning/entities.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from inference.core.entities.types import DatasetID, WorkspaceID
8
+ from inference.core.exceptions import ActiveLearningConfigurationDecodingError
9
+
10
+ LocalImageIdentifier = str
11
+ PredictionType = str
12
+ Prediction = dict
13
+ SerialisedPrediction = str
14
+ PredictionFileType = str
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class ImageDimensions:
19
+ height: int
20
+ width: int
21
+
22
+ def to_hw(self) -> Tuple[int, int]:
23
+ return self.height, self.width
24
+
25
+ def to_wh(self) -> Tuple[int, int]:
26
+ return self.width, self.height
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class SamplingMethod:
31
+ name: str
32
+ sample: Callable[[np.ndarray, Prediction, PredictionType], bool]
33
+
34
+
35
+ class BatchReCreationInterval(Enum):
36
+ NEVER = "never"
37
+ DAILY = "daily"
38
+ WEEKLY = "weekly"
39
+ MONTHLY = "monthly"
40
+
41
+
42
+ class StrategyLimitType(Enum):
43
+ MINUTELY = "minutely"
44
+ HOURLY = "hourly"
45
+ DAILY = "daily"
46
+
47
+
48
+ @dataclass(frozen=True)
49
+ class StrategyLimit:
50
+ limit_type: StrategyLimitType
51
+ value: int
52
+
53
+ @classmethod
54
+ def from_dict(cls, specification: dict) -> "StrategyLimit":
55
+ return cls(
56
+ limit_type=StrategyLimitType(specification["type"]),
57
+ value=specification["value"],
58
+ )
59
+
60
+
61
+ @dataclass(frozen=True)
62
+ class ActiveLearningConfiguration:
63
+ max_image_size: Optional[ImageDimensions]
64
+ jpeg_compression_level: int
65
+ persist_predictions: bool
66
+ sampling_methods: List[SamplingMethod]
67
+ batches_name_prefix: str
68
+ batch_recreation_interval: BatchReCreationInterval
69
+ max_batch_images: Optional[int]
70
+ workspace_id: WorkspaceID
71
+ dataset_id: DatasetID
72
+ model_id: str
73
+ strategies_limits: Dict[str, List[StrategyLimit]]
74
+ tags: List[str]
75
+ strategies_tags: Dict[str, List[str]]
76
+
77
+ @classmethod
78
+ def init(
79
+ cls,
80
+ roboflow_api_configuration: Dict[str, Any],
81
+ sampling_methods: List[SamplingMethod],
82
+ workspace_id: WorkspaceID,
83
+ dataset_id: DatasetID,
84
+ model_id: str,
85
+ ) -> "ActiveLearningConfiguration":
86
+ try:
87
+ max_image_size = roboflow_api_configuration.get("max_image_size")
88
+ if max_image_size is not None:
89
+ max_image_size = ImageDimensions(
90
+ height=roboflow_api_configuration["max_image_size"][0],
91
+ width=roboflow_api_configuration["max_image_size"][1],
92
+ )
93
+ strategies_limits = {
94
+ strategy["name"]: [
95
+ StrategyLimit.from_dict(specification=specification)
96
+ for specification in strategy.get("limits", [])
97
+ ]
98
+ for strategy in roboflow_api_configuration["sampling_strategies"]
99
+ }
100
+ strategies_tags = {
101
+ strategy["name"]: strategy.get("tags", [])
102
+ for strategy in roboflow_api_configuration["sampling_strategies"]
103
+ }
104
+ return cls(
105
+ max_image_size=max_image_size,
106
+ jpeg_compression_level=roboflow_api_configuration.get(
107
+ "jpeg_compression_level", 95
108
+ ),
109
+ persist_predictions=roboflow_api_configuration["persist_predictions"],
110
+ sampling_methods=sampling_methods,
111
+ batches_name_prefix=roboflow_api_configuration["batching_strategy"][
112
+ "batches_name_prefix"
113
+ ],
114
+ batch_recreation_interval=BatchReCreationInterval(
115
+ roboflow_api_configuration["batching_strategy"][
116
+ "recreation_interval"
117
+ ]
118
+ ),
119
+ max_batch_images=roboflow_api_configuration["batching_strategy"].get(
120
+ "max_batch_images"
121
+ ),
122
+ workspace_id=workspace_id,
123
+ dataset_id=dataset_id,
124
+ model_id=model_id,
125
+ strategies_limits=strategies_limits,
126
+ tags=roboflow_api_configuration.get("tags", []),
127
+ strategies_tags=strategies_tags,
128
+ )
129
+ except (KeyError, ValueError) as e:
130
+ raise ActiveLearningConfigurationDecodingError(
131
+ f"Failed to initialise Active Learning configuration. Cause: {str(e)}"
132
+ ) from e
133
+
134
+
135
+ @dataclass(frozen=True)
136
+ class RoboflowProjectMetadata:
137
+ dataset_id: DatasetID
138
+ version_id: str
139
+ workspace_id: WorkspaceID
140
+ dataset_type: str
141
+ active_learning_configuration: dict
inference/core/active_learning/middlewares.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import queue
2
+ from queue import Queue
3
+ from threading import Thread
4
+ from typing import Any, List, Optional
5
+
6
+ from inference.core import logger
7
+ from inference.core.active_learning.accounting import image_can_be_submitted_to_batch
8
+ from inference.core.active_learning.batching import generate_batch_name
9
+ from inference.core.active_learning.configuration import (
10
+ prepare_active_learning_configuration,
11
+ prepare_active_learning_configuration_inplace,
12
+ )
13
+ from inference.core.active_learning.core import (
14
+ execute_datapoint_registration,
15
+ execute_sampling,
16
+ )
17
+ from inference.core.active_learning.entities import (
18
+ ActiveLearningConfiguration,
19
+ Prediction,
20
+ PredictionType,
21
+ )
22
+ from inference.core.cache.base import BaseCache
23
+ from inference.core.utils.image_utils import load_image
24
+
25
+ MAX_REGISTRATION_QUEUE_SIZE = 512
26
+
27
+
28
+ class NullActiveLearningMiddleware:
29
+ def register_batch(
30
+ self,
31
+ inference_inputs: List[Any],
32
+ predictions: List[Prediction],
33
+ prediction_type: PredictionType,
34
+ disable_preproc_auto_orient: bool = False,
35
+ ) -> None:
36
+ pass
37
+
38
+ def register(
39
+ self,
40
+ inference_input: Any,
41
+ prediction: dict,
42
+ prediction_type: PredictionType,
43
+ disable_preproc_auto_orient: bool = False,
44
+ ) -> None:
45
+ pass
46
+
47
+ def start_registration_thread(self) -> None:
48
+ pass
49
+
50
+ def stop_registration_thread(self) -> None:
51
+ pass
52
+
53
+ def __enter__(self) -> "NullActiveLearningMiddleware":
54
+ return self
55
+
56
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
57
+ pass
58
+
59
+
60
+ class ActiveLearningMiddleware:
61
+ @classmethod
62
+ def init(
63
+ cls, api_key: str, model_id: str, cache: BaseCache
64
+ ) -> "ActiveLearningMiddleware":
65
+ configuration = prepare_active_learning_configuration(
66
+ api_key=api_key,
67
+ model_id=model_id,
68
+ cache=cache,
69
+ )
70
+ return cls(
71
+ api_key=api_key,
72
+ configuration=configuration,
73
+ cache=cache,
74
+ )
75
+
76
+ @classmethod
77
+ def init_from_config(
78
+ cls, api_key: str, model_id: str, cache: BaseCache, config: Optional[dict]
79
+ ) -> "ActiveLearningMiddleware":
80
+ configuration = prepare_active_learning_configuration_inplace(
81
+ api_key=api_key,
82
+ model_id=model_id,
83
+ active_learning_configuration=config,
84
+ )
85
+ return cls(
86
+ api_key=api_key,
87
+ configuration=configuration,
88
+ cache=cache,
89
+ )
90
+
91
+ def __init__(
92
+ self,
93
+ api_key: str,
94
+ configuration: Optional[ActiveLearningConfiguration],
95
+ cache: BaseCache,
96
+ ):
97
+ self._api_key = api_key
98
+ self._configuration = configuration
99
+ self._cache = cache
100
+
101
+ def register_batch(
102
+ self,
103
+ inference_inputs: List[Any],
104
+ predictions: List[Prediction],
105
+ prediction_type: PredictionType,
106
+ disable_preproc_auto_orient: bool = False,
107
+ ) -> None:
108
+ for inference_input, prediction in zip(inference_inputs, predictions):
109
+ self.register(
110
+ inference_input=inference_input,
111
+ prediction=prediction,
112
+ prediction_type=prediction_type,
113
+ disable_preproc_auto_orient=disable_preproc_auto_orient,
114
+ )
115
+
116
+ def register(
117
+ self,
118
+ inference_input: Any,
119
+ prediction: dict,
120
+ prediction_type: PredictionType,
121
+ disable_preproc_auto_orient: bool = False,
122
+ ) -> None:
123
+ self._execute_registration(
124
+ inference_input=inference_input,
125
+ prediction=prediction,
126
+ prediction_type=prediction_type,
127
+ disable_preproc_auto_orient=disable_preproc_auto_orient,
128
+ )
129
+
130
+ def _execute_registration(
131
+ self,
132
+ inference_input: Any,
133
+ prediction: dict,
134
+ prediction_type: PredictionType,
135
+ disable_preproc_auto_orient: bool = False,
136
+ ) -> None:
137
+ if self._configuration is None:
138
+ return None
139
+ image, is_bgr = load_image(
140
+ value=inference_input,
141
+ disable_preproc_auto_orient=disable_preproc_auto_orient,
142
+ )
143
+ if not is_bgr:
144
+ image = image[:, :, ::-1]
145
+ matching_strategies = execute_sampling(
146
+ image=image,
147
+ prediction=prediction,
148
+ prediction_type=prediction_type,
149
+ sampling_methods=self._configuration.sampling_methods,
150
+ )
151
+ if len(matching_strategies) == 0:
152
+ return None
153
+ batch_name = generate_batch_name(configuration=self._configuration)
154
+ if not image_can_be_submitted_to_batch(
155
+ batch_name=batch_name,
156
+ workspace_id=self._configuration.workspace_id,
157
+ dataset_id=self._configuration.dataset_id,
158
+ max_batch_images=self._configuration.max_batch_images,
159
+ api_key=self._api_key,
160
+ ):
161
+ logger.debug(f"Limit on Active Learning batch size reached.")
162
+ return None
163
+ execute_datapoint_registration(
164
+ cache=self._cache,
165
+ matching_strategies=matching_strategies,
166
+ image=image,
167
+ prediction=prediction,
168
+ prediction_type=prediction_type,
169
+ configuration=self._configuration,
170
+ api_key=self._api_key,
171
+ batch_name=batch_name,
172
+ )
173
+
174
+
175
+ class ThreadingActiveLearningMiddleware(ActiveLearningMiddleware):
176
+ @classmethod
177
+ def init(
178
+ cls,
179
+ api_key: str,
180
+ model_id: str,
181
+ cache: BaseCache,
182
+ max_queue_size: int = MAX_REGISTRATION_QUEUE_SIZE,
183
+ ) -> "ThreadingActiveLearningMiddleware":
184
+ configuration = prepare_active_learning_configuration(
185
+ api_key=api_key,
186
+ model_id=model_id,
187
+ cache=cache,
188
+ )
189
+ task_queue = Queue(max_queue_size)
190
+ return cls(
191
+ api_key=api_key,
192
+ configuration=configuration,
193
+ cache=cache,
194
+ task_queue=task_queue,
195
+ )
196
+
197
+ @classmethod
198
+ def init_from_config(
199
+ cls,
200
+ api_key: str,
201
+ model_id: str,
202
+ cache: BaseCache,
203
+ config: Optional[dict],
204
+ max_queue_size: int = MAX_REGISTRATION_QUEUE_SIZE,
205
+ ) -> "ThreadingActiveLearningMiddleware":
206
+ configuration = prepare_active_learning_configuration_inplace(
207
+ api_key=api_key,
208
+ model_id=model_id,
209
+ active_learning_configuration=config,
210
+ )
211
+ task_queue = Queue(max_queue_size)
212
+ return cls(
213
+ api_key=api_key,
214
+ configuration=configuration,
215
+ cache=cache,
216
+ task_queue=task_queue,
217
+ )
218
+
219
+ def __init__(
220
+ self,
221
+ api_key: str,
222
+ configuration: ActiveLearningConfiguration,
223
+ cache: BaseCache,
224
+ task_queue: Queue,
225
+ ):
226
+ super().__init__(api_key=api_key, configuration=configuration, cache=cache)
227
+ self._task_queue = task_queue
228
+ self._registration_thread: Optional[Thread] = None
229
+
230
+ def register(
231
+ self,
232
+ inference_input: Any,
233
+ prediction: dict,
234
+ prediction_type: PredictionType,
235
+ disable_preproc_auto_orient: bool = False,
236
+ ) -> None:
237
+ logger.debug(f"Putting registration task into queue")
238
+ try:
239
+ self._task_queue.put_nowait(
240
+ (
241
+ inference_input,
242
+ prediction,
243
+ prediction_type,
244
+ disable_preproc_auto_orient,
245
+ )
246
+ )
247
+ except queue.Full:
248
+ logger.warning(
249
+ f"Dropping datapoint registered in Active Learning due to insufficient processing "
250
+ f"capabilities."
251
+ )
252
+
253
+ def start_registration_thread(self) -> None:
254
+ if self._registration_thread is not None:
255
+ logger.warning(f"Registration thread already started.")
256
+ return None
257
+ logger.debug("Staring registration thread")
258
+ self._registration_thread = Thread(target=self._consume_queue)
259
+ self._registration_thread.start()
260
+
261
+ def stop_registration_thread(self) -> None:
262
+ if self._registration_thread is None:
263
+ logger.warning("Registration thread is already stopped.")
264
+ return None
265
+ logger.debug("Stopping registration thread")
266
+ self._task_queue.put(None)
267
+ self._registration_thread.join()
268
+ if self._registration_thread.is_alive():
269
+ logger.warning(f"Registration thread stopping was unsuccessful.")
270
+ self._registration_thread = None
271
+
272
+ def _consume_queue(self) -> None:
273
+ queue_closed = False
274
+ while not queue_closed:
275
+ queue_closed = self._consume_queue_task()
276
+
277
+ def _consume_queue_task(self) -> bool:
278
+ logger.debug("Consuming registration task")
279
+ task = self._task_queue.get()
280
+ logger.debug("Received registration task")
281
+ if task is None:
282
+ logger.debug("Terminating registration thread")
283
+ self._task_queue.task_done()
284
+ return True
285
+ inference_input, prediction, prediction_type, disable_preproc_auto_orient = task
286
+ try:
287
+ self._execute_registration(
288
+ inference_input=inference_input,
289
+ prediction=prediction,
290
+ prediction_type=prediction_type,
291
+ disable_preproc_auto_orient=disable_preproc_auto_orient,
292
+ )
293
+ except Exception as error:
294
+ # Error handling to be decided
295
+ logger.warning(
296
+ f"Error in datapoint registration for Active Learning. Details: {error}. "
297
+ f"Error is suppressed in favour of normal operations of registration thread."
298
+ )
299
+ self._task_queue.task_done()
300
+ return False
301
+
302
+ def __enter__(self) -> "ThreadingActiveLearningMiddleware":
303
+ self.start_registration_thread()
304
+ return self
305
+
306
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
307
+ self.stop_registration_thread()
inference/core/active_learning/post_processing.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import List, Tuple
3
+
4
+ from inference.core.active_learning.entities import (
5
+ Prediction,
6
+ PredictionFileType,
7
+ PredictionType,
8
+ SerialisedPrediction,
9
+ )
10
+ from inference.core.constants import (
11
+ CLASSIFICATION_TASK,
12
+ INSTANCE_SEGMENTATION_TASK,
13
+ OBJECT_DETECTION_TASK,
14
+ )
15
+ from inference.core.exceptions import PredictionFormatNotSupported
16
+
17
+
18
+ def adjust_prediction_to_client_scaling_factor(
19
+ prediction: dict, scaling_factor: float, prediction_type: PredictionType
20
+ ) -> dict:
21
+ if abs(scaling_factor - 1.0) < 1e-5:
22
+ return prediction
23
+ if "image" in prediction:
24
+ prediction["image"] = {
25
+ "width": round(prediction["image"]["width"] / scaling_factor),
26
+ "height": round(prediction["image"]["height"] / scaling_factor),
27
+ }
28
+ if predictions_should_not_be_post_processed(
29
+ prediction=prediction, prediction_type=prediction_type
30
+ ):
31
+ return prediction
32
+ if prediction_type == INSTANCE_SEGMENTATION_TASK:
33
+ prediction["predictions"] = (
34
+ adjust_prediction_with_bbox_and_points_to_client_scaling_factor(
35
+ predictions=prediction["predictions"],
36
+ scaling_factor=scaling_factor,
37
+ points_key="points",
38
+ )
39
+ )
40
+ if prediction_type == OBJECT_DETECTION_TASK:
41
+ prediction["predictions"] = (
42
+ adjust_object_detection_predictions_to_client_scaling_factor(
43
+ predictions=prediction["predictions"],
44
+ scaling_factor=scaling_factor,
45
+ )
46
+ )
47
+ return prediction
48
+
49
+
50
+ def predictions_should_not_be_post_processed(
51
+ prediction: dict, prediction_type: PredictionType
52
+ ) -> bool:
53
+ # excluding from post-processing classification output, stub-output and empty predictions
54
+ return (
55
+ "is_stub" in prediction
56
+ or "predictions" not in prediction
57
+ or CLASSIFICATION_TASK in prediction_type
58
+ or len(prediction["predictions"]) == 0
59
+ )
60
+
61
+
62
+ def adjust_object_detection_predictions_to_client_scaling_factor(
63
+ predictions: List[dict],
64
+ scaling_factor: float,
65
+ ) -> List[dict]:
66
+ result = []
67
+ for prediction in predictions:
68
+ prediction = adjust_bbox_coordinates_to_client_scaling_factor(
69
+ bbox=prediction,
70
+ scaling_factor=scaling_factor,
71
+ )
72
+ result.append(prediction)
73
+ return result
74
+
75
+
76
+ def adjust_prediction_with_bbox_and_points_to_client_scaling_factor(
77
+ predictions: List[dict],
78
+ scaling_factor: float,
79
+ points_key: str,
80
+ ) -> List[dict]:
81
+ result = []
82
+ for prediction in predictions:
83
+ prediction = adjust_bbox_coordinates_to_client_scaling_factor(
84
+ bbox=prediction,
85
+ scaling_factor=scaling_factor,
86
+ )
87
+ prediction[points_key] = adjust_points_coordinates_to_client_scaling_factor(
88
+ points=prediction[points_key],
89
+ scaling_factor=scaling_factor,
90
+ )
91
+ result.append(prediction)
92
+ return result
93
+
94
+
95
+ def adjust_bbox_coordinates_to_client_scaling_factor(
96
+ bbox: dict,
97
+ scaling_factor: float,
98
+ ) -> dict:
99
+ bbox["x"] = bbox["x"] / scaling_factor
100
+ bbox["y"] = bbox["y"] / scaling_factor
101
+ bbox["width"] = bbox["width"] / scaling_factor
102
+ bbox["height"] = bbox["height"] / scaling_factor
103
+ return bbox
104
+
105
+
106
+ def adjust_points_coordinates_to_client_scaling_factor(
107
+ points: List[dict],
108
+ scaling_factor: float,
109
+ ) -> List[dict]:
110
+ result = []
111
+ for point in points:
112
+ point["x"] = point["x"] / scaling_factor
113
+ point["y"] = point["y"] / scaling_factor
114
+ result.append(point)
115
+ return result
116
+
117
+
118
+ def encode_prediction(
119
+ prediction: Prediction,
120
+ prediction_type: PredictionType,
121
+ ) -> Tuple[SerialisedPrediction, PredictionFileType]:
122
+ if CLASSIFICATION_TASK not in prediction_type:
123
+ return json.dumps(prediction), "json"
124
+ if "top" in prediction:
125
+ return prediction["top"], "txt"
126
+ raise PredictionFormatNotSupported(
127
+ f"Prediction type or prediction format not supported."
128
+ )
inference/core/active_learning/samplers/__init__.py ADDED
File without changes
inference/core/active_learning/samplers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (201 Bytes). View file
 
inference/core/active_learning/samplers/__pycache__/close_to_threshold.cpython-310.pyc ADDED
Binary file (4.68 kB). View file
 
inference/core/active_learning/samplers/__pycache__/contains_classes.cpython-310.pyc ADDED
Binary file (1.71 kB). View file
 
inference/core/active_learning/samplers/__pycache__/number_of_detections.cpython-310.pyc ADDED
Binary file (2.74 kB). View file
 
inference/core/active_learning/samplers/__pycache__/random.cpython-310.pyc ADDED
Binary file (1.22 kB). View file
 
inference/core/active_learning/samplers/close_to_threshold.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from functools import partial
3
+ from typing import Any, Dict, Optional, Set
4
+
5
+ import numpy as np
6
+
7
+ from inference.core.active_learning.entities import (
8
+ Prediction,
9
+ PredictionType,
10
+ SamplingMethod,
11
+ )
12
+ from inference.core.constants import (
13
+ CLASSIFICATION_TASK,
14
+ INSTANCE_SEGMENTATION_TASK,
15
+ KEYPOINTS_DETECTION_TASK,
16
+ OBJECT_DETECTION_TASK,
17
+ )
18
+ from inference.core.exceptions import ActiveLearningConfigurationError
19
+
20
+ ELIGIBLE_PREDICTION_TYPES = {
21
+ CLASSIFICATION_TASK,
22
+ INSTANCE_SEGMENTATION_TASK,
23
+ KEYPOINTS_DETECTION_TASK,
24
+ OBJECT_DETECTION_TASK,
25
+ }
26
+
27
+
28
+ def initialize_close_to_threshold_sampling(
29
+ strategy_config: Dict[str, Any]
30
+ ) -> SamplingMethod:
31
+ try:
32
+ selected_class_names = strategy_config.get("selected_class_names")
33
+ if selected_class_names is not None:
34
+ selected_class_names = set(selected_class_names)
35
+ sample_function = partial(
36
+ sample_close_to_threshold,
37
+ selected_class_names=selected_class_names,
38
+ threshold=strategy_config["threshold"],
39
+ epsilon=strategy_config["epsilon"],
40
+ only_top_classes=strategy_config.get("only_top_classes", True),
41
+ minimum_objects_close_to_threshold=strategy_config.get(
42
+ "minimum_objects_close_to_threshold",
43
+ 1,
44
+ ),
45
+ probability=strategy_config["probability"],
46
+ )
47
+ return SamplingMethod(
48
+ name=strategy_config["name"],
49
+ sample=sample_function,
50
+ )
51
+ except KeyError as error:
52
+ raise ActiveLearningConfigurationError(
53
+ f"In configuration of `close_to_threshold_sampling` missing key detected: {error}."
54
+ ) from error
55
+
56
+
57
+ def sample_close_to_threshold(
58
+ image: np.ndarray,
59
+ prediction: Prediction,
60
+ prediction_type: PredictionType,
61
+ selected_class_names: Optional[Set[str]],
62
+ threshold: float,
63
+ epsilon: float,
64
+ only_top_classes: bool,
65
+ minimum_objects_close_to_threshold: int,
66
+ probability: float,
67
+ ) -> bool:
68
+ if is_prediction_a_stub(prediction=prediction):
69
+ return False
70
+ if prediction_type not in ELIGIBLE_PREDICTION_TYPES:
71
+ return False
72
+ close_to_threshold = prediction_is_close_to_threshold(
73
+ prediction=prediction,
74
+ prediction_type=prediction_type,
75
+ selected_class_names=selected_class_names,
76
+ threshold=threshold,
77
+ epsilon=epsilon,
78
+ only_top_classes=only_top_classes,
79
+ minimum_objects_close_to_threshold=minimum_objects_close_to_threshold,
80
+ )
81
+ if not close_to_threshold:
82
+ return False
83
+ return random.random() < probability
84
+
85
+
86
+ def is_prediction_a_stub(prediction: Prediction) -> bool:
87
+ return prediction.get("is_stub", False)
88
+
89
+
90
+ def prediction_is_close_to_threshold(
91
+ prediction: Prediction,
92
+ prediction_type: PredictionType,
93
+ selected_class_names: Optional[Set[str]],
94
+ threshold: float,
95
+ epsilon: float,
96
+ only_top_classes: bool,
97
+ minimum_objects_close_to_threshold: int,
98
+ ) -> bool:
99
+ if CLASSIFICATION_TASK not in prediction_type:
100
+ return detections_are_close_to_threshold(
101
+ prediction=prediction,
102
+ selected_class_names=selected_class_names,
103
+ threshold=threshold,
104
+ epsilon=epsilon,
105
+ minimum_objects_close_to_threshold=minimum_objects_close_to_threshold,
106
+ )
107
+ checker = multi_label_classification_prediction_is_close_to_threshold
108
+ if "top" in prediction:
109
+ checker = multi_class_classification_prediction_is_close_to_threshold
110
+ return checker(
111
+ prediction=prediction,
112
+ selected_class_names=selected_class_names,
113
+ threshold=threshold,
114
+ epsilon=epsilon,
115
+ only_top_classes=only_top_classes,
116
+ )
117
+
118
+
119
+ def multi_class_classification_prediction_is_close_to_threshold(
120
+ prediction: Prediction,
121
+ selected_class_names: Optional[Set[str]],
122
+ threshold: float,
123
+ epsilon: float,
124
+ only_top_classes: bool,
125
+ ) -> bool:
126
+ if only_top_classes:
127
+ return (
128
+ multi_class_classification_prediction_is_close_to_threshold_for_top_class(
129
+ prediction=prediction,
130
+ selected_class_names=selected_class_names,
131
+ threshold=threshold,
132
+ epsilon=epsilon,
133
+ )
134
+ )
135
+ for prediction_details in prediction["predictions"]:
136
+ if class_to_be_excluded(
137
+ class_name=prediction_details["class"],
138
+ selected_class_names=selected_class_names,
139
+ ):
140
+ continue
141
+ if is_close_to_threshold(
142
+ value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon
143
+ ):
144
+ return True
145
+ return False
146
+
147
+
148
+ def multi_class_classification_prediction_is_close_to_threshold_for_top_class(
149
+ prediction: Prediction,
150
+ selected_class_names: Optional[Set[str]],
151
+ threshold: float,
152
+ epsilon: float,
153
+ ) -> bool:
154
+ if (
155
+ selected_class_names is not None
156
+ and prediction["top"] not in selected_class_names
157
+ ):
158
+ return False
159
+ return abs(prediction["confidence"] - threshold) < epsilon
160
+
161
+
162
+ def multi_label_classification_prediction_is_close_to_threshold(
163
+ prediction: Prediction,
164
+ selected_class_names: Optional[Set[str]],
165
+ threshold: float,
166
+ epsilon: float,
167
+ only_top_classes: bool,
168
+ ) -> bool:
169
+ predicted_classes = set(prediction["predicted_classes"])
170
+ for class_name, prediction_details in prediction["predictions"].items():
171
+ if only_top_classes and class_name not in predicted_classes:
172
+ continue
173
+ if class_to_be_excluded(
174
+ class_name=class_name, selected_class_names=selected_class_names
175
+ ):
176
+ continue
177
+ if is_close_to_threshold(
178
+ value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon
179
+ ):
180
+ return True
181
+ return False
182
+
183
+
184
+ def detections_are_close_to_threshold(
185
+ prediction: Prediction,
186
+ selected_class_names: Optional[Set[str]],
187
+ threshold: float,
188
+ epsilon: float,
189
+ minimum_objects_close_to_threshold: int,
190
+ ) -> bool:
191
+ detections_close_to_threshold = count_detections_close_to_threshold(
192
+ prediction=prediction,
193
+ selected_class_names=selected_class_names,
194
+ threshold=threshold,
195
+ epsilon=epsilon,
196
+ )
197
+ return detections_close_to_threshold >= minimum_objects_close_to_threshold
198
+
199
+
200
+ def count_detections_close_to_threshold(
201
+ prediction: Prediction,
202
+ selected_class_names: Optional[Set[str]],
203
+ threshold: float,
204
+ epsilon: float,
205
+ ) -> int:
206
+ counter = 0
207
+ for prediction_details in prediction["predictions"]:
208
+ if class_to_be_excluded(
209
+ class_name=prediction_details["class"],
210
+ selected_class_names=selected_class_names,
211
+ ):
212
+ continue
213
+ if is_close_to_threshold(
214
+ value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon
215
+ ):
216
+ counter += 1
217
+ return counter
218
+
219
+
220
+ def class_to_be_excluded(
221
+ class_name: str, selected_class_names: Optional[Set[str]]
222
+ ) -> bool:
223
+ return selected_class_names is not None and class_name not in selected_class_names
224
+
225
+
226
+ def is_close_to_threshold(value: float, threshold: float, epsilon: float) -> bool:
227
+ return abs(value - threshold) < epsilon
inference/core/active_learning/samplers/contains_classes.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, Dict, Set
3
+
4
+ import numpy as np
5
+
6
+ from inference.core.active_learning.entities import (
7
+ Prediction,
8
+ PredictionType,
9
+ SamplingMethod,
10
+ )
11
+ from inference.core.active_learning.samplers.close_to_threshold import (
12
+ sample_close_to_threshold,
13
+ )
14
+ from inference.core.constants import CLASSIFICATION_TASK
15
+ from inference.core.exceptions import ActiveLearningConfigurationError
16
+
17
+ ELIGIBLE_PREDICTION_TYPES = {CLASSIFICATION_TASK}
18
+
19
+
20
+ def initialize_classes_based_sampling(
21
+ strategy_config: Dict[str, Any]
22
+ ) -> SamplingMethod:
23
+ try:
24
+ sample_function = partial(
25
+ sample_based_on_classes,
26
+ selected_class_names=set(strategy_config["selected_class_names"]),
27
+ probability=strategy_config["probability"],
28
+ )
29
+ return SamplingMethod(
30
+ name=strategy_config["name"],
31
+ sample=sample_function,
32
+ )
33
+ except KeyError as error:
34
+ raise ActiveLearningConfigurationError(
35
+ f"In configuration of `classes_based_sampling` missing key detected: {error}."
36
+ ) from error
37
+
38
+
39
+ def sample_based_on_classes(
40
+ image: np.ndarray,
41
+ prediction: Prediction,
42
+ prediction_type: PredictionType,
43
+ selected_class_names: Set[str],
44
+ probability: float,
45
+ ) -> bool:
46
+ if prediction_type not in ELIGIBLE_PREDICTION_TYPES:
47
+ return False
48
+ return sample_close_to_threshold(
49
+ image=image,
50
+ prediction=prediction,
51
+ prediction_type=prediction_type,
52
+ selected_class_names=selected_class_names,
53
+ threshold=0.5,
54
+ epsilon=1.0,
55
+ only_top_classes=True,
56
+ minimum_objects_close_to_threshold=1,
57
+ probability=probability,
58
+ )
inference/core/active_learning/samplers/number_of_detections.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from functools import partial
3
+ from typing import Any, Dict, Optional, Set
4
+
5
+ import numpy as np
6
+
7
+ from inference.core.active_learning.entities import (
8
+ Prediction,
9
+ PredictionType,
10
+ SamplingMethod,
11
+ )
12
+ from inference.core.active_learning.samplers.close_to_threshold import (
13
+ count_detections_close_to_threshold,
14
+ is_prediction_a_stub,
15
+ )
16
+ from inference.core.constants import (
17
+ INSTANCE_SEGMENTATION_TASK,
18
+ KEYPOINTS_DETECTION_TASK,
19
+ OBJECT_DETECTION_TASK,
20
+ )
21
+ from inference.core.exceptions import ActiveLearningConfigurationError
22
+
23
+ ELIGIBLE_PREDICTION_TYPES = {
24
+ INSTANCE_SEGMENTATION_TASK,
25
+ KEYPOINTS_DETECTION_TASK,
26
+ OBJECT_DETECTION_TASK,
27
+ }
28
+
29
+
30
+ def initialize_detections_number_based_sampling(
31
+ strategy_config: Dict[str, Any]
32
+ ) -> SamplingMethod:
33
+ try:
34
+ more_than = strategy_config.get("more_than")
35
+ less_than = strategy_config.get("less_than")
36
+ ensure_range_configuration_is_valid(more_than=more_than, less_than=less_than)
37
+ selected_class_names = strategy_config.get("selected_class_names")
38
+ if selected_class_names is not None:
39
+ selected_class_names = set(selected_class_names)
40
+ sample_function = partial(
41
+ sample_based_on_detections_number,
42
+ less_than=less_than,
43
+ more_than=more_than,
44
+ selected_class_names=selected_class_names,
45
+ probability=strategy_config["probability"],
46
+ )
47
+ return SamplingMethod(
48
+ name=strategy_config["name"],
49
+ sample=sample_function,
50
+ )
51
+ except KeyError as error:
52
+ raise ActiveLearningConfigurationError(
53
+ f"In configuration of `detections_number_based_sampling` missing key detected: {error}."
54
+ ) from error
55
+
56
+
57
+ def sample_based_on_detections_number(
58
+ image: np.ndarray,
59
+ prediction: Prediction,
60
+ prediction_type: PredictionType,
61
+ more_than: Optional[int],
62
+ less_than: Optional[int],
63
+ selected_class_names: Optional[Set[str]],
64
+ probability: float,
65
+ ) -> bool:
66
+ if is_prediction_a_stub(prediction=prediction):
67
+ return False
68
+ if prediction_type not in ELIGIBLE_PREDICTION_TYPES:
69
+ return False
70
+ detections_close_to_threshold = count_detections_close_to_threshold(
71
+ prediction=prediction,
72
+ selected_class_names=selected_class_names,
73
+ threshold=0.5,
74
+ epsilon=1.0,
75
+ )
76
+ if is_in_range(
77
+ value=detections_close_to_threshold, less_than=less_than, more_than=more_than
78
+ ):
79
+ return random.random() < probability
80
+ return False
81
+
82
+
83
+ def is_in_range(
84
+ value: int,
85
+ more_than: Optional[int],
86
+ less_than: Optional[int],
87
+ ) -> bool:
88
+ # calculates value > more_than and value < less_than, with optional borders of range
89
+ less_than_satisfied, more_than_satisfied = less_than is None, more_than is None
90
+ if less_than is not None and value < less_than:
91
+ less_than_satisfied = True
92
+ if more_than is not None and value > more_than:
93
+ more_than_satisfied = True
94
+ return less_than_satisfied and more_than_satisfied
95
+
96
+
97
+ def ensure_range_configuration_is_valid(
98
+ more_than: Optional[int],
99
+ less_than: Optional[int],
100
+ ) -> None:
101
+ if more_than is None or less_than is None:
102
+ return None
103
+ if more_than >= less_than:
104
+ raise ActiveLearningConfigurationError(
105
+ f"Misconfiguration of detections number sampling: "
106
+ f"`more_than` parameter ({more_than}) >= `less_than` ({less_than})."
107
+ )
inference/core/active_learning/samplers/random.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from functools import partial
3
+ from typing import Any, Dict
4
+
5
+ import numpy as np
6
+
7
+ from inference.core.active_learning.entities import (
8
+ Prediction,
9
+ PredictionType,
10
+ SamplingMethod,
11
+ )
12
+ from inference.core.exceptions import ActiveLearningConfigurationError
13
+
14
+
15
+ def initialize_random_sampling(strategy_config: Dict[str, Any]) -> SamplingMethod:
16
+ try:
17
+ sample_function = partial(
18
+ sample_randomly,
19
+ traffic_percentage=strategy_config["traffic_percentage"],
20
+ )
21
+ return SamplingMethod(
22
+ name=strategy_config["name"],
23
+ sample=sample_function,
24
+ )
25
+ except KeyError as error:
26
+ raise ActiveLearningConfigurationError(
27
+ f"In configuration of `random_sampling` missing key detected: {error}."
28
+ ) from error
29
+
30
+
31
+ def sample_randomly(
32
+ image: np.ndarray,
33
+ prediction: Prediction,
34
+ prediction_type: PredictionType,
35
+ traffic_percentage: float,
36
+ ) -> bool:
37
+ return random.random() < traffic_percentage
inference/core/active_learning/utils.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timedelta
2
+
3
+ TIMESTAMP_FORMAT = "%Y_%m_%d"
4
+
5
+
6
+ def generate_today_timestamp() -> str:
7
+ return datetime.today().strftime(TIMESTAMP_FORMAT)
8
+
9
+
10
+ def generate_start_timestamp_for_this_week() -> str:
11
+ today = datetime.today()
12
+ return (today - timedelta(days=today.weekday())).strftime(TIMESTAMP_FORMAT)
13
+
14
+
15
+ def generate_start_timestamp_for_this_month() -> str:
16
+ return datetime.today().replace(day=1).strftime(TIMESTAMP_FORMAT)
inference/core/cache/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from redis.exceptions import ConnectionError, TimeoutError
2
+
3
+ from inference.core import logger
4
+ from inference.core.cache.memory import MemoryCache
5
+ from inference.core.cache.redis import RedisCache
6
+ from inference.core.env import REDIS_HOST, REDIS_PORT, REDIS_SSL, REDIS_TIMEOUT
7
+
8
+ if REDIS_HOST is not None:
9
+ try:
10
+ cache = RedisCache(
11
+ host=REDIS_HOST, port=REDIS_PORT, ssl=REDIS_SSL, timeout=REDIS_TIMEOUT
12
+ )
13
+ logger.info(f"Redis Cache initialised")
14
+ except (TimeoutError, ConnectionError):
15
+ logger.error(
16
+ f"Could not connect to Redis under {REDIS_HOST}:{REDIS_PORT}. MemoryCache to be used."
17
+ )
18
+ cache = MemoryCache()
19
+ logger.info(f"Memory Cache initialised")
20
+ else:
21
+ cache = MemoryCache()
22
+ logger.info(f"Memory Cache initialised")
inference/core/cache/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (864 Bytes). View file
 
inference/core/cache/__pycache__/base.cpython-310.pyc ADDED
Binary file (4.93 kB). View file
 
inference/core/cache/__pycache__/memory.cpython-310.pyc ADDED
Binary file (6.56 kB). View file
 
inference/core/cache/__pycache__/model_artifacts.cpython-310.pyc ADDED
Binary file (3.17 kB). View file
 
inference/core/cache/__pycache__/redis.cpython-310.pyc ADDED
Binary file (7.3 kB). View file
 
inference/core/cache/__pycache__/serializers.cpython-310.pyc ADDED
Binary file (1.91 kB). View file
 
inference/core/cache/base.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from typing import Any, Optional
3
+
4
+ from inference.core import logger
5
+
6
+
7
+ class BaseCache:
8
+ """
9
+ BaseCache is an abstract base class that defines the interface for a cache.
10
+ """
11
+
12
+ def get(self, key: str):
13
+ """
14
+ Gets the value associated with the given key.
15
+
16
+ Args:
17
+ key (str): The key to retrieve the value.
18
+
19
+ Raises:
20
+ NotImplementedError: This method must be implemented by subclasses.
21
+ """
22
+ raise NotImplementedError()
23
+
24
+ def set(self, key: str, value: str, expire: float = None):
25
+ """
26
+ Sets a value for a given key with an optional expire time.
27
+
28
+ Args:
29
+ key (str): The key to store the value.
30
+ value (str): The value to store.
31
+ expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.
32
+
33
+ Raises:
34
+ NotImplementedError: This method must be implemented by subclasses.
35
+ """
36
+ raise NotImplementedError()
37
+
38
+ def zadd(self, key: str, value: str, score: float, expire: float = None):
39
+ """
40
+ Adds a member with the specified score to the sorted set stored at key.
41
+
42
+ Args:
43
+ key (str): The key of the sorted set.
44
+ value (str): The value to add to the sorted set.
45
+ score (float): The score associated with the value.
46
+ expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.
47
+
48
+ Raises:
49
+ NotImplementedError: This method must be implemented by subclasses.
50
+ """
51
+ raise NotImplementedError()
52
+
53
+ def zrangebyscore(
54
+ self,
55
+ key: str,
56
+ min: Optional[float] = -1,
57
+ max: Optional[float] = float("inf"),
58
+ withscores: bool = False,
59
+ ):
60
+ """
61
+ Retrieves a range of members from a sorted set.
62
+
63
+ Args:
64
+ key (str): The key of the sorted set.
65
+ start (int, optional): The starting index of the range. Defaults to -1.
66
+ stop (int, optional): The ending index of the range. Defaults to float("inf").
67
+ withscores (bool, optional): Whether to return the scores along with the values. Defaults to False.
68
+
69
+ Raises:
70
+ NotImplementedError: This method must be implemented by subclasses.
71
+ """
72
+ raise NotImplementedError()
73
+
74
+ def zremrangebyscore(
75
+ self,
76
+ key: str,
77
+ start: Optional[int] = -1,
78
+ stop: Optional[int] = float("inf"),
79
+ ):
80
+ """
81
+ Removes all members in a sorted set within the given scores.
82
+
83
+ Args:
84
+ key (str): The key of the sorted set.
85
+ start (int, optional): The minimum score of the range. Defaults to -1.
86
+ stop (int, optional): The maximum score of the range. Defaults to float("inf").
87
+
88
+ Raises:
89
+ NotImplementedError: This method must be implemented by subclasses.
90
+ """
91
+ raise NotImplementedError()
92
+
93
+ def acquire_lock(self, key: str, expire: float = None) -> Any:
94
+ raise NotImplementedError()
95
+
96
+ @contextmanager
97
+ def lock(self, key: str, expire: float = None) -> Any:
98
+ logger.debug(f"Acquiring lock at cache key: {key}")
99
+ l = self.acquire_lock(key, expire=expire)
100
+ try:
101
+ yield l
102
+ finally:
103
+ logger.debug(f"Releasing lock at cache key: {key}")
104
+ l.release()
105
+
106
+ def set_numpy(self, key: str, value: Any, expire: float = None):
107
+ """
108
+ Caches a numpy array.
109
+
110
+ Args:
111
+ key (str): The key to store the value.
112
+ value (Any): The value to store.
113
+ expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.
114
+
115
+ Raises:
116
+ NotImplementedError: This method must be implemented by subclasses.
117
+ """
118
+ raise NotImplementedError()
119
+
120
+ def get_numpy(self, key: str) -> Any:
121
+ """
122
+ Retrieves a numpy array from the cache.
123
+
124
+ Args:
125
+ key (str): The key of the value to retrieve.
126
+
127
+ Raises:
128
+ NotImplementedError: This method must be implemented by subclasses.
129
+ """
130
+ raise NotImplementedError()