|
import os |
|
import streamlit as st |
|
from streamlit_chat import message as st_message |
|
from sqlalchemy import create_engine |
|
|
|
from langchain.agents import Tool, initialize_agent |
|
from langchain.chains.conversation.memory import ConversationBufferMemory |
|
|
|
from llama_index import GPTSQLStructStoreIndex, LLMPredictor, ServiceContext |
|
from llama_index import SQLDatabase as llama_SQLDatabase |
|
from llama_index.indices.struct_store import SQLContextContainerBuilder |
|
|
|
from constants import ( |
|
DEFAULT_SQL_PATH, |
|
DEFAULT_BUSINESS_TABLE_DESCRP, |
|
DEFAULT_VIOLATIONS_TABLE_DESCRP, |
|
DEFAULT_INSPECTIONS_TABLE_DESCRP, |
|
DEFAULT_LC_TOOL_DESCRP, |
|
) |
|
from utils import get_sql_index_tool, get_llm |
|
|
|
|
|
@st.cache_resource |
|
def initialize_index( |
|
llm_name, model_temperature, table_context_dict, api_key, sql_path=DEFAULT_SQL_PATH |
|
): |
|
"""Create the GPTSQLStructStoreIndex object.""" |
|
llm = get_llm(llm_name, model_temperature, api_key) |
|
|
|
engine = create_engine(sql_path) |
|
sql_database = llama_SQLDatabase(engine) |
|
|
|
context_container = None |
|
if table_context_dict is not None: |
|
context_builder = SQLContextContainerBuilder( |
|
sql_database, context_dict=table_context_dict |
|
) |
|
context_container = context_builder.build_context_container() |
|
|
|
service_context = ServiceContext.from_defaults(llm_predictor=LLMPredictor(llm=llm)) |
|
index = GPTSQLStructStoreIndex( |
|
[], |
|
sql_database=sql_database, |
|
sql_context_container=context_container, |
|
service_context=service_context, |
|
) |
|
|
|
return index |
|
|
|
|
|
@st.cache_resource |
|
def initialize_chain(llm_name, model_temperature, lc_descrp, api_key, _sql_index): |
|
"""Create a (rather hacky) custom agent and sql_index tool.""" |
|
sql_tool = Tool( |
|
name="SQL Index", |
|
func=get_sql_index_tool( |
|
_sql_index, _sql_index.sql_context_container.context_dict |
|
), |
|
description=lc_descrp, |
|
) |
|
|
|
llm = get_llm(llm_name, model_temperature, api_key=api_key) |
|
|
|
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) |
|
|
|
agent_chain = initialize_agent( |
|
[sql_tool], |
|
llm, |
|
agent="chat-conversational-react-description", |
|
verbose=True, |
|
memory=memory, |
|
) |
|
|
|
return agent_chain |
|
|
|
|
|
st.title("π¦ Llama Index SQL Sandbox π¦") |
|
st.markdown( |
|
( |
|
"This sandbox uses a sqlite database by default, powered by [Llama Index](https://gpt-index.readthedocs.io/en/latest/index.html) ChatGPT, and LangChain.\n\n" |
|
"The database contains information on health violations and inspections at restaurants in San Francisco." |
|
"This data is spread across three tables - businesses, inspections, and violations.\n\n" |
|
"Using the setup page, you can adjust LLM settings, change the context for the SQL tables, and change the tool description for Langchain." |
|
"The other tabs will perform chatbot and text2sql operations.\n\n" |
|
"Read more about LlamaIndexes structured data support [here!](https://gpt-index.readthedocs.io/en/latest/guides/tutorials/sql_guide.html)" |
|
) |
|
) |
|
|
|
|
|
setup_tab, llama_tab, lc_tab = st.tabs( |
|
["Setup", "Llama Index", "Langchain+Llama Index"] |
|
) |
|
|
|
with setup_tab: |
|
st.subheader("LLM Setup") |
|
api_key = st.text_input("Enter your OpenAI API key here", type="password") |
|
llm_name = st.selectbox( |
|
"Which LLM?", ["text-davinci-003", "gpt-3.5-turbo", "gpt-4"] |
|
) |
|
model_temperature = st.slider( |
|
"LLM Temperature", min_value=0.0, max_value=1.0, step=0.1 |
|
) |
|
|
|
st.subheader("Table Setup") |
|
business_table_descrp = st.text_area( |
|
"Business table description", value=DEFAULT_BUSINESS_TABLE_DESCRP |
|
) |
|
violations_table_descrp = st.text_area( |
|
"Business table description", value=DEFAULT_VIOLATIONS_TABLE_DESCRP |
|
) |
|
inspections_table_descrp = st.text_area( |
|
"Business table description", value=DEFAULT_INSPECTIONS_TABLE_DESCRP |
|
) |
|
|
|
table_context_dict = { |
|
"businesses": business_table_descrp, |
|
"inspections": inspections_table_descrp, |
|
"violations": violations_table_descrp, |
|
} |
|
|
|
use_table_descrp = st.checkbox("Use table descriptions?", value=True) |
|
lc_descrp = st.text_area("LangChain Tool Description", value=DEFAULT_LC_TOOL_DESCRP) |
|
|
|
with llama_tab: |
|
st.subheader("Text2SQL with Llama Index") |
|
if st.button("Initialize Index", key="init_index_1"): |
|
st.session_state["llama_index"] = initialize_index( |
|
llm_name, |
|
model_temperature, |
|
table_context_dict if use_table_descrp else None, |
|
api_key, |
|
) |
|
|
|
if "llama_index" in st.session_state: |
|
query_text = st.text_input( |
|
"Query:", value="Which restaurant has the most violations?" |
|
) |
|
use_nl = st.checkbox("Return natural language response?") |
|
if st.button("Run Query") and query_text: |
|
with st.spinner("Getting response..."): |
|
try: |
|
response = st.session_state["llama_index"].as_query_engine(synthesize_response=use_nl).query(query_text) |
|
response_text = str(response) |
|
response_sql = response.extra_info["sql_query"] |
|
except Exception as e: |
|
response_text = "Error running SQL Query." |
|
response_sql = str(e) |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.text("SQL Result:") |
|
st.markdown(response_text) |
|
|
|
with col2: |
|
st.text("SQL Query:") |
|
st.markdown(response_sql) |
|
|
|
with lc_tab: |
|
st.subheader("Langchain + Llama Index SQL Demo") |
|
|
|
if st.button("Initialize Agent"): |
|
st.session_state["llama_index"] = initialize_index( |
|
llm_name, |
|
model_temperature, |
|
table_context_dict if use_table_descrp else None, |
|
api_key, |
|
) |
|
st.session_state["lc_agent"] = initialize_chain( |
|
llm_name, |
|
model_temperature, |
|
lc_descrp, |
|
api_key, |
|
st.session_state["llama_index"], |
|
) |
|
st.session_state["chat_history"] = [] |
|
|
|
model_input = st.text_input( |
|
"Message:", value="Which restaurant has the most violations?" |
|
) |
|
if "lc_agent" in st.session_state and st.button("Send"): |
|
model_input = "User: " + model_input |
|
st.session_state["chat_history"].append(model_input) |
|
with st.spinner("Getting response..."): |
|
response = st.session_state["lc_agent"].run(input=model_input) |
|
st.session_state["chat_history"].append(response) |
|
|
|
if "chat_history" in st.session_state: |
|
for msg in st.session_state["chat_history"]: |
|
st_message(msg.split("User: ")[-1], is_user="User: " in msg) |
|
|