TextToSQL / app.py
thechaiexperiment's picture
Update app.py
08e4afd verified
raw
history blame
4.41 kB
import gradio as gr
import openai
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
# OpenRouter API Key (Replace with yours)
OPENROUTER_API_KEY = "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"
OPENROUTER_MODEL = "sophosympatheia/rogue-rose-103b-v0.2:free"
# Hugging Face Space path
DB_PATH = "ecommerce.db"
# Ensure dataset exists
if not os.path.exists(DB_PATH):
os.system("wget https://your-dataset-link.com/ecommerce.db -O ecommerce.db") # Replace with actual dataset link
# Initialize OpenAI client
openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1")
# Few-shot examples for text-to-SQL
few_shot_examples = [
{"input": "Show all customers from São Paulo.", "output": "SELECT * FROM customers WHERE customer_state = 'SP';"},
{"input": "Find the total sales per product.", "output": "SELECT product_id, SUM(price) FROM order_items GROUP BY product_id;"},
{"input": "List all orders placed in 2017.", "output": "SELECT * FROM orders WHERE order_purchase_timestamp LIKE '2017%';"}
]
# Function: Convert text to SQL
def text_to_sql(query):
prompt = "Convert the following queries into SQL:\n\n"
for example in few_shot_examples:
prompt += f"Input: {example['input']}\nOutput: {example['output']}\n\n"
prompt += f"Input: {query}\nOutput:"
try:
response = openai_client.chat.completions.create(
model=OPENROUTER_MODEL,
messages=[{"role": "system", "content": "You are an SQL expert."}, {"role": "user", "content": prompt}]
)
sql_query = response.choices[0].message.content.strip()
# Ensure only one query is returned (remove extra text)
sql_query = sql_query.split("\n")[0].strip()
return sql_query
except Exception as e:
return f"Error: {e}"
# Function: Execute SQL on SQLite database
def execute_sql(sql_query):
try:
conn = sqlite3.connect(DB_PATH)
df = pd.read_sql_query(sql_query, conn)
conn.close()
return df
except Exception as e:
return f"SQL Execution Error: {e}"
# Function: Generate Dynamic Visualization
def visualize_data(df):
if df.empty or df.shape[1] < 2:
return None
# Detect numeric columns
numeric_cols = df.select_dtypes(include=['number']).columns
if len(numeric_cols) < 1:
return None
plt.figure(figsize=(6, 4))
sns.set_theme(style="darkgrid")
# Choose visualization type dynamically
if len(numeric_cols) == 1: # Single numeric column, assume it's a count metric
sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal")
plt.title(f"Distribution of {numeric_cols[0]}")
elif len(numeric_cols) == 2: # Two numeric columns, assume X-Y plot
sns.scatterplot(x=df[numeric_cols[0]], y=df[numeric_cols[1]], color="blue")
plt.title(f"{numeric_cols[0]} vs {numeric_cols[1]}")
elif df.shape[0] < 10: # If rows are few, prefer pie chart
plt.pie(df[numeric_cols[0]], labels=df.iloc[:, 0], autopct='%1.1f%%', colors=sns.color_palette("pastel"))
plt.title(f"Proportion of {numeric_cols[0]}")
else: # Default: Bar chart for categories + values
sns.barplot(x=df.iloc[:, 0], y=df[numeric_cols[0]], palette="coolwarm")
plt.xticks(rotation=45)
plt.title(f"{df.columns[0]} vs {numeric_cols[0]}")
plt.tight_layout()
plt.savefig("chart.png")
return "chart.png"
# Gradio UI
def gradio_ui(query):
sql_query = text_to_sql(query)
results = execute_sql(sql_query)
visualization = visualize_data(results) if isinstance(results, pd.DataFrame) else None
return sql_query, results.to_string(index=False) if isinstance(results, pd.DataFrame) else results, visualization
with gr.Blocks() as demo:
gr.Markdown("## SQL Explorer: Text-to-SQL with Real Execution & Visualization")
query_input = gr.Textbox(label="Enter your query", placeholder="e.g., Show all products sold in 2018.")
submit_btn = gr.Button("Convert & Execute")
sql_output = gr.Textbox(label="Generated SQL Query")
table_output = gr.Textbox(label="Query Results")
chart_output = gr.Image(label="Data Visualization")
submit_btn.click(gradio_ui, inputs=[query_input], outputs=[sql_output, table_output, chart_output])
# Launch
demo.launch()