cheesyFishes commited on
Commit
3eeb9d5
β€’
1 Parent(s): 59f5daa

update to llamaindex v0.6.13

Browse files
Files changed (4) hide show
  1. app.py +100 -48
  2. constants.py +1 -1
  3. requirements.txt +3 -2
  4. utils.py +6 -5
app.py CHANGED
@@ -15,13 +15,15 @@ from constants import (
15
  DEFAULT_BUSINESS_TABLE_DESCRP,
16
  DEFAULT_VIOLATIONS_TABLE_DESCRP,
17
  DEFAULT_INSPECTIONS_TABLE_DESCRP,
18
- DEFAULT_LC_TOOL_DESCRP
19
  )
20
  from utils import get_sql_index_tool, get_llm
21
 
22
 
23
  @st.cache_resource
24
- def initialize_index(llm_name, model_temperature, table_context_dict, api_key, sql_path=DEFAULT_SQL_PATH):
 
 
25
  """Create the GPTSQLStructStoreIndex object."""
26
  llm = get_llm(llm_name, model_temperature, api_key)
27
 
@@ -30,14 +32,18 @@ def initialize_index(llm_name, model_temperature, table_context_dict, api_key, s
30
 
31
  context_container = None
32
  if table_context_dict is not None:
33
- context_builder = SQLContextContainerBuilder(sql_database, context_dict=table_context_dict)
 
 
34
  context_container = context_builder.build_context_container()
35
-
36
  service_context = ServiceContext.from_defaults(llm_predictor=LLMPredictor(llm=llm))
37
- index = GPTSQLStructStoreIndex([],
38
- sql_database=sql_database,
39
- sql_context_container=context_container,
40
- service_context=service_context)
 
 
41
 
42
  return index
43
 
@@ -45,63 +51,97 @@ def initialize_index(llm_name, model_temperature, table_context_dict, api_key, s
45
  @st.cache_resource
46
  def initialize_chain(llm_name, model_temperature, lc_descrp, api_key, _sql_index):
47
  """Create a (rather hacky) custom agent and sql_index tool."""
48
- sql_tool = Tool(name="SQL Index",
49
- func=get_sql_index_tool(_sql_index, _sql_index.sql_context_container.context_dict),
50
- description=lc_descrp)
 
 
 
 
51
 
52
  llm = get_llm(llm_name, model_temperature, api_key=api_key)
53
 
54
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
55
 
56
- agent_chain = initialize_agent([sql_tool], llm, agent="chat-conversational-react-description", verbose=True, memory=memory)
 
 
 
 
 
 
57
 
58
  return agent_chain
59
 
60
 
61
  st.title("πŸ¦™ Llama Index SQL Sandbox πŸ¦™")
62
- st.markdown((
63
- "This sandbox uses a sqlite database by default, powered by [Llama Index](https://gpt-index.readthedocs.io/en/latest/index.html) ChatGPT, and LangChain.\n\n"
64
- "The database contains information on health violations and inspections at restaurants in San Francisco."
65
- "This data is spread across three tables - businesses, inspections, and violations.\n\n"
66
- "Using the setup page, you can adjust LLM settings, change the context for the SQL tables, and change the tool description for Langchain."
67
- "The other tabs will perform chatbot and text2sql operations.\n\n"
68
- "Read more about LlamaIndexes structured data support [here!](https://gpt-index.readthedocs.io/en/latest/guides/tutorials/sql_guide.html)"
69
- ))
 
 
70
 
71
 
72
- setup_tab, llama_tab, lc_tab = st.tabs(["Setup", "Llama Index", "Langchain+Llama Index"])
 
 
73
 
74
  with setup_tab:
75
  st.subheader("LLM Setup")
76
  api_key = st.text_input("Enter your OpenAI API key here", type="password")
77
- llm_name = st.selectbox('Which LLM?', ["text-davinci-003", "gpt-3.5-turbo", "gpt-4"])
78
- model_temperature = st.slider("LLM Temperature", min_value=0.0, max_value=1.0, step=0.1)
 
 
 
 
79
 
80
  st.subheader("Table Setup")
81
- business_table_descrp = st.text_area("Business table description", value=DEFAULT_BUSINESS_TABLE_DESCRP)
82
- violations_table_descrp = st.text_area("Business table description", value=DEFAULT_VIOLATIONS_TABLE_DESCRP)
83
- inspections_table_descrp = st.text_area("Business table description", value=DEFAULT_INSPECTIONS_TABLE_DESCRP)
84
-
85
- table_context_dict = {"businesses": business_table_descrp,
86
- "inspections": inspections_table_descrp,
87
- "violations": violations_table_descrp}
88
-
 
 
 
 
 
 
 
 
89
  use_table_descrp = st.checkbox("Use table descriptions?", value=True)
90
  lc_descrp = st.text_area("LangChain Tool Description", value=DEFAULT_LC_TOOL_DESCRP)
91
 
92
  with llama_tab:
93
  st.subheader("Text2SQL with Llama Index")
94
  if st.button("Initialize Index", key="init_index_1"):
95
- st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None, api_key)
96
-
 
 
 
 
 
97
  if "llama_index" in st.session_state:
98
- query_text = st.text_input("Query:", value="Which restaurant has the most violations?")
 
 
 
99
  if st.button("Run Query") and query_text:
100
  with st.spinner("Getting response..."):
101
  try:
102
- response = st.session_state['llama_index'].query(query_text)
103
  response_text = str(response)
104
- response_sql = response.extra_info['sql_query']
105
  except Exception as e:
106
  response_text = "Error running SQL Query."
107
  response_sql = str(e)
@@ -119,19 +159,31 @@ with lc_tab:
119
  st.subheader("Langchain + Llama Index SQL Demo")
120
 
121
  if st.button("Initialize Agent"):
122
- st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None, api_key)
123
- st.session_state['lc_agent'] = initialize_chain(llm_name, model_temperature, lc_descrp, api_key, st.session_state['llama_index'])
124
- st.session_state['chat_history'] = []
125
-
126
- model_input = st.text_input("Message:", value="Which restaurant has the most violations?")
127
- if 'lc_agent' in st.session_state and st.button("Send"):
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  model_input = "User: " + model_input
129
- st.session_state['chat_history'].append(model_input)
130
  with st.spinner("Getting response..."):
