Ari
Update app.py
82bfc51 verified
raw
history blame
8.16 kB
import os
import streamlit as st
import pandas as pd
import sqlite3
import numpy as np # For numerical operations
from langchain import OpenAI, LLMChain, PromptTemplate
import sqlparse
import logging
from sklearn.linear_model import LinearRegression # For machine learning tasks
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
# Initialize conversation history
if 'history' not in st.session_state:
st.session_state.history = []
# Set up logging
logging.basicConfig(level=logging.ERROR)
# OpenAI API key (ensure it is securely stored)
openai_api_key = os.getenv("OPENAI_API_KEY")
# Set OpenAI API key for langchain
from langchain.llms import OpenAI as LangchainOpenAI
LangchainOpenAI.api_key = openai_api_key
# Step 1: Upload CSV data file (or use default)
st.title("Data Science Chatbot")
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
if csv_file is None:
data = pd.read_csv("default_data.csv") # Use default CSV if no file is uploaded
st.write("Using default_data.csv file.")
else:
data = pd.read_csv(csv_file)
st.write(f"Data Preview ({csv_file.name}):")
st.dataframe(data.head())
# Step 2: Load CSV data into a persistent SQLite database
db_file = 'my_database.db'
conn = sqlite3.connect(db_file)
table_name = csv_file.name.split('.')[0] if csv_file else "default_table"
data.to_sql(table_name, conn, index=False, if_exists='replace')
# SQL table metadata (for validation and schema)
valid_columns = list(data.columns)
st.write(f"Valid columns: {valid_columns}")
# Step 3: Define helper functions
def extract_code(response):
"""Extracts code enclosed between <CODE> and </CODE> tags."""
import re
pattern = r"<CODE>(.*?)</CODE>"
match = re.search(pattern, response, re.DOTALL)
if match:
return match.group(1).strip()
else:
return None
# Step 4: Set up the LLM Chain to generate SQL queries or Python code
template = """
You are an expert data scientist assistant. Given a natural language question, the name of the table, and a list of valid columns, decide whether to generate a SQL query to retrieve data, perform statistical analysis, or build a simple machine learning model.
Instructions:
- If the question involves data retrieval or simple aggregations, generate a SQL query.
- If the question requires statistical analysis, generate a Python code snippet using pandas and numpy.
- If the question involves predictions or modeling, generate a Python code snippet using scikit-learn.
- Ensure that you only use the columns provided.
- Do not include any import statements in the code.
- For case-insensitive string comparisons in SQL, use either 'LOWER(column) = LOWER(value)' or 'column = value COLLATE NOCASE', but do not use both together.
- Provide the code between <CODE> and </CODE> tags.
Question: {question}
Table name: {table_name}
Valid columns: {columns}
Response:
"""
prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
llm = LangchainOpenAI(temperature=0)
sql_generation_chain = LLMChain(llm=llm, prompt=prompt)
# Define the callback function
def process_input():
user_prompt = st.session_state['user_input']
if user_prompt:
try:
# Append user message to history
st.session_state.history.append({"role": "user", "content": user_prompt})
if "columns" in user_prompt.lower():
assistant_response = f"The columns are: {', '.join(valid_columns)}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
else:
columns = ', '.join(valid_columns)
response = sql_generation_chain.run({
'question': user_prompt,
'table_name': table_name,
'columns': columns
})
# Extract code from response
code = extract_code(response)
if code:
# Determine if the code is SQL or Python
if code.strip().lower().startswith('select'):
# It's a SQL query
st.write(f"Generated SQL Query:\n{code}")
try:
# Execute the SQL query
result = pd.read_sql_query(code, conn)
assistant_response = f"Generated SQL Query:\n{code}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
st.session_state.history.append({"role": "assistant", "content": result})
except Exception as e:
logging.error(f"An error occurred during SQL execution: {e}")
assistant_response = f"Error executing SQL query: {e}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
else:
# It's Python code
st.write(f"Generated Python Code:\n{code}")
try:
# Prepare the local namespace
local_vars = {
'pd': pd,
'np': np,
'data': data.copy(),
'result': None,
'LinearRegression': LinearRegression,
'train_test_split': train_test_split,
'mean_squared_error': mean_squared_error,
'r2_score': r2_score
}
exec(code, {}, local_vars)
result = local_vars.get('result')
if result is not None:
assistant_response = "Result:"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
st.session_state.history.append({"role": "assistant", "content": result})
else:
assistant_response = "Code executed successfully."
st.session_state.history.append({"role": "assistant", "content": assistant_response})
except Exception as e:
logging.error(f"An error occurred during code execution: {e}")
assistant_response = f"Error executing code: {e}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
else:
assistant_response = response.strip()
st.session_state.history.append({"role": "assistant", "content": assistant_response})
except Exception as e:
logging.error(f"An error occurred: {e}")
assistant_response = f"Error: {e}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
# Reset the user_input in session state
st.session_state['user_input'] = ''
# Display the conversation history
for message in st.session_state.history:
if message['role'] == 'user':
st.markdown(f"**User:** {message['content']}")
elif message['role'] == 'assistant':
content = message['content']
if isinstance(content, pd.DataFrame):
st.markdown("**Assistant:** Here are the results:")
st.dataframe(content)
elif isinstance(content, (int, float)):
st.markdown(f"**Assistant:** {content}")
elif isinstance(content, dict):
st.markdown("**Assistant:** Here are the results:")
st.json(content)
else:
st.markdown(f"**Assistant:** {content}")
# Place the input field at the bottom with the callback
st.text_input("Enter your message:", key='user_input', on_change=process_input)