HMC-demo / agent.py
ofermend's picture
updated
3a7ce4d
import os
from typing import Optional
from pydantic import Field, BaseModel
from omegaconf import OmegaConf
from vectara_agentic.agent import Agent
from vectara_agentic.tools import VectaraToolFactory, ToolsFactory
from vectara_agentic.agent_config import AgentConfig
from vectara_agentic.types import ModelProvider, AgentType
from dotenv import load_dotenv
load_dotenv(override=True)
initial_prompt = "How can I help you today?"
tickers = {
"GOOG": "Google",
"NVDA": "Nvidia",
"META": "Meta",
"BKNG": "Bookings holding",
}
def create_assistant_tools(cfg):
def get_company_info() -> list[str]:
"""
Returns a dictionary of companies you can query about. Always check this before using any other tool.
The output is a dictionary of valid ticker symbols mapped to company names.
You can use this to identify the companies you can query about, and their ticker information.
"""
return tickers
class QueryHMC(BaseModel):
ticker: Optional[str] = Field(
default=None,
description="The company ticker.",
examples=['GOOG', 'META']
)
year: int | str = Field(
default=None,
description="The year of the report, or a string specifying a condition on the year",
examples=[2020, '>2021', '<2023', '>=2021', '<=2023', '[2021, 2023]', '[2021, 2023)']
)
quarter: Optional[int] = Field(
default=None,
description="The quarter of the report.",
examples=[1, 2, 3, 4]
)
filing_type: Optional[str] = Field(
default=None,
description="The type of filing.",
examples=['10K', '10Q']
)
vec_factory = VectaraToolFactory(
vectara_api_key=cfg.api_key,
vectara_corpus_key=cfg.corpus_key
)
summarizer = 'vectara-summary-table-md-query-ext-jan-2025-gpt-4o'
#summarizer = 'vectara-summary-ext-24-05-med-omni'
ask_hmc = vec_factory.create_rag_tool(
tool_name = "ask_hmc",
tool_description = """
Given a user query,
returns a response to a user question about fund management companies.
""",
tool_args_schema = QueryHMC,
reranker = "chain", rerank_k = 100,
rerank_chain = [
{
"type": "slingshot",
"cutoff": 0.2
},
{
"type": "mmr",
"diversity_bias": 0.05,
"limit": 50
}
],
n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005,
vectara_summarizer = summarizer,
summary_num_results = 10,
max_tokens = 4096, max_response_chars = 8192,
include_citations = True,
verbose = True,
save_history = True,
)
tools_factory = ToolsFactory()
return [ask_hmc] + [tools_factory.create_tool(get_company_info)]
def initialize_agent(_cfg, agent_progress_callback=None):
bot_instructions = """
- You are a helpful assistant, with expertise in management of public company stock portfolios.
- Use the 'ask_hmc' tool to answer questions about public company performance, risks, and other financial metrics.
If the tool responds with "I don't have enough information to answer", try rephrasing the question.
- Use the year, quarter, filing_type and ticker arguments to the 'ask_hmc' tool to get more specific answers.
- Note that 10Q reports exist for quarters 1, 2, 3 and for the 4th quarter there is a 10K report.
- If the 'ask_hmc' tool does not return any results, check the year and ticker and try calling it again with the right values.
"""
agent_config = AgentConfig(
agent_type = os.getenv("VECTARA_AGENTIC_AGENT_TYPE", AgentType.OPENAI.value),
main_llm_provider = os.getenv("VECTARA_AGENTIC_MAIN_LLM_PROVIDER", ModelProvider.OPENAI.value),
main_llm_model_name = os.getenv("VECTARA_AGENTIC_MAIN_MODEL_NAME", ""),
tool_llm_provider = os.getenv("VECTARA_AGENTIC_TOOL_LLM_PROVIDER", ModelProvider.OPENAI.value),
tool_llm_model_name = os.getenv("VECTARA_AGENTIC_TOOL_MODEL_NAME", ""),
observer = os.getenv("VECTARA_AGENTIC_OBSERVER_TYPE", "NO_OBSERVER")
)
fallback_agent_config = AgentConfig(
agent_type = os.getenv("VECTARA_AGENTIC_FALLBACK_AGENT_TYPE", AgentType.OPENAI.value),
main_llm_provider = os.getenv("VECTARA_AGENTIC_FALLBACK_MAIN_LLM_PROVIDER", ModelProvider.OPENAI.value),
main_llm_model_name = os.getenv("VECTARA_AGENTIC_FALLBACK_MAIN_MODEL_NAME", ""),
tool_llm_provider = os.getenv("VECTARA_AGENTIC_FALLBACK_TOOL_LLM_PROVIDER", ModelProvider.OPENAI.value),
tool_llm_model_name = os.getenv("VECTARA_AGENTIC_FALLBACK_TOOL_MODEL_NAME", ""),
observer = os.getenv("VECTARA_AGENTIC_OBSERVER_TYPE", "NO_OBSERVER")
)
agent = Agent(
tools=create_assistant_tools(_cfg),
topic="Endowment fund management",
custom_instructions=bot_instructions,
agent_progress_callback=agent_progress_callback,
verbose=True,
agent_config=agent_config,
fallback_agent_config=fallback_agent_config,
)
agent.report()
return agent
def get_agent_config() -> OmegaConf:
cfg = OmegaConf.create({
'corpus_key': str(os.environ['VECTARA_CORPUS_KEY']),
'api_key': str(os.environ['VECTARA_API_KEY']),
'examples': os.environ.get('QUERY_EXAMPLES', None),
'demo_name': "Harvard Management Company",
'demo_welcome': "Harvard Management Company.",
'demo_description': "AI Assistant.",
})
return cfg