coderpotter's picture
Upload folder using huggingface_hub
7b2e5db verified
import os
import sys
from langchain_anthropic import ChatAnthropic
from langchain_fireworks import ChatFireworks
from langchain_google_vertexai import ChatVertexAI
from langchain_openai import ChatOpenAI
sys.path.append(os.getcwd())
import KEYS
from research_assistant.app_logging import app_logger
def set_api_key(env_var: str, api_key: str):
os.environ[env_var] = api_key
class Agent:
def __init__(self, model_name: str):
model_classes = {
"gpt": (
(ChatOpenAI, "OPENAI_API_KEY", KEYS.OPENAI) # type: ignore
if "OPENAI" in KEYS.__dict__
else (None, None, None)
),
"claude": (
(ChatAnthropic, "ANTHROPIC_API_KEY", KEYS.ANTHROPIC) # type: ignore
if "ANTHROPIC" in KEYS.__dict__
else (None, None, None)
),
"gemini": (
(ChatVertexAI, "GOOGLE_API_KEY", KEYS.VERTEX_AI) # type: ignore
if "VERTEX_AI" in KEYS.__dict__
else (None, None, None)
),
"fireworks": (
(ChatFireworks, "FIREWORKS_API_KEY", KEYS.FIREWORKS_AI) # type: ignore
if "FIREWORKS_AI" in KEYS.__dict__
else (None, None, None)
),
}
max_tokens_map = {
"gpt-3.5": 16000,
"gpt-4": 8000,
"gpt-4o-mini": 8000,
"llama-v3p2-1b-instruct": 128000,
"llama-v3p2-3b-instruct": 128000,
"llama-v3p1-8b-instruct": 128000,
"llama-v3p1-70b-instruct": 128000,
"llama-v3p1-405b-instruct": 128000,
"mixtral-8x22b-instruct": 64000,
"mixtral-8x7b-instruct": 32000,
"mixtral-8x7b-instruct-hf": 32000,
"qwen2p5-72b-instruct": 32000,
"gemma2-9b-it": 8000,
"llama-v3-8b-instruct": 8000,
"llama-v3-70b-instruct": 8000,
"llama-v3-70b-instruct-hf": 8000,
}
for key, (model_class, env_var, api_key) in model_classes.items():
if model_class is not None and key in model_name:
set_api_key(env_var, api_key) # type: ignore
model = model_class(model=model_name, temperature=0.5) # type: ignore
max_tokens = max_tokens_map.get(model_name, 128000)
break
else:
raise ValueError(f"Model {model_name} not supported")
app_logger.info(f"Model {model_name} is initialized successfully")
self.model = model
self.max_tokens = max_tokens
def get_model(self):
return self.model