File size: 1,395 Bytes
2eafbc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from inference.core.exceptions import ModelNotRecognisedError
from inference.core.models.base import Model


class ModelRegistry:
    """An object which is able to return model classes based on given model IDs and model types.

    Attributes:
        registry_dict (dict): A dictionary mapping model types to model classes.
    """

    def __init__(self, registry_dict) -> None:
        """Initializes the ModelRegistry with the given dictionary of registered models.

        Args:
            registry_dict (dict): A dictionary mapping model types to model classes.
        """
        self.registry_dict = registry_dict

    def get_model(self, model_type: str, model_id: str) -> Model:
        """Returns the model class based on the given model type.

        Args:
            model_type (str): The type of the model to be retrieved.
            model_id (str): The ID of the model to be retrieved (unused in the current implementation).

        Returns:
            Model: The model class corresponding to the given model type.

        Raises:
            ModelNotRecognisedError: If the model_type is not found in the registry_dict.
        """
        if model_type not in self.registry_dict:
            raise ModelNotRecognisedError(
                f"Could not find model of type: {model_type} in configured registry."
            )
        return self.registry_dict[model_type]