ofermend commited on
Commit
49b0a2d
1 Parent(s): f8d2846
Files changed (2) hide show
  1. agent.py +112 -0
  2. app.py +17 -155
agent.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+ import os
3
+
4
+ from typing import Optional
5
+ from pydantic import Field, BaseModel
6
+
7
+ from dotenv import load_dotenv
8
+ load_dotenv(override=True)
9
+
10
+ from vectara_agent.agent import Agent
11
+ from vectara_agent.tools import ToolsFactory, VectaraToolFactory
12
+
13
+ def create_assistant_tools(cfg):
14
+
15
+ class QueryElectricCars(BaseModel):
16
+ query: str = Field(description="The user query.")
17
+
18
+ vec_factory_1 = VectaraToolFactory(vectara_api_key=cfg.api_keys[0],
19
+ vectara_customer_id=cfg.customer_id,
20
+ vectara_corpus_id=cfg.corpus_ids[0])
21
+
22
+ ask_vehicles = vec_factory_1.create_rag_tool(
23
+ tool_name = "ask_vehicles",
24
+ tool_description = """
25
+ Given a user query,
26
+ returns a response (str) to a user question about electric vehicles based on online resources.
27
+ You can ask this tool any question about electric cars, including the different types of EVs, how they work, the pros and cons of different models, the environmental impact, and more.
28
+ """,
29
+ tool_args_schema = QueryElectricCars,
30
+ reranker = "multilingual_reranker_v1", rerank_k = 100,
31
+ n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005,
32
+ summary_num_results = 10,
33
+ vectara_summarizer = 'vectara-summary-ext-24-05-sml',
34
+ include_citations = False,
35
+ )
36
+
37
+ vec_factory_2 = VectaraToolFactory(vectara_api_key=cfg.api_keys[1],
38
+ vectara_customer_id=cfg.customer_id,
39
+ vectara_corpus_id=cfg.corpus_ids[1])
40
+
41
+
42
+ class QueryEVLaws(BaseModel):
43
+ query: str = Field(description="The user query")
44
+ state: Optional[str] = Field(default=None,
45
+ description="The two digit state code. Optional.",
46
+ examples=['CA', 'US', 'WA'])
47
+ type: Optional[str] = Field(default=None,
48
+ description="The type of policy. Optional",
49
+ examples = ['Laws and Regulations', 'State Incentives', 'Incentives', 'Utility / Private Incentives', 'Programs'])
50
+
51
+
52
+
53
+ ask_policies = vec_factory_2.create_rag_tool(
54
+ tool_name = "ask_policies",
55
+ tool_description = """
56
+ Given a user query,
57
+ returns a response (str) to a user question about incentives and regulations about electric vehicles in the United States.
58
+ You can ask this tool any question about laws passed by states or the federal government related to electric vehicles.
59
+ """,
60
+ tool_args_schema = QueryEVLaws,
61
+ reranker = "multilingual_reranker_v1", rerank_k = 100,
62
+ n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005,
63
+ summary_num_results = 10,
64
+ vectara_summarizer = 'vectara-summary-ext-24-05-sml',
65
+ include_citations = False,
66
+ )
67
+
68
+ tools_factory = ToolsFactory()
69
+
70
+ return (tools_factory.standard_tools() +
71
+ tools_factory.guardrail_tools() +
72
+ tools_factory.database_tools(
73
+ content_description = 'Electric Vehicles',
74
+ scheme = 'postgresql',
75
+ host = 'localhost', port = '5432',
76
+ user = 'ofer',
77
+ password = 'noanoa',
78
+ dbname = 'ev_database'
79
+ ) +
80
+ [ask_vehicles, ask_policies]
81
+ )
82
+
83
+ def initialize_agent(_cfg, update_func):
84
+ electric_vehicle_bot_instructions = """
85
+ - You are a helpful research assistant, with expertise in electric vehicles, in conversation with a user.
86
+ - For a query with multiple sub-questions, break down the query into the sub-questions,
87
+ and make separate calls to the ask_vehicles or ask_policies tool to answer each sub-question,
88
+ then combine the answers to provide a complete response.
89
+ - Never discuss politics, and always respond politely.
90
+ """
91
+
92
+ agent = Agent(
93
+ tools=create_assistant_tools(_cfg),
94
+ topic="Electric vehicles in the United States",
95
+ custom_instructions=electric_vehicle_bot_instructions,
96
+ update_func=update_func
97
+ )
98
+ return agent
99
+
100
+
101
+ def get_agent_config() -> OmegaConf:
102
+ cfg = OmegaConf.create({
103
+ 'customer_id': str(os.environ['VECTARA_CUSTOMER_ID']),
104
+ 'corpus_ids': str(os.environ['VECTARA_CORPUS_IDS']).split(','),
105
+ 'api_keys': str(os.environ['VECTARA_API_KEYS']).split(','),
106
+ 'examples': os.environ.get('QUERY_EXAMPLES', None),
107
+ 'title': "Electric Vehicles in the United States",
108
+ 'demo_welcome': "Welcome to the EV Assistant demo.",
109
+ 'demo_description': "This assistant can help you learn about electric vehicles in the United States, including how they work, the advantages of purchasing them, and reviews on the top choices.",
110
+ })
111
+ return cfg
112
+
app.py CHANGED
@@ -1,148 +1,14 @@
1
-
2
- import os
3
  from PIL import Image
