rengaraj commited on
Commit
087b0d6
1 Parent(s): 0215e86

Upload 5 files

Browse files
Files changed (5) hide show
  1. assist_logo.jpg +0 -0
  2. be.py +115 -0
  3. example.py +48 -0
  4. fe.py +130 -0
  5. requirements.txt +15 -0
assist_logo.jpg ADDED
be.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from example import example
2
+ from datetime import datetime
3
+ import pandas as pd
4
+ # agent will directly create query and run the query in DB
5
+ from langchain.agents import create_sql_agent
6
+ # Simple chain to create the SQL statements, it doesn't execute the query
7
+ from langchain.chains import create_sql_query_chain
8
+ # to execute the query
9
+ from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
10
+ from langchain_community.agent_toolkits import SQLDatabaseToolkit
11
+ from langchain.sql_database import SQLDatabase
12
+ from langchain.agents import AgentExecutor
13
+ from langchain.agents.agent_types import AgentType
14
+ from langchain_experimental.sql import SQLDatabaseChain
15
+ from langchain_community.vectorstores import Chroma
16
+ from langchain.prompts import SemanticSimilarityExampleSelector
17
+ # Prompt input for MYSQL
18
+ from langchain.chains.sql_database.prompt import PROMPT_SUFFIX, _mysql_prompt
19
+ # Create the prompt template for creating the prompt for mysqlprompt
20
+ from langchain.prompts.prompt import PromptTemplate
21
+ from langchain.prompts import FewShotPromptTemplate
22
+ # to create the tools to be used by agent
23
+ from langchain.agents import Tool
24
+
25
+ # create the agent prompts
26
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
27
+ # Huggingface embeddings using Langchain
28
+ from langchain_community.embeddings import HuggingFaceEmbeddings
29
+ from langchain_core.messages import HumanMessage, SystemMessage
30
+ from langchain_core.prompts import ChatPromptTemplate
31
+ from langchain_core.prompts import HumanMessagePromptTemplate
32
+ from langchain_core.output_parsers import StrOutputParser
33
+ # Load Env parameters
34
+ from dotenv import load_dotenv
35
+ from langchain_openai import ChatOpenAI
36
+ from sqlalchemy import create_engine, text, URL
37
+
38
+
39
+ def config():
40
+ load_dotenv() # load env parameters
41
+ llm = ChatOpenAI(temperature=0.5, model="gpt-3.5-turbo") # create LLM
42
+ #llm = OpenAI(temperature=0.5) # create LLM
43
+ return llm
44
+
45
+ # Setting up URL parameter to connect to MySQL Database
46
+ def get_db_chain(question):
47
+ db_user="root"
48
+ db_password="root"
49
+ db_host="localhost"
50
+ db_name="retail"
51
+
52
+ # create LLM
53
+ llm = ChatOpenAI(temperature=0.5, model="gpt-3.5-turbo")
54
+ # Initialize SQL DB using Langchain
55
+ db = SQLDatabase.from_uri(f"mysql://{db_user}:{db_password}@{db_host}/{db_name}")
56
+ toolkit = SQLDatabaseToolkit(db=db, llm=llm)
57
+ embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
58
+ # create the list with only values and ready to be vectorized
59
+ to_vectorize = [" ".join(example.values()) for example in example] # use join to convert Dict to list
60
+ # Setup the Chroma database and vectorize
61
+ vectorstore = Chroma.from_texts(to_vectorize, embedding=embeddings, metadatas=example)
62
+ # Based on the user question, convert them to vector and take the similar looking vectors from Chroma DB
63
+ example_selector = SemanticSimilarityExampleSelector(
64
+ vectorstore = vectorstore,
65
+ k=2)
66
+ example_prompt = PromptTemplate(
67
+ input_variables=["Question", "SQLQuery", "SQLResult", "Answer",],
68
+ template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}\n?Answer: {Answer}",)
69
+ few_shot_prompt = FewShotPromptTemplate(
70
+ example_selector=example_selector, # Hey LLM, if you dont know refer the examples giving in vector DB
71
+ example_prompt=example_prompt, # This is the Prompt template we have created
72
+ prefix=_mysql_prompt, # This is prefix of the prompt
73
+ suffix=PROMPT_SUFFIX, # This is suffix of the prompt
74
+ input_variables=["input", "table_info", "top_k"], # variables used in forming the prompt to LLM
75
+ )
76
+ chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt)
77
+ response = chain.invoke(question)
78
+ return response
79
+ # Call the LLM with the question and the fewshotprompt
80
+ # write_query = create_sql_query_chain(llm=llm,db=db, prompt=few_shot_prompt)
81
+ #print(write_query)
82
+ # Execute the Query using QuerySQLDataBaseTool
83
+ #execute_query = QuerySQLDataBaseTool(db=db)
84
+ # Chain to combine write SQL and Execute SQL
85
+ #chain = write_query | execute_query | llm
86
+ #response = chain.invoke("Question")
87
+ def get_store_address(store):
88
+ url_object = URL.create(
89
+ "mysql",
90
+ username="root",
91
+ password="root", # plain (unescaped) text
92
+ host="localhost",
93
+ database="retail",
94
+ )
95
+ engine = create_engine(url_object)
96
+ #connect to engine
97
+ connection = engine.connect()
98
+ sql_query = "SELECT STORE_NUMBER, STORE_ADDRESS FROM STORES WHERE STORE_NUMBER = " + store
99
+ df = pd.read_sql(sql_query, con=engine)
100
+ response = df.to_string()
101
+ return response
102
+ def outreach_sms_message(outreach_input):
103
+ # create LLM
104
+ llm = ChatOpenAI(temperature=0.5, model="gpt-3.5-turbo", verbose=True)
105
+ prompt = ChatPromptTemplate.from_template("You are a expert in writing a text message for appointment setup with less than 35 words."
106
+ "With {outreach_input}, generate a text message for appointment to be sent to customer")
107
+ output_parser = StrOutputParser()
108
+ chain = prompt | llm | output_parser
109
+ response = chain.invoke({"outreach_input": outreach_input})
110
+ return response
111
+
112
+ #if __name__ == "__main__":
113
+ # chain = get_db_chain()
114
+ # print(chain.run("List of all sales transactions for Trevor Nelson in June 2020"))
115
+ # Setting up URL parameter to connect to MySQL Database
example.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Few Shot Learnings
2
+ example = [
3
+ { 'Question': "Top 10 stores with highest sales and the number of transactions",
4
+ 'SQLQuery': "SELECT store_number, COUNT(store_number) AS num_transactions FROM sales GROUP BY store_number ORDER BY SUM(sale_price) DESC;",
5
+ 'SQLResult': "Result of the SQL query",
6
+ 'Answer': "answer"
7
+ },
8
+ { 'Question': "List of all stores transactions for Sheri Williams",
9
+ 'SQLQuery': "SELECT customer_id, product, quantity, sale_price, sales_channel FROM sales WHERE name = 'Sheri Williams' AND sales_channel = 'st'",
10
+ 'SQLResult': "Result of the SQL Query",
11
+ 'Answer': "answer"
12
+ },
13
+ { 'Question': "Top customer who bought product7 the most in store",
14
+ 'SQLQuery': "SELECT customer_id, name, count(product) as total_product7_purchases from sales where product = 'Product7' and store_number = 1365 group by customer_id, name ORDER BY total_product7_purchases DESC LIMIT 10",
15
+ 'SQLResult': "Result of the SQL Query",
16
+ 'Answer': "answer"
17
+ },
18
+ {
19
+ 'Question': "List of all online transactions for Sheri Williams",
20
+ 'SQLQuery': "SELECT customer_id, product, quantity, sale_price, sales_channel FROM sales WHERE name = 'Sheri Williams' AND sales_channel = 'ol'",
21
+ 'SQLResult': "Result of the SQL query",
22
+ 'Answer': "answer"
23
+ },
24
+ {
25
+ 'Question': "find the product with the highest sales in store 4057",
26
+ 'SQLQuery': "SELECT product, SUM(sale_price) as total_sales_amount FROM sales WHERE store_number = 4057 GROUP BY product ORDER BY total_sales_amount DESC LIMIT 1;",
27
+ 'SQLResult': "Result of the SQL query",
28
+ 'Answer': "answer"
29
+ },
30
+ {
31
+ 'Question': "List of all sales transactions in the last 1 week",
32
+ 'SQLQuery': "SELECT customer_id, product, quantity, sale_price, sales_channel, date FROM sales WHERE date >= CURDATE() - INTERVAL 1 WEEK;",
33
+ 'SQLResult': "Result of the SQL query",
34
+ 'Answer': "answer"
35
+ },
36
+ {
37
+ 'Question': "List of all sales transactions this week",
38
+ 'SQLQuery': "SELECT customer_id, product, quantity, sale_price, sales_channel, date FROM sales WHERE date >= CURDATE() - INTERVAL DAYOFWEEK(CURDATE())-1 DAY AND date < CURDATE() + INTERVAL 1 DAY;",
39
+ 'SQLResult': "Result of the SQL query",
40
+ 'Answer': "answer"
41
+ },
42
+ {
43
+ 'Question': " Top 10 stores with highest sales and the number of transactions with sales amount & store associate name",
44
+ 'SQLQuery': "SELECT s.store_number, COUNT(s.store_number) AS num_transactions, SUM(s.sale_price) AS total_sales, st.store_associate FROM sales s JOIN stores st ON s.store_number = st.store_number GROUP BY s.store_number ORDER BY total_sales DESC;",
45
+ 'SQLResult': "Result of the SQL query",
46
+ 'Answer': "answer"
47
+ }
48
+ ]
fe.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.tools import DuckDuckGoSearchRun
2
+ from langchain.agents.tools import tool
3
+ from langchain import OpenAI
4
+ from langchain.agents import Tool, load_tools, initialize_agent, AgentType
5
+ from langchain.prompts import PromptTemplate
6
+ from langchain.chains import LLMChain
7
+ # Streamlit UI Callback
8
+ from langchain.callbacks import StreamlitCallbackHandler
9
+ from langchain.chains import LLMMathChain
10
+ from langchain.memory import ConversationBufferMemory
11
+ from langchain_openai import ChatOpenAI
12
+ import streamlit as st
13
+ from dotenv import load_dotenv
14
+ from sqlalchemy import create_engine, text, URL
15
+ from htmlTemplates import css, bot_template, user_template
16
+
17
+ import openai
18
+ import os
19
+ import time
20
+
21
+ from be import config, get_db_chain, outreach_sms_message, get_store_address
22
+ from PIL import Image
23
+ def conversation_agent(question):
24
+ llm = ChatOpenAI(temperature=0.5, model="gpt-3.5-turbo") # create LLM
25
+ # search = DuckDuckGoSearchRun()
26
+ llm_math_chain = LLMMathChain(llm=llm, verbose=True)
27
+ tools = [
28
+ Tool(
29
+ name="Calculator",
30
+ func=llm_math_chain.run,
31
+ description="useful when you need to answer questions with math"
32
+ )
33
+ ]
34
+ ######### CREATING ALL THE TOOLS FOR THE AGENT TO USE #####################
35
+ # Create the final SMS message
36
+ outreach_sms_message_tool = Tool(
37
+ name="Outreach SMS Message",
38
+ func=outreach_sms_message,
39
+ description="Create an outreach SMS message for the customer. Pass both user input and the Store Address as ONE SINGLE INPUT STRING. Use this Tool only to create an outreach SMS or Text message. At the end always include the Store Address for appointment confirmation messages"
40
+ )
41
+ #tools.append(outreach_sms_message_tool)
42
+ # Creating a Query Tool - to generate SQL statements and query the database
43
+ get_db_chain_tool = Tool(
44
+ name='Query Generation Tool',
45
+ func=get_db_chain,
46
+ description="ONLY use this tool for query generation and to fetch any information. Use this to Create MYSQL Query with the ORIGINAL User Input question to pull customer, store, product information. MySQL database connections are established within this tool. Use this tool first"
47
+ "During SQL Query Generation, make sure the SELECT list of columns are in GROUP BY clause"
48
+ "Use this to get the store address from the database"
49
+ )
50
+ # create the tool for finding the store details
51
+ get_store_address_tool = Tool(
52
+ name="Get Store Address",
53
+ func=get_store_address,
54
+ description="Use this tool with store number to get the store address. INPUT to this tool is Store number. Do not use this tool if you don't have Store number as input"
55
+ )
56
+ #tools.append(get_db_chain_tool)
57
+ # List all the tools for the agent to use
58
+ tools = [get_db_chain_tool, get_store_address_tool, outreach_sms_message_tool]
59
+ conversational_agent = initialize_agent(
60
+ agent="conversational-react-description",
61
+ tools=tools,
62
+ llm=llm,
63
+ verbose=True,
64
+ max_iterations=10,
65
+ memory=st.session_state.memory,
66
+ handle_parsing_errors=True
67
+ )
68
+ response = conversational_agent.invoke(question)
69
+ return response
70
+
71
+ def main():
72
+ img = Image.open('assist_logo.jpg')
73
+ user_avatar = Image.open('renga_profile.jpg')
74
+ ai_avatar = Image.open('Designer.png')
75
+ load_dotenv() # load env parameters
76
+ st.set_page_config(page_title="Assist", page_icon=img)
77
+ st.write(css, unsafe_allow_html=True)
78
+ # Logo and image next to each other with a space column separating them out for rendering in small devices
79
+ st.title(':blue[Assist] Store Associates')
80
+ with st.sidebar:
81
+ st.image('assist_logo.jpg', width=120)
82
+ st.sidebar.header("Assist App for Store Associates")
83
+ st.write("Assist store associates to get information on Customers, Stores, Product, Sales Analytics, Inventory Management and help with customer outreach")
84
+ st.write(" ")
85
+ st.write("Tasks I can help with:")
86
+ st.write("a. Extract Data/info")
87
+ st.write("b. Outreach message ")
88
+ st.write("c. Send Text to Customers")
89
+ st.write("d. Search websites and look up Product prices & other info")
90
+ st.write("e. Generate charts for greater visualization")
91
+
92
+ if "chat_history" not in st.session_state:
93
+ st.session_state.memory = ConversationBufferMemory(memory_key="chat_history")
94
+ # ini chat history
95
+ if "messages" not in st.session_state:
96
+ st.session_state.messages = []
97
+
98
+ user_question = st.chat_input("Type your question")
99
+ if user_question:
100
+
101
+ assistant_response = conversation_agent(user_question)
102
+ st.session_state.chat_history = assistant_response['chat_history']
103
+ chistory = assistant_response["chat_history"]
104
+ # Process the chat history
105
+ messages = chistory.split("Human: ")
106
+ l = len(messages)
107
+ # Step 3: Print the output
108
+ for i in range(1, l):
109
+ response = messages[i].strip()
110
+ response = response.split("AI:")
111
+ # Print the Human message from history
112
+ with st.chat_message("human"):
113
+ st.markdown(response[0])
114
+ # Print the AI message from history
115
+ with st.chat_message("ai"):
116
+ st.markdown(response[1])
117
+ #Print the last question from user
118
+ with st.chat_message("human"):
119
+ st.markdown(user_question)
120
+ #st.write(user_template.replace("{{MSG}}", user_question), unsafe_allow_html=True)
121
+ # Print the last answer from user
122
+ assistant_response_output = assistant_response["output"]
123
+ with st.chat_message("ai"):
124
+ st.write(assistant_response_output)
125
+ #st.write(bot_template.replace("{{MSG}}", assistant_response_output), unsafe_allow_html=True)
126
+
127
+
128
+ if __name__== '__main__':
129
+ main()
130
+
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ google-generativeai
3
+ python-dotenv
4
+ langchain-community
5
+ langchain-experimental
6
+ mysqlclient
7
+ sqlparse
8
+ mysql-connector-python
9
+ PyMySQL
10
+ pandasai
11
+ langchain
12
+ transformers
13
+ sentence-transformers
14
+ chromadb
15
+ langchain-openai