ClariDoc / app /utils /model_loader.py
Kshitijk20's picture
fix
08cfd0c
# from config_loader import
import os
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from typing import Literal, Optional,Any
from app.utils.config_loader import load_config
from langchain_groq import ChatGroq
from langchain_google_genai import ChatGoogleGenerativeAI
from dotenv import load_dotenv
from langchain_huggingface import HuggingFaceEmbeddings
# from langchain_openai import OpenAIEmbeddings
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.chat_models import init_chat_model
class ConfigLoader:
def __init__(self):
print(f"Loading config....")
self.config = load_config()
def __getitem__(self,key):## This method allows you to access config values using dictionary-like syntax
return self.config[key]
class ModelLoader(BaseModel):
model_provider: Literal["groq", "gemini", "openai","gemini_lite", "huggingface","openrouter"] = "openrouter"
config: Optional[ConfigLoader] = Field(default = None, exclude = True) # either the config is ConfigLoader object or None
def model_post_init(self, __context: Any)->None:
self.config = ConfigLoader() # Automatically ensures that whenever you create ModelLoader, it loads the config.. model_post_init is a Pydantic V2 hook, which runs after model creation.It assigns a ConfigLoader() instance to self.config.This ensures the configuration is loaded whenever you create a ModelLoader.
class Config:
arbitrary_types_allowed = True # Allows ConfigLoader (a non-Pydantic class) to be used as a field in the model.
def load_llm(self):
"""
Load and return the LLM model
"""
print("LLM loading...")
print("Loading model from provider: ")
if self.model_provider == "groq":
print("Loading model from GROQ:")
groq_api_key = os.getenv("GROQ_API_KEY")
model_name = self.config["llm"]["groq"]["model_name"]
llm = ChatGroq(model = model_name, api_key = groq_api_key)
elif self.model_provider =="gemini":
print("Loading model from gemini:")
load_dotenv()
gemini_api_key = os.getenv("GEMINI_API_KEY")
model_name = self.config["llm"]["gemini"]["model_name"]
llm = ChatGoogleGenerativeAI(
model=model_name,
google_api_key= gemini_api_key
)
elif self.model_provider =="gemini_lite":
print("Loading model from gemini-flash-lite:")
load_dotenv()
gemini_api_key = os.getenv("GEMINI_API_KEY")
model_name = self.config["llm"]["gemini_lite"]["model_name"]
llm = ChatGoogleGenerativeAI(
model=model_name,
google_api_key= gemini_api_key
)
elif self.model_provider =="openai":
load_dotenv()
print("Loading model from openai:")
api_key = os.getenv("OPENAI_API_KEY")
model_name = self.config["embedding_model"]["openai"]["model_name"]
llm = OpenAIEmbeddings(model=model_name, api_key = api_key)
elif self.model_provider == "openrouter":
load_dotenv()
api_key = os.getenv("OPENROUTER_API_KEY")
model_name = self.config["llm"]["openrouter"]["model_name"]
llm = init_chat_model(
model=model_name,
model_provider="openai",
base_url="https://openrouter.ai/api/v1",
api_key=api_key
)
elif self.model_provider =="huggingface":
load_dotenv()
print("Loading model from huggingface:")
print("HF_TOKEN in env:", os.getenv("HF_TOKEN"))
api_key = os.getenv("HF_TOKEN")
print(f"HF api key {api_key}")
os.environ["HF_TOKEN"] = api_key # Ensure the token is set in the environment
model_name = self.config["embedding_model"]["huggingface"]["model_name"]
llm = HuggingFaceEmbeddings(model=model_name)
else:
raise ValueError(f"Unsupported model provider: {self.model_provider}")
return llm