Spaces:
Configuration error
Configuration error
Upload 422 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- inference/__init__.py +3 -0
- inference/__pycache__/__init__.cpython-310.pyc +0 -0
- inference/core/__init__.py +52 -0
- inference/core/__pycache__/__init__.cpython-310.pyc +0 -0
- inference/core/__pycache__/constants.cpython-310.pyc +0 -0
- inference/core/__pycache__/env.cpython-310.pyc +0 -0
- inference/core/__pycache__/exceptions.cpython-310.pyc +0 -0
- inference/core/__pycache__/logger.cpython-310.pyc +0 -0
- inference/core/__pycache__/nms.cpython-310.pyc +0 -0
- inference/core/__pycache__/roboflow_api.cpython-310.pyc +0 -0
- inference/core/__pycache__/usage.cpython-310.pyc +0 -0
- inference/core/__pycache__/version.cpython-310.pyc +0 -0
- inference/core/active_learning/__init__.py +0 -0
- inference/core/active_learning/__pycache__/__init__.cpython-310.pyc +0 -0
- inference/core/active_learning/__pycache__/accounting.cpython-310.pyc +0 -0
- inference/core/active_learning/__pycache__/batching.cpython-310.pyc +0 -0
- inference/core/active_learning/__pycache__/cache_operations.cpython-310.pyc +0 -0
- inference/core/active_learning/__pycache__/configuration.cpython-310.pyc +0 -0
- inference/core/active_learning/__pycache__/core.cpython-310.pyc +0 -0
- inference/core/active_learning/__pycache__/entities.cpython-310.pyc +0 -0
- inference/core/active_learning/__pycache__/middlewares.cpython-310.pyc +0 -0
- inference/core/active_learning/__pycache__/post_processing.cpython-310.pyc +0 -0
- inference/core/active_learning/__pycache__/utils.cpython-310.pyc +0 -0
- inference/core/active_learning/accounting.py +96 -0
- inference/core/active_learning/batching.py +26 -0
- inference/core/active_learning/cache_operations.py +293 -0
- inference/core/active_learning/configuration.py +203 -0
- inference/core/active_learning/core.py +219 -0
- inference/core/active_learning/entities.py +141 -0
- inference/core/active_learning/middlewares.py +307 -0
- inference/core/active_learning/post_processing.py +128 -0
- inference/core/active_learning/samplers/__init__.py +0 -0
- inference/core/active_learning/samplers/__pycache__/__init__.cpython-310.pyc +0 -0
- inference/core/active_learning/samplers/__pycache__/close_to_threshold.cpython-310.pyc +0 -0
- inference/core/active_learning/samplers/__pycache__/contains_classes.cpython-310.pyc +0 -0
- inference/core/active_learning/samplers/__pycache__/number_of_detections.cpython-310.pyc +0 -0
- inference/core/active_learning/samplers/__pycache__/random.cpython-310.pyc +0 -0
- inference/core/active_learning/samplers/close_to_threshold.py +227 -0
- inference/core/active_learning/samplers/contains_classes.py +58 -0
- inference/core/active_learning/samplers/number_of_detections.py +107 -0
- inference/core/active_learning/samplers/random.py +37 -0
- inference/core/active_learning/utils.py +16 -0
- inference/core/cache/__init__.py +22 -0
- inference/core/cache/__pycache__/__init__.cpython-310.pyc +0 -0
- inference/core/cache/__pycache__/base.cpython-310.pyc +0 -0
- inference/core/cache/__pycache__/memory.cpython-310.pyc +0 -0
- inference/core/cache/__pycache__/model_artifacts.cpython-310.pyc +0 -0
- inference/core/cache/__pycache__/redis.cpython-310.pyc +0 -0
- inference/core/cache/__pycache__/serializers.cpython-310.pyc +0 -0
- inference/core/cache/base.py +130 -0
inference/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from inference.core.interfaces.stream.stream import Stream # isort:skip
|
| 2 |
+
from inference.core.interfaces.stream.inference_pipeline import InferencePipeline
|
| 3 |
+
from inference.models.utils import get_roboflow_model
|
inference/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (399 Bytes). View file
|
|
|
inference/core/__init__.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import requests
|
| 5 |
+
|
| 6 |
+
from inference.core.env import DISABLE_VERSION_CHECK, VERSION_CHECK_MODE
|
| 7 |
+
from inference.core.logger import logger
|
| 8 |
+
from inference.core.version import __version__
|
| 9 |
+
|
| 10 |
+
latest_release = None
|
| 11 |
+
last_checked = 0
|
| 12 |
+
cache_duration = 86400 # 24 hours
|
| 13 |
+
log_frequency = 300 # 5 minutes
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_latest_release_version():
|
| 17 |
+
global latest_release, last_checked
|
| 18 |
+
now = time.time()
|
| 19 |
+
if latest_release is None or now - last_checked > cache_duration:
|
| 20 |
+
try:
|
| 21 |
+
logger.debug("Checking for latest inference release version...")
|
| 22 |
+
response = requests.get(
|
| 23 |
+
"https://api.github.com/repos/roboflow/inference/releases/latest"
|
| 24 |
+
)
|
| 25 |
+
response.raise_for_status()
|
| 26 |
+
latest_release = response.json()["tag_name"].lstrip("v")
|
| 27 |
+
last_checked = now
|
| 28 |
+
except requests.exceptions.RequestException:
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def check_latest_release_against_current():
|
| 33 |
+
get_latest_release_version()
|
| 34 |
+
if latest_release is not None and latest_release != __version__:
|
| 35 |
+
logger.warning(
|
| 36 |
+
f"Your inference package version {__version__} is out of date! Please upgrade to version {latest_release} of inference for the latest features and bug fixes by running `pip install --upgrade inference`."
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def check_latest_release_against_current_continuous():
|
| 41 |
+
while True:
|
| 42 |
+
check_latest_release_against_current()
|
| 43 |
+
time.sleep(log_frequency)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if not DISABLE_VERSION_CHECK:
|
| 47 |
+
if VERSION_CHECK_MODE == "continuous":
|
| 48 |
+
t = threading.Thread(target=check_latest_release_against_current_continuous)
|
| 49 |
+
t.daemon = True
|
| 50 |
+
t.start()
|
| 51 |
+
else:
|
| 52 |
+
check_latest_release_against_current()
|
inference/core/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.73 kB). View file
|
|
|
inference/core/__pycache__/constants.cpython-310.pyc
ADDED
|
Binary file (371 Bytes). View file
|
|
|
inference/core/__pycache__/env.cpython-310.pyc
ADDED
|
Binary file (6.87 kB). View file
|
|
|
inference/core/__pycache__/exceptions.cpython-310.pyc
ADDED
|
Binary file (6.17 kB). View file
|
|
|
inference/core/__pycache__/logger.cpython-310.pyc
ADDED
|
Binary file (551 Bytes). View file
|
|
|
inference/core/__pycache__/nms.cpython-310.pyc
ADDED
|
Binary file (4.74 kB). View file
|
|
|
inference/core/__pycache__/roboflow_api.cpython-310.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
inference/core/__pycache__/usage.cpython-310.pyc
ADDED
|
Binary file (1.85 kB). View file
|
|
|
inference/core/__pycache__/version.cpython-310.pyc
ADDED
|
Binary file (250 Bytes). View file
|
|
|
inference/core/active_learning/__init__.py
ADDED
|
File without changes
|
inference/core/active_learning/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (192 Bytes). View file
|
|
|
inference/core/active_learning/__pycache__/accounting.cpython-310.pyc
ADDED
|
Binary file (2.76 kB). View file
|
|
|
inference/core/active_learning/__pycache__/batching.cpython-310.pyc
ADDED
|
Binary file (921 Bytes). View file
|
|
|
inference/core/active_learning/__pycache__/cache_operations.cpython-310.pyc
ADDED
|
Binary file (5.9 kB). View file
|
|
|
inference/core/active_learning/__pycache__/configuration.cpython-310.pyc
ADDED
|
Binary file (5.3 kB). View file
|
|
|
inference/core/active_learning/__pycache__/core.cpython-310.pyc
ADDED
|
Binary file (5.2 kB). View file
|
|
|
inference/core/active_learning/__pycache__/entities.cpython-310.pyc
ADDED
|
Binary file (4.72 kB). View file
|
|
|
inference/core/active_learning/__pycache__/middlewares.cpython-310.pyc
ADDED
|
Binary file (8.68 kB). View file
|
|
|
inference/core/active_learning/__pycache__/post_processing.cpython-310.pyc
ADDED
|
Binary file (2.94 kB). View file
|
|
|
inference/core/active_learning/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (852 Bytes). View file
|
|
|
inference/core/active_learning/accounting.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
|
| 3 |
+
from inference.core.entities.types import DatasetID, WorkspaceID
|
| 4 |
+
from inference.core.roboflow_api import (
|
| 5 |
+
get_roboflow_labeling_batches,
|
| 6 |
+
get_roboflow_labeling_jobs,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def image_can_be_submitted_to_batch(
|
| 11 |
+
batch_name: str,
|
| 12 |
+
workspace_id: WorkspaceID,
|
| 13 |
+
dataset_id: DatasetID,
|
| 14 |
+
max_batch_images: Optional[int],
|
| 15 |
+
api_key: str,
|
| 16 |
+
) -> bool:
|
| 17 |
+
"""Check if an image can be submitted to a batch.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
batch_name: Name of the batch.
|
| 21 |
+
workspace_id: ID of the workspace.
|
| 22 |
+
dataset_id: ID of the dataset.
|
| 23 |
+
max_batch_images: Maximum number of images allowed in the batch.
|
| 24 |
+
api_key: API key to use for the request.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
True if the image can be submitted to the batch, False otherwise.
|
| 28 |
+
"""
|
| 29 |
+
if max_batch_images is None:
|
| 30 |
+
return True
|
| 31 |
+
labeling_batches = get_roboflow_labeling_batches(
|
| 32 |
+
api_key=api_key,
|
| 33 |
+
workspace_id=workspace_id,
|
| 34 |
+
dataset_id=dataset_id,
|
| 35 |
+
)
|
| 36 |
+
matching_labeling_batch = get_matching_labeling_batch(
|
| 37 |
+
all_labeling_batches=labeling_batches["batches"],
|
| 38 |
+
batch_name=batch_name,
|
| 39 |
+
)
|
| 40 |
+
if matching_labeling_batch is None:
|
| 41 |
+
return max_batch_images > 0
|
| 42 |
+
batch_images_under_labeling = 0
|
| 43 |
+
if matching_labeling_batch["numJobs"] > 0:
|
| 44 |
+
labeling_jobs = get_roboflow_labeling_jobs(
|
| 45 |
+
api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id
|
| 46 |
+
)
|
| 47 |
+
batch_images_under_labeling = get_images_in_labeling_jobs_of_specific_batch(
|
| 48 |
+
all_labeling_jobs=labeling_jobs["jobs"],
|
| 49 |
+
batch_id=matching_labeling_batch["id"],
|
| 50 |
+
)
|
| 51 |
+
total_batch_images = matching_labeling_batch["images"] + batch_images_under_labeling
|
| 52 |
+
return max_batch_images > total_batch_images
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_matching_labeling_batch(
|
| 56 |
+
all_labeling_batches: List[dict],
|
| 57 |
+
batch_name: str,
|
| 58 |
+
) -> Optional[dict]:
|
| 59 |
+
"""Get the matching labeling batch.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
all_labeling_batches: All labeling batches.
|
| 63 |
+
batch_name: Name of the batch.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
The matching labeling batch if found, None otherwise.
|
| 67 |
+
|
| 68 |
+
"""
|
| 69 |
+
matching_batch = None
|
| 70 |
+
for labeling_batch in all_labeling_batches:
|
| 71 |
+
if labeling_batch["name"] == batch_name:
|
| 72 |
+
matching_batch = labeling_batch
|
| 73 |
+
break
|
| 74 |
+
return matching_batch
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_images_in_labeling_jobs_of_specific_batch(
|
| 78 |
+
all_labeling_jobs: List[dict],
|
| 79 |
+
batch_id: str,
|
| 80 |
+
) -> int:
|
| 81 |
+
"""Get the number of images in labeling jobs of a specific batch.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
all_labeling_jobs: All labeling jobs.
|
| 85 |
+
batch_id: ID of the batch.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
The number of images in labeling jobs of the batch.
|
| 89 |
+
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
matching_jobs = []
|
| 93 |
+
for labeling_job in all_labeling_jobs:
|
| 94 |
+
if batch_id in labeling_job["sourceBatch"]:
|
| 95 |
+
matching_jobs.append(labeling_job)
|
| 96 |
+
return sum(job["numImages"] for job in matching_jobs)
|
inference/core/active_learning/batching.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from inference.core.active_learning.entities import (
|
| 2 |
+
ActiveLearningConfiguration,
|
| 3 |
+
BatchReCreationInterval,
|
| 4 |
+
)
|
| 5 |
+
from inference.core.active_learning.utils import (
|
| 6 |
+
generate_start_timestamp_for_this_month,
|
| 7 |
+
generate_start_timestamp_for_this_week,
|
| 8 |
+
generate_today_timestamp,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
RECREATION_INTERVAL2TIMESTAMP_GENERATOR = {
|
| 12 |
+
BatchReCreationInterval.DAILY: generate_today_timestamp,
|
| 13 |
+
BatchReCreationInterval.WEEKLY: generate_start_timestamp_for_this_week,
|
| 14 |
+
BatchReCreationInterval.MONTHLY: generate_start_timestamp_for_this_month,
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def generate_batch_name(configuration: ActiveLearningConfiguration) -> str:
|
| 19 |
+
batch_name = configuration.batches_name_prefix
|
| 20 |
+
if configuration.batch_recreation_interval is BatchReCreationInterval.NEVER:
|
| 21 |
+
return batch_name
|
| 22 |
+
timestamp_generator = RECREATION_INTERVAL2TIMESTAMP_GENERATOR[
|
| 23 |
+
configuration.batch_recreation_interval
|
| 24 |
+
]
|
| 25 |
+
timestamp = timestamp_generator()
|
| 26 |
+
return f"{batch_name}_{timestamp}"
|
inference/core/active_learning/cache_operations.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from typing import Generator, List, Optional, OrderedDict, Union
|
| 5 |
+
|
| 6 |
+
import redis.lock
|
| 7 |
+
|
| 8 |
+
from inference.core import logger
|
| 9 |
+
from inference.core.active_learning.entities import StrategyLimit, StrategyLimitType
|
| 10 |
+
from inference.core.active_learning.utils import TIMESTAMP_FORMAT
|
| 11 |
+
from inference.core.cache.base import BaseCache
|
| 12 |
+
|
| 13 |
+
MAX_LOCK_TIME = 5
|
| 14 |
+
SECONDS_IN_HOUR = 60 * 60
|
| 15 |
+
USAGE_KEY = "usage"
|
| 16 |
+
|
| 17 |
+
LIMIT_TYPE2KEY_INFIX_GENERATOR = {
|
| 18 |
+
StrategyLimitType.MINUTELY: lambda: f"minute_{datetime.utcnow().minute}",
|
| 19 |
+
StrategyLimitType.HOURLY: lambda: f"hour_{datetime.utcnow().hour}",
|
| 20 |
+
StrategyLimitType.DAILY: lambda: f"day_{datetime.utcnow().strftime(TIMESTAMP_FORMAT)}",
|
| 21 |
+
}
|
| 22 |
+
LIMIT_TYPE2KEY_EXPIRATION = {
|
| 23 |
+
StrategyLimitType.MINUTELY: 120,
|
| 24 |
+
StrategyLimitType.HOURLY: 2 * SECONDS_IN_HOUR,
|
| 25 |
+
StrategyLimitType.DAILY: 25 * SECONDS_IN_HOUR,
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def use_credit_of_matching_strategy(
|
| 30 |
+
cache: BaseCache,
|
| 31 |
+
workspace: str,
|
| 32 |
+
project: str,
|
| 33 |
+
matching_strategies_limits: OrderedDict[str, List[StrategyLimit]],
|
| 34 |
+
) -> Optional[str]:
|
| 35 |
+
# In scope of this function, cache keys updates regarding usage limits for
|
| 36 |
+
# specific :workspace and :project are locked - to ensure increment to be done atomically
|
| 37 |
+
# Limits are accounted at the moment of registration - which may introduce inaccuracy
|
| 38 |
+
# given that registration is postponed from prediction
|
| 39 |
+
# Returns: strategy with spare credit if found - else None
|
| 40 |
+
with lock_limits(cache=cache, workspace=workspace, project=project):
|
| 41 |
+
strategy_with_spare_credit = find_strategy_with_spare_usage_credit(
|
| 42 |
+
cache=cache,
|
| 43 |
+
workspace=workspace,
|
| 44 |
+
project=project,
|
| 45 |
+
matching_strategies_limits=matching_strategies_limits,
|
| 46 |
+
)
|
| 47 |
+
if strategy_with_spare_credit is None:
|
| 48 |
+
return None
|
| 49 |
+
consume_strategy_limits_usage_credit(
|
| 50 |
+
cache=cache,
|
| 51 |
+
workspace=workspace,
|
| 52 |
+
project=project,
|
| 53 |
+
strategy_name=strategy_with_spare_credit,
|
| 54 |
+
)
|
| 55 |
+
return strategy_with_spare_credit
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def return_strategy_credit(
|
| 59 |
+
cache: BaseCache,
|
| 60 |
+
workspace: str,
|
| 61 |
+
project: str,
|
| 62 |
+
strategy_name: str,
|
| 63 |
+
) -> None:
|
| 64 |
+
# In scope of this function, cache keys updates regarding usage limits for
|
| 65 |
+
# specific :workspace and :project are locked - to ensure decrement to be done atomically
|
| 66 |
+
# Returning strategy is a bit naive (we may add to a pool of credits from the next period - but only
|
| 67 |
+
# if we have previously taken from the previous one and some credits are used in the new pool) -
|
| 68 |
+
# in favour of easier implementation.
|
| 69 |
+
with lock_limits(cache=cache, workspace=workspace, project=project):
|
| 70 |
+
return_strategy_limits_usage_credit(
|
| 71 |
+
cache=cache,
|
| 72 |
+
workspace=workspace,
|
| 73 |
+
project=project,
|
| 74 |
+
strategy_name=strategy_name,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@contextmanager
|
| 79 |
+
def lock_limits(
|
| 80 |
+
cache: BaseCache,
|
| 81 |
+
workspace: str,
|
| 82 |
+
project: str,
|
| 83 |
+
) -> Generator[Union[threading.Lock, redis.lock.Lock], None, None]:
|
| 84 |
+
limits_lock_key = generate_cache_key_for_active_learning_usage_lock(
|
| 85 |
+
workspace=workspace,
|
| 86 |
+
project=project,
|
| 87 |
+
)
|
| 88 |
+
with cache.lock(key=limits_lock_key, expire=MAX_LOCK_TIME) as lock:
|
| 89 |
+
yield lock
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def find_strategy_with_spare_usage_credit(
|
| 93 |
+
cache: BaseCache,
|
| 94 |
+
workspace: str,
|
| 95 |
+
project: str,
|
| 96 |
+
matching_strategies_limits: OrderedDict[str, List[StrategyLimit]],
|
| 97 |
+
) -> Optional[str]:
|
| 98 |
+
for strategy_name, strategy_limits in matching_strategies_limits.items():
|
| 99 |
+
rejected_by_strategy = (
|
| 100 |
+
datapoint_should_be_rejected_based_on_strategy_usage_limits(
|
| 101 |
+
cache=cache,
|
| 102 |
+
workspace=workspace,
|
| 103 |
+
project=project,
|
| 104 |
+
strategy_name=strategy_name,
|
| 105 |
+
strategy_limits=strategy_limits,
|
| 106 |
+
)
|
| 107 |
+
)
|
| 108 |
+
if not rejected_by_strategy:
|
| 109 |
+
return strategy_name
|
| 110 |
+
return None
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def datapoint_should_be_rejected_based_on_strategy_usage_limits(
|
| 114 |
+
cache: BaseCache,
|
| 115 |
+
workspace: str,
|
| 116 |
+
project: str,
|
| 117 |
+
strategy_name: str,
|
| 118 |
+
strategy_limits: List[StrategyLimit],
|
| 119 |
+
) -> bool:
|
| 120 |
+
for strategy_limit in strategy_limits:
|
| 121 |
+
limit_reached = datapoint_should_be_rejected_based_on_limit_usage(
|
| 122 |
+
cache=cache,
|
| 123 |
+
workspace=workspace,
|
| 124 |
+
project=project,
|
| 125 |
+
strategy_name=strategy_name,
|
| 126 |
+
strategy_limit=strategy_limit,
|
| 127 |
+
)
|
| 128 |
+
if limit_reached:
|
| 129 |
+
logger.debug(
|
| 130 |
+
f"Violated Active Learning strategy limit: {strategy_limit.limit_type.name} "
|
| 131 |
+
f"with value {strategy_limit.value} for sampling strategy: {strategy_name}."
|
| 132 |
+
)
|
| 133 |
+
return True
|
| 134 |
+
return False
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def datapoint_should_be_rejected_based_on_limit_usage(
|
| 138 |
+
cache: BaseCache,
|
| 139 |
+
workspace: str,
|
| 140 |
+
project: str,
|
| 141 |
+
strategy_name: str,
|
| 142 |
+
strategy_limit: StrategyLimit,
|
| 143 |
+
) -> bool:
|
| 144 |
+
current_usage = get_current_strategy_limit_usage(
|
| 145 |
+
cache=cache,
|
| 146 |
+
workspace=workspace,
|
| 147 |
+
project=project,
|
| 148 |
+
strategy_name=strategy_name,
|
| 149 |
+
limit_type=strategy_limit.limit_type,
|
| 150 |
+
)
|
| 151 |
+
if current_usage is None:
|
| 152 |
+
current_usage = 0
|
| 153 |
+
return current_usage >= strategy_limit.value
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def consume_strategy_limits_usage_credit(
|
| 157 |
+
cache: BaseCache,
|
| 158 |
+
workspace: str,
|
| 159 |
+
project: str,
|
| 160 |
+
strategy_name: str,
|
| 161 |
+
) -> None:
|
| 162 |
+
for limit_type in StrategyLimitType:
|
| 163 |
+
consume_strategy_limit_usage_credit(
|
| 164 |
+
cache=cache,
|
| 165 |
+
workspace=workspace,
|
| 166 |
+
project=project,
|
| 167 |
+
strategy_name=strategy_name,
|
| 168 |
+
limit_type=limit_type,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def consume_strategy_limit_usage_credit(
|
| 173 |
+
cache: BaseCache,
|
| 174 |
+
workspace: str,
|
| 175 |
+
project: str,
|
| 176 |
+
strategy_name: str,
|
| 177 |
+
limit_type: StrategyLimitType,
|
| 178 |
+
) -> None:
|
| 179 |
+
current_value = get_current_strategy_limit_usage(
|
| 180 |
+
cache=cache,
|
| 181 |
+
limit_type=limit_type,
|
| 182 |
+
workspace=workspace,
|
| 183 |
+
project=project,
|
| 184 |
+
strategy_name=strategy_name,
|
| 185 |
+
)
|
| 186 |
+
if current_value is None:
|
| 187 |
+
current_value = 0
|
| 188 |
+
current_value += 1
|
| 189 |
+
set_current_strategy_limit_usage(
|
| 190 |
+
current_value=current_value,
|
| 191 |
+
cache=cache,
|
| 192 |
+
limit_type=limit_type,
|
| 193 |
+
workspace=workspace,
|
| 194 |
+
project=project,
|
| 195 |
+
strategy_name=strategy_name,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def return_strategy_limits_usage_credit(
|
| 200 |
+
cache: BaseCache,
|
| 201 |
+
workspace: str,
|
| 202 |
+
project: str,
|
| 203 |
+
strategy_name: str,
|
| 204 |
+
) -> None:
|
| 205 |
+
for limit_type in StrategyLimitType:
|
| 206 |
+
return_strategy_limit_usage_credit(
|
| 207 |
+
cache=cache,
|
| 208 |
+
workspace=workspace,
|
| 209 |
+
project=project,
|
| 210 |
+
strategy_name=strategy_name,
|
| 211 |
+
limit_type=limit_type,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def return_strategy_limit_usage_credit(
|
| 216 |
+
cache: BaseCache,
|
| 217 |
+
workspace: str,
|
| 218 |
+
project: str,
|
| 219 |
+
strategy_name: str,
|
| 220 |
+
limit_type: StrategyLimitType,
|
| 221 |
+
) -> None:
|
| 222 |
+
current_value = get_current_strategy_limit_usage(
|
| 223 |
+
cache=cache,
|
| 224 |
+
limit_type=limit_type,
|
| 225 |
+
workspace=workspace,
|
| 226 |
+
project=project,
|
| 227 |
+
strategy_name=strategy_name,
|
| 228 |
+
)
|
| 229 |
+
if current_value is None:
|
| 230 |
+
return None
|
| 231 |
+
current_value = max(current_value - 1, 0)
|
| 232 |
+
set_current_strategy_limit_usage(
|
| 233 |
+
current_value=current_value,
|
| 234 |
+
cache=cache,
|
| 235 |
+
limit_type=limit_type,
|
| 236 |
+
workspace=workspace,
|
| 237 |
+
project=project,
|
| 238 |
+
strategy_name=strategy_name,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def get_current_strategy_limit_usage(
|
| 243 |
+
cache: BaseCache,
|
| 244 |
+
workspace: str,
|
| 245 |
+
project: str,
|
| 246 |
+
strategy_name: str,
|
| 247 |
+
limit_type: StrategyLimitType,
|
| 248 |
+
) -> Optional[int]:
|
| 249 |
+
usage_key = generate_cache_key_for_active_learning_usage(
|
| 250 |
+
limit_type=limit_type,
|
| 251 |
+
workspace=workspace,
|
| 252 |
+
project=project,
|
| 253 |
+
strategy_name=strategy_name,
|
| 254 |
+
)
|
| 255 |
+
value = cache.get(usage_key)
|
| 256 |
+
if value is None:
|
| 257 |
+
return value
|
| 258 |
+
return value[USAGE_KEY]
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def set_current_strategy_limit_usage(
|
| 262 |
+
current_value: int,
|
| 263 |
+
cache: BaseCache,
|
| 264 |
+
workspace: str,
|
| 265 |
+
project: str,
|
| 266 |
+
strategy_name: str,
|
| 267 |
+
limit_type: StrategyLimitType,
|
| 268 |
+
) -> None:
|
| 269 |
+
usage_key = generate_cache_key_for_active_learning_usage(
|
| 270 |
+
limit_type=limit_type,
|
| 271 |
+
workspace=workspace,
|
| 272 |
+
project=project,
|
| 273 |
+
strategy_name=strategy_name,
|
| 274 |
+
)
|
| 275 |
+
expire = LIMIT_TYPE2KEY_EXPIRATION[limit_type]
|
| 276 |
+
cache.set(key=usage_key, value={USAGE_KEY: current_value}, expire=expire) # type: ignore
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def generate_cache_key_for_active_learning_usage_lock(
|
| 280 |
+
workspace: str,
|
| 281 |
+
project: str,
|
| 282 |
+
) -> str:
|
| 283 |
+
return f"active_learning:usage:{workspace}:{project}:usage:lock"
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def generate_cache_key_for_active_learning_usage(
|
| 287 |
+
limit_type: StrategyLimitType,
|
| 288 |
+
workspace: str,
|
| 289 |
+
project: str,
|
| 290 |
+
strategy_name: str,
|
| 291 |
+
) -> str:
|
| 292 |
+
time_infix = LIMIT_TYPE2KEY_INFIX_GENERATOR[limit_type]()
|
| 293 |
+
return f"active_learning:usage:{workspace}:{project}:{strategy_name}:{time_infix}"
|
inference/core/active_learning/configuration.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
from dataclasses import asdict
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
from inference.core import logger
|
| 6 |
+
from inference.core.active_learning.entities import (
|
| 7 |
+
ActiveLearningConfiguration,
|
| 8 |
+
RoboflowProjectMetadata,
|
| 9 |
+
SamplingMethod,
|
| 10 |
+
)
|
| 11 |
+
from inference.core.active_learning.samplers.close_to_threshold import (
|
| 12 |
+
initialize_close_to_threshold_sampling,
|
| 13 |
+
)
|
| 14 |
+
from inference.core.active_learning.samplers.contains_classes import (
|
| 15 |
+
initialize_classes_based_sampling,
|
| 16 |
+
)
|
| 17 |
+
from inference.core.active_learning.samplers.number_of_detections import (
|
| 18 |
+
initialize_detections_number_based_sampling,
|
| 19 |
+
)
|
| 20 |
+
from inference.core.active_learning.samplers.random import initialize_random_sampling
|
| 21 |
+
from inference.core.cache.base import BaseCache
|
| 22 |
+
from inference.core.exceptions import (
|
| 23 |
+
ActiveLearningConfigurationDecodingError,
|
| 24 |
+
ActiveLearningConfigurationError,
|
| 25 |
+
RoboflowAPINotAuthorizedError,
|
| 26 |
+
RoboflowAPINotNotFoundError,
|
| 27 |
+
)
|
| 28 |
+
from inference.core.roboflow_api import (
|
| 29 |
+
get_roboflow_active_learning_configuration,
|
| 30 |
+
get_roboflow_dataset_type,
|
| 31 |
+
get_roboflow_workspace,
|
| 32 |
+
)
|
| 33 |
+
from inference.core.utils.roboflow import get_model_id_chunks
|
| 34 |
+
|
| 35 |
+
TYPE2SAMPLING_INITIALIZERS = {
|
| 36 |
+
"random": initialize_random_sampling,
|
| 37 |
+
"close_to_threshold": initialize_close_to_threshold_sampling,
|
| 38 |
+
"classes_based": initialize_classes_based_sampling,
|
| 39 |
+
"detections_number_based": initialize_detections_number_based_sampling,
|
| 40 |
+
}
|
| 41 |
+
ACTIVE_LEARNING_CONFIG_CACHE_EXPIRE = 900 # 15 min
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def prepare_active_learning_configuration(
|
| 45 |
+
api_key: str,
|
| 46 |
+
model_id: str,
|
| 47 |
+
cache: BaseCache,
|
| 48 |
+
) -> Optional[ActiveLearningConfiguration]:
|
| 49 |
+
project_metadata = get_roboflow_project_metadata(
|
| 50 |
+
api_key=api_key,
|
| 51 |
+
model_id=model_id,
|
| 52 |
+
cache=cache,
|
| 53 |
+
)
|
| 54 |
+
if not project_metadata.active_learning_configuration.get("enabled", False):
|
| 55 |
+
return None
|
| 56 |
+
logger.info(
|
| 57 |
+
f"Configuring active learning for workspace: {project_metadata.workspace_id}, "
|
| 58 |
+
f"project: {project_metadata.dataset_id} of type: {project_metadata.dataset_type}. "
|
| 59 |
+
f"AL configuration: {project_metadata.active_learning_configuration}"
|
| 60 |
+
)
|
| 61 |
+
return initialise_active_learning_configuration(
|
| 62 |
+
project_metadata=project_metadata,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def prepare_active_learning_configuration_inplace(
|
| 67 |
+
api_key: str,
|
| 68 |
+
model_id: str,
|
| 69 |
+
active_learning_configuration: Optional[dict],
|
| 70 |
+
) -> Optional[ActiveLearningConfiguration]:
|
| 71 |
+
if (
|
| 72 |
+
active_learning_configuration is None
|
| 73 |
+
or active_learning_configuration.get("enabled", False) is False
|
| 74 |
+
):
|
| 75 |
+
return None
|
| 76 |
+
dataset_id, version_id = get_model_id_chunks(model_id=model_id)
|
| 77 |
+
workspace_id = get_roboflow_workspace(api_key=api_key)
|
| 78 |
+
dataset_type = get_roboflow_dataset_type(
|
| 79 |
+
api_key=api_key,
|
| 80 |
+
workspace_id=workspace_id,
|
| 81 |
+
dataset_id=dataset_id,
|
| 82 |
+
)
|
| 83 |
+
project_metadata = RoboflowProjectMetadata(
|
| 84 |
+
dataset_id=dataset_id,
|
| 85 |
+
version_id=version_id,
|
| 86 |
+
workspace_id=workspace_id,
|
| 87 |
+
dataset_type=dataset_type,
|
| 88 |
+
active_learning_configuration=active_learning_configuration,
|
| 89 |
+
)
|
| 90 |
+
return initialise_active_learning_configuration(
|
| 91 |
+
project_metadata=project_metadata,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def get_roboflow_project_metadata(
|
| 96 |
+
api_key: str,
|
| 97 |
+
model_id: str,
|
| 98 |
+
cache: BaseCache,
|
| 99 |
+
) -> RoboflowProjectMetadata:
|
| 100 |
+
logger.info(f"Fetching active learning configuration.")
|
| 101 |
+
config_cache_key = construct_cache_key_for_active_learning_config(
|
| 102 |
+
api_key=api_key, model_id=model_id
|
| 103 |
+
)
|
| 104 |
+
cached_config = cache.get(config_cache_key)
|
| 105 |
+
if cached_config is not None:
|
| 106 |
+
logger.info("Found Active Learning configuration in cache.")
|
| 107 |
+
return parse_cached_roboflow_project_metadata(cached_config=cached_config)
|
| 108 |
+
dataset_id, version_id = get_model_id_chunks(model_id=model_id)
|
| 109 |
+
workspace_id = get_roboflow_workspace(api_key=api_key)
|
| 110 |
+
dataset_type = get_roboflow_dataset_type(
|
| 111 |
+
api_key=api_key,
|
| 112 |
+
workspace_id=workspace_id,
|
| 113 |
+
dataset_id=dataset_id,
|
| 114 |
+
)
|
| 115 |
+
try:
|
| 116 |
+
roboflow_api_configuration = get_roboflow_active_learning_configuration(
|
| 117 |
+
api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id
|
| 118 |
+
)
|
| 119 |
+
except (RoboflowAPINotAuthorizedError, RoboflowAPINotNotFoundError):
|
| 120 |
+
# currently backend returns HTTP 404 if dataset does not exist
|
| 121 |
+
# or workspace_id from api_key indicate that the owner is different,
|
| 122 |
+
# so in the situation when we query for Universe dataset.
|
| 123 |
+
# We want the owner of public dataset to be able to set AL configs
|
| 124 |
+
# and use them, but not other people. At this point it's known
|
| 125 |
+
# that HTTP 404 means not authorised (which will probably change
|
| 126 |
+
# in future iteration of backend) - so on both NotAuth and NotFound
|
| 127 |
+
# errors we assume that we simply cannot use AL with this model and
|
| 128 |
+
# this api_key.
|
| 129 |
+
roboflow_api_configuration = {"enabled": False}
|
| 130 |
+
configuration = RoboflowProjectMetadata(
|
| 131 |
+
dataset_id=dataset_id,
|
| 132 |
+
version_id=version_id,
|
| 133 |
+
workspace_id=workspace_id,
|
| 134 |
+
dataset_type=dataset_type,
|
| 135 |
+
active_learning_configuration=roboflow_api_configuration,
|
| 136 |
+
)
|
| 137 |
+
cache.set(
|
| 138 |
+
key=config_cache_key,
|
| 139 |
+
value=asdict(configuration),
|
| 140 |
+
expire=ACTIVE_LEARNING_CONFIG_CACHE_EXPIRE,
|
| 141 |
+
)
|
| 142 |
+
return configuration
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def construct_cache_key_for_active_learning_config(api_key: str, model_id: str) -> str:
|
| 146 |
+
dataset_id = model_id.split("/")[0]
|
| 147 |
+
api_key_hash = hashlib.md5(api_key.encode("utf-8")).hexdigest()
|
| 148 |
+
return f"active_learning:configurations:{api_key_hash}:{dataset_id}"
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def parse_cached_roboflow_project_metadata(
|
| 152 |
+
cached_config: dict,
|
| 153 |
+
) -> RoboflowProjectMetadata:
|
| 154 |
+
try:
|
| 155 |
+
return RoboflowProjectMetadata(**cached_config)
|
| 156 |
+
except Exception as error:
|
| 157 |
+
raise ActiveLearningConfigurationDecodingError(
|
| 158 |
+
f"Failed to initialise Active Learning configuration. Cause: {str(error)}"
|
| 159 |
+
) from error
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def initialise_active_learning_configuration(
|
| 163 |
+
project_metadata: RoboflowProjectMetadata,
|
| 164 |
+
) -> ActiveLearningConfiguration:
|
| 165 |
+
sampling_methods = initialize_sampling_methods(
|
| 166 |
+
sampling_strategies_configs=project_metadata.active_learning_configuration[
|
| 167 |
+
"sampling_strategies"
|
| 168 |
+
],
|
| 169 |
+
)
|
| 170 |
+
target_workspace_id = project_metadata.active_learning_configuration.get(
|
| 171 |
+
"target_workspace", project_metadata.workspace_id
|
| 172 |
+
)
|
| 173 |
+
target_dataset_id = project_metadata.active_learning_configuration.get(
|
| 174 |
+
"target_project", project_metadata.dataset_id
|
| 175 |
+
)
|
| 176 |
+
return ActiveLearningConfiguration.init(
|
| 177 |
+
roboflow_api_configuration=project_metadata.active_learning_configuration,
|
| 178 |
+
sampling_methods=sampling_methods,
|
| 179 |
+
workspace_id=target_workspace_id,
|
| 180 |
+
dataset_id=target_dataset_id,
|
| 181 |
+
model_id=f"{project_metadata.dataset_id}/{project_metadata.version_id}",
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def initialize_sampling_methods(
|
| 186 |
+
sampling_strategies_configs: List[Dict[str, Any]]
|
| 187 |
+
) -> List[SamplingMethod]:
|
| 188 |
+
result = []
|
| 189 |
+
for sampling_strategy_config in sampling_strategies_configs:
|
| 190 |
+
sampling_type = sampling_strategy_config["type"]
|
| 191 |
+
if sampling_type not in TYPE2SAMPLING_INITIALIZERS:
|
| 192 |
+
logger.warn(
|
| 193 |
+
f"Could not identify sampling method `{sampling_type}` - skipping initialisation."
|
| 194 |
+
)
|
| 195 |
+
continue
|
| 196 |
+
initializer = TYPE2SAMPLING_INITIALIZERS[sampling_type]
|
| 197 |
+
result.append(initializer(sampling_strategy_config))
|
| 198 |
+
names = set(m.name for m in result)
|
| 199 |
+
if len(names) != len(result):
|
| 200 |
+
raise ActiveLearningConfigurationError(
|
| 201 |
+
"Detected duplication of Active Learning strategies names."
|
| 202 |
+
)
|
| 203 |
+
return result
|
inference/core/active_learning/core.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
from typing import List, Optional, Tuple
|
| 3 |
+
from uuid import uuid4
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from inference.core import logger
|
| 8 |
+
from inference.core.active_learning.cache_operations import (
|
| 9 |
+
return_strategy_credit,
|
| 10 |
+
use_credit_of_matching_strategy,
|
| 11 |
+
)
|
| 12 |
+
from inference.core.active_learning.entities import (
|
| 13 |
+
ActiveLearningConfiguration,
|
| 14 |
+
ImageDimensions,
|
| 15 |
+
Prediction,
|
| 16 |
+
PredictionType,
|
| 17 |
+
SamplingMethod,
|
| 18 |
+
)
|
| 19 |
+
from inference.core.active_learning.post_processing import (
|
| 20 |
+
adjust_prediction_to_client_scaling_factor,
|
| 21 |
+
encode_prediction,
|
| 22 |
+
)
|
| 23 |
+
from inference.core.cache.base import BaseCache
|
| 24 |
+
from inference.core.env import ACTIVE_LEARNING_TAGS
|
| 25 |
+
from inference.core.roboflow_api import (
|
| 26 |
+
annotate_image_at_roboflow,
|
| 27 |
+
register_image_at_roboflow,
|
| 28 |
+
)
|
| 29 |
+
from inference.core.utils.image_utils import encode_image_to_jpeg_bytes
|
| 30 |
+
from inference.core.utils.preprocess import downscale_image_keeping_aspect_ratio
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def execute_sampling(
|
| 34 |
+
image: np.ndarray,
|
| 35 |
+
prediction: Prediction,
|
| 36 |
+
prediction_type: PredictionType,
|
| 37 |
+
sampling_methods: List[SamplingMethod],
|
| 38 |
+
) -> List[str]:
|
| 39 |
+
matching_strategies = []
|
| 40 |
+
for method in sampling_methods:
|
| 41 |
+
sampling_result = method.sample(image, prediction, prediction_type)
|
| 42 |
+
if sampling_result:
|
| 43 |
+
matching_strategies.append(method.name)
|
| 44 |
+
return matching_strategies
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def execute_datapoint_registration(
|
| 48 |
+
cache: BaseCache,
|
| 49 |
+
matching_strategies: List[str],
|
| 50 |
+
image: np.ndarray,
|
| 51 |
+
prediction: Prediction,
|
| 52 |
+
prediction_type: PredictionType,
|
| 53 |
+
configuration: ActiveLearningConfiguration,
|
| 54 |
+
api_key: str,
|
| 55 |
+
batch_name: str,
|
| 56 |
+
) -> None:
|
| 57 |
+
local_image_id = str(uuid4())
|
| 58 |
+
encoded_image, scaling_factor = prepare_image_to_registration(
|
| 59 |
+
image=image,
|
| 60 |
+
desired_size=configuration.max_image_size,
|
| 61 |
+
jpeg_compression_level=configuration.jpeg_compression_level,
|
| 62 |
+
)
|
| 63 |
+
prediction = adjust_prediction_to_client_scaling_factor(
|
| 64 |
+
prediction=prediction,
|
| 65 |
+
scaling_factor=scaling_factor,
|
| 66 |
+
prediction_type=prediction_type,
|
| 67 |
+
)
|
| 68 |
+
matching_strategies_limits = OrderedDict(
|
| 69 |
+
(strategy_name, configuration.strategies_limits[strategy_name])
|
| 70 |
+
for strategy_name in matching_strategies
|
| 71 |
+
)
|
| 72 |
+
strategy_with_spare_credit = use_credit_of_matching_strategy(
|
| 73 |
+
cache=cache,
|
| 74 |
+
workspace=configuration.workspace_id,
|
| 75 |
+
project=configuration.dataset_id,
|
| 76 |
+
matching_strategies_limits=matching_strategies_limits,
|
| 77 |
+
)
|
| 78 |
+
if strategy_with_spare_credit is None:
|
| 79 |
+
logger.debug(f"Limit on Active Learning strategy reached.")
|
| 80 |
+
return None
|
| 81 |
+
register_datapoint_at_roboflow(
|
| 82 |
+
cache=cache,
|
| 83 |
+
strategy_with_spare_credit=strategy_with_spare_credit,
|
| 84 |
+
encoded_image=encoded_image,
|
| 85 |
+
local_image_id=local_image_id,
|
| 86 |
+
prediction=prediction,
|
| 87 |
+
prediction_type=prediction_type,
|
| 88 |
+
configuration=configuration,
|
| 89 |
+
api_key=api_key,
|
| 90 |
+
batch_name=batch_name,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def prepare_image_to_registration(
|
| 95 |
+
image: np.ndarray,
|
| 96 |
+
desired_size: Optional[ImageDimensions],
|
| 97 |
+
jpeg_compression_level: int,
|
| 98 |
+
) -> Tuple[bytes, float]:
|
| 99 |
+
scaling_factor = 1.0
|
| 100 |
+
if desired_size is not None:
|
| 101 |
+
height_before_scale = image.shape[0]
|
| 102 |
+
image = downscale_image_keeping_aspect_ratio(
|
| 103 |
+
image=image,
|
| 104 |
+
desired_size=desired_size.to_wh(),
|
| 105 |
+
)
|
| 106 |
+
scaling_factor = image.shape[0] / height_before_scale
|
| 107 |
+
return (
|
| 108 |
+
encode_image_to_jpeg_bytes(image=image, jpeg_quality=jpeg_compression_level),
|
| 109 |
+
scaling_factor,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def register_datapoint_at_roboflow(
|
| 114 |
+
cache: BaseCache,
|
| 115 |
+
strategy_with_spare_credit: str,
|
| 116 |
+
encoded_image: bytes,
|
| 117 |
+
local_image_id: str,
|
| 118 |
+
prediction: Prediction,
|
| 119 |
+
prediction_type: PredictionType,
|
| 120 |
+
configuration: ActiveLearningConfiguration,
|
| 121 |
+
api_key: str,
|
| 122 |
+
batch_name: str,
|
| 123 |
+
) -> None:
|
| 124 |
+
tags = collect_tags(
|
| 125 |
+
configuration=configuration,
|
| 126 |
+
sampling_strategy=strategy_with_spare_credit,
|
| 127 |
+
)
|
| 128 |
+
roboflow_image_id = safe_register_image_at_roboflow(
|
| 129 |
+
cache=cache,
|
| 130 |
+
strategy_with_spare_credit=strategy_with_spare_credit,
|
| 131 |
+
encoded_image=encoded_image,
|
| 132 |
+
local_image_id=local_image_id,
|
| 133 |
+
configuration=configuration,
|
| 134 |
+
api_key=api_key,
|
| 135 |
+
batch_name=batch_name,
|
| 136 |
+
tags=tags,
|
| 137 |
+
)
|
| 138 |
+
if is_prediction_registration_forbidden(
|
| 139 |
+
prediction=prediction,
|
| 140 |
+
persist_predictions=configuration.persist_predictions,
|
| 141 |
+
roboflow_image_id=roboflow_image_id,
|
| 142 |
+
):
|
| 143 |
+
return None
|
| 144 |
+
encoded_prediction, prediction_file_type = encode_prediction(
|
| 145 |
+
prediction=prediction, prediction_type=prediction_type
|
| 146 |
+
)
|
| 147 |
+
_ = annotate_image_at_roboflow(
|
| 148 |
+
api_key=api_key,
|
| 149 |
+
dataset_id=configuration.dataset_id,
|
| 150 |
+
local_image_id=local_image_id,
|
| 151 |
+
roboflow_image_id=roboflow_image_id,
|
| 152 |
+
annotation_content=encoded_prediction,
|
| 153 |
+
annotation_file_type=prediction_file_type,
|
| 154 |
+
is_prediction=True,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def collect_tags(
|
| 159 |
+
configuration: ActiveLearningConfiguration, sampling_strategy: str
|
| 160 |
+
) -> List[str]:
|
| 161 |
+
tags = ACTIVE_LEARNING_TAGS if ACTIVE_LEARNING_TAGS is not None else []
|
| 162 |
+
tags.extend(configuration.tags)
|
| 163 |
+
tags.extend(configuration.strategies_tags[sampling_strategy])
|
| 164 |
+
if configuration.persist_predictions:
|
| 165 |
+
# this replacement is needed due to backend input validation
|
| 166 |
+
tags.append(configuration.model_id.replace("/", "-"))
|
| 167 |
+
return tags
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def safe_register_image_at_roboflow(
|
| 171 |
+
cache: BaseCache,
|
| 172 |
+
strategy_with_spare_credit: str,
|
| 173 |
+
encoded_image: bytes,
|
| 174 |
+
local_image_id: str,
|
| 175 |
+
configuration: ActiveLearningConfiguration,
|
| 176 |
+
api_key: str,
|
| 177 |
+
batch_name: str,
|
| 178 |
+
tags: List[str],
|
| 179 |
+
) -> Optional[str]:
|
| 180 |
+
credit_to_be_returned = False
|
| 181 |
+
try:
|
| 182 |
+
registration_response = register_image_at_roboflow(
|
| 183 |
+
api_key=api_key,
|
| 184 |
+
dataset_id=configuration.dataset_id,
|
| 185 |
+
local_image_id=local_image_id,
|
| 186 |
+
image_bytes=encoded_image,
|
| 187 |
+
batch_name=batch_name,
|
| 188 |
+
tags=tags,
|
| 189 |
+
)
|
| 190 |
+
image_duplicated = registration_response.get("duplicate", False)
|
| 191 |
+
if image_duplicated:
|
| 192 |
+
credit_to_be_returned = True
|
| 193 |
+
logger.warning(f"Image duplication detected: {registration_response}.")
|
| 194 |
+
return None
|
| 195 |
+
return registration_response["id"]
|
| 196 |
+
except Exception as error:
|
| 197 |
+
credit_to_be_returned = True
|
| 198 |
+
raise error
|
| 199 |
+
finally:
|
| 200 |
+
if credit_to_be_returned:
|
| 201 |
+
return_strategy_credit(
|
| 202 |
+
cache=cache,
|
| 203 |
+
workspace=configuration.workspace_id,
|
| 204 |
+
project=configuration.dataset_id,
|
| 205 |
+
strategy_name=strategy_with_spare_credit,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def is_prediction_registration_forbidden(
|
| 210 |
+
prediction: Prediction,
|
| 211 |
+
persist_predictions: bool,
|
| 212 |
+
roboflow_image_id: Optional[str],
|
| 213 |
+
) -> bool:
|
| 214 |
+
return (
|
| 215 |
+
roboflow_image_id is None
|
| 216 |
+
or persist_predictions is False
|
| 217 |
+
or prediction.get("is_stub", False) is True
|
| 218 |
+
or (len(prediction.get("predictions", [])) == 0 and "top" not in prediction)
|
| 219 |
+
)
|
inference/core/active_learning/entities.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from enum import Enum
|
| 3 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from inference.core.entities.types import DatasetID, WorkspaceID
|
| 8 |
+
from inference.core.exceptions import ActiveLearningConfigurationDecodingError
|
| 9 |
+
|
| 10 |
+
LocalImageIdentifier = str
|
| 11 |
+
PredictionType = str
|
| 12 |
+
Prediction = dict
|
| 13 |
+
SerialisedPrediction = str
|
| 14 |
+
PredictionFileType = str
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass(frozen=True)
|
| 18 |
+
class ImageDimensions:
|
| 19 |
+
height: int
|
| 20 |
+
width: int
|
| 21 |
+
|
| 22 |
+
def to_hw(self) -> Tuple[int, int]:
|
| 23 |
+
return self.height, self.width
|
| 24 |
+
|
| 25 |
+
def to_wh(self) -> Tuple[int, int]:
|
| 26 |
+
return self.width, self.height
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass(frozen=True)
|
| 30 |
+
class SamplingMethod:
|
| 31 |
+
name: str
|
| 32 |
+
sample: Callable[[np.ndarray, Prediction, PredictionType], bool]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class BatchReCreationInterval(Enum):
|
| 36 |
+
NEVER = "never"
|
| 37 |
+
DAILY = "daily"
|
| 38 |
+
WEEKLY = "weekly"
|
| 39 |
+
MONTHLY = "monthly"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class StrategyLimitType(Enum):
|
| 43 |
+
MINUTELY = "minutely"
|
| 44 |
+
HOURLY = "hourly"
|
| 45 |
+
DAILY = "daily"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass(frozen=True)
|
| 49 |
+
class StrategyLimit:
|
| 50 |
+
limit_type: StrategyLimitType
|
| 51 |
+
value: int
|
| 52 |
+
|
| 53 |
+
@classmethod
|
| 54 |
+
def from_dict(cls, specification: dict) -> "StrategyLimit":
|
| 55 |
+
return cls(
|
| 56 |
+
limit_type=StrategyLimitType(specification["type"]),
|
| 57 |
+
value=specification["value"],
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclass(frozen=True)
|
| 62 |
+
class ActiveLearningConfiguration:
|
| 63 |
+
max_image_size: Optional[ImageDimensions]
|
| 64 |
+
jpeg_compression_level: int
|
| 65 |
+
persist_predictions: bool
|
| 66 |
+
sampling_methods: List[SamplingMethod]
|
| 67 |
+
batches_name_prefix: str
|
| 68 |
+
batch_recreation_interval: BatchReCreationInterval
|
| 69 |
+
max_batch_images: Optional[int]
|
| 70 |
+
workspace_id: WorkspaceID
|
| 71 |
+
dataset_id: DatasetID
|
| 72 |
+
model_id: str
|
| 73 |
+
strategies_limits: Dict[str, List[StrategyLimit]]
|
| 74 |
+
tags: List[str]
|
| 75 |
+
strategies_tags: Dict[str, List[str]]
|
| 76 |
+
|
| 77 |
+
@classmethod
|
| 78 |
+
def init(
|
| 79 |
+
cls,
|
| 80 |
+
roboflow_api_configuration: Dict[str, Any],
|
| 81 |
+
sampling_methods: List[SamplingMethod],
|
| 82 |
+
workspace_id: WorkspaceID,
|
| 83 |
+
dataset_id: DatasetID,
|
| 84 |
+
model_id: str,
|
| 85 |
+
) -> "ActiveLearningConfiguration":
|
| 86 |
+
try:
|
| 87 |
+
max_image_size = roboflow_api_configuration.get("max_image_size")
|
| 88 |
+
if max_image_size is not None:
|
| 89 |
+
max_image_size = ImageDimensions(
|
| 90 |
+
height=roboflow_api_configuration["max_image_size"][0],
|
| 91 |
+
width=roboflow_api_configuration["max_image_size"][1],
|
| 92 |
+
)
|
| 93 |
+
strategies_limits = {
|
| 94 |
+
strategy["name"]: [
|
| 95 |
+
StrategyLimit.from_dict(specification=specification)
|
| 96 |
+
for specification in strategy.get("limits", [])
|
| 97 |
+
]
|
| 98 |
+
for strategy in roboflow_api_configuration["sampling_strategies"]
|
| 99 |
+
}
|
| 100 |
+
strategies_tags = {
|
| 101 |
+
strategy["name"]: strategy.get("tags", [])
|
| 102 |
+
for strategy in roboflow_api_configuration["sampling_strategies"]
|
| 103 |
+
}
|
| 104 |
+
return cls(
|
| 105 |
+
max_image_size=max_image_size,
|
| 106 |
+
jpeg_compression_level=roboflow_api_configuration.get(
|
| 107 |
+
"jpeg_compression_level", 95
|
| 108 |
+
),
|
| 109 |
+
persist_predictions=roboflow_api_configuration["persist_predictions"],
|
| 110 |
+
sampling_methods=sampling_methods,
|
| 111 |
+
batches_name_prefix=roboflow_api_configuration["batching_strategy"][
|
| 112 |
+
"batches_name_prefix"
|
| 113 |
+
],
|
| 114 |
+
batch_recreation_interval=BatchReCreationInterval(
|
| 115 |
+
roboflow_api_configuration["batching_strategy"][
|
| 116 |
+
"recreation_interval"
|
| 117 |
+
]
|
| 118 |
+
),
|
| 119 |
+
max_batch_images=roboflow_api_configuration["batching_strategy"].get(
|
| 120 |
+
"max_batch_images"
|
| 121 |
+
),
|
| 122 |
+
workspace_id=workspace_id,
|
| 123 |
+
dataset_id=dataset_id,
|
| 124 |
+
model_id=model_id,
|
| 125 |
+
strategies_limits=strategies_limits,
|
| 126 |
+
tags=roboflow_api_configuration.get("tags", []),
|
| 127 |
+
strategies_tags=strategies_tags,
|
| 128 |
+
)
|
| 129 |
+
except (KeyError, ValueError) as e:
|
| 130 |
+
raise ActiveLearningConfigurationDecodingError(
|
| 131 |
+
f"Failed to initialise Active Learning configuration. Cause: {str(e)}"
|
| 132 |
+
) from e
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@dataclass(frozen=True)
|
| 136 |
+
class RoboflowProjectMetadata:
|
| 137 |
+
dataset_id: DatasetID
|
| 138 |
+
version_id: str
|
| 139 |
+
workspace_id: WorkspaceID
|
| 140 |
+
dataset_type: str
|
| 141 |
+
active_learning_configuration: dict
|
inference/core/active_learning/middlewares.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import queue
|
| 2 |
+
from queue import Queue
|
| 3 |
+
from threading import Thread
|
| 4 |
+
from typing import Any, List, Optional
|
| 5 |
+
|
| 6 |
+
from inference.core import logger
|
| 7 |
+
from inference.core.active_learning.accounting import image_can_be_submitted_to_batch
|
| 8 |
+
from inference.core.active_learning.batching import generate_batch_name
|
| 9 |
+
from inference.core.active_learning.configuration import (
|
| 10 |
+
prepare_active_learning_configuration,
|
| 11 |
+
prepare_active_learning_configuration_inplace,
|
| 12 |
+
)
|
| 13 |
+
from inference.core.active_learning.core import (
|
| 14 |
+
execute_datapoint_registration,
|
| 15 |
+
execute_sampling,
|
| 16 |
+
)
|
| 17 |
+
from inference.core.active_learning.entities import (
|
| 18 |
+
ActiveLearningConfiguration,
|
| 19 |
+
Prediction,
|
| 20 |
+
PredictionType,
|
| 21 |
+
)
|
| 22 |
+
from inference.core.cache.base import BaseCache
|
| 23 |
+
from inference.core.utils.image_utils import load_image
|
| 24 |
+
|
| 25 |
+
MAX_REGISTRATION_QUEUE_SIZE = 512
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class NullActiveLearningMiddleware:
|
| 29 |
+
def register_batch(
|
| 30 |
+
self,
|
| 31 |
+
inference_inputs: List[Any],
|
| 32 |
+
predictions: List[Prediction],
|
| 33 |
+
prediction_type: PredictionType,
|
| 34 |
+
disable_preproc_auto_orient: bool = False,
|
| 35 |
+
) -> None:
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
def register(
|
| 39 |
+
self,
|
| 40 |
+
inference_input: Any,
|
| 41 |
+
prediction: dict,
|
| 42 |
+
prediction_type: PredictionType,
|
| 43 |
+
disable_preproc_auto_orient: bool = False,
|
| 44 |
+
) -> None:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
def start_registration_thread(self) -> None:
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
def stop_registration_thread(self) -> None:
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
def __enter__(self) -> "NullActiveLearningMiddleware":
|
| 54 |
+
return self
|
| 55 |
+
|
| 56 |
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class ActiveLearningMiddleware:
|
| 61 |
+
@classmethod
|
| 62 |
+
def init(
|
| 63 |
+
cls, api_key: str, model_id: str, cache: BaseCache
|
| 64 |
+
) -> "ActiveLearningMiddleware":
|
| 65 |
+
configuration = prepare_active_learning_configuration(
|
| 66 |
+
api_key=api_key,
|
| 67 |
+
model_id=model_id,
|
| 68 |
+
cache=cache,
|
| 69 |
+
)
|
| 70 |
+
return cls(
|
| 71 |
+
api_key=api_key,
|
| 72 |
+
configuration=configuration,
|
| 73 |
+
cache=cache,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
@classmethod
|
| 77 |
+
def init_from_config(
|
| 78 |
+
cls, api_key: str, model_id: str, cache: BaseCache, config: Optional[dict]
|
| 79 |
+
) -> "ActiveLearningMiddleware":
|
| 80 |
+
configuration = prepare_active_learning_configuration_inplace(
|
| 81 |
+
api_key=api_key,
|
| 82 |
+
model_id=model_id,
|
| 83 |
+
active_learning_configuration=config,
|
| 84 |
+
)
|
| 85 |
+
return cls(
|
| 86 |
+
api_key=api_key,
|
| 87 |
+
configuration=configuration,
|
| 88 |
+
cache=cache,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
api_key: str,
|
| 94 |
+
configuration: Optional[ActiveLearningConfiguration],
|
| 95 |
+
cache: BaseCache,
|
| 96 |
+
):
|
| 97 |
+
self._api_key = api_key
|
| 98 |
+
self._configuration = configuration
|
| 99 |
+
self._cache = cache
|
| 100 |
+
|
| 101 |
+
def register_batch(
|
| 102 |
+
self,
|
| 103 |
+
inference_inputs: List[Any],
|
| 104 |
+
predictions: List[Prediction],
|
| 105 |
+
prediction_type: PredictionType,
|
| 106 |
+
disable_preproc_auto_orient: bool = False,
|
| 107 |
+
) -> None:
|
| 108 |
+
for inference_input, prediction in zip(inference_inputs, predictions):
|
| 109 |
+
self.register(
|
| 110 |
+
inference_input=inference_input,
|
| 111 |
+
prediction=prediction,
|
| 112 |
+
prediction_type=prediction_type,
|
| 113 |
+
disable_preproc_auto_orient=disable_preproc_auto_orient,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def register(
|
| 117 |
+
self,
|
| 118 |
+
inference_input: Any,
|
| 119 |
+
prediction: dict,
|
| 120 |
+
prediction_type: PredictionType,
|
| 121 |
+
disable_preproc_auto_orient: bool = False,
|
| 122 |
+
) -> None:
|
| 123 |
+
self._execute_registration(
|
| 124 |
+
inference_input=inference_input,
|
| 125 |
+
prediction=prediction,
|
| 126 |
+
prediction_type=prediction_type,
|
| 127 |
+
disable_preproc_auto_orient=disable_preproc_auto_orient,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def _execute_registration(
|
| 131 |
+
self,
|
| 132 |
+
inference_input: Any,
|
| 133 |
+
prediction: dict,
|
| 134 |
+
prediction_type: PredictionType,
|
| 135 |
+
disable_preproc_auto_orient: bool = False,
|
| 136 |
+
) -> None:
|
| 137 |
+
if self._configuration is None:
|
| 138 |
+
return None
|
| 139 |
+
image, is_bgr = load_image(
|
| 140 |
+
value=inference_input,
|
| 141 |
+
disable_preproc_auto_orient=disable_preproc_auto_orient,
|
| 142 |
+
)
|
| 143 |
+
if not is_bgr:
|
| 144 |
+
image = image[:, :, ::-1]
|
| 145 |
+
matching_strategies = execute_sampling(
|
| 146 |
+
image=image,
|
| 147 |
+
prediction=prediction,
|
| 148 |
+
prediction_type=prediction_type,
|
| 149 |
+
sampling_methods=self._configuration.sampling_methods,
|
| 150 |
+
)
|
| 151 |
+
if len(matching_strategies) == 0:
|
| 152 |
+
return None
|
| 153 |
+
batch_name = generate_batch_name(configuration=self._configuration)
|
| 154 |
+
if not image_can_be_submitted_to_batch(
|
| 155 |
+
batch_name=batch_name,
|
| 156 |
+
workspace_id=self._configuration.workspace_id,
|
| 157 |
+
dataset_id=self._configuration.dataset_id,
|
| 158 |
+
max_batch_images=self._configuration.max_batch_images,
|
| 159 |
+
api_key=self._api_key,
|
| 160 |
+
):
|
| 161 |
+
logger.debug(f"Limit on Active Learning batch size reached.")
|
| 162 |
+
return None
|
| 163 |
+
execute_datapoint_registration(
|
| 164 |
+
cache=self._cache,
|
| 165 |
+
matching_strategies=matching_strategies,
|
| 166 |
+
image=image,
|
| 167 |
+
prediction=prediction,
|
| 168 |
+
prediction_type=prediction_type,
|
| 169 |
+
configuration=self._configuration,
|
| 170 |
+
api_key=self._api_key,
|
| 171 |
+
batch_name=batch_name,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class ThreadingActiveLearningMiddleware(ActiveLearningMiddleware):
|
| 176 |
+
@classmethod
|
| 177 |
+
def init(
|
| 178 |
+
cls,
|
| 179 |
+
api_key: str,
|
| 180 |
+
model_id: str,
|
| 181 |
+
cache: BaseCache,
|
| 182 |
+
max_queue_size: int = MAX_REGISTRATION_QUEUE_SIZE,
|
| 183 |
+
) -> "ThreadingActiveLearningMiddleware":
|
| 184 |
+
configuration = prepare_active_learning_configuration(
|
| 185 |
+
api_key=api_key,
|
| 186 |
+
model_id=model_id,
|
| 187 |
+
cache=cache,
|
| 188 |
+
)
|
| 189 |
+
task_queue = Queue(max_queue_size)
|
| 190 |
+
return cls(
|
| 191 |
+
api_key=api_key,
|
| 192 |
+
configuration=configuration,
|
| 193 |
+
cache=cache,
|
| 194 |
+
task_queue=task_queue,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
@classmethod
|
| 198 |
+
def init_from_config(
|
| 199 |
+
cls,
|
| 200 |
+
api_key: str,
|
| 201 |
+
model_id: str,
|
| 202 |
+
cache: BaseCache,
|
| 203 |
+
config: Optional[dict],
|
| 204 |
+
max_queue_size: int = MAX_REGISTRATION_QUEUE_SIZE,
|
| 205 |
+
) -> "ThreadingActiveLearningMiddleware":
|
| 206 |
+
configuration = prepare_active_learning_configuration_inplace(
|
| 207 |
+
api_key=api_key,
|
| 208 |
+
model_id=model_id,
|
| 209 |
+
active_learning_configuration=config,
|
| 210 |
+
)
|
| 211 |
+
task_queue = Queue(max_queue_size)
|
| 212 |
+
return cls(
|
| 213 |
+
api_key=api_key,
|
| 214 |
+
configuration=configuration,
|
| 215 |
+
cache=cache,
|
| 216 |
+
task_queue=task_queue,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
def __init__(
|
| 220 |
+
self,
|
| 221 |
+
api_key: str,
|
| 222 |
+
configuration: ActiveLearningConfiguration,
|
| 223 |
+
cache: BaseCache,
|
| 224 |
+
task_queue: Queue,
|
| 225 |
+
):
|
| 226 |
+
super().__init__(api_key=api_key, configuration=configuration, cache=cache)
|
| 227 |
+
self._task_queue = task_queue
|
| 228 |
+
self._registration_thread: Optional[Thread] = None
|
| 229 |
+
|
| 230 |
+
def register(
|
| 231 |
+
self,
|
| 232 |
+
inference_input: Any,
|
| 233 |
+
prediction: dict,
|
| 234 |
+
prediction_type: PredictionType,
|
| 235 |
+
disable_preproc_auto_orient: bool = False,
|
| 236 |
+
) -> None:
|
| 237 |
+
logger.debug(f"Putting registration task into queue")
|
| 238 |
+
try:
|
| 239 |
+
self._task_queue.put_nowait(
|
| 240 |
+
(
|
| 241 |
+
inference_input,
|
| 242 |
+
prediction,
|
| 243 |
+
prediction_type,
|
| 244 |
+
disable_preproc_auto_orient,
|
| 245 |
+
)
|
| 246 |
+
)
|
| 247 |
+
except queue.Full:
|
| 248 |
+
logger.warning(
|
| 249 |
+
f"Dropping datapoint registered in Active Learning due to insufficient processing "
|
| 250 |
+
f"capabilities."
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
def start_registration_thread(self) -> None:
|
| 254 |
+
if self._registration_thread is not None:
|
| 255 |
+
logger.warning(f"Registration thread already started.")
|
| 256 |
+
return None
|
| 257 |
+
logger.debug("Staring registration thread")
|
| 258 |
+
self._registration_thread = Thread(target=self._consume_queue)
|
| 259 |
+
self._registration_thread.start()
|
| 260 |
+
|
| 261 |
+
def stop_registration_thread(self) -> None:
|
| 262 |
+
if self._registration_thread is None:
|
| 263 |
+
logger.warning("Registration thread is already stopped.")
|
| 264 |
+
return None
|
| 265 |
+
logger.debug("Stopping registration thread")
|
| 266 |
+
self._task_queue.put(None)
|
| 267 |
+
self._registration_thread.join()
|
| 268 |
+
if self._registration_thread.is_alive():
|
| 269 |
+
logger.warning(f"Registration thread stopping was unsuccessful.")
|
| 270 |
+
self._registration_thread = None
|
| 271 |
+
|
| 272 |
+
def _consume_queue(self) -> None:
|
| 273 |
+
queue_closed = False
|
| 274 |
+
while not queue_closed:
|
| 275 |
+
queue_closed = self._consume_queue_task()
|
| 276 |
+
|
| 277 |
+
def _consume_queue_task(self) -> bool:
|
| 278 |
+
logger.debug("Consuming registration task")
|
| 279 |
+
task = self._task_queue.get()
|
| 280 |
+
logger.debug("Received registration task")
|
| 281 |
+
if task is None:
|
| 282 |
+
logger.debug("Terminating registration thread")
|
| 283 |
+
self._task_queue.task_done()
|
| 284 |
+
return True
|
| 285 |
+
inference_input, prediction, prediction_type, disable_preproc_auto_orient = task
|
| 286 |
+
try:
|
| 287 |
+
self._execute_registration(
|
| 288 |
+
inference_input=inference_input,
|
| 289 |
+
prediction=prediction,
|
| 290 |
+
prediction_type=prediction_type,
|
| 291 |
+
disable_preproc_auto_orient=disable_preproc_auto_orient,
|
| 292 |
+
)
|
| 293 |
+
except Exception as error:
|
| 294 |
+
# Error handling to be decided
|
| 295 |
+
logger.warning(
|
| 296 |
+
f"Error in datapoint registration for Active Learning. Details: {error}. "
|
| 297 |
+
f"Error is suppressed in favour of normal operations of registration thread."
|
| 298 |
+
)
|
| 299 |
+
self._task_queue.task_done()
|
| 300 |
+
return False
|
| 301 |
+
|
| 302 |
+
def __enter__(self) -> "ThreadingActiveLearningMiddleware":
|
| 303 |
+
self.start_registration_thread()
|
| 304 |
+
return self
|
| 305 |
+
|
| 306 |
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
| 307 |
+
self.stop_registration_thread()
|
inference/core/active_learning/post_processing.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from typing import List, Tuple
|
| 3 |
+
|
| 4 |
+
from inference.core.active_learning.entities import (
|
| 5 |
+
Prediction,
|
| 6 |
+
PredictionFileType,
|
| 7 |
+
PredictionType,
|
| 8 |
+
SerialisedPrediction,
|
| 9 |
+
)
|
| 10 |
+
from inference.core.constants import (
|
| 11 |
+
CLASSIFICATION_TASK,
|
| 12 |
+
INSTANCE_SEGMENTATION_TASK,
|
| 13 |
+
OBJECT_DETECTION_TASK,
|
| 14 |
+
)
|
| 15 |
+
from inference.core.exceptions import PredictionFormatNotSupported
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def adjust_prediction_to_client_scaling_factor(
|
| 19 |
+
prediction: dict, scaling_factor: float, prediction_type: PredictionType
|
| 20 |
+
) -> dict:
|
| 21 |
+
if abs(scaling_factor - 1.0) < 1e-5:
|
| 22 |
+
return prediction
|
| 23 |
+
if "image" in prediction:
|
| 24 |
+
prediction["image"] = {
|
| 25 |
+
"width": round(prediction["image"]["width"] / scaling_factor),
|
| 26 |
+
"height": round(prediction["image"]["height"] / scaling_factor),
|
| 27 |
+
}
|
| 28 |
+
if predictions_should_not_be_post_processed(
|
| 29 |
+
prediction=prediction, prediction_type=prediction_type
|
| 30 |
+
):
|
| 31 |
+
return prediction
|
| 32 |
+
if prediction_type == INSTANCE_SEGMENTATION_TASK:
|
| 33 |
+
prediction["predictions"] = (
|
| 34 |
+
adjust_prediction_with_bbox_and_points_to_client_scaling_factor(
|
| 35 |
+
predictions=prediction["predictions"],
|
| 36 |
+
scaling_factor=scaling_factor,
|
| 37 |
+
points_key="points",
|
| 38 |
+
)
|
| 39 |
+
)
|
| 40 |
+
if prediction_type == OBJECT_DETECTION_TASK:
|
| 41 |
+
prediction["predictions"] = (
|
| 42 |
+
adjust_object_detection_predictions_to_client_scaling_factor(
|
| 43 |
+
predictions=prediction["predictions"],
|
| 44 |
+
scaling_factor=scaling_factor,
|
| 45 |
+
)
|
| 46 |
+
)
|
| 47 |
+
return prediction
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def predictions_should_not_be_post_processed(
|
| 51 |
+
prediction: dict, prediction_type: PredictionType
|
| 52 |
+
) -> bool:
|
| 53 |
+
# excluding from post-processing classification output, stub-output and empty predictions
|
| 54 |
+
return (
|
| 55 |
+
"is_stub" in prediction
|
| 56 |
+
or "predictions" not in prediction
|
| 57 |
+
or CLASSIFICATION_TASK in prediction_type
|
| 58 |
+
or len(prediction["predictions"]) == 0
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def adjust_object_detection_predictions_to_client_scaling_factor(
|
| 63 |
+
predictions: List[dict],
|
| 64 |
+
scaling_factor: float,
|
| 65 |
+
) -> List[dict]:
|
| 66 |
+
result = []
|
| 67 |
+
for prediction in predictions:
|
| 68 |
+
prediction = adjust_bbox_coordinates_to_client_scaling_factor(
|
| 69 |
+
bbox=prediction,
|
| 70 |
+
scaling_factor=scaling_factor,
|
| 71 |
+
)
|
| 72 |
+
result.append(prediction)
|
| 73 |
+
return result
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def adjust_prediction_with_bbox_and_points_to_client_scaling_factor(
|
| 77 |
+
predictions: List[dict],
|
| 78 |
+
scaling_factor: float,
|
| 79 |
+
points_key: str,
|
| 80 |
+
) -> List[dict]:
|
| 81 |
+
result = []
|
| 82 |
+
for prediction in predictions:
|
| 83 |
+
prediction = adjust_bbox_coordinates_to_client_scaling_factor(
|
| 84 |
+
bbox=prediction,
|
| 85 |
+
scaling_factor=scaling_factor,
|
| 86 |
+
)
|
| 87 |
+
prediction[points_key] = adjust_points_coordinates_to_client_scaling_factor(
|
| 88 |
+
points=prediction[points_key],
|
| 89 |
+
scaling_factor=scaling_factor,
|
| 90 |
+
)
|
| 91 |
+
result.append(prediction)
|
| 92 |
+
return result
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def adjust_bbox_coordinates_to_client_scaling_factor(
|
| 96 |
+
bbox: dict,
|
| 97 |
+
scaling_factor: float,
|
| 98 |
+
) -> dict:
|
| 99 |
+
bbox["x"] = bbox["x"] / scaling_factor
|
| 100 |
+
bbox["y"] = bbox["y"] / scaling_factor
|
| 101 |
+
bbox["width"] = bbox["width"] / scaling_factor
|
| 102 |
+
bbox["height"] = bbox["height"] / scaling_factor
|
| 103 |
+
return bbox
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def adjust_points_coordinates_to_client_scaling_factor(
|
| 107 |
+
points: List[dict],
|
| 108 |
+
scaling_factor: float,
|
| 109 |
+
) -> List[dict]:
|
| 110 |
+
result = []
|
| 111 |
+
for point in points:
|
| 112 |
+
point["x"] = point["x"] / scaling_factor
|
| 113 |
+
point["y"] = point["y"] / scaling_factor
|
| 114 |
+
result.append(point)
|
| 115 |
+
return result
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def encode_prediction(
|
| 119 |
+
prediction: Prediction,
|
| 120 |
+
prediction_type: PredictionType,
|
| 121 |
+
) -> Tuple[SerialisedPrediction, PredictionFileType]:
|
| 122 |
+
if CLASSIFICATION_TASK not in prediction_type:
|
| 123 |
+
return json.dumps(prediction), "json"
|
| 124 |
+
if "top" in prediction:
|
| 125 |
+
return prediction["top"], "txt"
|
| 126 |
+
raise PredictionFormatNotSupported(
|
| 127 |
+
f"Prediction type or prediction format not supported."
|
| 128 |
+
)
|
inference/core/active_learning/samplers/__init__.py
ADDED
|
File without changes
|
inference/core/active_learning/samplers/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (201 Bytes). View file
|
|
|
inference/core/active_learning/samplers/__pycache__/close_to_threshold.cpython-310.pyc
ADDED
|
Binary file (4.68 kB). View file
|
|
|
inference/core/active_learning/samplers/__pycache__/contains_classes.cpython-310.pyc
ADDED
|
Binary file (1.71 kB). View file
|
|
|
inference/core/active_learning/samplers/__pycache__/number_of_detections.cpython-310.pyc
ADDED
|
Binary file (2.74 kB). View file
|
|
|
inference/core/active_learning/samplers/__pycache__/random.cpython-310.pyc
ADDED
|
Binary file (1.22 kB). View file
|
|
|
inference/core/active_learning/samplers/close_to_threshold.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from functools import partial
|
| 3 |
+
from typing import Any, Dict, Optional, Set
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from inference.core.active_learning.entities import (
|
| 8 |
+
Prediction,
|
| 9 |
+
PredictionType,
|
| 10 |
+
SamplingMethod,
|
| 11 |
+
)
|
| 12 |
+
from inference.core.constants import (
|
| 13 |
+
CLASSIFICATION_TASK,
|
| 14 |
+
INSTANCE_SEGMENTATION_TASK,
|
| 15 |
+
KEYPOINTS_DETECTION_TASK,
|
| 16 |
+
OBJECT_DETECTION_TASK,
|
| 17 |
+
)
|
| 18 |
+
from inference.core.exceptions import ActiveLearningConfigurationError
|
| 19 |
+
|
| 20 |
+
ELIGIBLE_PREDICTION_TYPES = {
|
| 21 |
+
CLASSIFICATION_TASK,
|
| 22 |
+
INSTANCE_SEGMENTATION_TASK,
|
| 23 |
+
KEYPOINTS_DETECTION_TASK,
|
| 24 |
+
OBJECT_DETECTION_TASK,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def initialize_close_to_threshold_sampling(
|
| 29 |
+
strategy_config: Dict[str, Any]
|
| 30 |
+
) -> SamplingMethod:
|
| 31 |
+
try:
|
| 32 |
+
selected_class_names = strategy_config.get("selected_class_names")
|
| 33 |
+
if selected_class_names is not None:
|
| 34 |
+
selected_class_names = set(selected_class_names)
|
| 35 |
+
sample_function = partial(
|
| 36 |
+
sample_close_to_threshold,
|
| 37 |
+
selected_class_names=selected_class_names,
|
| 38 |
+
threshold=strategy_config["threshold"],
|
| 39 |
+
epsilon=strategy_config["epsilon"],
|
| 40 |
+
only_top_classes=strategy_config.get("only_top_classes", True),
|
| 41 |
+
minimum_objects_close_to_threshold=strategy_config.get(
|
| 42 |
+
"minimum_objects_close_to_threshold",
|
| 43 |
+
1,
|
| 44 |
+
),
|
| 45 |
+
probability=strategy_config["probability"],
|
| 46 |
+
)
|
| 47 |
+
return SamplingMethod(
|
| 48 |
+
name=strategy_config["name"],
|
| 49 |
+
sample=sample_function,
|
| 50 |
+
)
|
| 51 |
+
except KeyError as error:
|
| 52 |
+
raise ActiveLearningConfigurationError(
|
| 53 |
+
f"In configuration of `close_to_threshold_sampling` missing key detected: {error}."
|
| 54 |
+
) from error
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def sample_close_to_threshold(
|
| 58 |
+
image: np.ndarray,
|
| 59 |
+
prediction: Prediction,
|
| 60 |
+
prediction_type: PredictionType,
|
| 61 |
+
selected_class_names: Optional[Set[str]],
|
| 62 |
+
threshold: float,
|
| 63 |
+
epsilon: float,
|
| 64 |
+
only_top_classes: bool,
|
| 65 |
+
minimum_objects_close_to_threshold: int,
|
| 66 |
+
probability: float,
|
| 67 |
+
) -> bool:
|
| 68 |
+
if is_prediction_a_stub(prediction=prediction):
|
| 69 |
+
return False
|
| 70 |
+
if prediction_type not in ELIGIBLE_PREDICTION_TYPES:
|
| 71 |
+
return False
|
| 72 |
+
close_to_threshold = prediction_is_close_to_threshold(
|
| 73 |
+
prediction=prediction,
|
| 74 |
+
prediction_type=prediction_type,
|
| 75 |
+
selected_class_names=selected_class_names,
|
| 76 |
+
threshold=threshold,
|
| 77 |
+
epsilon=epsilon,
|
| 78 |
+
only_top_classes=only_top_classes,
|
| 79 |
+
minimum_objects_close_to_threshold=minimum_objects_close_to_threshold,
|
| 80 |
+
)
|
| 81 |
+
if not close_to_threshold:
|
| 82 |
+
return False
|
| 83 |
+
return random.random() < probability
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def is_prediction_a_stub(prediction: Prediction) -> bool:
|
| 87 |
+
return prediction.get("is_stub", False)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def prediction_is_close_to_threshold(
|
| 91 |
+
prediction: Prediction,
|
| 92 |
+
prediction_type: PredictionType,
|
| 93 |
+
selected_class_names: Optional[Set[str]],
|
| 94 |
+
threshold: float,
|
| 95 |
+
epsilon: float,
|
| 96 |
+
only_top_classes: bool,
|
| 97 |
+
minimum_objects_close_to_threshold: int,
|
| 98 |
+
) -> bool:
|
| 99 |
+
if CLASSIFICATION_TASK not in prediction_type:
|
| 100 |
+
return detections_are_close_to_threshold(
|
| 101 |
+
prediction=prediction,
|
| 102 |
+
selected_class_names=selected_class_names,
|
| 103 |
+
threshold=threshold,
|
| 104 |
+
epsilon=epsilon,
|
| 105 |
+
minimum_objects_close_to_threshold=minimum_objects_close_to_threshold,
|
| 106 |
+
)
|
| 107 |
+
checker = multi_label_classification_prediction_is_close_to_threshold
|
| 108 |
+
if "top" in prediction:
|
| 109 |
+
checker = multi_class_classification_prediction_is_close_to_threshold
|
| 110 |
+
return checker(
|
| 111 |
+
prediction=prediction,
|
| 112 |
+
selected_class_names=selected_class_names,
|
| 113 |
+
threshold=threshold,
|
| 114 |
+
epsilon=epsilon,
|
| 115 |
+
only_top_classes=only_top_classes,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def multi_class_classification_prediction_is_close_to_threshold(
|
| 120 |
+
prediction: Prediction,
|
| 121 |
+
selected_class_names: Optional[Set[str]],
|
| 122 |
+
threshold: float,
|
| 123 |
+
epsilon: float,
|
| 124 |
+
only_top_classes: bool,
|
| 125 |
+
) -> bool:
|
| 126 |
+
if only_top_classes:
|
| 127 |
+
return (
|
| 128 |
+
multi_class_classification_prediction_is_close_to_threshold_for_top_class(
|
| 129 |
+
prediction=prediction,
|
| 130 |
+
selected_class_names=selected_class_names,
|
| 131 |
+
threshold=threshold,
|
| 132 |
+
epsilon=epsilon,
|
| 133 |
+
)
|
| 134 |
+
)
|
| 135 |
+
for prediction_details in prediction["predictions"]:
|
| 136 |
+
if class_to_be_excluded(
|
| 137 |
+
class_name=prediction_details["class"],
|
| 138 |
+
selected_class_names=selected_class_names,
|
| 139 |
+
):
|
| 140 |
+
continue
|
| 141 |
+
if is_close_to_threshold(
|
| 142 |
+
value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon
|
| 143 |
+
):
|
| 144 |
+
return True
|
| 145 |
+
return False
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def multi_class_classification_prediction_is_close_to_threshold_for_top_class(
|
| 149 |
+
prediction: Prediction,
|
| 150 |
+
selected_class_names: Optional[Set[str]],
|
| 151 |
+
threshold: float,
|
| 152 |
+
epsilon: float,
|
| 153 |
+
) -> bool:
|
| 154 |
+
if (
|
| 155 |
+
selected_class_names is not None
|
| 156 |
+
and prediction["top"] not in selected_class_names
|
| 157 |
+
):
|
| 158 |
+
return False
|
| 159 |
+
return abs(prediction["confidence"] - threshold) < epsilon
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def multi_label_classification_prediction_is_close_to_threshold(
|
| 163 |
+
prediction: Prediction,
|
| 164 |
+
selected_class_names: Optional[Set[str]],
|
| 165 |
+
threshold: float,
|
| 166 |
+
epsilon: float,
|
| 167 |
+
only_top_classes: bool,
|
| 168 |
+
) -> bool:
|
| 169 |
+
predicted_classes = set(prediction["predicted_classes"])
|
| 170 |
+
for class_name, prediction_details in prediction["predictions"].items():
|
| 171 |
+
if only_top_classes and class_name not in predicted_classes:
|
| 172 |
+
continue
|
| 173 |
+
if class_to_be_excluded(
|
| 174 |
+
class_name=class_name, selected_class_names=selected_class_names
|
| 175 |
+
):
|
| 176 |
+
continue
|
| 177 |
+
if is_close_to_threshold(
|
| 178 |
+
value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon
|
| 179 |
+
):
|
| 180 |
+
return True
|
| 181 |
+
return False
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def detections_are_close_to_threshold(
|
| 185 |
+
prediction: Prediction,
|
| 186 |
+
selected_class_names: Optional[Set[str]],
|
| 187 |
+
threshold: float,
|
| 188 |
+
epsilon: float,
|
| 189 |
+
minimum_objects_close_to_threshold: int,
|
| 190 |
+
) -> bool:
|
| 191 |
+
detections_close_to_threshold = count_detections_close_to_threshold(
|
| 192 |
+
prediction=prediction,
|
| 193 |
+
selected_class_names=selected_class_names,
|
| 194 |
+
threshold=threshold,
|
| 195 |
+
epsilon=epsilon,
|
| 196 |
+
)
|
| 197 |
+
return detections_close_to_threshold >= minimum_objects_close_to_threshold
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def count_detections_close_to_threshold(
|
| 201 |
+
prediction: Prediction,
|
| 202 |
+
selected_class_names: Optional[Set[str]],
|
| 203 |
+
threshold: float,
|
| 204 |
+
epsilon: float,
|
| 205 |
+
) -> int:
|
| 206 |
+
counter = 0
|
| 207 |
+
for prediction_details in prediction["predictions"]:
|
| 208 |
+
if class_to_be_excluded(
|
| 209 |
+
class_name=prediction_details["class"],
|
| 210 |
+
selected_class_names=selected_class_names,
|
| 211 |
+
):
|
| 212 |
+
continue
|
| 213 |
+
if is_close_to_threshold(
|
| 214 |
+
value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon
|
| 215 |
+
):
|
| 216 |
+
counter += 1
|
| 217 |
+
return counter
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def class_to_be_excluded(
|
| 221 |
+
class_name: str, selected_class_names: Optional[Set[str]]
|
| 222 |
+
) -> bool:
|
| 223 |
+
return selected_class_names is not None and class_name not in selected_class_names
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def is_close_to_threshold(value: float, threshold: float, epsilon: float) -> bool:
|
| 227 |
+
return abs(value - threshold) < epsilon
|
inference/core/active_learning/samplers/contains_classes.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from typing import Any, Dict, Set
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from inference.core.active_learning.entities import (
|
| 7 |
+
Prediction,
|
| 8 |
+
PredictionType,
|
| 9 |
+
SamplingMethod,
|
| 10 |
+
)
|
| 11 |
+
from inference.core.active_learning.samplers.close_to_threshold import (
|
| 12 |
+
sample_close_to_threshold,
|
| 13 |
+
)
|
| 14 |
+
from inference.core.constants import CLASSIFICATION_TASK
|
| 15 |
+
from inference.core.exceptions import ActiveLearningConfigurationError
|
| 16 |
+
|
| 17 |
+
ELIGIBLE_PREDICTION_TYPES = {CLASSIFICATION_TASK}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def initialize_classes_based_sampling(
|
| 21 |
+
strategy_config: Dict[str, Any]
|
| 22 |
+
) -> SamplingMethod:
|
| 23 |
+
try:
|
| 24 |
+
sample_function = partial(
|
| 25 |
+
sample_based_on_classes,
|
| 26 |
+
selected_class_names=set(strategy_config["selected_class_names"]),
|
| 27 |
+
probability=strategy_config["probability"],
|
| 28 |
+
)
|
| 29 |
+
return SamplingMethod(
|
| 30 |
+
name=strategy_config["name"],
|
| 31 |
+
sample=sample_function,
|
| 32 |
+
)
|
| 33 |
+
except KeyError as error:
|
| 34 |
+
raise ActiveLearningConfigurationError(
|
| 35 |
+
f"In configuration of `classes_based_sampling` missing key detected: {error}."
|
| 36 |
+
) from error
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def sample_based_on_classes(
|
| 40 |
+
image: np.ndarray,
|
| 41 |
+
prediction: Prediction,
|
| 42 |
+
prediction_type: PredictionType,
|
| 43 |
+
selected_class_names: Set[str],
|
| 44 |
+
probability: float,
|
| 45 |
+
) -> bool:
|
| 46 |
+
if prediction_type not in ELIGIBLE_PREDICTION_TYPES:
|
| 47 |
+
return False
|
| 48 |
+
return sample_close_to_threshold(
|
| 49 |
+
image=image,
|
| 50 |
+
prediction=prediction,
|
| 51 |
+
prediction_type=prediction_type,
|
| 52 |
+
selected_class_names=selected_class_names,
|
| 53 |
+
threshold=0.5,
|
| 54 |
+
epsilon=1.0,
|
| 55 |
+
only_top_classes=True,
|
| 56 |
+
minimum_objects_close_to_threshold=1,
|
| 57 |
+
probability=probability,
|
| 58 |
+
)
|
inference/core/active_learning/samplers/number_of_detections.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from functools import partial
|
| 3 |
+
from typing import Any, Dict, Optional, Set
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from inference.core.active_learning.entities import (
|
| 8 |
+
Prediction,
|
| 9 |
+
PredictionType,
|
| 10 |
+
SamplingMethod,
|
| 11 |
+
)
|
| 12 |
+
from inference.core.active_learning.samplers.close_to_threshold import (
|
| 13 |
+
count_detections_close_to_threshold,
|
| 14 |
+
is_prediction_a_stub,
|
| 15 |
+
)
|
| 16 |
+
from inference.core.constants import (
|
| 17 |
+
INSTANCE_SEGMENTATION_TASK,
|
| 18 |
+
KEYPOINTS_DETECTION_TASK,
|
| 19 |
+
OBJECT_DETECTION_TASK,
|
| 20 |
+
)
|
| 21 |
+
from inference.core.exceptions import ActiveLearningConfigurationError
|
| 22 |
+
|
| 23 |
+
ELIGIBLE_PREDICTION_TYPES = {
|
| 24 |
+
INSTANCE_SEGMENTATION_TASK,
|
| 25 |
+
KEYPOINTS_DETECTION_TASK,
|
| 26 |
+
OBJECT_DETECTION_TASK,
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def initialize_detections_number_based_sampling(
|
| 31 |
+
strategy_config: Dict[str, Any]
|
| 32 |
+
) -> SamplingMethod:
|
| 33 |
+
try:
|
| 34 |
+
more_than = strategy_config.get("more_than")
|
| 35 |
+
less_than = strategy_config.get("less_than")
|
| 36 |
+
ensure_range_configuration_is_valid(more_than=more_than, less_than=less_than)
|
| 37 |
+
selected_class_names = strategy_config.get("selected_class_names")
|
| 38 |
+
if selected_class_names is not None:
|
| 39 |
+
selected_class_names = set(selected_class_names)
|
| 40 |
+
sample_function = partial(
|
| 41 |
+
sample_based_on_detections_number,
|
| 42 |
+
less_than=less_than,
|
| 43 |
+
more_than=more_than,
|
| 44 |
+
selected_class_names=selected_class_names,
|
| 45 |
+
probability=strategy_config["probability"],
|
| 46 |
+
)
|
| 47 |
+
return SamplingMethod(
|
| 48 |
+
name=strategy_config["name"],
|
| 49 |
+
sample=sample_function,
|
| 50 |
+
)
|
| 51 |
+
except KeyError as error:
|
| 52 |
+
raise ActiveLearningConfigurationError(
|
| 53 |
+
f"In configuration of `detections_number_based_sampling` missing key detected: {error}."
|
| 54 |
+
) from error
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def sample_based_on_detections_number(
|
| 58 |
+
image: np.ndarray,
|
| 59 |
+
prediction: Prediction,
|
| 60 |
+
prediction_type: PredictionType,
|
| 61 |
+
more_than: Optional[int],
|
| 62 |
+
less_than: Optional[int],
|
| 63 |
+
selected_class_names: Optional[Set[str]],
|
| 64 |
+
probability: float,
|
| 65 |
+
) -> bool:
|
| 66 |
+
if is_prediction_a_stub(prediction=prediction):
|
| 67 |
+
return False
|
| 68 |
+
if prediction_type not in ELIGIBLE_PREDICTION_TYPES:
|
| 69 |
+
return False
|
| 70 |
+
detections_close_to_threshold = count_detections_close_to_threshold(
|
| 71 |
+
prediction=prediction,
|
| 72 |
+
selected_class_names=selected_class_names,
|
| 73 |
+
threshold=0.5,
|
| 74 |
+
epsilon=1.0,
|
| 75 |
+
)
|
| 76 |
+
if is_in_range(
|
| 77 |
+
value=detections_close_to_threshold, less_than=less_than, more_than=more_than
|
| 78 |
+
):
|
| 79 |
+
return random.random() < probability
|
| 80 |
+
return False
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def is_in_range(
|
| 84 |
+
value: int,
|
| 85 |
+
more_than: Optional[int],
|
| 86 |
+
less_than: Optional[int],
|
| 87 |
+
) -> bool:
|
| 88 |
+
# calculates value > more_than and value < less_than, with optional borders of range
|
| 89 |
+
less_than_satisfied, more_than_satisfied = less_than is None, more_than is None
|
| 90 |
+
if less_than is not None and value < less_than:
|
| 91 |
+
less_than_satisfied = True
|
| 92 |
+
if more_than is not None and value > more_than:
|
| 93 |
+
more_than_satisfied = True
|
| 94 |
+
return less_than_satisfied and more_than_satisfied
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def ensure_range_configuration_is_valid(
|
| 98 |
+
more_than: Optional[int],
|
| 99 |
+
less_than: Optional[int],
|
| 100 |
+
) -> None:
|
| 101 |
+
if more_than is None or less_than is None:
|
| 102 |
+
return None
|
| 103 |
+
if more_than >= less_than:
|
| 104 |
+
raise ActiveLearningConfigurationError(
|
| 105 |
+
f"Misconfiguration of detections number sampling: "
|
| 106 |
+
f"`more_than` parameter ({more_than}) >= `less_than` ({less_than})."
|
| 107 |
+
)
|
inference/core/active_learning/samplers/random.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from functools import partial
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from inference.core.active_learning.entities import (
|
| 8 |
+
Prediction,
|
| 9 |
+
PredictionType,
|
| 10 |
+
SamplingMethod,
|
| 11 |
+
)
|
| 12 |
+
from inference.core.exceptions import ActiveLearningConfigurationError
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def initialize_random_sampling(strategy_config: Dict[str, Any]) -> SamplingMethod:
|
| 16 |
+
try:
|
| 17 |
+
sample_function = partial(
|
| 18 |
+
sample_randomly,
|
| 19 |
+
traffic_percentage=strategy_config["traffic_percentage"],
|
| 20 |
+
)
|
| 21 |
+
return SamplingMethod(
|
| 22 |
+
name=strategy_config["name"],
|
| 23 |
+
sample=sample_function,
|
| 24 |
+
)
|
| 25 |
+
except KeyError as error:
|
| 26 |
+
raise ActiveLearningConfigurationError(
|
| 27 |
+
f"In configuration of `random_sampling` missing key detected: {error}."
|
| 28 |
+
) from error
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def sample_randomly(
|
| 32 |
+
image: np.ndarray,
|
| 33 |
+
prediction: Prediction,
|
| 34 |
+
prediction_type: PredictionType,
|
| 35 |
+
traffic_percentage: float,
|
| 36 |
+
) -> bool:
|
| 37 |
+
return random.random() < traffic_percentage
|
inference/core/active_learning/utils.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime, timedelta
|
| 2 |
+
|
| 3 |
+
TIMESTAMP_FORMAT = "%Y_%m_%d"
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def generate_today_timestamp() -> str:
|
| 7 |
+
return datetime.today().strftime(TIMESTAMP_FORMAT)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def generate_start_timestamp_for_this_week() -> str:
|
| 11 |
+
today = datetime.today()
|
| 12 |
+
return (today - timedelta(days=today.weekday())).strftime(TIMESTAMP_FORMAT)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def generate_start_timestamp_for_this_month() -> str:
|
| 16 |
+
return datetime.today().replace(day=1).strftime(TIMESTAMP_FORMAT)
|
inference/core/cache/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from redis.exceptions import ConnectionError, TimeoutError
|
| 2 |
+
|
| 3 |
+
from inference.core import logger
|
| 4 |
+
from inference.core.cache.memory import MemoryCache
|
| 5 |
+
from inference.core.cache.redis import RedisCache
|
| 6 |
+
from inference.core.env import REDIS_HOST, REDIS_PORT, REDIS_SSL, REDIS_TIMEOUT
|
| 7 |
+
|
| 8 |
+
if REDIS_HOST is not None:
|
| 9 |
+
try:
|
| 10 |
+
cache = RedisCache(
|
| 11 |
+
host=REDIS_HOST, port=REDIS_PORT, ssl=REDIS_SSL, timeout=REDIS_TIMEOUT
|
| 12 |
+
)
|
| 13 |
+
logger.info(f"Redis Cache initialised")
|
| 14 |
+
except (TimeoutError, ConnectionError):
|
| 15 |
+
logger.error(
|
| 16 |
+
f"Could not connect to Redis under {REDIS_HOST}:{REDIS_PORT}. MemoryCache to be used."
|
| 17 |
+
)
|
| 18 |
+
cache = MemoryCache()
|
| 19 |
+
logger.info(f"Memory Cache initialised")
|
| 20 |
+
else:
|
| 21 |
+
cache = MemoryCache()
|
| 22 |
+
logger.info(f"Memory Cache initialised")
|
inference/core/cache/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (864 Bytes). View file
|
|
|
inference/core/cache/__pycache__/base.cpython-310.pyc
ADDED
|
Binary file (4.93 kB). View file
|
|
|
inference/core/cache/__pycache__/memory.cpython-310.pyc
ADDED
|
Binary file (6.56 kB). View file
|
|
|
inference/core/cache/__pycache__/model_artifacts.cpython-310.pyc
ADDED
|
Binary file (3.17 kB). View file
|
|
|
inference/core/cache/__pycache__/redis.cpython-310.pyc
ADDED
|
Binary file (7.3 kB). View file
|
|
|
inference/core/cache/__pycache__/serializers.cpython-310.pyc
ADDED
|
Binary file (1.91 kB). View file
|
|
|
inference/core/cache/base.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import contextmanager
|
| 2 |
+
from typing import Any, Optional
|
| 3 |
+
|
| 4 |
+
from inference.core import logger
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class BaseCache:
|
| 8 |
+
"""
|
| 9 |
+
BaseCache is an abstract base class that defines the interface for a cache.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def get(self, key: str):
|
| 13 |
+
"""
|
| 14 |
+
Gets the value associated with the given key.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
key (str): The key to retrieve the value.
|
| 18 |
+
|
| 19 |
+
Raises:
|
| 20 |
+
NotImplementedError: This method must be implemented by subclasses.
|
| 21 |
+
"""
|
| 22 |
+
raise NotImplementedError()
|
| 23 |
+
|
| 24 |
+
def set(self, key: str, value: str, expire: float = None):
|
| 25 |
+
"""
|
| 26 |
+
Sets a value for a given key with an optional expire time.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
key (str): The key to store the value.
|
| 30 |
+
value (str): The value to store.
|
| 31 |
+
expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.
|
| 32 |
+
|
| 33 |
+
Raises:
|
| 34 |
+
NotImplementedError: This method must be implemented by subclasses.
|
| 35 |
+
"""
|
| 36 |
+
raise NotImplementedError()
|
| 37 |
+
|
| 38 |
+
def zadd(self, key: str, value: str, score: float, expire: float = None):
|
| 39 |
+
"""
|
| 40 |
+
Adds a member with the specified score to the sorted set stored at key.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
key (str): The key of the sorted set.
|
| 44 |
+
value (str): The value to add to the sorted set.
|
| 45 |
+
score (float): The score associated with the value.
|
| 46 |
+
expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.
|
| 47 |
+
|
| 48 |
+
Raises:
|
| 49 |
+
NotImplementedError: This method must be implemented by subclasses.
|
| 50 |
+
"""
|
| 51 |
+
raise NotImplementedError()
|
| 52 |
+
|
| 53 |
+
def zrangebyscore(
|
| 54 |
+
self,
|
| 55 |
+
key: str,
|
| 56 |
+
min: Optional[float] = -1,
|
| 57 |
+
max: Optional[float] = float("inf"),
|
| 58 |
+
withscores: bool = False,
|
| 59 |
+
):
|
| 60 |
+
"""
|
| 61 |
+
Retrieves a range of members from a sorted set.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
key (str): The key of the sorted set.
|
| 65 |
+
start (int, optional): The starting index of the range. Defaults to -1.
|
| 66 |
+
stop (int, optional): The ending index of the range. Defaults to float("inf").
|
| 67 |
+
withscores (bool, optional): Whether to return the scores along with the values. Defaults to False.
|
| 68 |
+
|
| 69 |
+
Raises:
|
| 70 |
+
NotImplementedError: This method must be implemented by subclasses.
|
| 71 |
+
"""
|
| 72 |
+
raise NotImplementedError()
|
| 73 |
+
|
| 74 |
+
def zremrangebyscore(
|
| 75 |
+
self,
|
| 76 |
+
key: str,
|
| 77 |
+
start: Optional[int] = -1,
|
| 78 |
+
stop: Optional[int] = float("inf"),
|
| 79 |
+
):
|
| 80 |
+
"""
|
| 81 |
+
Removes all members in a sorted set within the given scores.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
key (str): The key of the sorted set.
|
| 85 |
+
start (int, optional): The minimum score of the range. Defaults to -1.
|
| 86 |
+
stop (int, optional): The maximum score of the range. Defaults to float("inf").
|
| 87 |
+
|
| 88 |
+
Raises:
|
| 89 |
+
NotImplementedError: This method must be implemented by subclasses.
|
| 90 |
+
"""
|
| 91 |
+
raise NotImplementedError()
|
| 92 |
+
|
| 93 |
+
def acquire_lock(self, key: str, expire: float = None) -> Any:
|
| 94 |
+
raise NotImplementedError()
|
| 95 |
+
|
| 96 |
+
@contextmanager
|
| 97 |
+
def lock(self, key: str, expire: float = None) -> Any:
|
| 98 |
+
logger.debug(f"Acquiring lock at cache key: {key}")
|
| 99 |
+
l = self.acquire_lock(key, expire=expire)
|
| 100 |
+
try:
|
| 101 |
+
yield l
|
| 102 |
+
finally:
|
| 103 |
+
logger.debug(f"Releasing lock at cache key: {key}")
|
| 104 |
+
l.release()
|
| 105 |
+
|
| 106 |
+
def set_numpy(self, key: str, value: Any, expire: float = None):
|
| 107 |
+
"""
|
| 108 |
+
Caches a numpy array.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
key (str): The key to store the value.
|
| 112 |
+
value (Any): The value to store.
|
| 113 |
+
expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.
|
| 114 |
+
|
| 115 |
+
Raises:
|
| 116 |
+
NotImplementedError: This method must be implemented by subclasses.
|
| 117 |
+
"""
|
| 118 |
+
raise NotImplementedError()
|
| 119 |
+
|
| 120 |
+
def get_numpy(self, key: str) -> Any:
|
| 121 |
+
"""
|
| 122 |
+
Retrieves a numpy array from the cache.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
key (str): The key of the value to retrieve.
|
| 126 |
+
|
| 127 |
+
Raises:
|
| 128 |
+
NotImplementedError: This method must be implemented by subclasses.
|
| 129 |
+
"""
|
| 130 |
+
raise NotImplementedError()
|