alejandro commited on
Commit
c796379
·
1 Parent(s): ec96023

feat: create chain that returns sql query

Browse files
Files changed (1) hide show
  1. src/app.py +64 -4
src/app.py CHANGED
@@ -1,13 +1,57 @@
1
  import streamlit as st
2
  from langchain_community.utilities import SQLDatabase
 
 
 
 
 
3
 
4
  def initialize_database(host, port, username, password, database):
5
  db_uri = f"mysql+mysqlconnector://{username}:{password}@{host}:{port}/{database}"
6
  return SQLDatabase.from_uri(db_uri)
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  st.set_page_config(initial_sidebar_state="expanded", page_title="Chat with a MySQL Database", page_icon=":speech_balloon:")
10
 
 
 
 
 
 
 
 
 
11
  with st.sidebar:
12
  st.title("Chat with a MySQL Database")
13
  st.write("This is a simple chat application allows you to chat with a MySQL database.")
@@ -20,16 +64,32 @@ with st.sidebar:
20
 
21
  if st.button("Connect"):
22
  with st.spinner("Connecting to the database..."):
23
- db = initialize_database(
24
  username=st.session_state.username,
25
  password=st.session_state.password,
26
  host=st.session_state.name,
27
  port=st.session_state.port,
28
  database=st.session_state.database
29
  )
30
- st.success("Connected to the database!")
 
31
 
32
  user_query = st.chat_input("Type a message...")
33
 
34
- if user_query != "":
35
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from langchain_community.utilities import SQLDatabase
3
+ from langchain_core.output_parsers import StrOutputParser
4
+ from langchain_core.runnables import RunnablePassthrough
5
+ from langchain_openai import ChatOpenAI
6
+ from langchain_core.messages import HumanMessage, AIMessage
7
+ from dotenv import load_dotenv
8
 
9
  def initialize_database(host, port, username, password, database):
10
  db_uri = f"mysql+mysqlconnector://{username}:{password}@{host}:{port}/{database}"
11
  return SQLDatabase.from_uri(db_uri)
12
 
13
+ def get_response(user_query, chat_history, db):
14
+
15
+ from langchain_core.prompts import ChatPromptTemplate
16
+
17
+ template = """
18
+ Based on the table schema below, write a SQL query that would answer the user's question.
19
+ {schema}
20
+
21
+ Question: {question}
22
+ SQL Query:
23
+ """
24
+
25
+ prompt = ChatPromptTemplate.from_template(template)
26
+
27
+ llm = ChatOpenAI()
28
+
29
+ def get_schema(_):
30
+ return db.get_table_info()
31
+
32
+ sql_chain = (
33
+ RunnablePassthrough.assign(schema=get_schema)
34
+ | prompt
35
+ | llm.bind(stop="\nSQL Result:")
36
+ | StrOutputParser()
37
+ )
38
+
39
+ return sql_chain.invoke({
40
+ "question": user_query
41
+ })
42
+
43
+ load_dotenv()
44
 
45
  st.set_page_config(initial_sidebar_state="expanded", page_title="Chat with a MySQL Database", page_icon=":speech_balloon:")
46
 
47
+ if 'chat_history' not in st.session_state:
48
+ st.session_state.chat_history = [
49
+ AIMessage(content="")
50
+ ]
51
+
52
+ if 'db' not in st.session_state:
53
+ st.session_state.db = None
54
+
55
  with st.sidebar:
56
  st.title("Chat with a MySQL Database")
57
  st.write("This is a simple chat application allows you to chat with a MySQL database.")
 
64
 
65
  if st.button("Connect"):
66
  with st.spinner("Connecting to the database..."):
67
+ st.session_state.db = initialize_database(
68
  username=st.session_state.username,
69
  password=st.session_state.password,
70
  host=st.session_state.name,
71
  port=st.session_state.port,
72
  database=st.session_state.database
73
  )
74
+ if st.session_state.db is not None:
75
+ st.success("Connected to the database!")
76
 
77
  user_query = st.chat_input("Type a message...")
78
 
79
+ if user_query is not None and user_query != "":
80
+ st.session_state.chat_history.append(HumanMessage(content=user_query))
81
+
82
+ with st.chat_message("Human"):
83
+ st.markdown(user_query)
84
+
85
+ with st.chat_message("AI"):
86
+ response = get_response(
87
+ user_query,
88
+ st.session_state.chat_history,
89
+ st.session_state.db
90
+ )
91
+ print(f"Response generated: {response}")
92
+ st.markdown(response)
93
+
94
+ st.session_state.chat_history.append(AIMessage(content=response))
95
+