Spaces:
Sleeping
Sleeping
from example import example | |
from datetime import datetime | |
#from datetime import time | |
import time | |
from time import sleep | |
import pandas as pd | |
# agent will directly create query and run the query in DB | |
from langchain.agents import create_sql_agent | |
# Simple chain to create the SQL statements, it doesn't execute the query | |
from langchain.chains import create_sql_query_chain | |
# to execute the query | |
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool | |
from langchain_community.agent_toolkits import SQLDatabaseToolkit | |
from langchain.sql_database import SQLDatabase | |
from langchain.agents import AgentExecutor | |
from langchain.agents.agent_types import AgentType | |
from langchain_experimental.sql import SQLDatabaseChain | |
from langchain_community.vectorstores import Chroma | |
from langchain.prompts import SemanticSimilarityExampleSelector | |
# Prompt input for MYSQL | |
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX, _mysql_prompt | |
# Create the prompt template for creating the prompt for mysqlprompt | |
from langchain.prompts.prompt import PromptTemplate | |
from langchain.prompts import FewShotPromptTemplate | |
# to create the tools to be used by agent | |
from langchain.agents import Tool | |
import sqlparse | |
# create the agent prompts | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
# Huggingface embeddings using Langchain | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.prompts import HumanMessagePromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
# Load Env parameters | |
from dotenv import load_dotenv | |
from langchain_openai import ChatOpenAI | |
from sqlalchemy import create_engine, text, URL | |
from openai import OpenAI | |
import openai | |
import logging | |
import re | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import streamlit as st | |
from datetime import datetime | |
from datetime import date | |
client = OpenAI() | |
def config(): | |
load_dotenv() # load env parameters | |
llm = ChatOpenAI(temperature=0.5, model="gpt-3.5-turbo") # create LLM | |
#llm = OpenAI(temperature=0.5) # create LLM | |
return llm | |
# Setting up URL parameter to connect to MySQL Database | |
def get_db_chain(question): | |
''' | |
This tool will take the input as user question and query the SQL database to extract the output | |
Use this tool first to query any data from the database to answer user question. | |
you can generate multiple sql queries to extract two distinct information. | |
''' | |
db_user="admin" | |
db_password="Epperson" | |
db_host="retail.cd6uaise2moh.us-west-2.rds.amazonaws.com" | |
db_name="retail" | |
# create LLM | |
llm = ChatOpenAI(temperature=0.5, model="gpt-3.5-turbo") | |
# Initialize SQL DB using Langchain | |
db = SQLDatabase.from_uri(f"mysql://{db_user}:{db_password}@{db_host}/{db_name}") | |
toolkit = SQLDatabaseToolkit(db=db, llm=llm) | |
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2') | |
# create the list with only values and ready to be vectorized | |
to_vectorize = [" ".join(example.values()) for example in example] # use join to convert Dict to list | |
# Setup the Chroma database and vectorize | |
vectorstore = Chroma.from_texts(to_vectorize, embedding=embeddings, metadatas=example) | |
# Based on the user question, convert them to vector and take the similar looking vectors from Chroma DB | |
example_selector = SemanticSimilarityExampleSelector( | |
vectorstore = vectorstore, | |
k=2) | |
example_prompt = PromptTemplate( | |
input_variables=["Question", "SQLQuery", "SQLResult", "Answer",], | |
template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}\n?Answer: {Answer}",) | |
top_k = 30 | |
few_shot_prompt = FewShotPromptTemplate( | |
example_selector=example_selector, # Hey LLM, if you dont know refer the examples giving in vector DB | |
example_prompt=example_prompt, # This is the Prompt template we have created | |
prefix=_mysql_prompt, # This is prefix of the prompt | |
suffix=PROMPT_SUFFIX, # This is suffix of the prompt | |
input_variables=["input", "table_info", "top_k"], # variables used in forming the prompt to LLM | |
) | |
print(example_selector) | |
print(_mysql_prompt) | |
print(PROMPT_SUFFIX) | |
chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt, return_intermediate_steps=True, top_k=30) | |
response = chain.invoke(question) | |
####### Code to extract the SQL Query and the database output and write to a file ####### | |
intermediate_steps = response["intermediate_steps"] | |
sql_query = intermediate_steps[1] | |
sql_query_cleaned = sql_query.replace('\n', " ") | |
sql_query_output = intermediate_steps[3] | |
with open('sql_query.txt', 'w') as file: | |
file.write(sql_query_cleaned + '\n') | |
file.write(sql_query_output) | |
with open('applog.txt', 'w') as file: | |
file.write(sql_query_cleaned + '\n') | |
file.write(sql_query_output) | |
####### Code to extract the SQL Query and the database output and write to a file ####### | |
#print(type(intermediate_steps)) | |
#logging.basicConfig(filename='app1.log', level=logging.info) | |
return response | |
# Call the LLM with the question and the fewshotprompt | |
# write_query = create_sql_query_chain(llm=llm,db=db, prompt=few_shot_prompt) | |
#print(write_query) | |
# Execute the Query using QuerySQLDataBaseTool | |
#execute_query = QuerySQLDataBaseTool(db=db) | |
# Chain to combine write SQL and Execute SQL | |
#chain = write_query | execute_query | llm | |
#response = chain.invoke("Question") | |
def get_store_address(store): | |
url_object = URL.create( | |
"mysql", | |
username="root", | |
password="root", # plain (unescaped) text | |
host="localhost", | |
database="retail", | |
) | |
engine = create_engine(url_object) | |
#connect to engine | |
connection = engine.connect() | |
sql_query = "SELECT STORE_NUMBER, STORE_ADDRESS FROM STORES WHERE STORE_NUMBER = " + store | |
df = pd.read_sql(sql_query, con=engine) | |
response = df.to_string() | |
return response | |
def outreach_sms_message(outreach_input): | |
# create LLM | |
llm = ChatOpenAI(temperature=0.5, model="gpt-3.5-turbo", verbose=True) | |
prompt = ChatPromptTemplate.from_template("You are a expert in writing a text message for appointment setup with less than 35 words." | |
"With {outreach_input}, generate a text message for appointment to be sent to customer") | |
output_parser = StrOutputParser() | |
chain = prompt | llm | output_parser | |
response = chain.invoke({"outreach_input": outreach_input}) | |
return response | |
### Function to extract the column names from SQL Query #### | |
def find_selected_columns(query: str) -> list[str]: | |
tokens = sqlparse.parse(query)[0].tokens | |
found_select = False | |
for token in tokens: | |
if found_select: | |
if isinstance(token, sqlparse.sql.IdentifierList): | |
return [col.value.split(" ")[-1].strip("`").rpartition('.')[-1] | |
for col in token.tokens if isinstance(col, sqlparse.sql.Identifier)] | |
else: | |
found_select = token.match(sqlparse.tokens.Keyword.DML, ["select", "SELECT"]) | |
raise Exception("Could not find a select statement. Weird query :)") | |
def get_query_columns(sql): | |
stmt = sqlparse.parse(sql)[0] | |
columns = [] | |
column_identifiers = [] | |
# get column_identifieres | |
in_select = False | |
for token in stmt.tokens: | |
if isinstance(token, sqlparse.sql.Comment): | |
continue | |
if str(token).lower() == 'select': | |
in_select = True | |
elif in_select and token.ttype is None: | |
for identifier in token.get_identifiers(): | |
column_identifiers.append(identifier) | |
break | |
# get column names | |
for column_identifier in column_identifiers: | |
columns.append(column_identifier.get_name()) | |
return columns | |
# Function to write the Python code to generate charts | |
def data_visualization(question): | |
'''This tool will receive the user question as input to generate the chart for data visualization using Pandas and Matplotlib | |
Parameters: | |
- question: User question to generate the charts | |
- Column Names: Column names for the data frame | |
- Data Frame: Dataframe with the actual data for the chart visualization | |
''' | |
####### Read the file input to get the SQL Query and the SQL Output | |
with open('sql_query.txt', 'r') as file: | |
raw_sql_query = file.readline().strip() | |
raw_sql_output = file.readline().strip() | |
print("raw_sql_query:", raw_sql_query) | |
print("raw_sql_output:", raw_sql_output) | |
#print(type(raw_sql_output)) | |
# Replace "datetime.date" with "date" | |
cleaned_sql_output = re.sub(r"datetime\.date", "date", raw_sql_output) | |
# Use regex to find and replace the decimal value | |
pattern1 = r"Decimal\('(\d+)'\)" | |
pattern2 = r"Decimal\('(\d+\.\d+)'\)" | |
cleaned_sql_output = re.sub(pattern1, r"\1", cleaned_sql_output) | |
cleaned_sql_output = re.sub(pattern2, r"\1", cleaned_sql_output) | |
#data = cleaned_sql_output | |
#print("cleaned sql output:", cleaned_sql_output) | |
#print(type(cleaned_sql_output)) | |
data = [] | |
data = eval(cleaned_sql_output) | |
#print("data", data) | |
#print(type(data)) | |
# Create Column names | |
column_names = get_query_columns(raw_sql_query) | |
print("col names:", column_names) | |
len_col = len(column_names) | |
#data = [] | |
#data = [(date(2024, 5, 9), 182492.92, 'Jessica Warren'), | |
# (date(2024, 5, 9), 70727.75, 'Linda White')] | |
## Create a dictionary with column names and corresponding data | |
data_dict = {} | |
for i in range(len_col): | |
data_dict[column_names[i]] = [value[i] for value in data] | |
#print("data_dict:", data_dict) | |
# dataframe creation and formatting | |
#df = pd.DataFrame(data_dict) | |
#print("df:", df) | |
if 'sale_date' in df: | |
df = df.sort_values(by=['sale_date']) | |
if 'store_number' in df: | |
str_lst = [str(i) for i in df['store_number']] | |
df['store_number'] = str_lst | |
#print("df:", df) | |
#print(st.session_state.chat_history) | |
#answer = st.session_state.chat_history[-2:] | |
model_name = 'gpt-3.5-turbo' | |
prompt = f""" | |
The dataset is ALREADY loaded into a DataFrame name 'df' DO NOT load the data again. | |
Dataframe has the following columns: {column_names} | |
DO NOT create Sample or Example Dataframe. | |
Before plotting, ensure the data is ready: | |
1. DO NOT create sample or example dataframe. Dataframe is available in df {df}. | |
2. Check if columns that are supposed to be numeric are recognized as such. If not attempt to convert them. | |
Use Package Pandas, numpy, Seaborn and Matplotlib ONLY. | |
Provide SINGLE CODE BLOCK with a solution using Pandas and Matplotlib plots in a single figure to address the following query: | |
{question} | |
- DO NOT EXPLAIN the code | |
- DO NOT COMMENT the code. | |
- DO NOT create Sample or Example DataFrame | |
- ALWAYS WRAP UP THE CODE IN A SINGLE CODE BLOCK. | |
- The code block must start and end with ''' | |
- ALWAYS include Legend and data labels in the chart but skip them if it is overlapping with other charts when creating subplots | |
- Always import matplotlib.pyplot as plt and import pandas as pd | |
- format the x & y axis labels so that it doesn't overlap with each other. IF the label is long, position them as slantic | |
- DO NOT sum the column names in the dataframe unless asked by the user in the question | |
- Display full date in the chart if available in YYYY-MM-DD format | |
- Colors to use for background and axes of the figure : #F0F0F6 | |
- Try to use the following color palette for coloring the plots : #8f63ee #ced5ce #a27bf6 #3d3b41 | |
""" | |
messages = [ | |
{ | |
"role": "system", | |
"content": "You are a helpful Data Visualization assitant for store associates working in a retail store. Use the question, dataframe df and column names to visualize the data easily, you have to give a single block of code without explaining or commenting the code" | |
}, | |
{"role": "user", "content": prompt}, | |
] | |
#with st.status("Generating Python code to create data visualization..."): | |
# with st.chat_message("assistant", avatar="π"): | |
# botmsg = st.empty() | |
# response = [] | |
# for chunk in client.chat.completions.create( | |
# model=model_name, messages=messages, stream=True): | |
# text = chunk.choices[0].get("delta", {}).get("content") | |
# if text: | |
# response.append(text) | |
# result = "".join(response).strip() | |
# botmsg.write(result) | |
result = client.chat.completions.create(model=model_name, messages=messages) | |
#print("generatedcode:", generated_code) | |
raw_code = result.choices[0].message.content | |
#print(result.choices[0].message.content) | |
#print("raw code", raw_code) | |
code_blocks = re.findall(r"'''(python)?(.*?)'''", raw_code, re.DOTALL) | |
#print("code blocks:", code_blocks) | |
code = "\n".join([block[1].strip() for block in code_blocks]) | |
#print("code", code) | |
# Assuming 'result' contains the generated content | |
code_to_stream = result.choices[0].message.content | |
def generate_content(): | |
for item in code_to_stream: | |
yield item | |
# Get the generator | |
content_generator = generate_content() | |
#print("content generator:", content_generator) | |
#print("result:", result) | |
# Stream the content using st.write_stream() | |
with st.status("Generating Python code for data visualization...π"): | |
# with st.chat_message("assistant", avatar="π"): | |
#for line in content_generator: | |
st.write_stream(generate_content()) | |
# time.sleep(0.03) | |
#print("answer:", answer) | |
print("code:", code) | |
#code = data_visualization(question) | |
# If there's code in the response, try to execute it | |
#if code != "": | |
try: | |
print("****************code executed****************************") | |
exec(code) | |
st.pyplot() | |
except Exception as e: | |
error_message = str(e) | |
st.error( | |
f"π Apologies, failed to execute the code due to the error: {error_message}" | |
) | |
st.warning( | |
""" | |
π Check the error message and the code executed above to investigate further. | |
Pro tips: | |
- Tweak your prompts to overcome the error | |
- Use the words 'Plot'/ 'Subplot' | |
- Use simpler, concise words | |
""" | |
) | |
#else: | |
# st.write(raw_code) | |
#exec(code) | |
#st.pyplot() | |
return code | |
### Function to extract the column names from SQL Query #### | |
def find_selected_columns(query: str) -> list[str]: | |
tokens = sqlparse.parse(query)[0].tokens | |
found_select = False | |
for token in tokens: | |
if found_select: | |
if isinstance(token, sqlparse.sql.IdentifierList): | |
return [col.value.split(" ")[-1].strip("`").rpartition('.')[-1] | |
for col in token.tokens if isinstance(col, sqlparse.sql.Identifier)] | |
else: | |
found_select = token.match(sqlparse.tokens.Keyword.DML, ["select", "SELECT"]) | |
raise Exception("Could not find a select statement. Weird query :)") | |