4
  import sys
5
- import pandas as pd
6
- import requests
7
 
8
- from omegaconf import OmegaConf
9
  import streamlit as st
10
  from streamlit_pills import pills
11
 
12
- from dotenv import load_dotenv
13
- load_dotenv(override=True)
14
-
15
- from pydantic import Field, BaseModel
16
- from vectara_agent.agent import Agent, AgentStatusType
17
- from vectara_agent.tools import ToolsFactory, VectaraToolFactory
18
 
19
- tickers = {
20
- "AAPL": "Apple Computer",
21
- "GOOG": "Google",
22
- "AMZN": "Amazon",
23
- "SNOW": "Snowflake",
24
- "TEAM": "Atlassian",
25
- "TSLA": "Tesla",
26
- "NVDA": "Nvidia",
27
- "MSFT": "Microsoft",
28
- "AMD": "Advanced Micro Devices",
29
- "INTC": "Intel",
30
- "NFLX": "Netflix",
31
- }
32
- years = [2020, 2021, 2022, 2023, 2024]
33
  initial_prompt = "How can I help you today?"
34
 
35
- def create_assistant_tools(cfg):
36
-
37
- def get_company_info() -> list[str]:
38
- """
39
- Returns a dictionary of companies you can query about. Always check this before using any other tool.
40
- The output is a dictionary of valid ticker symbols mapped to company names.
41
- You can use this to identify the companies you can query about, and their ticker information.
42
- """
43
- return tickers
44
-
45
- def get_valid_years() -> list[str]:
46
- """
47
- Returns a list of the years for which financial reports are available.
48
- Always check this before using any other tool.
49
- """
50
- return years
51
-
52
- # Tool to get the income statement for a given company and year using the FMP API
53
- def get_income_statement(
54
- ticker=Field(description="the ticker symbol of the company."),
55
- year=Field(description="the year for which to get the income statement."),
56
- ) -> str:
57
- """
58
- Get the income statement for a given company and year using the FMP (https://financialmodelingprep.com) API.
59
- Returns a dictionary with the income statement data. All data is in USD, but you can convert it to more compact form like K, M, B.
60
- """
61
- fmp_api_key = os.environ.get("FMP_API_KEY", None)
62
- if fmp_api_key is None:
63
- return "FMP_API_KEY environment variable not set. This tool does not work."
64
- url = f"https://financialmodelingprep.com/api/v3/income-statement/{ticker}?apikey={fmp_api_key}"
65
- response = requests.get(url)
66
- if response.status_code == 200:
67
- data = response.json()
68
- income_statement = pd.DataFrame(data)
69
- income_statement["date"] = pd.to_datetime(income_statement["date"])
70
- income_statement_specific_year = income_statement[
71
- income_statement["date"].dt.year == int(year)
72
- ]
73
- values_dict = income_statement_specific_year.to_dict(orient="records")[0]
74
- return f"Financial results: {', '.join([f'{key}: {value}' for key, value in values_dict.items() if key not in ['date', 'cik', 'link', 'finalLink']])}"
75
- else:
76
- return "FMP API returned error. This tool does not work."
77
-
78
- class QueryTranscriptsArgs(BaseModel):
79
- query: str = Field(..., description="The user query.")
80
- year: int = Field(..., description=f"The year. An integer between {min(years)} and {max(years)}.")
81
- ticker: str = Field(..., description=f"The company ticker. Must be a valid ticket symbol from the list {tickers.keys()}.")
82
-
83
- vec_factory = VectaraToolFactory(vectara_api_key=cfg.api_key,
84
- vectara_customer_id=cfg.customer_id,
85
- vectara_corpus_id=cfg.corpus_id)
86
- tools_factory = ToolsFactory()
87
-
88
- ask_transcripts = vec_factory.create_rag_tool(
89
- tool_name = "ask_transcripts",
90
- tool_description = """
91
- Given a company name and year, responds to a user question about the company, based on analyst call transcripts about the company's financial reports for that year.
92
- You can ask this tool any question about the compaany including risks, opportunities, financial performance, competitors and more.
93
- """,
94
- tool_args_schema = QueryTranscriptsArgs,
95
- reranker = "multilingual_reranker_v1", rerank_k = 100,
96
- n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005,
97
- summary_num_results = 10,
98
- vectara_summarizer = 'vectara-summary-ext-24-05-med-omni',
99
- include_citations = False,
100
- )
101
-
102
- return (
103
- [tools_factory.create_tool(tool) for tool in
104
- [
105
- get_company_info,
106
- get_valid_years,
107
- get_income_statement,
108
- ]
109
- ] +
110
- tools_factory.standard_tools() +
111
- tools_factory.financial_tools() +
112
- tools_factory.guardrail_tools() +
113
- [ask_transcripts]
114
- )
115
-
116
- def initialize_agent(_cfg):
117
- if 'agent' in st.session_state:
118
- return st.session_state.agent
119
-
120
- financial_bot_instructions = """
121
- - You are a helpful financial assistant, with expertise in financial reporting, in conversation with a user.
122
- - Respond in a compact format by using appropriate units of measure (e.g., K for thousands, M for millions, B for billions).
123
- Do not report the same number twice (e.g. $100K and 100,000 USD).
124
- - Always check the get_company_info and get_valid_years tools to validate company and year are valid.
125
- - Do not include URLS unless they are from one of the tools.
126
- - When querying a tool for a numeric value or KPI, use a concise and non-ambiguous description of what you are looking for.
127
- - If you calculate a metric, make sure you have all the necessary information to complete the calculation. Don't guess.
128
- """
129
-
130
- def update_func(status_type: AgentStatusType, msg: str):
131
- if status_type != AgentStatusType.AGENT_UPDATE:
132
- output = f"{status_type.value} - {msg}"
133
- st.session_state.log_messages.append(output)
134
-
135
- agent = Agent(
136
- tools=create_assistant_tools(_cfg),
137
- topic="Financial data, annual reports and 10-K filings",
138
- custom_instructions=financial_bot_instructions,
139
- update_func=update_func
140
- )
141
- agent.report()
142
-
143
- return agent
144
-
145
-
146
  def toggle_logs():
