import logging from baseInfra.dbInterface import DbInterface from llm.hostedLLM import HostedLLM from llm.togetherLLM import TogetherLLM from llm.palmLLM import PalmLLM from llm.geminiLLM import GeminiLLM class LLMFactory: """ Factory class for creating LLM objects. """ def __init__(self): """ Constructor for the LLMFactory class. Args: db_interface: The DBInterface object to use for getting LLM configs. """ self._db_interface = DbInterface() def get_llm(self, llm_path: str) -> object: """ Gets an LLM object of the specified type. Args: llm_path: The path to the LLM config. Returns: The LLM object. """ logger = logging.getLogger(__name__) try: config = self._db_interface.get_config(llm_path) logger.debug(llm_path) logger.debug(config) llm_type = config["llm_type"] llm_config=config["llm_config"] except Exception as ex: logger.exception("Exception in getLLM") logger.exception(ex) config={} llm_type="" llm_config={} if llm_type == "hostedLLM": return HostedLLM(**llm_config) elif llm_type == "togetherLLM": return TogetherLLM(**llm_config) elif llm_type == "palmLLM": return PalmLLM(**llm_config) elif llm_type == "geminiLLM": return GeminiLLM(**llm_config) else: logger.error(f"Invalid LLM type: {llm_type}") raise ValueError(f"Invalid LLM type: {llm_type}")