Spaces:
Runtime error
Runtime error
import os | |
import sys | |
from langchain_anthropic import ChatAnthropic | |
from langchain_fireworks import ChatFireworks | |
from langchain_google_vertexai import ChatVertexAI | |
from langchain_openai import ChatOpenAI | |
sys.path.append(os.getcwd()) | |
import KEYS | |
from research_assistant.app_logging import app_logger | |
def set_api_key(env_var: str, api_key: str): | |
os.environ[env_var] = api_key | |
class Agent: | |
def __init__(self, model_name: str): | |
model_classes = { | |
"gpt": ( | |
(ChatOpenAI, "OPENAI_API_KEY", KEYS.OPENAI) # type: ignore | |
if "OPENAI" in KEYS.__dict__ | |
else (None, None, None) | |
), | |
"claude": ( | |
(ChatAnthropic, "ANTHROPIC_API_KEY", KEYS.ANTHROPIC) # type: ignore | |
if "ANTHROPIC" in KEYS.__dict__ | |
else (None, None, None) | |
), | |
"gemini": ( | |
(ChatVertexAI, "GOOGLE_API_KEY", KEYS.VERTEX_AI) # type: ignore | |
if "VERTEX_AI" in KEYS.__dict__ | |
else (None, None, None) | |
), | |
"fireworks": ( | |
(ChatFireworks, "FIREWORKS_API_KEY", KEYS.FIREWORKS_AI) # type: ignore | |
if "FIREWORKS_AI" in KEYS.__dict__ | |
else (None, None, None) | |
), | |
} | |
max_tokens_map = { | |
"gpt-3.5": 16000, | |
"gpt-4": 8000, | |
"gpt-4o-mini": 8000, | |
"llama-v3p2-1b-instruct": 128000, | |
"llama-v3p2-3b-instruct": 128000, | |
"llama-v3p1-8b-instruct": 128000, | |
"llama-v3p1-70b-instruct": 128000, | |
"llama-v3p1-405b-instruct": 128000, | |
"mixtral-8x22b-instruct": 64000, | |
"mixtral-8x7b-instruct": 32000, | |
"mixtral-8x7b-instruct-hf": 32000, | |
"qwen2p5-72b-instruct": 32000, | |
"gemma2-9b-it": 8000, | |
"llama-v3-8b-instruct": 8000, | |
"llama-v3-70b-instruct": 8000, | |
"llama-v3-70b-instruct-hf": 8000, | |
} | |
for key, (model_class, env_var, api_key) in model_classes.items(): | |
if model_class is not None and key in model_name: | |
set_api_key(env_var, api_key) # type: ignore | |
model = model_class(model=model_name, temperature=0.5) # type: ignore | |
max_tokens = max_tokens_map.get(model_name, 128000) | |
break | |
else: | |
raise ValueError(f"Model {model_name} not supported") | |
app_logger.info(f"Model {model_name} is initialized successfully") | |
self.model = model | |
self.max_tokens = max_tokens | |
def get_model(self): | |
return self.model | |