Fucius's picture
Upload 422 files
2eafbc4 verified
raw
history blame
8.6 kB
import os
from typing import Optional, Tuple, Union
from inference.core.cache import cache
from inference.core.devices.utils import GLOBAL_DEVICE_ID
from inference.core.entities.types import DatasetID, ModelType, TaskType, VersionID
from inference.core.env import LAMBDA, MODEL_CACHE_DIR
from inference.core.exceptions import (
MissingApiKeyError,
ModelArtefactError,
ModelNotRecognisedError,
)
from inference.core.logger import logger
from inference.core.models.base import Model
from inference.core.registries.base import ModelRegistry
from inference.core.roboflow_api import (
MODEL_TYPE_DEFAULTS,
MODEL_TYPE_KEY,
PROJECT_TASK_TYPE_KEY,
ModelEndpointType,
get_roboflow_dataset_type,
get_roboflow_model_data,
get_roboflow_workspace,
)
from inference.core.utils.file_system import dump_json, read_json
from inference.core.utils.roboflow import get_model_id_chunks
from inference.models.aliases import resolve_roboflow_model_alias
GENERIC_MODELS = {
"clip": ("embed", "clip"),
"sam": ("embed", "sam"),
"gaze": ("gaze", "l2cs"),
"doctr": ("ocr", "doctr"),
"grounding_dino": ("object-detection", "grounding-dino"),
"cogvlm": ("llm", "cogvlm"),
"yolo_world": ("object-detection", "yolo-world"),
}
STUB_VERSION_ID = "0"
CACHE_METADATA_LOCK_TIMEOUT = 1.0
class RoboflowModelRegistry(ModelRegistry):
"""A Roboflow-specific model registry which gets the model type using the model id,
then returns a model class based on the model type.
"""
def get_model(self, model_id: str, api_key: str) -> Model:
"""Returns the model class based on the given model id and API key.
Args:
model_id (str): The ID of the model to be retrieved.
api_key (str): The API key used to authenticate.
Returns:
Model: The model class corresponding to the given model ID and type.
Raises:
ModelNotRecognisedError: If the model type is not supported or found.
"""
model_type = get_model_type(model_id, api_key)
if model_type not in self.registry_dict:
raise ModelNotRecognisedError(f"Model type not supported: {model_type}")
return self.registry_dict[model_type]
def get_model_type(
model_id: str,
api_key: Optional[str] = None,
) -> Tuple[TaskType, ModelType]:
"""Retrieves the model type based on the given model ID and API key.
Args:
model_id (str): The ID of the model.
api_key (str): The API key used to authenticate.
Returns:
tuple: The project task type and the model type.
Raises:
WorkspaceLoadError: If the workspace could not be loaded or if the API key is invalid.
DatasetLoadError: If the dataset could not be loaded due to invalid ID, workspace ID or version ID.
MissingDefaultModelError: If default model is not configured and API does not provide this info
MalformedRoboflowAPIResponseError: Roboflow API responds in invalid format.
"""
model_id = resolve_roboflow_model_alias(model_id=model_id)
dataset_id, version_id = get_model_id_chunks(model_id=model_id)
if dataset_id in GENERIC_MODELS:
logger.debug(f"Loading generic model: {dataset_id}.")
return GENERIC_MODELS[dataset_id]
cached_metadata = get_model_metadata_from_cache(
dataset_id=dataset_id, version_id=version_id
)
if cached_metadata is not None:
return cached_metadata[0], cached_metadata[1]
if version_id == STUB_VERSION_ID:
if api_key is None:
raise MissingApiKeyError(
"Stub model version provided but no API key was provided. API key is required to load stub models."
)
workspace_id = get_roboflow_workspace(api_key=api_key)
project_task_type = get_roboflow_dataset_type(
api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id
)
model_type = "stub"
save_model_metadata_in_cache(
dataset_id=dataset_id,
version_id=version_id,
project_task_type=project_task_type,
model_type=model_type,
)
return project_task_type, model_type
api_data = get_roboflow_model_data(
api_key=api_key,
model_id=model_id,
endpoint_type=ModelEndpointType.ORT,
device_id=GLOBAL_DEVICE_ID,
).get("ort")
if api_data is None:
raise ModelArtefactError("Error loading model artifacts from Roboflow API.")
# some older projects do not have type field - hence defaulting
project_task_type = api_data.get("type", "object-detection")
model_type = api_data.get("modelType")
if model_type is None or model_type == "ort":
# some very old model versions do not have modelType reported - and API respond in a generic way -
# then we shall attempt using default model for given task type
model_type = MODEL_TYPE_DEFAULTS.get(project_task_type)
if model_type is None or project_task_type is None:
raise ModelArtefactError("Error loading model artifacts from Roboflow API.")
save_model_metadata_in_cache(
dataset_id=dataset_id,
version_id=version_id,
project_task_type=project_task_type,
model_type=model_type,
)
return project_task_type, model_type
def get_model_metadata_from_cache(
dataset_id: str, version_id: str
) -> Optional[Tuple[TaskType, ModelType]]:
if LAMBDA:
return _get_model_metadata_from_cache(
dataset_id=dataset_id, version_id=version_id
)
with cache.lock(
f"lock:metadata:{dataset_id}:{version_id}", expire=CACHE_METADATA_LOCK_TIMEOUT
):
return _get_model_metadata_from_cache(
dataset_id=dataset_id, version_id=version_id
)
def _get_model_metadata_from_cache(
dataset_id: str, version_id: str
) -> Optional[Tuple[TaskType, ModelType]]:
model_type_cache_path = construct_model_type_cache_path(
dataset_id=dataset_id, version_id=version_id
)
if not os.path.isfile(model_type_cache_path):
return None
try:
model_metadata = read_json(path=model_type_cache_path)
if model_metadata_content_is_invalid(content=model_metadata):
return None
return model_metadata[PROJECT_TASK_TYPE_KEY], model_metadata[MODEL_TYPE_KEY]
except ValueError as e:
logger.warning(
f"Could not load model description from cache under path: {model_type_cache_path} - decoding issue: {e}."
)
return None
def model_metadata_content_is_invalid(content: Optional[Union[list, dict]]) -> bool:
if content is None:
logger.warning("Empty model metadata file encountered in cache.")
return True
if not issubclass(type(content), dict):
logger.warning("Malformed file encountered in cache.")
return True
if PROJECT_TASK_TYPE_KEY not in content or MODEL_TYPE_KEY not in content:
logger.warning(
f"Could not find one of required keys {PROJECT_TASK_TYPE_KEY} or {MODEL_TYPE_KEY} in cache."
)
return True
return False
def save_model_metadata_in_cache(
dataset_id: DatasetID,
version_id: VersionID,
project_task_type: TaskType,
model_type: ModelType,
) -> None:
if LAMBDA:
_save_model_metadata_in_cache(
dataset_id=dataset_id,
version_id=version_id,
project_task_type=project_task_type,
model_type=model_type,
)
return None
with cache.lock(
f"lock:metadata:{dataset_id}:{version_id}", expire=CACHE_METADATA_LOCK_TIMEOUT
):
_save_model_metadata_in_cache(
dataset_id=dataset_id,
version_id=version_id,
project_task_type=project_task_type,
model_type=model_type,
)
return None
def _save_model_metadata_in_cache(
dataset_id: DatasetID,
version_id: VersionID,
project_task_type: TaskType,
model_type: ModelType,
) -> None:
model_type_cache_path = construct_model_type_cache_path(
dataset_id=dataset_id, version_id=version_id
)
metadata = {
PROJECT_TASK_TYPE_KEY: project_task_type,
MODEL_TYPE_KEY: model_type,
}
dump_json(
path=model_type_cache_path, content=metadata, allow_override=True, indent=4
)
def construct_model_type_cache_path(dataset_id: str, version_id: str) -> str:
cache_dir = os.path.join(MODEL_CACHE_DIR, dataset_id, version_id)
return os.path.join(cache_dir, "model_type.json")