Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| from langchain_anthropic import ChatAnthropic | |
| from langchain_fireworks import ChatFireworks | |
| from langchain_google_vertexai import ChatVertexAI | |
| from langchain_openai import ChatOpenAI | |
| sys.path.append(os.getcwd()) | |
| import KEYS | |
| from research_assistant.app_logging import app_logger | |
| def set_api_key(env_var: str, api_key: str): | |
| os.environ[env_var] = api_key | |
| class Agent: | |
| def __init__(self, model_name: str): | |
| model_classes = { | |
| "gpt": ( | |
| (ChatOpenAI, "OPENAI_API_KEY", KEYS.OPENAI) # type: ignore | |
| if "OPENAI" in KEYS.__dict__ | |
| else (None, None, None) | |
| ), | |
| "claude": ( | |
| (ChatAnthropic, "ANTHROPIC_API_KEY", KEYS.ANTHROPIC) # type: ignore | |
| if "ANTHROPIC" in KEYS.__dict__ | |
| else (None, None, None) | |
| ), | |
| "gemini": ( | |
| (ChatVertexAI, "GOOGLE_API_KEY", KEYS.VERTEX_AI) # type: ignore | |
| if "VERTEX_AI" in KEYS.__dict__ | |
| else (None, None, None) | |
| ), | |
| "fireworks": ( | |
| (ChatFireworks, "FIREWORKS_API_KEY", KEYS.FIREWORKS_AI) # type: ignore | |
| if "FIREWORKS_AI" in KEYS.__dict__ | |
| else (None, None, None) | |
| ), | |
| } | |
| max_tokens_map = { | |
| "gpt-3.5": 16000, | |
| "gpt-4": 8000, | |
| "gpt-4o-mini": 8000, | |
| "llama-v3p2-1b-instruct": 128000, | |
| "llama-v3p2-3b-instruct": 128000, | |
| "llama-v3p1-8b-instruct": 128000, | |
| "llama-v3p1-70b-instruct": 128000, | |
| "llama-v3p1-405b-instruct": 128000, | |
| "mixtral-8x22b-instruct": 64000, | |
| "mixtral-8x7b-instruct": 32000, | |
| "mixtral-8x7b-instruct-hf": 32000, | |
| "qwen2p5-72b-instruct": 32000, | |
| "gemma2-9b-it": 8000, | |
| "llama-v3-8b-instruct": 8000, | |
| "llama-v3-70b-instruct": 8000, | |
| "llama-v3-70b-instruct-hf": 8000, | |
| } | |
| for key, (model_class, env_var, api_key) in model_classes.items(): | |
| if model_class is not None and key in model_name: | |
| set_api_key(env_var, api_key) # type: ignore | |
| model = model_class(model=model_name, temperature=0.5) # type: ignore | |
| max_tokens = max_tokens_map.get(model_name, 128000) | |
| break | |
| else: | |
| raise ValueError(f"Model {model_name} not supported") | |
| app_logger.info(f"Model {model_name} is initialized successfully") | |
| self.model = model | |
| self.max_tokens = max_tokens | |
| def get_model(self): | |
| return self.model | |