131
- response = st.session_state['lc_agent'].run(input=model_input)
132
- st.session_state['chat_history'].append(response)
133
 
134
- if 'chat_history' in st.session_state:
135
- for msg in st.session_state['chat_history']:
136
  st_message(msg.split("User: ")[-1], is_user="User: " in msg)
137
-
 
15
  DEFAULT_BUSINESS_TABLE_DESCRP,
16
  DEFAULT_VIOLATIONS_TABLE_DESCRP,
17
  DEFAULT_INSPECTIONS_TABLE_DESCRP,
18
+ DEFAULT_LC_TOOL_DESCRP,
19
  )
20
  from utils import get_sql_index_tool, get_llm
21
 
22
 
23
  @st.cache_resource
24
+ def initialize_index(
25
+ llm_name, model_temperature, table_context_dict, api_key, sql_path=DEFAULT_SQL_PATH
26
+ ):
27
  """Create the GPTSQLStructStoreIndex object."""
28
  llm = get_llm(llm_name, model_temperature, api_key)
29
 
 
32
 
33
  context_container = None
34
  if table_context_dict is not None:
35
+ context_builder = SQLContextContainerBuilder(
36
+ sql_database, context_dict=table_context_dict
37
+ )
38
  context_container = context_builder.build_context_container()
39
+
40
  service_context = ServiceContext.from_defaults(llm_predictor=LLMPredictor(llm=llm))
41
+ index = GPTSQLStructStoreIndex(
42
+ [],
43
+ sql_database=sql_database,
44
+ sql_context_container=context_container,
45
+ service_context=service_context,
46
+ )
47
 
48
  return index
49
 
 
51
  @st.cache_resource
52
  def initialize_chain(llm_name, model_temperature, lc_descrp, api_key, _sql_index):
53
  """Create a (rather hacky) custom agent and sql_index tool."""
54
+ sql_tool = Tool(
55
+ name="SQL Index",
56
+ func=get_sql_index_tool(
57
+ _sql_index, _sql_index.sql_context_container.context_dict
58
+ ),
59
+ description=lc_descrp,
60
+ )
61
 
