Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| import pandas as pd | |
| import sqlite3 | |
| from transformers import pipeline | |
| import sqlparse | |
| import logging | |
| # Initialize conversation history | |
| if 'history' not in st.session_state: | |
| st.session_state.history = [] | |
| # Load a smaller and faster pre-trained model (distilgpt2) from Hugging Face | |
| llm = pipeline('text-generation', model='distilgpt2') # Using a smaller model for faster inference | |
| # Step 1: Upload CSV data file (or use default) | |
| st.title("Natural Language to SQL Query App with Enhanced Insights") | |
| st.write("Upload a CSV file to get started, or use the default dataset.") | |
| csv_file = st.file_uploader("Upload your CSV file", type=["csv"]) | |
| if csv_file is None: | |
| data = pd.read_csv("default_data.csv") # Ensure this file exists in your working directory | |
| st.write("Using default_data.csv file.") | |
| table_name = "default_table" | |
| else: | |
| data = pd.read_csv(csv_file) | |
| table_name = csv_file.name.split('.')[0] | |
| st.write(f"Data Preview ({csv_file.name}):") | |
| st.dataframe(data.head()) | |
| # Step 2: Load CSV data into a persistent SQLite database | |
| db_file = 'my_database.db' | |
| conn = sqlite3.connect(db_file, check_same_thread=False) # Allow connection across threads | |
| data.to_sql(table_name, conn, index=False, if_exists='replace') | |
| # SQL table metadata (for validation and schema) | |
| valid_columns = list(data.columns) | |
| st.write(f"Valid columns: {valid_columns}") | |
| # Function to generate SQL query using Hugging Face model | |
| def generate_sql_query(question, table_name, columns): | |
| # Simplified and direct prompt to focus on generating valid SQL | |
| prompt = f""" | |
| You are a SQL expert. Generate a SQL query using the columns: | |
| {columns}. | |
| Question: {question} | |
| Respond only with the SQL query. | |
| """ | |
| response = llm(prompt, max_new_tokens=50, truncation=True) # Ensure max tokens are reasonable | |
| return response[0]['generated_text'].strip() | |
| # Function to generate insights using Hugging Face model | |
| def generate_insights(question, result): | |
| prompt = f""" | |
| Based on the user's question and the SQL query result below, generate concise data insights: | |
| {result} | |
| """ | |
| response = llm(prompt, max_new_tokens=100, truncation=True) | |
| return response[0]['generated_text'].strip() | |
| # Function to classify user query as SQL or Insights | |
| def classify_query(question): | |
| prompt = f""" | |
| Classify the following question as 'SQL' or 'INSIGHTS': | |
| "{question}" | |
| """ | |
| response = llm(prompt, max_new_tokens=10, truncation=True) | |
| category = response[0]['generated_text'].strip().upper() | |
| return 'SQL' if 'SQL' in category else 'INSIGHTS' | |
| # Function to generate dataset summary | |
| def generate_dataset_summary(data): | |
| summary_template = f""" | |
| Provide a brief summary of the dataset: | |
| {data.head().to_string(index=False)} | |
| """ | |
| response = llm(summary_template, max_new_tokens=100, truncation=True) | |
| return response[0]['generated_text'].strip() | |
| # Function to validate if the generated SQL query is valid | |
| def is_valid_sql(query): | |
| sql_keywords = ["SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER"] | |
| return any(query.strip().upper().startswith(keyword) for keyword in sql_keywords) | |
| # Define the callback function | |
| def process_input(): | |
| user_prompt = st.session_state['user_input'] | |
| if user_prompt: | |
| try: | |
| # Append user message to history | |
| st.session_state.history.append({"role": "user", "content": user_prompt}) | |
| # Classify the user query | |
| category = classify_query(user_prompt) | |
| logging.info(f"User query classified as: {category}") | |
| if "COLUMNS" in user_prompt.upper(): | |
| assistant_response = f"The columns are: {', '.join(valid_columns)}" | |
| st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
| elif category == 'SQL': | |
| columns = ', '.join(valid_columns) | |
| generated_sql = generate_sql_query(user_prompt, table_name, columns) | |
| if generated_sql.upper() == "NO_SQL": | |
| # Handle cases where no SQL should be generated | |
| assistant_response = "Sure, let's discuss some general insights and recommendations based on the data." | |
| # Generate dataset summary | |
| dataset_summary = generate_dataset_summary(data) | |
| # Generate general insights and recommendations | |
| general_insights = generate_insights(user_prompt, dataset_summary) | |
| # Append the assistant's insights to the history | |
| st.session_state.history.append({"role": "assistant", "content": general_insights}) | |
| else: | |
| # Validate the SQL query | |
| if is_valid_sql(generated_sql): | |
| # Attempt to execute SQL query and handle exceptions | |
| try: | |
| result = pd.read_sql_query(generated_sql, conn) | |
| if result.empty: | |
| assistant_response = "The query returned no results. Please try a different question." | |
| st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
| else: | |
| # Convert the result to a string for the insights prompt | |
| result_str = result.head(10).to_string(index=False) # Limit to first 10 rows | |
| # Generate insights and recommendations based on the query result | |
| insights = generate_insights(user_prompt, result_str) | |
| # Append the assistant's insights to the history | |
| st.session_state.history.append({"role": "assistant", "content": insights}) | |
| # Append the result DataFrame to the history | |
| st.session_state.history.append({"role": "assistant", "content": result}) | |
| except Exception as e: | |
| logging.error(f"An error occurred during SQL execution: {e}") | |
| assistant_response = f"Error executing SQL query: {e}" | |
| st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
| else: | |
| # If generated text is not valid SQL, provide feedback to the user | |
| st.session_state.history.append({"role": "assistant", "content": "Generated text is not a valid SQL query. Please try rephrasing your question."}) | |
| else: # INSIGHTS category | |
| # Generate dataset summary | |
| dataset_summary = generate_dataset_summary(data) | |
| # Generate general insights and recommendations | |
| general_insights = generate_insights(user_prompt, dataset_summary) | |
| # Append the assistant's insights to the history | |
| st.session_state.history.append({"role": "assistant", "content": general_insights}) | |
| except Exception as e: | |
| logging.error(f"An error occurred: {e}") | |
| assistant_response = f"Error: {e}" | |
| st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
| # Reset the user_input in session state | |
| st.session_state['user_input'] = '' | |
| # Display the conversation history | |
| for message in st.session_state.history: | |
| if message['role'] == 'user': | |
| st.markdown(f"**User:** {message['content']}") | |
| elif message['role'] == 'assistant': | |
| if isinstance(message['content'], pd.DataFrame): | |
| st.markdown("**Assistant:** Query Results:") | |
| st.dataframe(message['content']) | |
| else: | |
| st.markdown(f"**Assistant:** {message['content']}") | |
| # Place the input field at the bottom with the callback | |
| st.text_input("Enter your message:", key='user_input', on_change=process_input) | |