Spaces:
Sleeping
Sleeping
File size: 8,151 Bytes
e37eda0 3eb59a4 75829f5 bb31796 6a2a63a e37eda0 82bfc51 8328a6d b9a3a14 cd60664 ec5af14 0bb1965 cd60664 0bb1965 82bfc51 5189e45 9e9d1c1 cd60664 5189e45 cd60664 82bfc51 1c40c30 cd60664 82bfc51 bb31796 45afb27 bb31796 45afb27 1c40c30 bb31796 45afb27 bb31796 1c40c30 bb31796 1c40c30 bb31796 1c40c30 bb31796 1d00adc bb31796 1c40c30 bb31796 1c40c30 bb31796 1d00adc bb31796 1c40c30 bb31796 1c40c30 bb31796 1d00adc 1c40c30 0bb1965 82bfc51 1d00adc 82bfc51 1d00adc 82bfc51 bb31796 82bfc51 1d00adc ec5af14 bb31796 ec5af14 1d00adc 1c40c30 1d00adc 1c40c30 6a9b2eb 1d00adc ec5af14 82bfc51 ec5af14 bb31796 e69e246 ec5af14 1d00adc 82bfc51 a3c9c61 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import os
import streamlit as st
import pandas as pd
import sqlite3
from transformers import pipeline
import sqlparse
import logging
# Initialize conversation history
if 'history' not in st.session_state:
st.session_state.history = []
# Load a smaller and faster pre-trained model (distilgpt2) from Hugging Face
llm = pipeline('text-generation', model='distilgpt2') # Using a smaller model for faster inference
# Step 1: Upload CSV data file (or use default)
st.title("Natural Language to SQL Query App with Enhanced Insights")
st.write("Upload a CSV file to get started, or use the default dataset.")
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
if csv_file is None:
data = pd.read_csv("default_data.csv") # Ensure this file exists in your working directory
st.write("Using default_data.csv file.")
table_name = "default_table"
else:
data = pd.read_csv(csv_file)
table_name = csv_file.name.split('.')[0]
st.write(f"Data Preview ({csv_file.name}):")
st.dataframe(data.head())
# Step 2: Load CSV data into a persistent SQLite database
db_file = 'my_database.db'
conn = sqlite3.connect(db_file, check_same_thread=False) # Allow connection across threads
data.to_sql(table_name, conn, index=False, if_exists='replace')
# SQL table metadata (for validation and schema)
valid_columns = list(data.columns)
st.write(f"Valid columns: {valid_columns}")
# Function to generate SQL query using Hugging Face model
def generate_sql_query(question, table_name, columns):
# Simplified and direct prompt to focus on generating valid SQL
prompt = f"""
You are a SQL expert. Generate a SQL query using the columns:
{columns}.
Question: {question}
Respond only with the SQL query.
"""
response = llm(prompt, max_new_tokens=50, truncation=True) # Ensure max tokens are reasonable
return response[0]['generated_text'].strip()
# Function to generate insights using Hugging Face model
def generate_insights(question, result):
prompt = f"""
Based on the user's question and the SQL query result below, generate concise data insights:
{result}
"""
response = llm(prompt, max_new_tokens=100, truncation=True)
return response[0]['generated_text'].strip()
# Function to classify user query as SQL or Insights
def classify_query(question):
prompt = f"""
Classify the following question as 'SQL' or 'INSIGHTS':
"{question}"
"""
response = llm(prompt, max_new_tokens=10, truncation=True)
category = response[0]['generated_text'].strip().upper()
return 'SQL' if 'SQL' in category else 'INSIGHTS'
# Function to generate dataset summary
def generate_dataset_summary(data):
summary_template = f"""
Provide a brief summary of the dataset:
{data.head().to_string(index=False)}
"""
response = llm(summary_template, max_new_tokens=100, truncation=True)
return response[0]['generated_text'].strip()
# Function to validate if the generated SQL query is valid
def is_valid_sql(query):
sql_keywords = ["SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER"]
return any(query.strip().upper().startswith(keyword) for keyword in sql_keywords)
# Define the callback function
def process_input():
user_prompt = st.session_state['user_input']
if user_prompt:
try:
# Append user message to history
st.session_state.history.append({"role": "user", "content": user_prompt})
# Classify the user query
category = classify_query(user_prompt)
logging.info(f"User query classified as: {category}")
if "COLUMNS" in user_prompt.upper():
assistant_response = f"The columns are: {', '.join(valid_columns)}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
elif category == 'SQL':
columns = ', '.join(valid_columns)
generated_sql = generate_sql_query(user_prompt, table_name, columns)
if generated_sql.upper() == "NO_SQL":
# Handle cases where no SQL should be generated
assistant_response = "Sure, let's discuss some general insights and recommendations based on the data."
# Generate dataset summary
dataset_summary = generate_dataset_summary(data)
# Generate general insights and recommendations
general_insights = generate_insights(user_prompt, dataset_summary)
# Append the assistant's insights to the history
st.session_state.history.append({"role": "assistant", "content": general_insights})
else:
# Validate the SQL query
if is_valid_sql(generated_sql):
# Attempt to execute SQL query and handle exceptions
try:
result = pd.read_sql_query(generated_sql, conn)
if result.empty:
assistant_response = "The query returned no results. Please try a different question."
st.session_state.history.append({"role": "assistant", "content": assistant_response})
else:
# Convert the result to a string for the insights prompt
result_str = result.head(10).to_string(index=False) # Limit to first 10 rows
# Generate insights and recommendations based on the query result
insights = generate_insights(user_prompt, result_str)
# Append the assistant's insights to the history
st.session_state.history.append({"role": "assistant", "content": insights})
# Append the result DataFrame to the history
st.session_state.history.append({"role": "assistant", "content": result})
except Exception as e:
logging.error(f"An error occurred during SQL execution: {e}")
assistant_response = f"Error executing SQL query: {e}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
else:
# If generated text is not valid SQL, provide feedback to the user
st.session_state.history.append({"role": "assistant", "content": "Generated text is not a valid SQL query. Please try rephrasing your question."})
else: # INSIGHTS category
# Generate dataset summary
dataset_summary = generate_dataset_summary(data)
# Generate general insights and recommendations
general_insights = generate_insights(user_prompt, dataset_summary)
# Append the assistant's insights to the history
st.session_state.history.append({"role": "assistant", "content": general_insights})
except Exception as e:
logging.error(f"An error occurred: {e}")
assistant_response = f"Error: {e}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
# Reset the user_input in session state
st.session_state['user_input'] = ''
# Display the conversation history
for message in st.session_state.history:
if message['role'] == 'user':
st.markdown(f"**User:** {message['content']}")
elif message['role'] == 'assistant':
if isinstance(message['content'], pd.DataFrame):
st.markdown("**Assistant:** Query Results:")
st.dataframe(message['content'])
else:
st.markdown(f"**Assistant:** {message['content']}")
# Place the input field at the bottom with the callback
st.text_input("Enter your message:", key='user_input', on_change=process_input)
|