147
  st.session_state.show_logs = not st.session_state.show_logs
148
 
@@ -155,6 +21,11 @@ def show_example_questions():
155
  return True
156
  return False
157
 
 
 
 
 
 
158
  def launch_bot():
159
  def reset():
160
  st.session_state.messages = [{"role": "assistant", "content": initial_prompt, "avatar": "🦖"}]
@@ -162,17 +33,13 @@ def launch_bot():
162
  st.session_state.log_messages = []
163
  st.session_state.prompt = None
164
  st.session_state.ex_prompt = None
165
- st.session_state.show_logs = False
166
  st.session_state.first_turn = True
 
 
 
167
 
168
- st.set_page_config(page_title="Financial Assistant", layout="wide")
169
  if 'cfg' not in st.session_state:
170
- cfg = OmegaConf.create({
171
- 'customer_id': str(os.environ['VECTARA_CUSTOMER_ID']),
172
- 'corpus_id': str(os.environ['VECTARA_CORPUS_ID']),
173
- 'api_key': str(os.environ['VECTARA_API_KEY']),
174
- 'examples': os.environ.get('QUERY_EXAMPLES', None)
175
- })
176
  st.session_state.cfg = cfg
177
  st.session_state.ex_prompt = None
178
  example_messages = [example.strip() for example in cfg.examples.split(",")] if cfg.examples else []
@@ -180,18 +47,14 @@ def launch_bot():
180
  reset()
