Spaces:
Runtime error
Runtime error
File size: 1,675 Bytes
ebd06cc 37419af ebd06cc 37419af ebd06cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import logging
from baseInfra.dbInterface import DbInterface
from llm.hostedLLM import HostedLLM
from llm.togetherLLM import TogetherLLM
from llm.palmLLM import PalmLLM
from llm.geminiLLM import GeminiLLM
class LLMFactory:
"""
Factory class for creating LLM objects.
"""
def __init__(self):
"""
Constructor for the LLMFactory class.
Args:
db_interface: The DBInterface object to use for getting LLM configs.
"""
self._db_interface = DbInterface()
def get_llm(self, llm_path: str) -> object:
"""
Gets an LLM object of the specified type.
Args:
llm_path: The path to the LLM config.
Returns:
The LLM object.
"""
logger = logging.getLogger(__name__)
try:
config = self._db_interface.get_config(llm_path)
logger.debug(llm_path)
logger.debug(config)
llm_type = config["llm_type"]
llm_config=config["llm_config"]
except Exception as ex:
logger.exception("Exception in getLLM")
logger.exception(ex)
config={}
llm_type=""
llm_config={}
if llm_type == "hostedLLM":
return HostedLLM(**llm_config)
elif llm_type == "togetherLLM":
return TogetherLLM(**llm_config)
elif llm_type == "palmLLM":
return PalmLLM(**llm_config)
elif llm_type == "geminiLLM":
return GeminiLLM(**llm_config)
else:
logger.error(f"Invalid LLM type: {llm_type}")
raise ValueError(f"Invalid LLM type: {llm_type}")
|