Spaces:
Sleeping
Sleeping
| # from config_loader import | |
| import os | |
| from dotenv import load_dotenv | |
| from pydantic import BaseModel, Field | |
| from typing import Literal, Optional,Any | |
| from app.utils.config_loader import load_config | |
| from langchain_groq import ChatGroq | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from dotenv import load_dotenv | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| # from langchain_openai import OpenAIEmbeddings | |
| from langchain_community.embeddings import OpenAIEmbeddings | |
| from langchain.chat_models import init_chat_model | |
| class ConfigLoader: | |
| def __init__(self): | |
| print(f"Loading config....") | |
| self.config = load_config() | |
| def __getitem__(self,key):## This method allows you to access config values using dictionary-like syntax | |
| return self.config[key] | |
| class ModelLoader(BaseModel): | |
| model_provider: Literal["groq", "gemini", "openai","gemini_lite", "huggingface","openrouter"] = "openrouter" | |
| config: Optional[ConfigLoader] = Field(default = None, exclude = True) # either the config is ConfigLoader object or None | |
| def model_post_init(self, __context: Any)->None: | |
| self.config = ConfigLoader() # Automatically ensures that whenever you create ModelLoader, it loads the config.. model_post_init is a Pydantic V2 hook, which runs after model creation.It assigns a ConfigLoader() instance to self.config.This ensures the configuration is loaded whenever you create a ModelLoader. | |
| class Config: | |
| arbitrary_types_allowed = True # Allows ConfigLoader (a non-Pydantic class) to be used as a field in the model. | |
| def load_llm(self): | |
| """ | |
| Load and return the LLM model | |
| """ | |
| print("LLM loading...") | |
| print("Loading model from provider: ") | |
| if self.model_provider == "groq": | |
| print("Loading model from GROQ:") | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| model_name = self.config["llm"]["groq"]["model_name"] | |
| llm = ChatGroq(model = model_name, api_key = groq_api_key) | |
| elif self.model_provider =="gemini": | |
| print("Loading model from gemini:") | |
| load_dotenv() | |
| gemini_api_key = os.getenv("GEMINI_API_KEY") | |
| model_name = self.config["llm"]["gemini"]["model_name"] | |
| llm = ChatGoogleGenerativeAI( | |
| model=model_name, | |
| google_api_key= gemini_api_key | |
| ) | |
| elif self.model_provider =="gemini_lite": | |
| print("Loading model from gemini-flash-lite:") | |
| load_dotenv() | |
| gemini_api_key = os.getenv("GEMINI_API_KEY") | |
| model_name = self.config["llm"]["gemini_lite"]["model_name"] | |
| llm = ChatGoogleGenerativeAI( | |
| model=model_name, | |
| google_api_key= gemini_api_key | |
| ) | |
| elif self.model_provider =="openai": | |
| load_dotenv() | |
| print("Loading model from openai:") | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| model_name = self.config["embedding_model"]["openai"]["model_name"] | |
| llm = OpenAIEmbeddings(model=model_name, api_key = api_key) | |
| elif self.model_provider == "openrouter": | |
| load_dotenv() | |
| api_key = os.getenv("OPENROUTER_API_KEY") | |
| model_name = self.config["llm"]["openrouter"]["model_name"] | |
| llm = init_chat_model( | |
| model=model_name, | |
| model_provider="openai", | |
| base_url="https://openrouter.ai/api/v1", | |
| api_key=api_key | |
| ) | |
| elif self.model_provider =="huggingface": | |
| load_dotenv() | |
| print("Loading model from huggingface:") | |
| print("HF_TOKEN in env:", os.getenv("HF_TOKEN")) | |
| api_key = os.getenv("HF_TOKEN") | |
| print(f"HF api key {api_key}") | |
| os.environ["HF_TOKEN"] = api_key # Ensure the token is set in the environment | |
| model_name = self.config["embedding_model"]["huggingface"]["model_name"] | |
| llm = HuggingFaceEmbeddings(model=model_name) | |
| else: | |
| raise ValueError(f"Unsupported model provider: {self.model_provider}") | |
| return llm | |