Spaces:
Runtime error
Runtime error
ramhemanth580
commited on
Commit
•
8d66574
1
Parent(s):
f34d870
Upload 8 files
Browse files- app.py +56 -0
- database_schema.png +0 -0
- database_table_descriptions.csv +9 -0
- examples.py +149 -0
- langchain_utils.py +64 -0
- prompts.py +38 -0
- requirements.txt +27 -0
- table_details.py +114 -0
app.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
import google.generativeai as genai
|
5 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
6 |
+
from langchain_utils import get_chain
|
7 |
+
from langchain.memory import ChatMessageHistory
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
st.title("Langchain NL2SQL Chatbot")
|
11 |
+
|
12 |
+
# Set Google GenAI API key from Streamlit secrets
|
13 |
+
#client = OpenAI(api_key="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")
|
14 |
+
|
15 |
+
load_dotenv()
|
16 |
+
genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
|
17 |
+
llm = ChatGoogleGenerativeAI(model="gemini-pro",temperature=0,convert_system_message_to_human=True)
|
18 |
+
|
19 |
+
# Set a default model
|
20 |
+
if "Gemini_model" not in st.session_state:
|
21 |
+
st.session_state["Gemini_model"] = "gemini-pro"
|
22 |
+
|
23 |
+
history = ChatMessageHistory()
|
24 |
+
|
25 |
+
if "messages" not in st.session_state:
|
26 |
+
# print("Creating session state")
|
27 |
+
st.session_state.messages = []
|
28 |
+
|
29 |
+
def invoke_chain(question,messages):
|
30 |
+
chain = get_chain()
|
31 |
+
#history = create_history(messages)
|
32 |
+
response = chain.invoke({"question": question,"top_k":3,"messages":history.messages})
|
33 |
+
# history.add_user_message(question)
|
34 |
+
# history.add_ai_message(response)
|
35 |
+
return response
|
36 |
+
|
37 |
+
question = st.text_input("Ask a Question about the database")
|
38 |
+
|
39 |
+
|
40 |
+
# if question :
|
41 |
+
# st.session_state.messages.append({"role": "user", "content": question})
|
42 |
+
# history.add_user_message(question)
|
43 |
+
# response = invoke_chain(question, st.session_state.messages)
|
44 |
+
# history.add_ai_message(response)
|
45 |
+
# st.session_state.messages.append({"role": "assistant", "content": response})
|
46 |
+
if st.button("submit") :
|
47 |
+
if question :
|
48 |
+
response = invoke_chain(question, st.session_state.messages)
|
49 |
+
st.markdown(response)
|
50 |
+
|
51 |
+
# Set up the sidebar with a button
|
52 |
+
st.sidebar.title("Database Info")
|
53 |
+
if st.sidebar.button('Show Database Schema'):
|
54 |
+
# Display the database schema image when the button is clicked
|
55 |
+
image = Image.open('database_schema.PNG')
|
56 |
+
st.image(image, caption='Database Schema', use_column_width=True)
|
database_schema.png
ADDED
database_table_descriptions.csv
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Table,Description
|
2 |
+
productlines,"Stores information about the different product lines offered by the company, including a unique name, textual description, HTML description, and image. Categorizes products into different lines."
|
3 |
+
products,"Contains details of each product sold by the company, including code, name, product line, scale, vendor, description, stock quantity, buy price, and MSRP. Linked to the productlines table."
|
4 |
+
offices,"Holds data on the company's sales offices, including office code, city, phone number, address, state, country, postal code, and territory. Each office is uniquely identified by its office code."
|
5 |
+
employees,"Stores information about employees, including number, last name, first name, job title, contact info, and office code. Links to offices and maps organizational structure through the reportsTo attribute."
|
6 |
+
customers,"Captures data on customers, including customer number, name, contact details, address, assigned sales rep, and credit limit. Central to managing customer relationships and sales processes."
|
7 |
+
payments,"Records payments made by customers, tracking the customer number, check number, payment date, and amount. Linked to the customers table for financial tracking and account management."
|
8 |
+
orders,"Details each sales order placed by customers, including order number, dates, status, comments, and customer number. Linked to the customers table, tracking sales transactions."
|
9 |
+
orderdetails,"Describes individual line items for each sales order, including order number, product code, quantity, price, and order line number. Links orders to products, detailing the items sold."
|
examples.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
examples = [
|
2 |
+
{
|
3 |
+
"input": "List all customers in France with a credit limit over 20,000.",
|
4 |
+
"query": "SELECT * FROM customers WHERE country = 'France' AND creditLimit > 20000;"
|
5 |
+
},
|
6 |
+
{
|
7 |
+
"input": "Get the highest payment amount made by any customer.",
|
8 |
+
"query": "SELECT MAX(amount) FROM payments;"
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"input": "Show product details for products in the 'Motorcycles' product line.",
|
12 |
+
"query": "SELECT * FROM products WHERE productLine = 'Motorcycles';"
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"input": "Retrieve the names of employees who report to employee number 1002.",
|
16 |
+
"query": "SELECT firstName, lastName FROM employees WHERE reportsTo = 1002;"
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"input": "List all products with a stock quantity less than 7000.",
|
20 |
+
"query": "SELECT productName, quantityInStock FROM products WHERE quantityInStock < 7000;"
|
21 |
+
},
|
22 |
+
{
|
23 |
+
'input':"what is price of `1968 Ford Mustang`",
|
24 |
+
"query": "SELECT `buyPrice`, `MSRP` FROM products WHERE `productName` = '1968 Ford Mustang' LIMIT 1;"
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"input": "List products sold by order date.",
|
28 |
+
"query": "SELECT productName , orderDate , DAYNAME(orderDate) AS 'DayName' FROM products INNER JOIN orderdetails ON products.productCode = orderdetails.productCode INNER JOIN Orders ON orderdetails.orderNumber = orders.orderNumber WHERE DAYNAME(Orders.orderDate) = 'MONDAY';"
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"input": "List the order dates in descending order for orders for the 1940 Ford Pickup Truck.",
|
32 |
+
"query": "SELECT DISTINCT(products.productName), orders.orderDate FROM orders JOIN orderdetails ON orderdetails.orderNumber = orders.orderNumber JOIN products ON orderdetails.productCode = products.productCode WHERE productName = '1940 Ford Pickup Truck' ORDER BY orderDate DESC;"
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"input": "List the names of customers and their corresponding order number where a particular order from that customer has a value greater than $25,000.",
|
36 |
+
"query": "SELECT customers.customerName, orders.orderNumber, SUM(orderdetails.priceEach * orderdetails.quantityOrdered) AS tot_value FROM customers JOIN orders ON customers.customerNumber = orders.customerNumber JOIN orderdetails ON orders.orderNumber = orderdetails.orderNumber GROUP BY customers.customerName, orders.orderNumber HAVING tot_value > 25000 ORDER BY customers.customerName;"
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"input": "For orders containing more than two products, report those products that constitute more than 50% of the value of the order.",
|
40 |
+
"query": "SELECT orderNumber, productName, ProductsCount ,contribution FROM (SELECT orderNumber, productCode, (SELECT Count(*) FROM orderdetails WHERE OrderNumber = Main.orderNumber) As 'ProductsCount', quantityOrdered*priceEach As 'Product Value', (quantityOrdered*priceEach / (SELECT SUM(quantityOrdered*priceEach) FROM orderdetails WHERE orderNumber = Main.orderNumber ))*100 As 'Contribution' FROM orderdetails Main ORDER BY orderNumber) DataTable INNER JOIN Products ON Products.productCode = DataTable.productCode WHERE ProductsCount > 2 AND Contribution > 50;"
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"input": "List all the products purchased by Herkku Gifts.",
|
44 |
+
"query": "SELECT productName FROM products INNER JOIN orderdetails od on products.productCode = od.productCode INNER JOIN orders o on od.orderNumber = o.orderNumber INNER JOIN customers c on o.customerNumber = c.customerNumber WHERE c.customerName = 'Herkku Gifts';"
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"input": "Find products containing the name 'Ford'.",
|
48 |
+
"query": "SELECT productName AS 'Products' FROM Products WHERE productName LIKE '%Ford%';"
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"input": "List products ending in 'ship'.",
|
52 |
+
"query": "SELECT productName FROM products WHERE productName LIKE '%ship';"
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"input": "Report the number of customers in Denmark, Norway, and Sweden.",
|
56 |
+
"query": "SELECT customerName FROM Customers WHERE country IN ('Denmark','Norway','Sweden');"
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"input": "What are the products with a product code in the range S700_1000 to S700_1499",
|
60 |
+
"query": "SELECT productCode,productName FROM Products WHERE RIGHT(productCode,4) BETWEEN 1000 AND 1499 ORDER BY RIGHT(productCode,4);"
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"input": "Which customers have a digit in their name?",
|
64 |
+
"query": "SELECT customerName FROM Customers WHERE customerName RLIKE '[0-9]';"
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"input": "List the names of employees called Dianne or Diane.",
|
68 |
+
"query": "SELECT CONCAT(firstName,' ',lastName) AS 'Employee Name' FROM Employees WHERE lastName RLIKE 'Dianne|Diane' OR firstName RLIKE 'Dianne|Diane';"
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"input": "List the products containing ship or boat in their product name.",
|
72 |
+
"query": "SELECT productName FROM Products WHERE productName RLIKE 'ship|boat';"
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"input": "List the products with a product code beginning with S700.",
|
76 |
+
"query": "SELECT productCode, productName FROM Products WHERE productCode LIKE 'S700%';"
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"input": "Find products containing the name 'Ford'.",
|
80 |
+
"query": "SELECT productName As 'Products' FROM Products WHERE productName LIKE '%Ford%';"
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"input": "List products ending in 'ship'.",
|
84 |
+
"query": "SELECT productName FROM products WHERE productName LIKE '%ship';"
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"input": "Report the number of customers in Denmark, Norway, and Sweden.",
|
88 |
+
"query": "SELECT customerName FROM Customers WHERE country IN ('Denmark','Norway','Sweden');"
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"input": "what is the minimum payment received ?",
|
92 |
+
"query": "SELECT min(amount) As 'Minimum Payment' FROM payments;"
|
93 |
+
}
|
94 |
+
]
|
95 |
+
|
96 |
+
from langchain_community.vectorstores import Chroma
|
97 |
+
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
|
98 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
99 |
+
import google.generativeai as genai
|
100 |
+
import streamlit as st
|
101 |
+
import os
|
102 |
+
from dotenv import load_dotenv
|
103 |
+
|
104 |
+
load_dotenv()
|
105 |
+
load_dotenv()
|
106 |
+
genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
|
107 |
+
# Access the value of Huggingface_API_KEY
|
108 |
+
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
|
109 |
+
#embeddings = HuggingFaceEmbeddings(huggingfacehub_api_token=HF_API_TOKEN,model_name="sentence-transformers/all-MiniLM-L6-v2")
|
110 |
+
#embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
111 |
+
|
112 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
113 |
+
|
114 |
+
class CustomGoogleGenerativeAIEmbeddings:
|
115 |
+
def __init__(self, model, task_type=None):
|
116 |
+
# Initialize the GoogleGenerativeAIEmbeddings with the model and task type
|
117 |
+
self.embeddings = GoogleGenerativeAIEmbeddings(model=model, task_type=task_type)
|
118 |
+
|
119 |
+
def __call__(self, input):
|
120 |
+
# Use the embed_query method for single inputs
|
121 |
+
return self.embeddings.embed_query(input)
|
122 |
+
|
123 |
+
def embed_query(self, text):
|
124 |
+
# Use the embed_query method to generate an embedding for a single piece of text
|
125 |
+
return self.embeddings.embed_query(text)
|
126 |
+
|
127 |
+
def embed_documents(self, documents):
|
128 |
+
# Use the embed_documents method to generate embeddings for multiple pieces of text
|
129 |
+
return self.embeddings.embed_documents(documents)
|
130 |
+
|
131 |
+
# Usage
|
132 |
+
model = "models/embedding-001" # Replace with your actual model name
|
133 |
+
task_type = "retrieval_document" # Replace with your actual task type if needembeddings = CustomGoogleGenerativeAIEmbeddings(model=model, task_type=task_type)
|
134 |
+
|
135 |
+
embeddings = CustomGoogleGenerativeAIEmbeddings(model=model, task_type=task_type)
|
136 |
+
|
137 |
+
vectorstore = Chroma()
|
138 |
+
vectorstore.delete_collection()
|
139 |
+
|
140 |
+
@st.cache_resource
|
141 |
+
def get_example_selector():
|
142 |
+
example_selector = SemanticSimilarityExampleSelector.from_examples(
|
143 |
+
examples,
|
144 |
+
embeddings,
|
145 |
+
vectorstore,
|
146 |
+
k=4,
|
147 |
+
input_keys=["input"],
|
148 |
+
)
|
149 |
+
return example_selector
|
langchain_utils.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
from operator import itemgetter
|
4 |
+
load_dotenv()
|
5 |
+
|
6 |
+
db_user = os.getenv("db_user")
|
7 |
+
db_password = os.getenv("db_password")
|
8 |
+
db_host = os.getenv("db_host")
|
9 |
+
db_name = os.getenv("db_name")
|
10 |
+
|
11 |
+
|
12 |
+
import google.generativeai as genai
|
13 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
14 |
+
|
15 |
+
from langchain_community.utilities.sql_database import SQLDatabase
|
16 |
+
from langchain.chains import create_sql_query_chain
|
17 |
+
from langchain_openai import ChatOpenAI
|
18 |
+
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
|
19 |
+
from langchain.memory import ChatMessageHistory
|
20 |
+
from langchain_core.output_parsers import StrOutputParser
|
21 |
+
from langchain_core.runnables import RunnablePassthrough
|
22 |
+
|
23 |
+
from table_details import table_chain as select_table
|
24 |
+
from prompts import final_prompt, answer_prompt
|
25 |
+
|
26 |
+
import streamlit as st
|
27 |
+
|
28 |
+
genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
|
29 |
+
llm = ChatGoogleGenerativeAI(model="gemini-pro",temperature=0,convert_system_message_to_human=True)
|
30 |
+
|
31 |
+
@st.cache_resource
|
32 |
+
def get_chain():
|
33 |
+
#print("Creating chain")
|
34 |
+
db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")
|
35 |
+
generate_query = create_sql_query_chain(llm, db,final_prompt)
|
36 |
+
execute_query = QuerySQLDataBaseTool(db=db)
|
37 |
+
rephrase_answer = answer_prompt | llm | StrOutputParser()
|
38 |
+
# chain = generate_query | execute_query
|
39 |
+
chain = (
|
40 |
+
RunnablePassthrough.assign(table_names_to_use=select_table) |
|
41 |
+
RunnablePassthrough.assign(query=generate_query).assign(
|
42 |
+
result=itemgetter("query") | execute_query
|
43 |
+
)
|
44 |
+
| rephrase_answer
|
45 |
+
)
|
46 |
+
|
47 |
+
return chain
|
48 |
+
|
49 |
+
def create_history(messages):
|
50 |
+
history = ChatMessageHistory()
|
51 |
+
for message in messages:
|
52 |
+
if message["role"] == "user":
|
53 |
+
history.add_user_message(message["content"])
|
54 |
+
else:
|
55 |
+
history.add_ai_message(message["content"])
|
56 |
+
return history
|
57 |
+
|
58 |
+
def invoke_chain(question,messages):
|
59 |
+
chain = get_chain()
|
60 |
+
history = create_history(messages)
|
61 |
+
response = chain.invoke({"question": question,"top_k":3,"messages":history.messages})
|
62 |
+
history.add_user_message(question)
|
63 |
+
history.add_ai_message(response)
|
64 |
+
return response
|
prompts.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from examples import get_example_selector
|
3 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate
|
4 |
+
|
5 |
+
example_prompt = ChatPromptTemplate.from_messages(
|
6 |
+
[
|
7 |
+
("human", "{input}\nSQLQuery:"),
|
8 |
+
("ai", "{query}"),
|
9 |
+
]
|
10 |
+
)
|
11 |
+
few_shot_prompt = FewShotChatMessagePromptTemplate(
|
12 |
+
example_prompt=example_prompt,
|
13 |
+
example_selector=get_example_selector(),
|
14 |
+
input_variables=["input","top_k"],
|
15 |
+
)
|
16 |
+
|
17 |
+
final_prompt = ChatPromptTemplate.from_messages(
|
18 |
+
[
|
19 |
+
(
|
20 |
+
"system", """You are a MySQL expert, Given an input question ,create a syntactically correct MySQL query to run.Unless otherwise specificed.\n\n
|
21 |
+
Here is the relevant table info: {table_info}\n\n
|
22 |
+
Below are a number of examples of questions and their corresponding SQL queries. Return the syntactically correct SQL query only and nothing else.\n\n
|
23 |
+
"""
|
24 |
+
),
|
25 |
+
few_shot_prompt,
|
26 |
+
MessagesPlaceholder(variable_name="messages"),
|
27 |
+
("human", "{input}"),
|
28 |
+
]
|
29 |
+
)
|
30 |
+
|
31 |
+
answer_prompt = PromptTemplate.from_template(
|
32 |
+
"""Given the following user question, corresponding SQL query, and SQL result, answer the user question.
|
33 |
+
|
34 |
+
Question: {question}
|
35 |
+
SQL Query: {query}
|
36 |
+
SQL Result: {result}
|
37 |
+
Answer: """
|
38 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
streamlit_chat
|
3 |
+
|
4 |
+
python-dotenv
|
5 |
+
chromadb
|
6 |
+
faiss-cpu
|
7 |
+
|
8 |
+
logging
|
9 |
+
warnings
|
10 |
+
operator
|
11 |
+
typing
|
12 |
+
ast
|
13 |
+
PIL
|
14 |
+
|
15 |
+
pandas
|
16 |
+
numpy
|
17 |
+
|
18 |
+
google-generativeai
|
19 |
+
langchain
|
20 |
+
langchain_community
|
21 |
+
langchain_core
|
22 |
+
langchain_google_genai
|
23 |
+
|
24 |
+
sentence-transformers==2.2.2
|
25 |
+
|
26 |
+
mysql-connector-python
|
27 |
+
pymysql
|
table_details.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
import streamlit as st
|
6 |
+
from operator import itemgetter
|
7 |
+
#from langchain.chains.openai_tools import create_extraction_chain_pydantic
|
8 |
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
9 |
+
#from langchain_openai import ChatOpenAI
|
10 |
+
from langchain.chains import LLMChain
|
11 |
+
from langchain_core.prompts import ChatPromptTemplate
|
12 |
+
|
13 |
+
import google.generativeai as genai
|
14 |
+
|
15 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
16 |
+
|
17 |
+
|
18 |
+
from typing import List
|
19 |
+
|
20 |
+
load_dotenv()
|
21 |
+
genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
|
22 |
+
llm = ChatGoogleGenerativeAI(model="gemini-pro",temperature=0,convert_system_message_to_human=True)
|
23 |
+
|
24 |
+
@st.cache_data
|
25 |
+
def get_table_details():
|
26 |
+
# Read the CSV file into a DataFrame
|
27 |
+
table_description = pd.read_csv("database_table_descriptions.csv")
|
28 |
+
table_docs = []
|
29 |
+
|
30 |
+
# Iterate over the DataFrame rows to create Document objects
|
31 |
+
table_details = ""
|
32 |
+
for index, row in table_description.iterrows():
|
33 |
+
table_details = table_details + "Table Name:" + row['Table'] + "\n" + "Table Description:" + row['Description'] + "\n\n"
|
34 |
+
|
35 |
+
return table_details
|
36 |
+
|
37 |
+
|
38 |
+
class Table(BaseModel):
|
39 |
+
"""Table in SQL database."""
|
40 |
+
|
41 |
+
name: str = Field(description="Name of table in SQL database.")
|
42 |
+
|
43 |
+
table_details = get_table_details()
|
44 |
+
|
45 |
+
prompt2 = ChatPromptTemplate.from_template(
|
46 |
+
"""
|
47 |
+
You are a helpful Data science assistant , Your objective is to analyze the following table descriptions and Return the names of ALL the SQL tables that MIGHT be relevant to the question: {question}
|
48 |
+
\n\nRemember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed.and you should return the table names as a list
|
49 |
+
for example question : which customers made the top 5 highest payments
|
50 |
+
the desired answer should be ['customers','payments']
|
51 |
+
\n\nThe tables descriptions are:
|
52 |
+
Table Name:productlines
|
53 |
+
Table Description:Stores information about the different product lines offered by the company, including a unique name, textual description, HTML description, and image. Categorizes products into different lines.
|
54 |
+
|
55 |
+
Table Name:products
|
56 |
+
Table Description:Contains details of each product sold by the company, including code, name, product line, scale, vendor, description, stock quantity, buy price, and MSRP. Linked to the productlines table.
|
57 |
+
|
58 |
+
Table Name:offices
|
59 |
+
Table Description:Holds data on the company's sales offices, including office code, city, phone number, address, state, country, postal code, and territory. Each office is uniquely identified by its office code.
|
60 |
+
|
61 |
+
Table Name:employees
|
62 |
+
Table Description:Stores information about employees, including number, last name, first name, job title, contact info, and office code. Links to offices and maps organizational structure through the reportsTo attribute.
|
63 |
+
|
64 |
+
Table Name:customers
|
65 |
+
Table Description:Captures data on customers, including customer number, name, contact details, address, assigned sales rep, and credit limit. Central to managing customer relationships and sales processes.
|
66 |
+
|
67 |
+
Table Name:payments
|
68 |
+
Table Description:Records payments made by customers, tracking the customer number, check number, payment date, and amount. Linked to the customers table for financial tracking and account management.
|
69 |
+
|
70 |
+
Table Name:orders
|
71 |
+
Table Description:Details each sales order placed by customers, including order number, dates, status, comments, and customer number. Linked to the customers table, tracking sales transactions.
|
72 |
+
|
73 |
+
Table Name:orderdetails
|
74 |
+
Table Description:Describes individual line items for each sales order, including order number, product code, quantity, price, and order line number. Links orders to products, detailing the items sold.
|
75 |
+
|
76 |
+
"""
|
77 |
+
)
|
78 |
+
|
79 |
+
from typing import List, Dict
|
80 |
+
import ast
|
81 |
+
|
82 |
+
# Assuming Table is a Pydantic model or similar
|
83 |
+
class Table:
|
84 |
+
name: str
|
85 |
+
|
86 |
+
def get_tables(output: Dict) -> List[str]:
|
87 |
+
# Extract the 'text' field from the output, which contains the list as a string
|
88 |
+
text_output = output.get('text', '')
|
89 |
+
|
90 |
+
try:
|
91 |
+
# Safely evaluate the string representation of the list
|
92 |
+
tables_list = ast.literal_eval(text_output)
|
93 |
+
# Ensure that the result is indeed a list
|
94 |
+
if isinstance(tables_list, list):
|
95 |
+
# Extract the table names if 'tables_list' is a list of Table objects
|
96 |
+
# If it's already a list of strings, you can return it directly
|
97 |
+
return [table.name if isinstance(table, Table) else table for table in tables_list]
|
98 |
+
except (ValueError, SyntaxError):
|
99 |
+
# Handle the case where the text output is not a valid list representation
|
100 |
+
return []
|
101 |
+
|
102 |
+
table_chain = {"question": itemgetter("question")} | LLMChain(llm=llm, prompt=prompt2) | get_tables
|
103 |
+
|
104 |
+
|
105 |
+
# table_names = "\n".join(db.get_usable_table_names())
|
106 |
+
# table_details = get_table_details()
|
107 |
+
# table_details_prompt = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
|
108 |
+
# The tables are:
|
109 |
+
|
110 |
+
# {table_details}
|
111 |
+
|
112 |
+
# Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""
|
113 |
+
|
114 |
+
# table_chain = {"input": itemgetter("question")} | create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt) | get_tables
|