from inference.core.exceptions import ModelNotRecognisedError from inference.core.models.base import Model class ModelRegistry: """An object which is able to return model classes based on given model IDs and model types. Attributes: registry_dict (dict): A dictionary mapping model types to model classes. """ def __init__(self, registry_dict) -> None: """Initializes the ModelRegistry with the given dictionary of registered models. Args: registry_dict (dict): A dictionary mapping model types to model classes. """ self.registry_dict = registry_dict def get_model(self, model_type: str, model_id: str) -> Model: """Returns the model class based on the given model type. Args: model_type (str): The type of the model to be retrieved. model_id (str): The ID of the model to be retrieved (unused in the current implementation). Returns: Model: The model class corresponding to the given model type. Raises: ModelNotRecognisedError: If the model_type is not found in the registry_dict. """ if model_type not in self.registry_dict: raise ModelNotRecognisedError( f"Could not find model of type: {model_type} in configured registry." ) return self.registry_dict[model_type]