62
  llm = get_llm(llm_name, model_temperature, api_key=api_key)
63
 
64
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
65
 
66
+ agent_chain = initialize_agent(
67
+ [sql_tool],
68
+ llm,
69
+ agent="chat-conversational-react-description",
70
+ verbose=True,
71
+ memory=memory,
72
+ )
73
 
74
  return agent_chain
75
 
76
 
77
  st.title("πŸ¦™ Llama Index SQL Sandbox πŸ¦™")
78
+ st.markdown(
79
+ (
80
+ "This sandbox uses a sqlite database by default, powered by [Llama Index](https://gpt-index.readthedocs.io/en/latest/index.html) ChatGPT, and LangChain.\n\n"
81
+ "The database contains information on health violations and inspections at restaurants in San Francisco."
82
+ "This data is spread across three tables - businesses, inspections, and violations.\n\n"
83
+ "Using the setup page, you can adjust LLM settings, change the context for the SQL tables, and change the tool description for Langchain."
84
+ "The other tabs will perform chatbot and text2sql operations.\n\n"
85
+ "Read more about LlamaIndexes structured data support [here!](https://gpt-index.readthedocs.io/en/latest/guides/tutorials/sql_guide.html)"
86
+ )
87
+ )
88
 
89
 
90
+ setup_tab, llama_tab, lc_tab = st.tabs(
91
+ ["Setup", "Llama Index", "Langchain+Llama Index"]
92
+ )
93
 
94
  with setup_tab:
95
  st.subheader("LLM Setup")
96
  api_key = st.text_input("Enter your OpenAI API key here", type="password")
97
+ llm_name = st.selectbox(
98
+ "Which LLM?", ["text-davinci-003", "gpt-3.5-turbo", "gpt-4"]
99
+ )
100
+ model_temperature = st.slider(
101
+ "LLM Temperature", min_value=0.0, max_value=1.0, step=0.1
102
+ )
103
 
104
  st.subheader("Table Setup")
105
+ business_table_descrp = st.text_area(
106
+ "Business table description", value=DEFAULT_BUSINESS_TABLE_DESCRP
107
+ )
108
+ violations_table_descrp = st.text_area(
109
+ "Business table description", value=DEFAULT_VIOLATIONS_TABLE_DESCRP
110
+ )
111
+ inspections_table_descrp = st.text_area(
112
+ "Business table description", value=DEFAULT_INSPECTIONS_TABLE_DESCRP
113
+ )
114
+
115
+ table_context_dict = {
116
+ "businesses": business_table_descrp,
117
+ "inspections": inspections_table_descrp,
118
+ "violations": violations_table_descrp,
119
+ }
120
+
121
  use_table_descrp = st.checkbox("Use table descriptions?", value=True)
122
  lc_descrp = st.text_area("LangChain Tool Description", value=DEFAULT_LC_TOOL_DESCRP)
123
 
124
  with llama_tab:
125
  st.subheader("Text2SQL with Llama Index")
126
  if st.button("Initialize Index", key="init_index_1"):
127
+ st.session_state["llama_index"] = initialize_index(
128
+ llm_name,
129
+ model_temperature,
130
+ table_context_dict if use_table_descrp else None,
131
+ api_key,
132
+ )
133
+
134
  if "llama_index" in st.session_state:
135
+ query_text = st.text_input(
136
+ "Query:", value="Which restaurant has the most violations?"
137
+ )
138
+ use_nl = st.checkbox("Return natural language response?")
139
  if st.button("Run Query") and query_text:
140
  with st.spinner("Getting response..."):
141
  try:
142
+ response = st.session_state["llama_index"].as_query_engine(synthesize_response=use_nl).query(query_text)
143
  response_text = str(response)
144
+ response_sql = response.extra_info["sql_query"]
145
  except Exception as e:
146
  response_text = "Error running SQL Query."
147
  response_sql = str(e)
 
159
  st.subheader("Langchain + Llama Index SQL Demo")
160
 
161
  if st.button("Initialize Agent"):
