File size: 8,151 Bytes
e37eda0
3eb59a4
 
75829f5
bb31796
6a2a63a
e37eda0
82bfc51
 
 
 
 
8328a6d
 
b9a3a14
cd60664
ec5af14
0bb1965
 
cd60664
 
0bb1965
82bfc51
5189e45
9e9d1c1
cd60664
5189e45
cd60664
 
 
82bfc51
 
1c40c30
cd60664
 
 
 
82bfc51
 
bb31796
 
45afb27
bb31796
45afb27
1c40c30
bb31796
45afb27
bb31796
1c40c30
bb31796
 
 
 
 
1c40c30
bb31796
 
1c40c30
bb31796
1d00adc
bb31796
 
 
1c40c30
 
bb31796
1c40c30
bb31796
 
1d00adc
bb31796
 
 
1c40c30
bb31796
 
1c40c30
bb31796
1d00adc
1c40c30
 
 
 
0bb1965
82bfc51
 
 
 
 
 
 
 
 
1d00adc
 
 
 
 
82bfc51
 
1d00adc
82bfc51
bb31796
82bfc51
1d00adc
ec5af14
 
 
 
 
 
 
bb31796
ec5af14
 
 
1d00adc
1c40c30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d00adc
1c40c30
6a9b2eb
 
1d00adc
ec5af14
 
82bfc51
ec5af14
bb31796
e69e246
ec5af14
 
1d00adc
82bfc51
 
 
 
 
 
 
a3c9c61
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os
import streamlit as st
import pandas as pd
import sqlite3
from transformers import pipeline
import sqlparse
import logging

# Initialize conversation history
if 'history' not in st.session_state:
    st.session_state.history = []

# Load a smaller and faster pre-trained model (distilgpt2) from Hugging Face
llm = pipeline('text-generation', model='distilgpt2')  # Using a smaller model for faster inference

# Step 1: Upload CSV data file (or use default)
st.title("Natural Language to SQL Query App with Enhanced Insights")
st.write("Upload a CSV file to get started, or use the default dataset.")

csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
if csv_file is None:
    data = pd.read_csv("default_data.csv")  # Ensure this file exists in your working directory
    st.write("Using default_data.csv file.")
    table_name = "default_table"
else:
    data = pd.read_csv(csv_file)
    table_name = csv_file.name.split('.')[0]
    st.write(f"Data Preview ({csv_file.name}):")
    st.dataframe(data.head())

# Step 2: Load CSV data into a persistent SQLite database
db_file = 'my_database.db'
conn = sqlite3.connect(db_file, check_same_thread=False)  # Allow connection across threads
data.to_sql(table_name, conn, index=False, if_exists='replace')

# SQL table metadata (for validation and schema)
valid_columns = list(data.columns)
st.write(f"Valid columns: {valid_columns}")

# Function to generate SQL query using Hugging Face model
def generate_sql_query(question, table_name, columns):
    # Simplified and direct prompt to focus on generating valid SQL
    prompt = f"""
    You are a SQL expert. Generate a SQL query using the columns:
    {columns}.
    Question: {question}
    Respond only with the SQL query.
    """
    response = llm(prompt, max_new_tokens=50, truncation=True)  # Ensure max tokens are reasonable
    return response[0]['generated_text'].strip()

# Function to generate insights using Hugging Face model
def generate_insights(question, result):
    prompt = f"""
    Based on the user's question and the SQL query result below, generate concise data insights:
    {result}
    """
    response = llm(prompt, max_new_tokens=100, truncation=True)
    return response[0]['generated_text'].strip()

# Function to classify user query as SQL or Insights
def classify_query(question):
    prompt = f"""
    Classify the following question as 'SQL' or 'INSIGHTS':
    "{question}"
    """
    response = llm(prompt, max_new_tokens=10, truncation=True)
    category = response[0]['generated_text'].strip().upper()
    return 'SQL' if 'SQL' in category else 'INSIGHTS'

