Spaces:
Running
Running
import gradio as gr | |
import openai | |
import sqlite3 | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import os | |
# OpenRouter API Key (Replace with yours) | |
OPENROUTER_API_KEY = "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" | |
OPENROUTER_MODEL = "sophosympatheia/rogue-rose-103b-v0.2:free" | |
# Hugging Face Space path | |
DB_PATH = "ecommerce.db" | |
# Ensure dataset exists | |
if not os.path.exists(DB_PATH): | |
os.system("wget https://your-dataset-link.com/ecommerce.db -O ecommerce.db") # Replace with actual dataset link | |
# Initialize OpenAI client | |
openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1") | |
# Few-shot examples for text-to-SQL | |
few_shot_examples = [ | |
{"input": "Show all customers from São Paulo.", "output": "SELECT * FROM customers WHERE customer_state = 'SP';"}, | |
{"input": "Find the total sales per product.", "output": "SELECT product_id, SUM(price) FROM order_items GROUP BY product_id;"}, | |
{"input": "List all orders placed in 2017.", "output": "SELECT * FROM orders WHERE order_purchase_timestamp LIKE '2017%';"} | |
] | |
# Function: Convert text to SQL | |
def text_to_sql(query): | |
prompt = "Convert the following queries into SQL:\n\n" | |
for example in few_shot_examples: | |
prompt += f"Input: {example['input']}\nOutput: {example['output']}\n\n" | |
prompt += f"Input: {query}\nOutput:" | |
try: | |
response = openai_client.chat.completions.create( | |
model=OPENROUTER_MODEL, | |
messages=[{"role": "system", "content": "You are an SQL expert."}, {"role": "user", "content": prompt}] | |
) | |
sql_query = response.choices[0].message.content.strip() | |
# Ensure only one query is returned (remove extra text) | |
sql_query = sql_query.split("\n")[0].strip() | |
return sql_query | |
except Exception as e: | |
return f"Error: {e}" | |
# Function: Execute SQL on SQLite database | |
def execute_sql(sql_query): | |
try: | |
conn = sqlite3.connect(DB_PATH) | |
df = pd.read_sql_query(sql_query, conn) | |
conn.close() | |
return df | |
except Exception as e: | |
return f"SQL Execution Error: {e}" | |
# Function: Generate Dynamic Visualization | |
def visualize_data(df): | |
if df.empty or df.shape[1] < 2: | |
return None | |
# Detect numeric columns | |
numeric_cols = df.select_dtypes(include=['number']).columns | |
if len(numeric_cols) < 1: | |
return None | |
plt.figure(figsize=(6, 4)) | |
sns.set_theme(style="darkgrid") | |
# Choose visualization type dynamically | |
if len(numeric_cols) == 1: # Single numeric column, assume it's a count metric | |
sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal") | |
plt.title(f"Distribution of {numeric_cols[0]}") | |
elif len(numeric_cols) == 2: # Two numeric columns, assume X-Y plot | |
sns.scatterplot(x=df[numeric_cols[0]], y=df[numeric_cols[1]], color="blue") | |
plt.title(f"{numeric_cols[0]} vs {numeric_cols[1]}") | |
elif df.shape[0] < 10: # If rows are few, prefer pie chart | |
plt.pie(df[numeric_cols[0]], labels=df.iloc[:, 0], autopct='%1.1f%%', colors=sns.color_palette("pastel")) | |
plt.title(f"Proportion of {numeric_cols[0]}") | |
else: # Default: Bar chart for categories + values | |
sns.barplot(x=df.iloc[:, 0], y=df[numeric_cols[0]], palette="coolwarm") | |
plt.xticks(rotation=45) | |
plt.title(f"{df.columns[0]} vs {numeric_cols[0]}") | |
plt.tight_layout() | |
plt.savefig("chart.png") | |
return "chart.png" | |
# Gradio UI | |
def gradio_ui(query): | |
sql_query = text_to_sql(query) | |
results = execute_sql(sql_query) | |
visualization = visualize_data(results) if isinstance(results, pd.DataFrame) else None | |
return sql_query, results.to_string(index=False) if isinstance(results, pd.DataFrame) else results, visualization | |
with gr.Blocks() as demo: | |
gr.Markdown("## SQL Explorer: Text-to-SQL with Real Execution & Visualization") | |
query_input = gr.Textbox(label="Enter your query", placeholder="e.g., Show all products sold in 2018.") | |
submit_btn = gr.Button("Convert & Execute") | |
sql_output = gr.Textbox(label="Generated SQL Query") | |
table_output = gr.Textbox(label="Query Results") | |
chart_output = gr.Image(label="Data Visualization") | |
submit_btn.click(gradio_ui, inputs=[query_input], outputs=[sql_output, table_output, chart_output]) | |
# Launch | |
demo.launch() | |