azenabeel commited on
Commit
b64b1fd
·
verified ·
1 Parent(s): 335a865

Upload 6 files

Browse files

deploying on cloud

Files changed (6) hide show
  1. .gitignore +1 -0
  2. Pipfile +11 -0
  3. app.py +141 -0
  4. image.png +0 -0
  5. project_workflow.png +0 -0
  6. requirements.txt +8 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
Pipfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [[source]]
2
+ url = "https://pypi.org/simple"
3
+ verify_ssl = true
4
+ name = "pypi"
5
+
6
+ [packages]
7
+
8
+ [dev-packages]
9
+
10
+ [requires]
11
+ python_version = "3.11"
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ from langchain_core.messages import AIMessage, HumanMessage
3
+ from langchain_core.prompts import ChatPromptTemplate
4
+ from langchain_core.runnables import RunnablePassthrough
5
+ from langchain_community.utilities import SQLDatabase
6
+ from langchain_core.output_parsers import StrOutputParser
7
+ from langchain_openai import ChatOpenAI
8
+ import streamlit as st
9
+
10
+ load_dotenv()
11
+
12
+ st.set_page_config(page_title="Chat with SQL", page_icon=":speech_ballon")
13
+ st.title("Chat with my MySQL")
14
+
15
+ # session state variable
16
+ if "chat_history" not in st.session_state:
17
+ st.session_state.chat_history = [AIMessage(content="Hello! I'm a SQL assistant. ASk me anything about your database."),]
18
+
19
+
20
+ def init_database(user: str, password: str, host: str, port: str, database: str) -> SQLDatabase:
21
+ # connceting to mysql db using mysql-connector-python driver
22
+ db_uri = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}"
23
+ return SQLDatabase.from_uri(db_uri)
24
+
25
+
26
+ def get_sql_chain(db):
27
+ template = """
28
+ You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
29
+ Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.
30
+
31
+ <SCHEMA>{schema}</SCHEMA>
32
+
33
+ Conversation History: {chat_history}
34
+
35
+ Write only the SQL query and nothing else. Do not wrap the SQL query in any other text, not even backticks.
36
+
37
+ For example(few shot learning):
38
+ Question: which 3 artists have the most tracks?
39
+ SQL Query: SELECT ArtistId, COUNT(*) as track_count FROM Track GROUP BY ArtistId ORDER BY track_count DESC LIMIT 3;
40
+ Question: Name 10 artists
41
+ SQL Query: SELECT Name FROM Artist LIMIT 10;
42
+
43
+ Your turn:
44
+
45
+ Question: {question}
46
+ SQL Query:
47
+ """
48
+
49
+ prompt = ChatPromptTemplate.from_template(template)
50
+ llm = ChatOpenAI(model="gpt-4")
51
+
52
+ def get_schema(_):
53
+ return db.get_table_info()
54
+
55
+ sql_chain = RunnablePassthrough.assign(schema=get_schema) | prompt | llm | StrOutputParser()
56
+
57
+ return sql_chain
58
+
59
+
60
+ def get_response(user_query: str, db: SQLDatabase, chat_history: list):
61
+ sql_chain = get_sql_chain(db)
62
+ template = """
63
+ You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
64
+ Based on the table schema below, question, sql query, and sql response, write a natural language response.
65
+ <SCHEMA>{schema}</SCHEMA>
66
+
67
+ Conversation History: {chat_history}
68
+ SQL Query: <SQL>{query}</SQL>
69
+ User Question: {question}
70
+ SQL Response: {response}
71
+ """
72
+ prompt = ChatPromptTemplate.from_template(template)
73
+ llm = ChatOpenAI()
74
+
75
+ response_chain = (
76
+ RunnablePassthrough.assign(query=sql_chain)
77
+ .assign(schema=lambda _: db.get_table_info(),
78
+ # response=lambda vars: print("variables: ", vars),
79
+ response=lambda vars: db.run(vars["query"]),
80
+ )
81
+ | prompt
82
+ | llm
83
+ | StrOutputParser()
84
+ )
85
+
86
+ return response_chain.invoke({"question": user_query, "chat_history": chat_history})
87
+
88
+
89
+
90
+ with st.sidebar:
91
+ st.subheader("Settings")
92
+ st.write("This is a simple chat application using LLM and MySQL")
93
+ st.write("Connect to the databse and satrt chatting.")
94
+
95
+ st.text_input("Host", value="localhost", key="Host")
96
+ st.text_input("Port", value="3306", key="Port")
97
+ st.text_input("User", value="root", key="User")
98
+ st.text_input("Password", type="password", value="admin", key="Password")
99
+ st.text_input("Database", value="Chinook", key="Database")
100
+
101
+ if st.button("Connect"):
102
+ with st.spinner("Connecting to database..."):
103
+ db = init_database(
104
+ st.session_state["User"],
105
+ st.session_state["Password"],
106
+ st.session_state["Host"],
107
+ st.session_state["Port"],
108
+ st.session_state["Database"],
109
+ )
110
+ st.session_state.db = db
111
+ st.success("Connected to database!")
112
+
113
+ # printing out messages/ chat
114
+ for message in st.session_state.chat_history:
115
+ if isinstance(message, AIMessage):
116
+ with st.chat_message("AI"):
117
+ st.markdown(message.content)
118
+ elif isinstance(message, HumanMessage):
119
+ with st.chat_message("Human"):
120
+ st.markdown(message.content)
121
+
122
+
123
+ user_query = st.chat_input("Type a message...")
124
+ if user_query is not None and user_query.strip() != "":
125
+ # adding to chat history
126
+ st.session_state.chat_history.append(HumanMessage(content=user_query))
127
+
128
+ # displaying user query// with manages the lifecycle of an object
129
+ with st.chat_message("Human"):
130
+ st.markdown(user_query)
131
+
132
+ with st.chat_message("AI"):
133
+ response = get_response(user_query, st.session_state.db, st.session_state.chat_history)
134
+ # sql_chain = get_sql_chain(st.session_state.db)
135
+ # response = sql_chain.invoke({
136
+ # "chat_history": st.session_state.chat_history, # scheam has already been populated in func getsqlchain
137
+ # "question" : user_query
138
+ # })
139
+ st.markdown(response)
140
+
141
+ st.session_state.chat_history.append(AIMessage(content=response))
image.png ADDED
project_workflow.png ADDED
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ langchain
3
+ langchain-openai
4
+ langchain_community
5
+ langchain-groq
6
+ mysql-connector-python
7
+ mysql
8
+ python-dotenv