maya-persistence / src /llm /llmFactory.py
anubhav77's picture
v0.1.1
37419af
import logging
from baseInfra.dbInterface import DbInterface
from llm.hostedLLM import HostedLLM
from llm.togetherLLM import TogetherLLM
from llm.palmLLM import PalmLLM
from llm.geminiLLM import GeminiLLM
class LLMFactory:
"""
Factory class for creating LLM objects.
"""
def __init__(self):
"""
Constructor for the LLMFactory class.
Args:
db_interface: The DBInterface object to use for getting LLM configs.
"""
self._db_interface = DbInterface()
def get_llm(self, llm_path: str) -> object:
"""
Gets an LLM object of the specified type.
Args:
llm_path: The path to the LLM config.
Returns:
The LLM object.
"""
logger = logging.getLogger(__name__)
try:
config = self._db_interface.get_config(llm_path)
logger.debug(llm_path)
logger.debug(config)
llm_type = config["llm_type"]
llm_config=config["llm_config"]
except Exception as ex:
logger.exception("Exception in getLLM")
logger.exception(ex)
config={}
llm_type=""
llm_config={}
if llm_type == "hostedLLM":
return HostedLLM(**llm_config)
elif llm_type == "togetherLLM":
return TogetherLLM(**llm_config)
elif llm_type == "palmLLM":
return PalmLLM(**llm_config)
elif llm_type == "geminiLLM":
return GeminiLLM(**llm_config)
else:
logger.error(f"Invalid LLM type: {llm_type}")
raise ValueError(f"Invalid LLM type: {llm_type}")