Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import pandas as pd | |
import sqlite3 | |
import numpy as np # For numerical operations | |
from langchain import OpenAI, LLMChain, PromptTemplate | |
import sqlparse | |
import logging | |
from sklearn.linear_model import LinearRegression # For machine learning tasks | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import mean_squared_error, r2_score | |
# Initialize conversation history | |
if 'history' not in st.session_state: | |
st.session_state.history = [] | |
# Set up logging | |
logging.basicConfig(level=logging.ERROR) | |
# OpenAI API key (ensure it is securely stored) | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
# Set OpenAI API key for langchain | |
from langchain.llms import OpenAI as LangchainOpenAI | |
LangchainOpenAI.api_key = openai_api_key | |
# Step 1: Upload CSV data file (or use default) | |
st.title("Data Science Chatbot") | |
csv_file = st.file_uploader("Upload your CSV file", type=["csv"]) | |
if csv_file is None: | |
data = pd.read_csv("default_data.csv") # Use default CSV if no file is uploaded | |
st.write("Using default_data.csv file.") | |
else: | |
data = pd.read_csv(csv_file) | |
st.write(f"Data Preview ({csv_file.name}):") | |
st.dataframe(data.head()) | |
# Step 2: Load CSV data into a persistent SQLite database | |
db_file = 'my_database.db' | |
conn = sqlite3.connect(db_file) | |
table_name = csv_file.name.split('.')[0] if csv_file else "default_table" | |
data.to_sql(table_name, conn, index=False, if_exists='replace') | |
# SQL table metadata (for validation and schema) | |
valid_columns = list(data.columns) | |
st.write(f"Valid columns: {valid_columns}") | |
# Step 3: Define helper functions | |
def extract_code(response): | |
"""Extracts code enclosed between <CODE> and </CODE> tags.""" | |
import re | |
pattern = r"<CODE>(.*?)</CODE>" | |
match = re.search(pattern, response, re.DOTALL) | |
if match: | |
return match.group(1).strip() | |
else: | |
return None | |
# Step 4: Set up the LLM Chain to generate SQL queries or Python code | |
template = """ | |
You are an expert data scientist assistant. Given a natural language question, the name of the table, and a list of valid columns, decide whether to generate a SQL query to retrieve data, perform statistical analysis, or build a simple machine learning model. | |
Instructions: | |
- If the question involves data retrieval or simple aggregations, generate a SQL query. | |
- If the question requires statistical analysis, generate a Python code snippet using pandas and numpy. | |
- If the question involves predictions or modeling, generate a Python code snippet using scikit-learn. | |
- Ensure that you only use the columns provided. | |
- Do not include any import statements in the code. | |
- For case-insensitive string comparisons in SQL, use either 'LOWER(column) = LOWER(value)' or 'column = value COLLATE NOCASE', but do not use both together. | |
- Provide the code between <CODE> and </CODE> tags. | |
Question: {question} | |
Table name: {table_name} | |
Valid columns: {columns} | |
Response: | |
""" | |
prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns']) | |
llm = LangchainOpenAI(temperature=0) | |
sql_generation_chain = LLMChain(llm=llm, prompt=prompt) | |
# Define the callback function | |
def process_input(): | |
user_prompt = st.session_state['user_input'] | |
if user_prompt: | |
try: | |
# Append user message to history | |
st.session_state.history.append({"role": "user", "content": user_prompt}) | |
if "columns" in user_prompt.lower(): | |
assistant_response = f"The columns are: {', '.join(valid_columns)}" | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
else: | |
columns = ', '.join(valid_columns) | |
response = sql_generation_chain.run({ | |
'question': user_prompt, | |
'table_name': table_name, | |
'columns': columns | |
}) | |
# Extract code from response | |
code = extract_code(response) | |
if code: | |
# Determine if the code is SQL or Python | |
if code.strip().lower().startswith('select'): | |
# It's a SQL query | |
st.write(f"Generated SQL Query:\n{code}") | |
try: | |
# Execute the SQL query | |
result = pd.read_sql_query(code, conn) | |
assistant_response = f"Generated SQL Query:\n{code}" | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
st.session_state.history.append({"role": "assistant", "content": result}) | |
except Exception as e: | |
logging.error(f"An error occurred during SQL execution: {e}") | |
assistant_response = f"Error executing SQL query: {e}" | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
else: | |
# It's Python code | |
st.write(f"Generated Python Code:\n{code}") | |
try: | |
# Prepare the local namespace | |
local_vars = { | |
'pd': pd, | |
'np': np, | |
'data': data.copy(), | |
'result': None, | |
'LinearRegression': LinearRegression, | |
'train_test_split': train_test_split, | |
'mean_squared_error': mean_squared_error, | |
'r2_score': r2_score | |
} | |
exec(code, {}, local_vars) | |
result = local_vars.get('result') | |
if result is not None: | |
assistant_response = "Result:" | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
st.session_state.history.append({"role": "assistant", "content": result}) | |
else: | |
assistant_response = "Code executed successfully." | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
except Exception as e: | |
logging.error(f"An error occurred during code execution: {e}") | |
assistant_response = f"Error executing code: {e}" | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
else: | |
assistant_response = response.strip() | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
except Exception as e: | |
logging.error(f"An error occurred: {e}") | |
assistant_response = f"Error: {e}" | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
# Reset the user_input in session state | |
st.session_state['user_input'] = '' | |
# Display the conversation history | |
for message in st.session_state.history: | |
if message['role'] == 'user': | |
st.markdown(f"**User:** {message['content']}") | |
elif message['role'] == 'assistant': | |
content = message['content'] | |
if isinstance(content, pd.DataFrame): | |
st.markdown("**Assistant:** Here are the results:") | |
st.dataframe(content) | |
elif isinstance(content, (int, float)): | |
st.markdown(f"**Assistant:** {content}") | |
elif isinstance(content, dict): | |
st.markdown("**Assistant:** Here are the results:") | |
st.json(content) | |
else: | |
st.markdown(f"**Assistant:** {content}") | |
# Place the input field at the bottom with the callback | |
st.text_input("Enter your message:", key='user_input', on_change=process_input) | |