File size: 1,675 Bytes
ebd06cc
 
 
 
 
37419af
ebd06cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37419af
 
ebd06cc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import logging
from baseInfra.dbInterface import DbInterface
from llm.hostedLLM import HostedLLM
from llm.togetherLLM import TogetherLLM
from llm.palmLLM import PalmLLM
from llm.geminiLLM import GeminiLLM


class LLMFactory:
    """
    Factory class for creating LLM objects.
    """

    def __init__(self):
        """
        Constructor for the LLMFactory class.

        Args:
            db_interface: The DBInterface object to use for getting LLM configs.
        """
        self._db_interface = DbInterface() 

    def get_llm(self, llm_path: str) -> object:
        """
        Gets an LLM object of the specified type.

        Args:
            llm_path: The path to the LLM config.

        Returns:
            The LLM object.
        """
        logger = logging.getLogger(__name__)
        try:
            config = self._db_interface.get_config(llm_path)
            logger.debug(llm_path)
            logger.debug(config)
            llm_type = config["llm_type"]
            llm_config=config["llm_config"]
        except Exception as ex:
            logger.exception("Exception in getLLM")
            logger.exception(ex)
            config={}
            llm_type=""
            llm_config={}

        if llm_type == "hostedLLM":
            return HostedLLM(**llm_config)
        elif llm_type == "togetherLLM":
            return TogetherLLM(**llm_config)
        elif llm_type == "palmLLM":
            return PalmLLM(**llm_config)
        elif llm_type  == "geminiLLM":
            return GeminiLLM(**llm_config)
        else:
            logger.error(f"Invalid LLM type: {llm_type}")
            raise ValueError(f"Invalid LLM type: {llm_type}")