Storeapp / be.py
rengaraj's picture
Update be.py
4f1d0b7 verified
from example import example
from datetime import datetime
#from datetime import time
import time
from time import sleep
import pandas as pd
import huggingface_hub
from huggingface_hub import Repository
import os
# agent will directly create query and run the query in DB
from langchain.agents import create_sql_agent
# Simple chain to create the SQL statements, it doesn't execute the query
from langchain.chains import create_sql_query_chain
# to execute the query
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
from langchain.agents import AgentExecutor
from langchain.agents.agent_types import AgentType
from langchain_experimental.sql import SQLDatabaseChain
from langchain_community.vectorstores import Chroma
from langchain.prompts import SemanticSimilarityExampleSelector
# Prompt input for MYSQL
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX, _mysql_prompt
# Create the prompt template for creating the prompt for mysqlprompt
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts import FewShotPromptTemplate
# to create the tools to be used by agent
from langchain.agents import Tool
import sqlparse
# create the agent prompts
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
# Huggingface embeddings using Langchain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import HumanMessagePromptTemplate
from langchain_core.output_parsers import StrOutputParser
# Load Env parameters
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from sqlalchemy import create_engine, text, URL
from openai import OpenAI
import openai
import logging
import re
import pandas as pd
import matplotlib.pyplot as plt
import streamlit as st
from datetime import datetime
from datetime import date
client = OpenAI()
def config():
load_dotenv() # load env parameters
llm = ChatOpenAI(temperature=0.5, model="gpt-3.5-turbo") # create LLM
#llm = OpenAI(temperature=0.5) # create LLM
return llm
# Setting up URL parameter to connect to MySQL Database
def get_db_chain(question):
'''
This tool will take the input as user question and query the SQL database to extract the output
Use this tool first to query any data from the database to answer user question.
you can generate multiple sql queries to extract two distinct information.
'''
db_user="admin"
db_password="Epperson"
db_host="retail.cd6uaise2moh.us-west-2.rds.amazonaws.com"
db_name="retail"
# create LLM
llm = ChatOpenAI(temperature=0.5, model="gpt-3.5-turbo")
# Initialize SQL DB using Langchain
db = SQLDatabase.from_uri(f"mysql://{db_user}:{db_password}@{db_host}/{db_name}")
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
# create the list with only values and ready to be vectorized
to_vectorize = [" ".join(example.values()) for example in example] # use join to convert Dict to list
# Setup the Chroma database and vectorize
vectorstore = Chroma.from_texts(to_vectorize, embedding=embeddings, metadatas=example)
# Based on the user question, convert them to vector and take the similar looking vectors from Chroma DB
example_selector = SemanticSimilarityExampleSelector(
vectorstore = vectorstore,
k=2)
example_prompt = PromptTemplate(
input_variables=["Question", "SQLQuery", "SQLResult", "Answer",],
template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}\n?Answer: {Answer}",)
top_k = 30
few_shot_prompt = FewShotPromptTemplate(
example_selector=example_selector, # Hey LLM, if you dont know refer the examples giving in vector DB
example_prompt=example_prompt, # This is the Prompt template we have created
prefix=_mysql_prompt, # This is prefix of the prompt
suffix=PROMPT_SUFFIX, # This is suffix of the prompt
input_variables=["input", "table_info", "top_k"], # variables used in forming the prompt to LLM
)
print(example_selector)
print(_mysql_prompt)
print(PROMPT_SUFFIX)
chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt, return_intermediate_steps=True, top_k=30)
response = chain.invoke(question)
####### Code to extract the SQL Query and the database output and write to a file #######
intermediate_steps = response["intermediate_steps"]
sql_query = intermediate_steps[1]
sql_query_cleaned = sql_query.replace('\n', " ")
sql_query_output = intermediate_steps[3]
st.session_state.sql_query_cleaned = sql_query_cleaned
st.session_state.sql_query_output = sql_query_output
#with open('sql_query.txt', 'w') as file:
# file.write(sql_query_cleaned + '\n')
# file.write(sql_query_output)
#with open('applog.txt', 'w') as file:
# file.write(sql_query_cleaned + '\n')
# file.write(sql_query_output)
####### Code to extract the SQL Query and the database output and write to a file #######
#print(type(intermediate_steps))
#logging.basicConfig(filename='app1.log', level=logging.info)
return response
# Call the LLM with the question and the fewshotprompt
# write_query = create_sql_query_chain(llm=llm,db=db, prompt=few_shot_prompt)
#print(write_query)
# Execute the Query using QuerySQLDataBaseTool
#execute_query = QuerySQLDataBaseTool(db=db)
# Chain to combine write SQL and Execute SQL
#chain = write_query | execute_query | llm
#response = chain.invoke("Question")
def get_store_address(store):
url_object = URL.create(
"mysql",
username="root",
password="root", # plain (unescaped) text
host="localhost",
database="retail",
)
engine = create_engine(url_object)
#connect to engine
connection = engine.connect()
sql_query = "SELECT STORE_NUMBER, STORE_ADDRESS FROM STORES WHERE STORE_NUMBER = " + store
df = pd.read_sql(sql_query, con=engine)
response = df.to_string()
return response
def outreach_sms_message(outreach_input):
# create LLM
llm = ChatOpenAI(temperature=0.5, model="gpt-3.5-turbo", verbose=True)
prompt = ChatPromptTemplate.from_template("You are a expert in writing a text message for appointment setup with less than 35 words."
"With {outreach_input}, generate a text message for appointment to be sent to customer")
output_parser = StrOutputParser()
chain = prompt | llm | output_parser
response = chain.invoke({"outreach_input": outreach_input})
return response
### Function to extract the column names from SQL Query ####
def find_selected_columns(query: str) -> list[str]:
tokens = sqlparse.parse(query)[0].tokens
found_select = False
for token in tokens:
if found_select:
if isinstance(token, sqlparse.sql.IdentifierList):
return [col.value.split(" ")[-1].strip("`").rpartition('.')[-1]
for col in token.tokens if isinstance(col, sqlparse.sql.Identifier)]
else:
found_select = token.match(sqlparse.tokens.Keyword.DML, ["select", "SELECT"])
raise Exception("Could not find a select statement. Weird query :)")
def get_query_columns(sql):
stmt = sqlparse.parse(sql)[0]
columns = []
column_identifiers = []
# get column_identifieres
in_select = False
for token in stmt.tokens:
if isinstance(token, sqlparse.sql.Comment):
continue
if str(token).lower() == 'select':
in_select = True
elif in_select and token.ttype is None:
for identifier in token.get_identifiers():
column_identifiers.append(identifier)
break
# get column names
for column_identifier in column_identifiers:
columns.append(column_identifier.get_name())
return columns
# Function to write the Python code to generate charts
def data_visualization(question):
'''This tool will receive the user question as input to generate the chart for data visualization using Pandas and Matplotlib. Use this tool only when the user question is to 'visualize' or 'summarize' the data or info.
Parameters:
- question: User question to generate the charts
- Column Names: Column names for the data frame
- Data Frame: Dataframe with the actual data for the chart visualization
'''
####### Read the file input to get the SQL Query and the SQL Output
#with open('sql_query.txt', 'r') as file:
# raw_sql_query = file.readline().strip()
# raw_sql_output = file.readline().strip()
raw_sql_query = st.session_state.sql_query_cleaned
raw_sql_output = st.session_state.sql_query_output
print("raw_sql_query:", raw_sql_query)
print("raw_sql_output:", raw_sql_output)
print(type(raw_sql_output))
# Replace "datetime.date" with "date"
cleaned_sql_output = re.sub(r"datetime\.date", "date", raw_sql_output)
# Use regex to find and replace the decimal value
pattern1 = r"Decimal\('(\d+)'\)"
pattern2 = r"Decimal\('(\d+\.\d+)'\)"
cleaned_sql_output = re.sub(pattern1, r"\1", cleaned_sql_output)
cleaned_sql_output = re.sub(pattern2, r"\1", cleaned_sql_output)
#data = cleaned_sql_output
print("cleaned sql output:", cleaned_sql_output)
print(type(cleaned_sql_output))
data = []
data = eval(cleaned_sql_output)
print("data", data)
print(type(data))
# Create Column names
column_names = get_query_columns(raw_sql_query)
print("col names:", column_names)
len_col = len(column_names)
#data = []
#data = [(date(2024, 5, 9), 182492.92, 'Jessica Warren'),
# (date(2024, 5, 9), 70727.75, 'Linda White')]
## Create a dictionary with column names and corresponding data
data_dict = {}
for i in range(len_col):
data_dict[column_names[i]] = [value[i] for value in data]
#print("data_dict:", data_dict)
# dataframe creation and formatting
df = pd.DataFrame(data_dict)
print("df:", df)
if 'sale_date' in df:
df = df.sort_values(by=['sale_date'])
if 'store_number' in df:
str_lst = [str(i) for i in df['store_number']]
df['store_number'] = str_lst
print("df:", df)
#print(st.session_state.chat_history)
#answer = st.session_state.chat_history[-2:]
model_name = 'gpt-3.5-turbo'
prompt = f"""
The dataset is ALREADY loaded into a DataFrame name 'df' DO NOT load the data again.
Dataframe has the following columns: {column_names}
DO NOT create Sample or Example Dataframe.
Before plotting, ensure the data is ready:
1. DO NOT create sample or example dataframe. Dataframe is available in df {df}.
2. Check if columns that are supposed to be numeric are recognized as such. If not attempt to convert them.
Use Package Pandas and Matplotlib ONLY.
Provide SINGLE CODE BLOCK with a solution using Pandas and Matplotlib plots in a single figure to address the following query:
{question}
- DO NOT EXPLAIN the code
- DO NOT COMMENT the code.
- DO NOT create Sample or Example DataFrame
- ALWAYS WRAP UP THE CODE IN A SINGLE CODE BLOCK.
- The code block must start and end with '''
- ALWAYS include Legend and data labels in the chart but skip them if it is overlapping with other charts when creating subplots
- Always import matplotlib.pyplot as plt and import pandas as pd
- format the x & y axis labels so that it doesn't overlap with each other. IF the label is long, position them as slantic
- DO NOT sum the column names in the dataframe unless asked by the user in the question
- Display full date in the chart if available in YYYY-MM-DD format
- Colors to use for background and axes of the figure : #F0F0F6
- Try to use the following color palette for coloring the plots : #8f63ee #ced5ce #a27bf6 #3d3b41
"""
messages = [
{
"role": "system",
"content": "You are a helpful Data Visualization assitant for store associates working in a retail store. Use the question, dataframe df and column names to visualize the data easily, you have to give a single block of code without explaining or commenting the code"
},
{"role": "user", "content": prompt},
]
#with st.status("Generating Python code to create data visualization..."):
# with st.chat_message("assistant", avatar="πŸ“Š"):
# botmsg = st.empty()
# response = []
# for chunk in client.chat.completions.create(
# model=model_name, messages=messages, stream=True):
# text = chunk.choices[0].get("delta", {}).get("content")
# if text:
# response.append(text)
# result = "".join(response).strip()
# botmsg.write(result)
result = client.chat.completions.create(model=model_name, messages=messages)
#print("generatedcode:", generated_code)
raw_code = result.choices[0].message.content
print(result.choices[0].message.content)
print("raw code", raw_code)
code_blocks = re.findall(r"'''(python)?(.*?)'''", raw_code, re.DOTALL)
print("code blocks:", code_blocks)
code = "\n".join([block[1].strip() for block in code_blocks])
print("code", code)
# Assuming 'result' contains the generated content
code_to_stream = result.choices[0].message.content
def generate_content():
for item in code_to_stream:
yield item
# Get the generator
content_generator = generate_content()
print("content generator:", content_generator)
print("result:", result)
# Stream the content using st.write_stream()
with st.status("Generating Python code for data visualization...πŸ“Š"):
# with st.chat_message("assistant", avatar="πŸ“Š"):
#for line in content_generator:
st.write_stream(generate_content())
# time.sleep(0.03)
#print("answer:", answer)
print("code:", code)
#code = data_visualization(question)
# If there's code in the response, try to execute it
#if code != "":
try:
print("****************code executed****************************")
exec(code)
st.pyplot()
except Exception as e:
error_message = str(e)
st.error(
f"πŸ“Ÿ Apologies, failed to execute the code due to the error: {error_message}"
)
st.warning(
"""
πŸ“Ÿ Check the error message and the code executed above to investigate further.
Pro tips:
- Tweak your prompts to overcome the error
- Use the words 'Plot'/ 'Subplot'
- Use simpler, concise words
"""
)
#else:
# st.write(raw_code)
#exec(code)
#st.pyplot()
return code
### Function to extract the column names from SQL Query ####
def find_selected_columns(query: str) -> list[str]:
tokens = sqlparse.parse(query)[0].tokens
found_select = False
for token in tokens:
if found_select:
if isinstance(token, sqlparse.sql.IdentifierList):
return [col.value.split(" ")[-1].strip("`").rpartition('.')[-1]
for col in token.tokens if isinstance(col, sqlparse.sql.Identifier)]
else:
found_select = token.match(sqlparse.tokens.Keyword.DML, ["select", "SELECT"])
raise Exception("Could not find a select statement. Weird query :)")