162
+ st.session_state["llama_index"] = initialize_index(
163
+ llm_name,
164
+ model_temperature,
165
+ table_context_dict if use_table_descrp else None,
166
+ api_key,
167
+ )
168
+ st.session_state["lc_agent"] = initialize_chain(
169
+ llm_name,
170
+ model_temperature,
171
+ lc_descrp,
172
+ api_key,
173
+ st.session_state["llama_index"],
174
+ )
175
+ st.session_state["chat_history"] = []
176
+
177
+ model_input = st.text_input(
178
+ "Message:", value="Which restaurant has the most violations?"
179
+ )
180
+ if "lc_agent" in st.session_state and st.button("Send"):
181
  model_input = "User: " + model_input
182
+ st.session_state["chat_history"].append(model_input)
183
  with st.spinner("Getting response..."):
184
+ response = st.session_state["lc_agent"].run(input=model_input)
185
+ st.session_state["chat_history"].append(response)
186
 
187
+ if "chat_history" in st.session_state:
188
+ for msg in st.session_state["chat_history"]:
189
  st_message(msg.split("User: ")[-1], is_user="User: " in msg)
 
constants.py CHANGED
@@ -21,4 +21,4 @@ DEFAULT_INGEST_DOCUMENT = (
21
  "The restaurant KING-KONG had an routine unscheduled inspection on 2023/12/31. "
22
  "The business achieved a score of 50. We two violations, a high risk "
23
  "vermin infestation as well as a high risk food holding temperatures."
24
- )
 
21
  "The restaurant KING-KONG had an routine unscheduled inspection on 2023/12/31. "
22
  "The business achieved a score of 50. We two violations, a high risk "
23
  "vermin infestation as well as a high risk food holding temperatures."
24
+ )
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
- langchain==0.0.123
2
- llama-index==0.5.1
 
3
  streamlit==1.19.0
4
  streamlit-chat==0.0.2.2
 
1
+ altair==4.2.2
2
+ langchain==0.0.154
3
+ llama-index==0.6.13
4
  streamlit==1.19.0
5
  streamlit-chat==0.0.2.2
utils.py CHANGED
@@ -5,21 +5,22 @@ from langchain.chat_models import ChatOpenAI
5
 
6
  def get_sql_index_tool(sql_index, table_context_dict):
7
  table_context_str = "\n".join(table_context_dict.values())
 
8
  def run_sql_index_query(query_text):
9
  try:
10
- response = sql_index.query(query_text)
11
  except Exception as e:
12
  return f"Error running SQL {e}.\nNot able to retrieve answer."
13
  text = str(response)
14
- sql = response.extra_info['sql_query']
15
  return f"Here are the details on the SQL table: {table_context_str}\nSQL Query Used: {sql}\nSQL Result: {text}\n"
16
- #return f"SQL Query Used: {sql}\nSQL Result: {text}\n"
17
- return run_sql_index_query
18
 
 
19
 
20
 
21
  def get_llm(llm_name, model_temperature, api_key):
22
- os.environ['OPENAI_API_KEY'] = api_key
23
  if llm_name == "text-davinci-003":
24
  return OpenAI(temperature=model_temperature, model_name=llm_name)
25
  else:
 
5
 
6
  def get_sql_index_tool(sql_index, table_context_dict):
7
  table_context_str = "\n".join(table_context_dict.values())
8
+
9
  def run_sql_index_query(query_text):
10
  try:
11
+ response = sql_index.as_query_engine(synthesize_response=False).query(query_text)
12
  except Exception as e:
13
  return f"Error running SQL {e}.\nNot able to retrieve answer."
14
  text = str(response)
15
+ sql = response.extra_info["sql_query"]
16
  return f"Here are the details on the SQL table: {table_context_str}\nSQL Query Used: {sql}\nSQL Result: {text}\n"
17
+ # return f"SQL Query Used: {sql}\nSQL Result: {text}\n"
 
18
 
19
+ return run_sql_index_query
20
 
21
 
22
  def get_llm(llm_name, model_temperature, api_key):
23
+ os.environ["OPENAI_API_KEY"] = api_key
24
  if llm_name == "text-davinci-003":
25
  return OpenAI(temperature=model_temperature, model_name=llm_name)
26
  else: