Spaces:
Running
Running
from example import example | |
from datetime import datetime | |
#from datetime import time | |
import time | |
from time import sleep | |
import pandas as pd | |
import huggingface_hub | |
from huggingface_hub import Repository | |
import os | |
# agent will directly create query and run the query in DB | |
from langchain.agents import create_sql_agent | |
# Simple chain to create the SQL statements, it doesn't execute the query | |
from langchain.chains import create_sql_query_chain | |
# to execute the query | |
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool | |
from langchain_community.agent_toolkits import SQLDatabaseToolkit | |
from langchain.sql_database import SQLDatabase | |
from langchain.agents import AgentExecutor | |
from langchain.agents.agent_types import AgentType | |
from langchain_experimental.sql import SQLDatabaseChain | |
from langchain_community.vectorstores import Chroma | |
from langchain.prompts import SemanticSimilarityExampleSelector | |
# Prompt input for MYSQL | |
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX, _mysql_prompt | |
# Create the prompt template for creating the prompt for mysqlprompt | |
from langchain.prompts.prompt import PromptTemplate | |
from langchain.prompts import FewShotPromptTemplate | |
# to create the tools to be used by agent | |
from langchain.agents import Tool | |
import sqlparse | |
# create the agent prompts | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
# Huggingface embeddings using Langchain | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.prompts import HumanMessagePromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
# Load Env parameters | |
from dotenv import load_dotenv | |
from langchain_openai import ChatOpenAI | |
from sqlalchemy import create_engine, text, URL | |
from openai import OpenAI | |
import openai | |
import logging | |
import re | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import streamlit as st | |
from datetime import datetime | |
from datetime import date | |
client = OpenAI() | |
def config(): | |
load_dotenv() # load env parameters | |
llm = ChatOpenAI(temperature=0.5, model="gpt-3.5-turbo") # create LLM | |
#llm = OpenAI(temperature=0.5) # create LLM | |
return llm | |
# Setting up URL parameter to connect to MySQL Database | |
def get_db_chain(question): | |
''' | |
This tool will take the input as user question and query the SQL database to extract the output | |
Use this tool first to query any data from the database to answer user question. | |
you can generate multiple sql queries to extract two distinct information. | |
''' | |
db_user="admin" | |
db_password="Epperson" | |
db_host="retail.cd6uaise2moh.us-west-2.rds.amazonaws.com" | |
db_name="retail" | |
# create LLM | |
llm = ChatOpenAI(temperature=0.5, model="gpt-3.5-turbo") | |
# Initialize SQL DB using Langchain | |
db = SQLDatabase.from_uri(f"mysql://{db_user}:{db_password}@{db_host}/{db_name}") | |
toolkit = SQLDatabaseToolkit(db=db, llm=llm) | |
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2') | |
# create the list with only values and ready to be vectorized | |
to_vectorize = [" ".join(example.values()) for example in example] # use join to convert Dict to list | |
# Setup the Chroma database and vectorize | |
vectorstore = Chroma.from_texts(to_vectorize, embedding=embeddings, metadatas=example) | |
# Based on the user question, convert them to vector and take the similar looking vectors from Chroma DB | |
example_selector = SemanticSimilarityExampleSelector( | |
vectorstore = vectorstore, | |
k=2) | |
example_prompt = PromptTemplate( | |
input_variables=["Question", "SQLQuery", "SQLResult", "Answer",], | |
template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}\n?Answer: {Answer}",) | |
top_k = 30 | |
few_shot_prompt = FewShotPromptTemplate( | |
example_selector=example_selector, # Hey LLM, if you dont know refer the examples giving in vector DB | |
example_prompt=example_prompt, # This is the Prompt template we have created | |
prefix=_mysql_prompt, # This is prefix of the prompt | |
suffix=PROMPT_SUFFIX, # This is suffix of the prompt | |
input_variables=["input", "table_info", "top_k"], # variables used in forming the prompt to LLM | |
) | |
print(example_selector) | |
print(_mysql_prompt) | |
print(PROMPT_SUFFIX) | |
chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt, return_intermediate_steps=True, top_k=30) | |
response = chain.invoke(question) | |
####### Code to extract the SQL Query and the database output and write to a file ####### | |
intermediate_steps = response["intermediate_steps"] | |
sql_query = intermediate_steps[1] | |
sql_query_cleaned = sql_query.replace('\n', " ") | |
sql_query_output = intermediate_steps[3] | |
st.session_state.sql_query_cleaned = sql_query_cleaned | |
st.session_state.sql_query_output = sql_query_output | |
#with open('sql_query.txt', 'w') as file: | |
# file.write(sql_query_cleaned + '\n') | |
# file.write(sql_query_output) | |
#with open('applog.txt', 'w') as file: | |
# file.write(sql_query_cleaned + '\n') | |
# file.write(sql_query_output) | |
####### Code to extract the SQL Query and the database output and write to a file ####### | |
#print(type(intermediate_steps)) | |
#logging.basicConfig(filename='app1.log', level=logging.info) | |
return response | |
# Call the LLM with the question and the fewshotprompt | |
# write_query = create_sql_query_chain(llm=llm,db=db, prompt=few_shot_prompt) | |
#print(write_query) | |
# Execute the Query using QuerySQLDataBaseTool | |
#execute_query = QuerySQLDataBaseTool(db=db) | |
# Chain to combine write SQL and Execute SQL | |
#chain = write_query | execute_query | llm | |
#response = chain.invoke("Question") | |
def get_store_address(store): | |
url_object = URL.create( | |
"mysql", | |
username="root", | |
password="root", # plain (unescaped) text | |
host="localhost", | |
database="retail", | |
) | |
engine = create_engine(url_object) | |
#connect to engine | |
connection = engine.connect() | |
sql_query = "SELECT STORE_NUMBER, STORE_ADDRESS FROM STORES WHERE STORE_NUMBER = " + store | |
df = pd.read_sql(sql_query, con=engine) | |
response = df.to_string() | |
return response | |
def outreach_sms_message(outreach_input): | |
# create LLM | |
llm = ChatOpenAI(temperature=0.5, model="gpt-3.5-turbo", verbose=True) | |
prompt = ChatPromptTemplate.from_template("You are a expert in writing a text message for appointment setup with less than 35 words." | |
"With {outreach_input}, generate a text message for appointment to be sent to customer") | |
output_parser = StrOutputParser() | |
chain = prompt | llm | output_parser | |
response = chain.invoke({"outreach_input": outreach_input}) | |
return response | |
### Function to extract the column names from SQL Query #### | |
def find_selected_columns(query: str) -> list[str]: | |
tokens = sqlparse.parse(query)[0].tokens | |
found_select = False | |
for token in tokens: | |
if found_select: | |
if isinstance(token, sqlparse.sql.IdentifierList): | |
return [col.value.split(" ")[-1].strip("`").rpartition('.')[-1] | |
for col in token.tokens if isinstance(col, sqlparse.sql.Identifier)] | |
else: | |
found_select = token.match(sqlparse.tokens.Keyword.DML, ["select", "SELECT"]) | |
raise Exception("Could not find a select statement. Weird query :)") | |
def get_query_columns(sql): | |
stmt = sqlparse.parse(sql)[0] | |
columns = [] | |
column_identifiers = [] | |
# get column_identifieres | |
in_select = False | |
for token in stmt.tokens: | |
if isinstance(token, sqlparse.sql.Comment): | |
continue | |
if str(token).lower() == 'select': | |
in_select = True | |
elif in_select and token.ttype is None: | |
for identifier in token.get_identifiers(): | |
column_identifiers.append(identifier) | |
break | |
# get column names | |
for column_identifier in column_identifiers: | |
columns.append(column_identifier.get_name()) | |
return columns | |
# Function to write the Python code to generate charts | |
def data_visualization(question): | |
'''This tool will receive the user question as input to generate the chart for data visualization using Pandas and Matplotlib. Use this tool only when the user question is to 'visualize' or 'summarize' the data or info. | |
Parameters: | |
- question: User question to generate the charts | |
- Column Names: Column names for the data frame | |
- Data Frame: Dataframe with the actual data for the chart visualization | |
''' | |
####### Read the file input to get the SQL Query and the SQL Output | |
#with open('sql_query.txt', 'r') as file: | |
# raw_sql_query = file.readline().strip() | |
# raw_sql_output = file.readline().strip() | |
raw_sql_query = st.session_state.sql_query_cleaned | |
raw_sql_output = st.session_state.sql_query_output | |
print("raw_sql_query:", raw_sql_query) | |
print("raw_sql_output:", raw_sql_output) | |
print(type(raw_sql_output)) | |
# Replace "datetime.date" with "date" | |
cleaned_sql_output = re.sub(r"datetime\.date", "date", raw_sql_output) | |
# Use regex to find and replace the decimal value | |
pattern1 = r"Decimal\('(\d+)'\)" | |
pattern2 = r"Decimal\('(\d+\.\d+)'\)" | |
cleaned_sql_output = re.sub(pattern1, r"\1", cleaned_sql_output) | |
cleaned_sql_output = re.sub(pattern2, r"\1", cleaned_sql_output) | |
#data = cleaned_sql_output | |
print("cleaned sql output:", cleaned_sql_output) | |
print(type(cleaned_sql_output)) | |
data = [] | |
data = eval(cleaned_sql_output) | |
print("data", data) | |
print(type(data)) | |
# Create Column names | |
column_names = get_query_columns(raw_sql_query) | |
print("col names:", column_names) | |
len_col = len(column_names) | |
#data = [] | |
#data = [(date(2024, 5, 9), 182492.92, 'Jessica Warren'), | |
# (date(2024, 5, 9), 70727.75, 'Linda White')] | |
## Create a dictionary with column names and corresponding data | |
data_dict = {} | |
for i in range(len_col): | |
data_dict[column_names[i]] = [value[i] for value in data] | |
#print("data_dict:", data_dict) | |
# dataframe creation and formatting | |
df = pd.DataFrame(data_dict) | |
print("df:", df) | |
if 'sale_date' in df: | |
df = df.sort_values(by=['sale_date']) | |
if 'store_number' in df: | |
str_lst = [str(i) for i in df['store_number']] | |
df['store_number'] = str_lst | |
print("df:", df) | |
#print(st.session_state.chat_history) | |
#answer = st.session_state.chat_history[-2:] | |
model_name = 'gpt-3.5-turbo' | |
prompt = f""" | |
The dataset is ALREADY loaded into a DataFrame name 'df' DO NOT load the data again. | |
Dataframe has the following columns: {column_names} | |
DO NOT create Sample or Example Dataframe. | |
Before plotting, ensure the data is ready: | |
1. DO NOT create sample or example dataframe. Dataframe is available in df {df}. | |
2. Check if columns that are supposed to be numeric are recognized as such. If not attempt to convert them. | |
Use Package Pandas and Matplotlib ONLY. | |
Provide SINGLE CODE BLOCK with a solution using Pandas and Matplotlib plots in a single figure to address the following query: | |
{question} | |
- DO NOT EXPLAIN the code | |
- DO NOT COMMENT the code. | |
- DO NOT create Sample or Example DataFrame | |
- ALWAYS WRAP UP THE CODE IN A SINGLE CODE BLOCK. | |
- The code block must start and end with ''' | |
- ALWAYS include Legend and data labels in the chart but skip them if it is overlapping with other charts when creating subplots | |
- Always import matplotlib.pyplot as plt and import pandas as pd | |
- format the x & y axis labels so that it doesn't overlap with each other. IF the label is long, position them as slantic | |
- DO NOT sum the column names in the dataframe unless asked by the user in the question | |
- Display full date in the chart if available in YYYY-MM-DD format | |
- Colors to use for background and axes of the figure : #F0F0F6 | |
- Try to use the following color palette for coloring the plots : #8f63ee #ced5ce #a27bf6 #3d3b41 | |
""" | |
messages = [ | |
{ | |
"role": "system", | |
"content": "You are a helpful Data Visualization assitant for store associates working in a retail store. Use the question, dataframe df and column names to visualize the data easily, you have to give a single block of code without explaining or commenting the code" | |
}, | |
{"role": "user", "content": prompt}, | |
] | |
#with st.status("Generating Python code to create data visualization..."): | |
# with st.chat_message("assistant", avatar="π"): | |
# botmsg = st.empty() | |
# response = [] | |
# for chunk in client.chat.completions.create( | |
# model=model_name, messages=messages, stream=True): | |
# text = chunk.choices[0].get("delta", {}).get("content") | |
# if text: | |
# response.append(text) | |
# result = "".join(response).strip() | |
# botmsg.write(result) | |
result = client.chat.completions.create(model=model_name, messages=messages) | |
#print("generatedcode:", generated_code) | |
raw_code = result.choices[0].message.content | |
print(result.choices[0].message.content) | |
print("raw code", raw_code) | |
code_blocks = re.findall(r"'''(python)?(.*?)'''", raw_code, re.DOTALL) | |
print("code blocks:", code_blocks) | |
code = "\n".join([block[1].strip() for block in code_blocks]) | |
print("code", code) | |
# Assuming 'result' contains the generated content | |
code_to_stream = result.choices[0].message.content | |
def generate_content(): | |
for item in code_to_stream: | |
yield item | |
# Get the generator | |
content_generator = generate_content() | |
print("content generator:", content_generator) | |
print("result:", result) | |
# Stream the content using st.write_stream() | |
with st.status("Generating Python code for data visualization...π"): | |
# with st.chat_message("assistant", avatar="π"): | |
#for line in content_generator: | |
st.write_stream(generate_content()) | |
# time.sleep(0.03) | |
#print("answer:", answer) | |
print("code:", code) | |
#code = data_visualization(question) | |
# If there's code in the response, try to execute it | |
#if code != "": | |
try: | |
print("****************code executed****************************") | |
exec(code) | |
st.pyplot() | |
except Exception as e: | |
error_message = str(e) | |
st.error( | |
f"π Apologies, failed to execute the code due to the error: {error_message}" | |
) | |
st.warning( | |
""" | |
π Check the error message and the code executed above to investigate further. | |
Pro tips: | |
- Tweak your prompts to overcome the error | |
- Use the words 'Plot'/ 'Subplot' | |
- Use simpler, concise words | |
""" | |
) | |
#else: | |
# st.write(raw_code) | |
#exec(code) | |
#st.pyplot() | |
return code | |
### Function to extract the column names from SQL Query #### | |
def find_selected_columns(query: str) -> list[str]: | |
tokens = sqlparse.parse(query)[0].tokens | |
found_select = False | |
for token in tokens: | |
if found_select: | |
if isinstance(token, sqlparse.sql.IdentifierList): | |
return [col.value.split(" ")[-1].strip("`").rpartition('.')[-1] | |
for col in token.tokens if isinstance(col, sqlparse.sql.Identifier)] | |
else: | |
found_select = token.match(sqlparse.tokens.Keyword.DML, ["select", "SELECT"]) | |
raise Exception("Could not find a select statement. Weird query :)") | |