import itertools import json import os from collections import OrderedDict from concurrent.futures import ThreadPoolExecutor from functools import partial from time import perf_counter from typing import Any, Dict, List, Optional, Tuple, Union import cv2 import numpy as np import onnxruntime from PIL import Image from inference.core.cache import cache from inference.core.cache.model_artifacts import ( are_all_files_cached, clear_cache, get_cache_dir, get_cache_file_path, initialise_cache, load_json_from_cache, load_text_file_from_cache, save_bytes_in_cache, save_json_in_cache, save_text_lines_in_cache, ) from inference.core.devices.utils import GLOBAL_DEVICE_ID from inference.core.entities.requests.inference import ( InferenceRequest, InferenceRequestImage, ) from inference.core.entities.responses.inference import InferenceResponse from inference.core.env import ( API_KEY, API_KEY_ENV_NAMES, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, CORE_MODEL_BUCKET, DISABLE_PREPROC_AUTO_ORIENT, INFER_BUCKET, LAMBDA, MAX_BATCH_SIZE, MODEL_CACHE_DIR, ONNXRUNTIME_EXECUTION_PROVIDERS, REQUIRED_ONNX_PROVIDERS, TENSORRT_CACHE_PATH, ) from inference.core.exceptions import ( MissingApiKeyError, ModelArtefactError, OnnxProviderNotAvailable, ) from inference.core.logger import logger from inference.core.models.base import Model from inference.core.models.utils.batching import ( calculate_input_elements, create_batches, ) from inference.core.roboflow_api import ( ModelEndpointType, get_from_url, get_roboflow_model_data, ) from inference.core.utils.image_utils import load_image from inference.core.utils.onnx import get_onnxruntime_execution_providers from inference.core.utils.preprocess import letterbox_image, prepare from inference.core.utils.visualisation import draw_detection_predictions from inference.models.aliases import resolve_roboflow_model_alias NUM_S3_RETRY = 5 SLEEP_SECONDS_BETWEEN_RETRIES = 3 MODEL_METADATA_CACHE_EXPIRATION_TIMEOUT = 3600 # 1 hour S3_CLIENT = None if AWS_ACCESS_KEY_ID and AWS_ACCESS_KEY_ID: try: import boto3 from botocore.config import Config from inference.core.utils.s3 import download_s3_files_to_directory config = Config(retries={"max_attempts": NUM_S3_RETRY, "mode": "standard"}) S3_CLIENT = boto3.client("s3", config=config) except: logger.debug("Error loading boto3") pass DEFAULT_COLOR_PALETTE = [ "#4892EA", "#00EEC3", "#FE4EF0", "#F4004E", "#FA7200", "#EEEE17", "#90FF00", "#78C1D2", "#8C29FF", ] class RoboflowInferenceModel(Model): """Base Roboflow inference model.""" def __init__( self, model_id: str, cache_dir_root=MODEL_CACHE_DIR, api_key=None, load_weights=True, ): """ Initialize the RoboflowInferenceModel object. Args: model_id (str): The unique identifier for the model. cache_dir_root (str, optional): The root directory for the cache. Defaults to MODEL_CACHE_DIR. api_key (str, optional): API key for authentication. Defaults to None. """ super().__init__() self.load_weights = load_weights self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0} self.api_key = api_key if api_key else API_KEY model_id = resolve_roboflow_model_alias(model_id=model_id) self.dataset_id, self.version_id = model_id.split("/") self.endpoint = model_id self.device_id = GLOBAL_DEVICE_ID self.cache_dir = os.path.join(cache_dir_root, self.endpoint) self.keypoints_metadata: Optional[dict] = None initialise_cache(model_id=self.endpoint) def cache_file(self, f: str) -> str: """Get the cache file path for a given file. Args: f (str): Filename. Returns: str: Full path to the cached file. """ return get_cache_file_path(file=f, model_id=self.endpoint) def clear_cache(self) -> None: """Clear the cache directory.""" clear_cache(model_id=self.endpoint) def draw_predictions( self, inference_request: InferenceRequest, inference_response: InferenceResponse, ) -> bytes: """Draw predictions from an inference response onto the original image provided by an inference request Args: inference_request (ObjectDetectionInferenceRequest): The inference request containing the image on which to draw predictions inference_response (ObjectDetectionInferenceResponse): The inference response containing predictions to be drawn Returns: str: A base64 encoded image string """ return draw_detection_predictions( inference_request=inference_request, inference_response=inference_response, colors=self.colors, ) @property def get_class_names(self): return self.class_names def get_device_id(self) -> str: """ Get the device identifier on which the model is deployed. Returns: str: Device identifier. """ return self.device_id def get_infer_bucket_file_list(self) -> List[str]: """Get a list of inference bucket files. Raises: NotImplementedError: If the method is not implemented. Returns: List[str]: A list of inference bucket files. """ raise NotImplementedError( self.__class__.__name__ + ".get_infer_bucket_file_list" ) @property def cache_key(self): return f"metadata:{self.endpoint}" @staticmethod def model_metadata_from_memcache_endpoint(endpoint): model_metadata = cache.get(f"metadata:{endpoint}") return model_metadata def model_metadata_from_memcache(self): model_metadata = cache.get(self.cache_key) return model_metadata def write_model_metadata_to_memcache(self, metadata): cache.set( self.cache_key, metadata, expire=MODEL_METADATA_CACHE_EXPIRATION_TIMEOUT ) @property def has_model_metadata(self): return self.model_metadata_from_memcache() is not None def get_model_artifacts(self) -> None: """Fetch or load the model artifacts. Downloads the model artifacts from S3 or the Roboflow API if they are not already cached. """ self.cache_model_artefacts() self.load_model_artifacts_from_cache() def cache_model_artefacts(self) -> None: infer_bucket_files = self.get_all_required_infer_bucket_file() if are_all_files_cached(files=infer_bucket_files, model_id=self.endpoint): return None if is_model_artefacts_bucket_available(): self.download_model_artefacts_from_s3() return None self.download_model_artifacts_from_roboflow_api() def get_all_required_infer_bucket_file(self) -> List[str]: infer_bucket_files = self.get_infer_bucket_file_list() infer_bucket_files.append(self.weights_file) logger.debug(f"List of files required to load model: {infer_bucket_files}") return [f for f in infer_bucket_files if f is not None] def download_model_artefacts_from_s3(self) -> None: try: logger.debug("Downloading model artifacts from S3") infer_bucket_files = self.get_all_required_infer_bucket_file() cache_directory = get_cache_dir() s3_keys = [f"{self.endpoint}/{file}" for file in infer_bucket_files] download_s3_files_to_directory( bucket=self.model_artifact_bucket, keys=s3_keys, target_dir=cache_directory, s3_client=S3_CLIENT, ) except Exception as error: raise ModelArtefactError( f"Could not obtain model artefacts from S3 with keys {s3_keys}. Cause: {error}" ) from error @property def model_artifact_bucket(self): return INFER_BUCKET def download_model_artifacts_from_roboflow_api(self) -> None: logger.debug("Downloading model artifacts from Roboflow API") api_data = get_roboflow_model_data( api_key=self.api_key, model_id=self.endpoint, endpoint_type=ModelEndpointType.ORT, device_id=self.device_id, ) if "ort" not in api_data.keys(): raise ModelArtefactError( "Could not find `ort` key in roboflow API model description response." ) api_data = api_data["ort"] if "classes" in api_data: save_text_lines_in_cache( content=api_data["classes"], file="class_names.txt", model_id=self.endpoint, ) if "model" not in api_data: raise ModelArtefactError( "Could not find `model` key in roboflow API model description response." ) if "environment" not in api_data: raise ModelArtefactError( "Could not find `environment` key in roboflow API model description response." ) environment = get_from_url(api_data["environment"]) model_weights_response = get_from_url(api_data["model"], json_response=False) save_bytes_in_cache( content=model_weights_response.content, file=self.weights_file, model_id=self.endpoint, ) if "colors" in api_data: environment["COLORS"] = api_data["colors"] save_json_in_cache( content=environment, file="environment.json", model_id=self.endpoint, ) if "keypoints_metadata" in api_data: # TODO: make sure backend provides that save_json_in_cache( content=api_data["keypoints_metadata"], file="keypoints_metadata.json", model_id=self.endpoint, ) def load_model_artifacts_from_cache(self) -> None: logger.debug("Model artifacts already downloaded, loading model from cache") infer_bucket_files = self.get_all_required_infer_bucket_file() if "environment.json" in infer_bucket_files: self.environment = load_json_from_cache( file="environment.json", model_id=self.endpoint, object_pairs_hook=OrderedDict, ) if "class_names.txt" in infer_bucket_files: self.class_names = load_text_file_from_cache( file="class_names.txt", model_id=self.endpoint, split_lines=True, strip_white_chars=True, ) else: self.class_names = get_class_names_from_environment_file( environment=self.environment ) self.colors = get_color_mapping_from_environment( environment=self.environment, class_names=self.class_names, ) if "keypoints_metadata.json" in infer_bucket_files: self.keypoints_metadata = parse_keypoints_metadata( load_json_from_cache( file="keypoints_metadata.json", model_id=self.endpoint, object_pairs_hook=OrderedDict, ) ) self.num_classes = len(self.class_names) if "PREPROCESSING" not in self.environment: raise ModelArtefactError( "Could not find `PREPROCESSING` key in environment file." ) if issubclass(type(self.environment["PREPROCESSING"]), dict): self.preproc = self.environment["PREPROCESSING"] else: self.preproc = json.loads(self.environment["PREPROCESSING"]) if self.preproc.get("resize"): self.resize_method = self.preproc["resize"].get("format", "Stretch to") if self.resize_method not in [ "Stretch to", "Fit (black edges) in", "Fit (white edges) in", ]: self.resize_method = "Stretch to" else: self.resize_method = "Stretch to" logger.debug(f"Resize method is '{self.resize_method}'") self.multiclass = self.environment.get("MULTICLASS", False) def initialize_model(self) -> None: """Initialize the model. Raises: NotImplementedError: If the method is not implemented. """ raise NotImplementedError(self.__class__.__name__ + ".initialize_model") def preproc_image( self, image: Union[Any, InferenceRequestImage], disable_preproc_auto_orient: bool = False, disable_preproc_contrast: bool = False, disable_preproc_grayscale: bool = False, disable_preproc_static_crop: bool = False, ) -> Tuple[np.ndarray, Tuple[int, int]]: """ Preprocesses an inference request image by loading it, then applying any pre-processing specified by the Roboflow platform, then scaling it to the inference input dimensions. Args: image (Union[Any, InferenceRequestImage]): An object containing information necessary to load the image for inference. disable_preproc_auto_orient (bool, optional): If true, the auto orient preprocessing step is disabled for this call. Default is False. disable_preproc_contrast (bool, optional): If true, the contrast preprocessing step is disabled for this call. Default is False. disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False. disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False. Returns: Tuple[np.ndarray, Tuple[int, int]]: A tuple containing a numpy array of the preprocessed image pixel data and a tuple of the images original size. """ np_image, is_bgr = load_image( image, disable_preproc_auto_orient=disable_preproc_auto_orient or "auto-orient" not in self.preproc.keys() or DISABLE_PREPROC_AUTO_ORIENT, ) preprocessed_image, img_dims = self.preprocess_image( np_image, disable_preproc_contrast=disable_preproc_contrast, disable_preproc_grayscale=disable_preproc_grayscale, disable_preproc_static_crop=disable_preproc_static_crop, ) if self.resize_method == "Stretch to": resized = cv2.resize( preprocessed_image, (self.img_size_w, self.img_size_h), cv2.INTER_CUBIC ) elif self.resize_method == "Fit (black edges) in": resized = letterbox_image( preprocessed_image, (self.img_size_w, self.img_size_h) ) elif self.resize_method == "Fit (white edges) in": resized = letterbox_image( preprocessed_image, (self.img_size_w, self.img_size_h), color=(255, 255, 255), ) if is_bgr: resized = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB) img_in = np.transpose(resized, (2, 0, 1)) img_in = img_in.astype(np.float32) img_in = np.expand_dims(img_in, axis=0) return img_in, img_dims def preprocess_image( self, image: np.ndarray, disable_preproc_contrast: bool = False, disable_preproc_grayscale: bool = False, disable_preproc_static_crop: bool = False, ) -> Tuple[np.ndarray, Tuple[int, int]]: """ Preprocesses the given image using specified preprocessing steps. Args: image (Image.Image): The PIL image to preprocess. disable_preproc_contrast (bool, optional): If true, the contrast preprocessing step is disabled for this call. Default is False. disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False. disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False. Returns: Image.Image: The preprocessed PIL image. """ return prepare( image, self.preproc, disable_preproc_contrast=disable_preproc_contrast, disable_preproc_grayscale=disable_preproc_grayscale, disable_preproc_static_crop=disable_preproc_static_crop, ) @property def weights_file(self) -> str: """Abstract property representing the file containing the model weights. Raises: NotImplementedError: This property must be implemented in subclasses. Returns: str: The file path to the weights file. """ raise NotImplementedError(self.__class__.__name__ + ".weights_file") class RoboflowCoreModel(RoboflowInferenceModel): """Base Roboflow inference model (Inherits from CvModel since all Roboflow models are CV models currently).""" def __init__( self, model_id: str, api_key=None, ): """Initializes the RoboflowCoreModel instance. Args: model_id (str): The identifier for the specific model. api_key ([type], optional): The API key for authentication. Defaults to None. """ super().__init__(model_id, api_key=api_key) self.download_weights() def download_weights(self) -> None: """Downloads the model weights from the configured source. This method includes handling for AWS access keys and error handling. """ infer_bucket_files = self.get_infer_bucket_file_list() if are_all_files_cached(files=infer_bucket_files, model_id=self.endpoint): logger.debug("Model artifacts already downloaded, loading from cache") return None if is_model_artefacts_bucket_available(): self.download_model_artefacts_from_s3() return None self.download_model_from_roboflow_api() def download_model_from_roboflow_api(self) -> None: api_data = get_roboflow_model_data( api_key=self.api_key, model_id=self.endpoint, endpoint_type=ModelEndpointType.CORE_MODEL, device_id=self.device_id, ) if "weights" not in api_data: raise ModelArtefactError( f"`weights` key not available in Roboflow API response while downloading model weights." ) for weights_url_key in api_data["weights"]: weights_url = api_data["weights"][weights_url_key] t1 = perf_counter() model_weights_response = get_from_url(weights_url, json_response=False) filename = weights_url.split("?")[0].split("/")[-1] save_bytes_in_cache( content=model_weights_response.content, file=filename, model_id=self.endpoint, ) if perf_counter() - t1 > 120: logger.debug( "Weights download took longer than 120 seconds, refreshing API request" ) api_data = get_roboflow_model_data( api_key=self.api_key, model_id=self.endpoint, endpoint_type=ModelEndpointType.CORE_MODEL, device_id=self.device_id, ) def get_device_id(self) -> str: """Returns the device ID associated with this model. Returns: str: The device ID. """ return self.device_id def get_infer_bucket_file_list(self) -> List[str]: """Abstract method to get the list of files to be downloaded from the inference bucket. Raises: NotImplementedError: This method must be implemented in subclasses. Returns: List[str]: A list of filenames. """ raise NotImplementedError( "get_infer_bucket_file_list not implemented for OnnxRoboflowCoreModel" ) def preprocess_image(self, image: Image.Image) -> Image.Image: """Abstract method to preprocess an image. Raises: NotImplementedError: This method must be implemented in subclasses. Returns: Image.Image: The preprocessed PIL image. """ raise NotImplementedError(self.__class__.__name__ + ".preprocess_image") @property def weights_file(self) -> str: """Abstract property representing the file containing the model weights. For core models, all model artifacts are handled through get_infer_bucket_file_list method.""" return None @property def model_artifact_bucket(self): return CORE_MODEL_BUCKET class OnnxRoboflowInferenceModel(RoboflowInferenceModel): """Roboflow Inference Model that operates using an ONNX model file.""" def __init__( self, model_id: str, onnxruntime_execution_providers: List[ str ] = get_onnxruntime_execution_providers(ONNXRUNTIME_EXECUTION_PROVIDERS), *args, **kwargs, ): """Initializes the OnnxRoboflowInferenceModel instance. Args: model_id (str): The identifier for the specific ONNX model. *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ super().__init__(model_id, *args, **kwargs) if self.load_weights or not self.has_model_metadata: self.onnxruntime_execution_providers = onnxruntime_execution_providers for ep in self.onnxruntime_execution_providers: if ep == "TensorrtExecutionProvider": ep = ( "TensorrtExecutionProvider", { "trt_engine_cache_enable": True, "trt_engine_cache_path": os.path.join( TENSORRT_CACHE_PATH, self.endpoint ), "trt_fp16_enable": True, }, ) self.initialize_model() self.image_loader_threadpool = ThreadPoolExecutor(max_workers=None) try: self.validate_model() except ModelArtefactError as e: logger.error(f"Unable to validate model artifacts, clearing cache: {e}") self.clear_cache() raise ModelArtefactError from e def infer(self, image: Any, **kwargs) -> Any: input_elements = calculate_input_elements(input_value=image) max_batch_size = MAX_BATCH_SIZE if self.batching_enabled else self.batch_size if (input_elements == 1) or (max_batch_size == float("inf")): return super().infer(image, **kwargs) logger.debug( f"Inference will be executed in batches, as there is {input_elements} input elements and " f"maximum batch size for a model is set to: {max_batch_size}" ) inference_results = [] for batch_input in create_batches(sequence=image, batch_size=max_batch_size): batch_inference_results = super().infer(batch_input, **kwargs) inference_results.append(batch_inference_results) return self.merge_inference_results(inference_results=inference_results) def merge_inference_results(self, inference_results: List[Any]) -> Any: return list(itertools.chain(*inference_results)) def validate_model(self) -> None: if not self.load_weights: return try: assert self.onnx_session is not None except AssertionError as e: raise ModelArtefactError( "ONNX session not initialized. Check that the model weights are available." ) from e try: self.run_test_inference() except Exception as e: raise ModelArtefactError(f"Unable to run test inference. Cause: {e}") from e try: self.validate_model_classes() except Exception as e: raise ModelArtefactError( f"Unable to validate model classes. Cause: {e}" ) from e def run_test_inference(self) -> None: test_image = (np.random.rand(1024, 1024, 3) * 255).astype(np.uint8) return self.infer(test_image) def get_model_output_shape(self) -> Tuple[int, int, int]: test_image = (np.random.rand(1024, 1024, 3) * 255).astype(np.uint8) test_image, _ = self.preprocess(test_image) output = self.predict(test_image)[0] return output.shape def validate_model_classes(self) -> None: pass def get_infer_bucket_file_list(self) -> list: """Returns the list of files to be downloaded from the inference bucket for ONNX model. Returns: list: A list of filenames specific to ONNX models. """ return ["environment.json", "class_names.txt"] def initialize_model(self) -> None: """Initializes the ONNX model, setting up the inference session and other necessary properties.""" self.get_model_artifacts() logger.debug("Creating inference session") if self.load_weights or not self.has_model_metadata: t1_session = perf_counter() # Create an ONNX Runtime Session with a list of execution providers in priority order. ORT attempts to load providers until one is successful. This keeps the code across devices identical. providers = self.onnxruntime_execution_providers if not self.load_weights: providers = ["CPUExecutionProvider"] try: self.onnx_session = onnxruntime.InferenceSession( self.cache_file(self.weights_file), providers=providers, ) except Exception as e: self.clear_cache() raise ModelArtefactError( f"Unable to load ONNX session. Cause: {e}" ) from e logger.debug(f"Session created in {perf_counter() - t1_session} seconds") if REQUIRED_ONNX_PROVIDERS: available_providers = onnxruntime.get_available_providers() for provider in REQUIRED_ONNX_PROVIDERS: if provider not in available_providers: raise OnnxProviderNotAvailable( f"Required ONNX Execution Provider {provider} is not availble. Check that you are using the correct docker image on a supported device." ) inputs = self.onnx_session.get_inputs()[0] input_shape = inputs.shape self.batch_size = input_shape[0] self.img_size_h = input_shape[2] self.img_size_w = input_shape[3] self.input_name = if isinstance(self.img_size_h, str) or isinstance(self.img_size_w, str): if "resize" in self.preproc: self.img_size_h = int(self.preproc["resize"]["height"]) self.img_size_w = int(self.preproc["resize"]["width"]) else: self.img_size_h = 640 self.img_size_w = 640 if isinstance(self.batch_size, str): self.batching_enabled = True logger.debug( f"Model {self.endpoint} is loaded with dynamic batching enabled" ) else: self.batching_enabled = False logger.debug( f"Model {self.endpoint} is loaded with dynamic batching disabled" ) model_metadata = { "batch_size": self.batch_size, "img_size_h": self.img_size_h, "img_size_w": self.img_size_w, } logger.debug(f"Writing model metadata to memcache") self.write_model_metadata_to_memcache(model_metadata) if not self.load_weights: # had to load weights to get metadata del self.onnx_session else: if not self.has_model_metadata: raise ValueError( "This should be unreachable, should get weights if we don't have model metadata" ) logger.debug(f"Loading model metadata from memcache") metadata = self.model_metadata_from_memcache() self.batch_size = metadata["batch_size"] self.img_size_h = metadata["img_size_h"] self.img_size_w = metadata["img_size_w"] if isinstance(self.batch_size, str): self.batching_enabled = True logger.debug( f"Model {self.endpoint} is loaded with dynamic batching enabled" ) else: self.batching_enabled = False logger.debug( f"Model {self.endpoint} is loaded with dynamic batching disabled" ) def load_image( self, image: Any, disable_preproc_auto_orient: bool = False, disable_preproc_contrast: bool = False, disable_preproc_grayscale: bool = False, disable_preproc_static_crop: bool = False, ) -> Tuple[np.ndarray, Tuple[int, int]]: if isinstance(image, list): preproc_image = partial( self.preproc_image, disable_preproc_auto_orient=disable_preproc_auto_orient, disable_preproc_contrast=disable_preproc_contrast, disable_preproc_grayscale=disable_preproc_grayscale, disable_preproc_static_crop=disable_preproc_static_crop, ) imgs_with_dims =, image) imgs, img_dims = zip(*imgs_with_dims) img_in = np.concatenate(imgs, axis=0) else: img_in, img_dims = self.preproc_image( image, disable_preproc_auto_orient=disable_preproc_auto_orient, disable_preproc_contrast=disable_preproc_contrast, disable_preproc_grayscale=disable_preproc_grayscale, disable_preproc_static_crop=disable_preproc_static_crop, ) img_dims = [img_dims] return img_in, img_dims @property def weights_file(self) -> str: """Returns the file containing the ONNX model weights. Returns: str: The file path to the weights file. """ return "weights.onnx" class OnnxRoboflowCoreModel(RoboflowCoreModel): """Roboflow Inference Model that operates using an ONNX model file.""" pass def get_class_names_from_environment_file(environment: Optional[dict]) -> List[str]: if environment is None: raise ModelArtefactError( f"Missing environment while attempting to get model class names." ) if class_mapping_not_available_in_environment(environment=environment): raise ModelArtefactError( f"Missing `CLASS_MAP` in environment or `CLASS_MAP` is not dict." ) class_names = [] for i in range(len(environment["CLASS_MAP"].keys())): class_names.append(environment["CLASS_MAP"][str(i)]) return class_names def class_mapping_not_available_in_environment(environment: dict) -> bool: return "CLASS_MAP" not in environment or not issubclass( type(environment["CLASS_MAP"]), dict ) def get_color_mapping_from_environment( environment: Optional[dict], class_names: List[str] ) -> Dict[str, str]: if color_mapping_available_in_environment(environment=environment): return environment["COLORS"] return { class_name: DEFAULT_COLOR_PALETTE[i % len(DEFAULT_COLOR_PALETTE)] for i, class_name in enumerate(class_names) } def color_mapping_available_in_environment(environment: Optional[dict]) -> bool: return ( environment is not None and "COLORS" in environment and issubclass(type(environment["COLORS"]), dict) ) def is_model_artefacts_bucket_available() -> bool: return ( AWS_ACCESS_KEY_ID is not None and AWS_SECRET_ACCESS_KEY is not None and LAMBDA and S3_CLIENT is not None ) def parse_keypoints_metadata(metadata: list) -> dict: return { e["object_class_id"]: {int(key): value for key, value in e["keypoints"].items()} for e in metadata }