# Function to generate dataset summary
def generate_dataset_summary(data):
    summary_template = f"""
    Provide a brief summary of the dataset:
    {data.head().to_string(index=False)}
    """
    response = llm(summary_template, max_new_tokens=100, truncation=True)
    return response[0]['generated_text'].strip()

# Function to validate if the generated SQL query is valid
def is_valid_sql(query):
    sql_keywords = ["SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER"]
    return any(query.strip().upper().startswith(keyword) for keyword in sql_keywords)

# Define the callback function
def process_input():
    user_prompt = st.session_state['user_input']

    if user_prompt:
        try:
            # Append user message to history
            st.session_state.history.append({"role": "user", "content": user_prompt})

            # Classify the user query
            category = classify_query(user_prompt)
            logging.info(f"User query classified as: {category}")

            if "COLUMNS" in user_prompt.upper():
                assistant_response = f"The columns are: {', '.join(valid_columns)}"
                st.session_state.history.append({"role": "assistant", "content": assistant_response})
            elif category == 'SQL':
                columns = ', '.join(valid_columns)
                generated_sql = generate_sql_query(user_prompt, table_name, columns)

                if generated_sql.upper() == "NO_SQL":
                    # Handle cases where no SQL should be generated
                    assistant_response = "Sure, let's discuss some general insights and recommendations based on the data."
                    
                    # Generate dataset summary
                    dataset_summary = generate_dataset_summary(data)
                    
                    # Generate general insights and recommendations
                    general_insights = generate_insights(user_prompt, dataset_summary)
                    
                    # Append the assistant's insights to the history
                    st.session_state.history.append({"role": "assistant", "content": general_insights})
                else:
                    # Validate the SQL query
                    if is_valid_sql(generated_sql):
                        # Attempt to execute SQL query and handle exceptions
                        try:
                            result = pd.read_sql_query(generated_sql, conn)

                            if result.empty:
                                assistant_response = "The query returned no results. Please try a different question."
                                st.session_state.history.append({"role": "assistant", "content": assistant_response})
                            else:
                                # Convert the result to a string for the insights prompt
                                result_str = result.head(10).to_string(index=False)  # Limit to first 10 rows

                                # Generate insights and recommendations based on the query result
                                insights = generate_insights(user_prompt, result_str)

                                # Append the assistant's insights to the history
                                st.session_state.history.append({"role": "assistant", "content": insights})
                                # Append the result DataFrame to the history
                                st.session_state.history.append({"role": "assistant", "content": result})
                        except Exception as e:
                            logging.error(f"An error occurred during SQL execution: {e}")
                            assistant_response = f"Error executing SQL query: {e}"
                            st.session_state.history.append({"role": "assistant", "content": assistant_response})
                    else:
                        # If generated text is not valid SQL, provide feedback to the user
                        st.session_state.history.append({"role": "assistant", "content": "Generated text is not a valid SQL query. Please try rephrasing your question."})
            else:  # INSIGHTS category
                # Generate dataset summary
                dataset_summary = generate_dataset_summary(data)

                # Generate general insights and recommendations
                general_insights = generate_insights(user_prompt, dataset_summary)

                # Append the assistant's insights to the history
                st.session_state.history.append({"role": "assistant", "content": general_insights})
        
        except Exception as e:
            logging.error(f"An error occurred: {e}")
            assistant_response = f"Error: {e}"
            st.session_state.history.append({"role": "assistant", "content": assistant_response})

        # Reset the user_input in session state
        st.session_state['user_input'] = ''

# Display the conversation history
for message in st.session_state.history:
    if message['role'] == 'user':
        st.markdown(f"**User:** {message['content']}")
    elif message['role'] == 'assistant':
        if isinstance(message['content'], pd.DataFrame):
            st.markdown("**Assistant:** Query Results:")
            st.dataframe(message['content'])
        else:
            st.markdown(f"**Assistant:** {message['content']}")

# Place the input field at the bottom with the callback
st.text_input("Enter your message:", key='user_input', on_change=process_input)