cheesyFishes commited on
Commit
3c8ea82
1 Parent(s): 4599370
Files changed (8) hide show
  1. .gitattributes +1 -0
  2. README.md +4 -4
  3. app.py +137 -0
  4. constants.py +24 -0
  5. requirements.txt +4 -0
  6. sfscores.sqlite +3 -0
  7. sql_index.json +1 -0
  8. utils.py +26 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *sqlite filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
  title: Llama Index Sql Sandbox
3
- emoji:
4
  colorFrom: blue
5
- colorTo: green
6
  sdk: streamlit
7
- sdk_version: 1.17.0
8
  app_file: app.py
9
- pinned: false
10
  license: mit
11
  ---
12
 
 
1
  ---
2
  title: Llama Index Sql Sandbox
3
+ emoji: 🦙
4
  colorFrom: blue
5
+ colorTo: pink
6
  sdk: streamlit
7
+ sdk_version: 1.19.0
8
  app_file: app.py
9
+ pinned: true
10
  license: mit
11
  ---
12
 
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from streamlit_chat import message as st_message
4
+ from sqlalchemy import create_engine
5
+
6
+ from langchain.agents import Tool, initialize_agent
7
+ from langchain.chains.conversation.memory import ConversationBufferMemory
8
+
9
+ from llama_index import GPTSQLStructStoreIndex, LLMPredictor, ServiceContext
10
+ from llama_index import SQLDatabase as llama_SQLDatabase
11
+ from llama_index.indices.struct_store import SQLContextContainerBuilder
12
+
13
+ from constants import (
14
+ DEFAULT_SQL_PATH,
15
+ DEFAULT_BUSINESS_TABLE_DESCRP,
16
+ DEFAULT_VIOLATIONS_TABLE_DESCRP,
17
+ DEFAULT_INSPECTIONS_TABLE_DESCRP,
18
+ DEFAULT_LC_TOOL_DESCRP
19
+ )
20
+ from utils import get_sql_index_tool, get_llm
21
+
22
+
23
+ @st.cache_resource
24
+ def initialize_index(llm_name, model_temperature, table_context_dict, api_key, sql_path=DEFAULT_SQL_PATH):
25
+ """Create the GPTSQLStructStoreIndex object."""
26
+ llm = get_llm(llm_name, model_temperature, api_key)
27
+
28
+ engine = create_engine(sql_path)
29
+ sql_database = llama_SQLDatabase(engine)
30
+
31
+ context_container = None
32
+ if table_context_dict is not None:
33
+ context_builder = SQLContextContainerBuilder(sql_database, context_dict=table_context_dict)
34
+ context_container = context_builder.build_context_container()
35
+
36
+ service_context = ServiceContext.from_defaults(llm_predictor=LLMPredictor(llm=llm))
37
+ index = GPTSQLStructStoreIndex([],
38
+ sql_database=sql_database,
39
+ sql_context_container=context_container,
40
+ service_context=service_context)
41
+
42
+ return index
43
+
44
+
45
+ @st.cache_resource
46
+ def initialize_chain(llm_name, model_temperature, lc_descrp, api_key, _sql_index):
47
+ """Create a (rather hacky) custom agent and sql_index tool."""
48
+ sql_tool = Tool(name="SQL Index",
49
+ func=get_sql_index_tool(_sql_index, _sql_index.sql_context_container.context_dict),
50
+ description=lc_descrp)
51
+
52
+ llm = get_llm(llm_name, model_temperature, api_key=api_key)
53
+
54
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
55
+
56
+ agent_chain = initialize_agent([sql_tool], llm, agent="chat-conversational-react-description", verbose=True, memory=memory)
57
+
58
+ return agent_chain
59
+
60
+
61
+ st.title("🦙 Llama Index SQL Sandbox 🦙")
62
+ st.markdown((
63
+ "This sandbox uses a sqlite database by default, powered by [Llama Index](https://gpt-index.readthedocs.io/en/latest/index.html) ChatGPT, and LangChain.\n\n"
64
+ "The database contains information on health violations and inspections at restaurants in San Francisco."
65
+ "This data is spread across three tables - businesses, inspections, and violations.\n\n"
66
+ "Using the setup page, you can adjust LLM settings, change the context for the SQL tables, and change the tool description for Langchain."
67
+ "The other tabs will perform chatbot and text2sql operations.\n\n"
68
+ "Read more about LlamaIndexes structured data support [here!](https://gpt-index.readthedocs.io/en/latest/guides/tutorials/sql_guide.html)"
69
+ ))
70
+
71
+
72
+ setup_tab, llama_tab, lc_tab = st.tabs(["Setup", "Llama Index", "Langchain+Llama Index"])
73
+
74
+ with setup_tab:
75
+ st.subheader("LLM Setup")
76
+ api_key = st.text_input("Enter your OpenAI API key here", type="password")
77
+ llm_name = st.selectbox('Which LLM?', ["text-davinci-003", "gpt-3.5-turbo", "gpt-4"])
78
+ model_temperature = st.slider("LLM Temperature", min_value=0.0, max_value=1.0, step=0.1)
79
+
80
+ st.subheader("Table Setup")
81
+ business_table_descrp = st.text_area("Business table description", value=DEFAULT_BUSINESS_TABLE_DESCRP)
82
+ violations_table_descrp = st.text_area("Business table description", value=DEFAULT_VIOLATIONS_TABLE_DESCRP)
83
+ inspections_table_descrp = st.text_area("Business table description", value=DEFAULT_INSPECTIONS_TABLE_DESCRP)
84
+
85
+ table_context_dict = {"businesses": business_table_descrp,
86
+ "inspections": inspections_table_descrp,
87
+ "violations": violations_table_descrp}
88
+
89
+ use_table_descrp = st.checkbox("Use table descriptions?", value=True)
90
+ lc_descrp = st.text_area("LangChain Tool Description", value=DEFAULT_LC_TOOL_DESCRP)
91
+
92
+ with llama_tab:
93
+ st.subheader("Text2SQL with Llama Index")
94
+ if st.button("Initialize Index", key="init_index_1"):
95
+ st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None, api_key)
96
+
97
+ if "llama_index" in st.session_state:
98
+ query_text = st.text_input("Query:", value="Which restaurant has the most violations?")
99
+ if st.button("Run Query") and query_text:
100
+ with st.spinner("Getting response..."):
101
+ try:
102
+ response = st.session_state['llama_index'].query(query_text)
103
+ response_text = str(response)
104
+ response_sql = response.extra_info['sql_query']
105
+ except Exception as e:
106
+ response_text = "Error running SQL Query."
107
+ response_sql = str(e)
108
+
109
+ col1, col2 = st.columns(2)
110
+ with col1:
111
+ st.text("SQL Result:")
112
+ st.markdown(response_text)
113
+
114
+ with col2:
115
+ st.text("SQL Query:")
116
+ st.markdown(response_sql)
117
+
118
+ with lc_tab:
119
+ st.subheader("Langchain + Llama Index SQL Demo")
120
+
121
+ if st.button("Initialize Agent"):
122
+ st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None, api_key)
123
+ st.session_state['lc_agent'] = initialize_chain(llm_name, model_temperature, lc_descrp, api_key, st.session_state['llama_index'])
124
+ st.session_state['chat_history'] = []
125
+
126
+ model_input = st.text_input("Message:", value="Which restaurant has the most violations?")
127
+ if 'lc_agent' in st.session_state and st.button("Send"):
128
+ model_input = "User: " + model_input
129
+ st.session_state['chat_history'].append(model_input)
130
+ with st.spinner("Getting response..."):
131
+ response = st.session_state['lc_agent'].run(input=model_input)
132
+ st.session_state['chat_history'].append(response)
133
+
134
+ if 'chat_history' in st.session_state:
135
+ for msg in st.session_state['chat_history']:
136
+ st_message(msg.split("User: ")[-1], is_user="User: " in msg)
137
+
constants.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DEFAULT_SQL_PATH = "sqlite:///sfscores.sqlite"
2
+ DEFAULT_BUSINESS_TABLE_DESCRP = (
3
+ "This table gives information on the IDs, addresses, and other location "
4
+ "information for several restaurants in San Francisco. This table will "
5
+ "need to be referenced when users ask about specific businesses."
6
+ )
7
+ DEFAULT_VIOLATIONS_TABLE_DESCRP = (
8
+ "This table gives information on which business IDs have recorded health violations, "
9
+ "including the date, risk, and description of each violation. The user may query "
10
+ "about specific businesses, whose names can be found by mapping the business_id "
11
+ "to the 'businesses' table."
12
+ )
13
+ DEFAULT_INSPECTIONS_TABLE_DESCRP = (
14
+ "This table gives information on when each business ID was inspected, including "
15
+ "the score, date, and type of inspection. The user may query about specific "
16
+ "businesses, whose names can be found by mapping the business_id to the 'businesses' table."
17
+ )
18
+ DEFAULT_LC_TOOL_DESCRP = "Useful for when you want to answer queries about violations and inspections of businesses."
19
+
20
+ DEFAULT_INGEST_DOCUMENT = (
21
+ "The restaurant KING-KONG had an routine unscheduled inspection on 2023/12/31. "
22
+ "The business achieved a score of 50. We two violations, a high risk "
23
+ "vermin infestation as well as a high risk food holding temperatures."
24
+ )
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ langchain==0.0.123
2
+ llama-index==0.5.1
3
+ streamlit==1.19.0
4
+ streamlit-chat==0.0.2.2
sfscores.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:240deebf58f54606266cdd4a5dca14c48d58f8b530941c0249a9b23f00589afa
3
+ size 9639936
sql_index.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"index_struct_id": "b52fad59-0c00-4392-b775-f9cd3fdb6deb", "docstore": {"docs": {"b52fad59-0c00-4392-b775-f9cd3fdb6deb": {"text": null, "doc_id": "b52fad59-0c00-4392-b775-f9cd3fdb6deb", "embedding": null, "doc_hash": "08a14830cef184731c6b6a0bdd67fa351d923556941aa99027b276bd839a07a4", "extra_info": null, "context_dict": {}, "__type__": "sql"}}, "ref_doc_info": {"b52fad59-0c00-4392-b775-f9cd3fdb6deb": {"doc_hash": "08a14830cef184731c6b6a0bdd67fa351d923556941aa99027b276bd839a07a4"}}}, "sql_context_container": {"context_dict": {"violations": "Schema of table violations:\nTable 'violations' has columns: business_id (TEXT), date (TEXT), ViolationTypeID (TEXT), risk_category (TEXT), description (TEXT) and foreign keys: .\nContext of table violations:\nThis table gives information on which business IDs have recorded health violations, including the date, risk, and description of each violation. The user may query about specific businesses, whose names can be found by mapping the business_id to the 'businesses' table.", "businesses": "Schema of table businesses:\nTable 'businesses' has columns: business_id (INTEGER), name (VARCHAR(64)), address (VARCHAR(50)), city (VARCHAR(23)), postal_code (VARCHAR(9)), latitude (FLOAT), longitude (FLOAT), phone_number (BIGINT), TaxCode (VARCHAR(4)), business_certificate (INTEGER), application_date (DATE), owner_name (VARCHAR(99)), owner_address (VARCHAR(74)), owner_city (VARCHAR(22)), owner_state (VARCHAR(14)), owner_zip (VARCHAR(15)) and foreign keys: .\nContext of table businesses:\nThis table gives information on the IDs, addresses, and other location information for several restaruants in San Fransisco. This table will need to be referenced when users ask about specific bussinesses.", "inspections": "Schema of table inspections:\nTable 'inspections' has columns: business_id (TEXT), Score (INTEGER), date (TEXT), type (VARCHAR(33)) and foreign keys: .\nContext of table inspections:\nThis table gives information on when each bussiness ID was inspected, including the score, date, and type of inspection. The user may query about specific businesses, whose names can be found by mapping the business_id to the 'businesses' table."}, "context_str": null}}
utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain import OpenAI
3
+ from langchain.chat_models import ChatOpenAI
4
+
5
+
6
+ def get_sql_index_tool(sql_index, table_context_dict):
7
+ table_context_str = "\n".join(table_context_dict.values())
8
+ def run_sql_index_query(query_text):
9
+ try:
10
+ response = sql_index.query(query_text)
11
+ except Exception as e:
12
+ return f"Error running SQL {e}.\nNot able to retrieve answer."
13
+ text = str(response)
14
+ sql = response.extra_info['sql_query']
15
+ return f"Here are the details on the SQL table: {table_context_str}\nSQL Query Used: {sql}\nSQL Result: {text}\n"
16
+ #return f"SQL Query Used: {sql}\nSQL Result: {text}\n"
17
+ return run_sql_index_query
18
+
19
+
20
+
21
+ def get_llm(llm_name, model_temperature, api_key):
22
+ os.environ['OPENAI_API_KEY'] = api_key
23
+ if llm_name == "text-davinci-003":
24
+ return OpenAI(temperature=model_temperature, model_name=llm_name)
25
+ else:
26
+ return ChatOpenAI(temperature=model_temperature, model_name=llm_name)