Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from typing import Optional, Tuple, Union | |
from inference.core.cache import cache | |
from inference.core.devices.utils import GLOBAL_DEVICE_ID | |
from inference.core.entities.types import DatasetID, ModelType, TaskType, VersionID | |
from inference.core.env import LAMBDA, MODEL_CACHE_DIR | |
from inference.core.exceptions import ( | |
MissingApiKeyError, | |
ModelArtefactError, | |
ModelNotRecognisedError, | |
) | |
from inference.core.logger import logger | |
from inference.core.models.base import Model | |
from inference.core.registries.base import ModelRegistry | |
from inference.core.roboflow_api import ( | |
MODEL_TYPE_DEFAULTS, | |
MODEL_TYPE_KEY, | |
PROJECT_TASK_TYPE_KEY, | |
ModelEndpointType, | |
get_roboflow_dataset_type, | |
get_roboflow_model_data, | |
get_roboflow_workspace, | |
) | |
from inference.core.utils.file_system import dump_json, read_json | |
from inference.core.utils.roboflow import get_model_id_chunks | |
from inference.models.aliases import resolve_roboflow_model_alias | |
GENERIC_MODELS = { | |
"clip": ("embed", "clip"), | |
"sam": ("embed", "sam"), | |
"gaze": ("gaze", "l2cs"), | |
"doctr": ("ocr", "doctr"), | |
"grounding_dino": ("object-detection", "grounding-dino"), | |
"cogvlm": ("llm", "cogvlm"), | |
"yolo_world": ("object-detection", "yolo-world"), | |
} | |
STUB_VERSION_ID = "0" | |
CACHE_METADATA_LOCK_TIMEOUT = 1.0 | |
class RoboflowModelRegistry(ModelRegistry): | |
"""A Roboflow-specific model registry which gets the model type using the model id, | |
then returns a model class based on the model type. | |
""" | |
def get_model(self, model_id: str, api_key: str) -> Model: | |
"""Returns the model class based on the given model id and API key. | |
Args: | |
model_id (str): The ID of the model to be retrieved. | |
api_key (str): The API key used to authenticate. | |
Returns: | |
Model: The model class corresponding to the given model ID and type. | |
Raises: | |
ModelNotRecognisedError: If the model type is not supported or found. | |
""" | |
model_type = get_model_type(model_id, api_key) | |
if model_type not in self.registry_dict: | |
raise ModelNotRecognisedError(f"Model type not supported: {model_type}") | |
return self.registry_dict[model_type] | |
def get_model_type( | |
model_id: str, | |
api_key: Optional[str] = None, | |
) -> Tuple[TaskType, ModelType]: | |
"""Retrieves the model type based on the given model ID and API key. | |
Args: | |
model_id (str): The ID of the model. | |
api_key (str): The API key used to authenticate. | |
Returns: | |
tuple: The project task type and the model type. | |
Raises: | |
WorkspaceLoadError: If the workspace could not be loaded or if the API key is invalid. | |
DatasetLoadError: If the dataset could not be loaded due to invalid ID, workspace ID or version ID. | |
MissingDefaultModelError: If default model is not configured and API does not provide this info | |
MalformedRoboflowAPIResponseError: Roboflow API responds in invalid format. | |
""" | |
model_id = resolve_roboflow_model_alias(model_id=model_id) | |
dataset_id, version_id = get_model_id_chunks(model_id=model_id) | |
if dataset_id in GENERIC_MODELS: | |
logger.debug(f"Loading generic model: {dataset_id}.") | |
return GENERIC_MODELS[dataset_id] | |
cached_metadata = get_model_metadata_from_cache( | |
dataset_id=dataset_id, version_id=version_id | |
) | |
if cached_metadata is not None: | |
return cached_metadata[0], cached_metadata[1] | |
if version_id == STUB_VERSION_ID: | |
if api_key is None: | |
raise MissingApiKeyError( | |
"Stub model version provided but no API key was provided. API key is required to load stub models." | |
) | |
workspace_id = get_roboflow_workspace(api_key=api_key) | |
project_task_type = get_roboflow_dataset_type( | |
api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id | |
) | |
model_type = "stub" | |
save_model_metadata_in_cache( | |
dataset_id=dataset_id, | |
version_id=version_id, | |
project_task_type=project_task_type, | |
model_type=model_type, | |
) | |
return project_task_type, model_type | |
api_data = get_roboflow_model_data( | |
api_key=api_key, | |
model_id=model_id, | |
endpoint_type=ModelEndpointType.ORT, | |
device_id=GLOBAL_DEVICE_ID, | |
).get("ort") | |
if api_data is None: | |
raise ModelArtefactError("Error loading model artifacts from Roboflow API.") | |
# some older projects do not have type field - hence defaulting | |
project_task_type = api_data.get("type", "object-detection") | |
model_type = api_data.get("modelType") | |
if model_type is None or model_type == "ort": | |
# some very old model versions do not have modelType reported - and API respond in a generic way - | |
# then we shall attempt using default model for given task type | |
model_type = MODEL_TYPE_DEFAULTS.get(project_task_type) | |
if model_type is None or project_task_type is None: | |
raise ModelArtefactError("Error loading model artifacts from Roboflow API.") | |
save_model_metadata_in_cache( | |
dataset_id=dataset_id, | |
version_id=version_id, | |
project_task_type=project_task_type, | |
model_type=model_type, | |
) | |
return project_task_type, model_type | |
def get_model_metadata_from_cache( | |
dataset_id: str, version_id: str | |
) -> Optional[Tuple[TaskType, ModelType]]: | |
if LAMBDA: | |
return _get_model_metadata_from_cache( | |
dataset_id=dataset_id, version_id=version_id | |
) | |
with cache.lock( | |
f"lock:metadata:{dataset_id}:{version_id}", expire=CACHE_METADATA_LOCK_TIMEOUT | |
): | |
return _get_model_metadata_from_cache( | |
dataset_id=dataset_id, version_id=version_id | |
) | |
def _get_model_metadata_from_cache( | |
dataset_id: str, version_id: str | |
) -> Optional[Tuple[TaskType, ModelType]]: | |
model_type_cache_path = construct_model_type_cache_path( | |
dataset_id=dataset_id, version_id=version_id | |
) | |
if not os.path.isfile(model_type_cache_path): | |
return None | |
try: | |
model_metadata = read_json(path=model_type_cache_path) | |
if model_metadata_content_is_invalid(content=model_metadata): | |
return None | |
return model_metadata[PROJECT_TASK_TYPE_KEY], model_metadata[MODEL_TYPE_KEY] | |
except ValueError as e: | |
logger.warning( | |
f"Could not load model description from cache under path: {model_type_cache_path} - decoding issue: {e}." | |
) | |
return None | |
def model_metadata_content_is_invalid(content: Optional[Union[list, dict]]) -> bool: | |
if content is None: | |
logger.warning("Empty model metadata file encountered in cache.") | |
return True | |
if not issubclass(type(content), dict): | |
logger.warning("Malformed file encountered in cache.") | |
return True | |
if PROJECT_TASK_TYPE_KEY not in content or MODEL_TYPE_KEY not in content: | |
logger.warning( | |
f"Could not find one of required keys {PROJECT_TASK_TYPE_KEY} or {MODEL_TYPE_KEY} in cache." | |
) | |
return True | |
return False | |
def save_model_metadata_in_cache( | |
dataset_id: DatasetID, | |
version_id: VersionID, | |
project_task_type: TaskType, | |
model_type: ModelType, | |
) -> None: | |
if LAMBDA: | |
_save_model_metadata_in_cache( | |
dataset_id=dataset_id, | |
version_id=version_id, | |
project_task_type=project_task_type, | |
model_type=model_type, | |
) | |
return None | |
with cache.lock( | |
f"lock:metadata:{dataset_id}:{version_id}", expire=CACHE_METADATA_LOCK_TIMEOUT | |
): | |
_save_model_metadata_in_cache( | |
dataset_id=dataset_id, | |
version_id=version_id, | |
project_task_type=project_task_type, | |
model_type=model_type, | |
) | |
return None | |
def _save_model_metadata_in_cache( | |
dataset_id: DatasetID, | |
version_id: VersionID, | |
project_task_type: TaskType, | |
model_type: ModelType, | |
) -> None: | |
model_type_cache_path = construct_model_type_cache_path( | |
dataset_id=dataset_id, version_id=version_id | |
) | |
metadata = { | |
PROJECT_TASK_TYPE_KEY: project_task_type, | |
MODEL_TYPE_KEY: model_type, | |
} | |
dump_json( | |
path=model_type_cache_path, content=metadata, allow_override=True, indent=4 | |
) | |
def construct_model_type_cache_path(dataset_id: str, version_id: str) -> str: | |
cache_dir = os.path.join(MODEL_CACHE_DIR, dataset_id, version_id) | |
return os.path.join(cache_dir, "model_type.json") | |