181
 
182
  cfg = st.session_state.cfg
183
- if 'agent' not in st.session_state:
184
- st.session_state.agent = initialize_agent(cfg)
185
 
186
  # left side content
187
  with st.sidebar:
188
  image = Image.open('Vectara-logo.png')
189
  st.image(image, width=175)
190
- st.markdown("## Welcome to the financial assistant demo.\n\n\n")
191
- companies = ", ".join(tickers.values())
192
- st.markdown(
193
- f"This assistant can help you with any questions about the financials of several companies:\n\n **{companies}**.\n"
194
- )
195
 
196
  st.markdown("\n\n")
197
  bc1, _ = st.columns([1, 1])
@@ -206,7 +69,6 @@ def launch_bot():
206
  "This app was built with [Vectara](https://vectara.com).\n\n"
207
  "It demonstrates the use of Agentic RAG functionality with Vectara"
208
  )
209
- st.markdown("---")
210
 
211
  if "messages" not in st.session_state.keys():
212
  reset()
@@ -249,8 +111,9 @@ def launch_bot():
249
  st.markdown(res)
250
  st.session_state.ex_prompt = None
251
  st.session_state.prompt = None
 
252
  st.rerun()
253
-
254
  log_placeholder = st.empty()
255
  with log_placeholder.container():
256
  if st.session_state.show_logs:
@@ -264,5 +127,4 @@ def launch_bot():
264
  sys.stdout.flush()
265
 
266
  if __name__ == "__main__":
267
- launch_bot()
268
-
 
 
 
1
  from PIL import Image
2
  import sys
 
 
3
 
 
4
  import streamlit as st
5
  from streamlit_pills import pills
6
 
7
+ from vectara_agent.agent import AgentStatusType
8
+ from agent import initialize_agent, get_agent_config
 
 
 
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  initial_prompt = "How can I help you today?"
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def toggle_logs():
13
  st.session_state.show_logs = not st.session_state.show_logs
14
 
 
21
  return True
22
  return False
23
 
24
+ def update_func(status_type: AgentStatusType, msg: str):
25
+ if status_type != AgentStatusType.AGENT_UPDATE:
26
+ output = f"{status_type.value} - {msg}"
27
+ st.session_state.log_messages.append(output)
28
+
29
  def launch_bot():
30
  def reset():
31
  st.session_state.messages = [{"role": "assistant", "content": initial_prompt, "avatar": "🦖"}]
 
33
  st.session_state.log_messages = []
34
  st.session_state.prompt = None
35
  st.session_state.ex_prompt = None
 
36
  st.session_state.first_turn = True
37
+ st.session_state.show_logs = False
38
+ if 'agent' not in st.session_state:
39
+ st.session_state.agent = initialize_agent(cfg, update_func=update_func)
40
 
 
41
  if 'cfg' not in st.session_state:
42
+ cfg = get_agent_config()
 
 
 
 
 
43
  st.session_state.cfg = cfg
44
  st.session_state.ex_prompt = None
45
  example_messages = [example.strip() for example in cfg.examples.split(",")] if cfg.examples else []
 
47
  reset()
48
 
49
  cfg = st.session_state.cfg
50
+ st.set_page_config(page_title=cfg['title'], layout="wide")
 
51
 
52
  # left side content
53
  with st.sidebar:
54
  image = Image.open('Vectara-logo.png')
55
  st.image(image, width=175)
56
+ st.markdown(f"## {cfg['demo_welcome']}")
57
+ st.markdown(f"{cfg['demo_description']}")
 
 
 
58
 
59
  st.markdown("\n\n")
60
  bc1, _ = st.columns([1, 1])
 
69
  "This app was built with [Vectara](https://vectara.com).\n\n"
70
  "It demonstrates the use of Agentic RAG functionality with Vectara"
71
  )
 
72
 
73
  if "messages" not in st.session_state.keys():
74
  reset()
 
111
  st.markdown(res)
112
  st.session_state.ex_prompt = None
113
  st.session_state.prompt = None
114
+ st.session_state.first_turn = False
115
  st.rerun()
116
+
117
  log_placeholder = st.empty()
118
  with log_placeholder.container():
119
  if st.session_state.show_logs:
 
127
  sys.stdout.flush()
128
 
129
  if __name__ == "__main__":
130
+ launch_bot()