Spaces:
Runtime error
Runtime error
import logging | |
from baseInfra.dbInterface import DbInterface | |
from llm.hostedLLM import HostedLLM | |
from llm.togetherLLM import TogetherLLM | |
from llm.palmLLM import PalmLLM | |
from llm.geminiLLM import GeminiLLM | |
class LLMFactory: | |
""" | |
Factory class for creating LLM objects. | |
""" | |
def __init__(self): | |
""" | |
Constructor for the LLMFactory class. | |
Args: | |
db_interface: The DBInterface object to use for getting LLM configs. | |
""" | |
self._db_interface = DbInterface() | |
def get_llm(self, llm_path: str) -> object: | |
""" | |
Gets an LLM object of the specified type. | |
Args: | |
llm_path: The path to the LLM config. | |
Returns: | |
The LLM object. | |
""" | |
logger = logging.getLogger(__name__) | |
try: | |
config = self._db_interface.get_config(llm_path) | |
logger.debug(llm_path) | |
logger.debug(config) | |
llm_type = config["llm_type"] | |
llm_config=config["llm_config"] | |
except Exception as ex: | |
logger.exception("Exception in getLLM") | |
logger.exception(ex) | |
config={} | |
llm_type="" | |
llm_config={} | |
if llm_type == "hostedLLM": | |
return HostedLLM(**llm_config) | |
elif llm_type == "togetherLLM": | |
return TogetherLLM(**llm_config) | |
elif llm_type == "palmLLM": | |
return PalmLLM(**llm_config) | |
elif llm_type == "geminiLLM": | |
return GeminiLLM(**llm_config) | |
else: | |
logger.error(f"Invalid LLM type: {llm_type}") | |
raise ValueError(f"Invalid LLM type: {llm_type}") | |