thechaiexperiment commited on
Commit
08e4afd
·
verified ·
1 Parent(s): 4aa996b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -31
app.py CHANGED
@@ -10,33 +10,26 @@ import os
10
  OPENROUTER_API_KEY = "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"
11
  OPENROUTER_MODEL = "sophosympatheia/rogue-rose-103b-v0.2:free"
12
 
13
-
14
- # Database Path
15
- db_path = "ecommerce.db"
16
 
17
  # Ensure dataset exists
18
- if not os.path.exists(db_path):
19
- print("Database file not found! Please upload ecommerce.db.")
20
 
21
  # Initialize OpenAI client
22
  openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1")
23
 
24
- # Updated Few-Shot Examples with SQLite-Compatible Queries
25
  few_shot_examples = [
26
- {"input": "Find the busiest months for orders.",
27
- "output": "SELECT strftime('%m', order_purchase_timestamp) AS month, COUNT(*) AS order_count FROM orders GROUP BY month ORDER BY order_count DESC;"},
28
- {"input": "Show all customers from São Paulo.",
29
- "output": "SELECT * FROM customers WHERE customer_state = 'SP';"},
30
- {"input": "Find the total sales per product.",
31
- "output": "SELECT product_id, SUM(price) FROM order_items GROUP BY product_id;"},
32
- {"input": "List all orders placed in 2017.",
33
- "output": "SELECT * FROM orders WHERE order_purchase_timestamp LIKE '2017%';"}
34
  ]
35
 
36
- # Function: Convert Text to SQL
37
-
38
  def text_to_sql(query):
39
- prompt = "Convert the following queries into SQLite-compatible SQL:\n\n"
40
  for example in few_shot_examples:
41
  prompt += f"Input: {example['input']}\nOutput: {example['output']}\n\n"
42
  prompt += f"Input: {query}\nOutput:"
@@ -44,31 +37,32 @@ def text_to_sql(query):
44
  try:
45
  response = openai_client.chat.completions.create(
46
  model=OPENROUTER_MODEL,
47
- messages=[{"role": "system", "content": "You are an SQLite expert."},
48
- {"role": "user", "content": prompt}]
49
  )
50
  sql_query = response.choices[0].message.content.strip()
51
- return sql_query if sql_query.lower().startswith("select") else f"Error: Invalid SQL generated - {sql_query}"
 
 
 
52
  except Exception as e:
53
  return f"Error: {e}"
54
 
55
- # Function: Execute SQL on SQLite Database
56
-
57
  def execute_sql(sql_query):
58
  try:
59
- conn = sqlite3.connect(db_path)
60
  df = pd.read_sql_query(sql_query, conn)
61
  conn.close()
62
  return df
63
  except Exception as e:
64
  return f"SQL Execution Error: {e}"
65
 
66
- # Function: Generate Data Visualization
67
-
68
  def visualize_data(df):
69
  if df.empty or df.shape[1] < 2:
70
  return None
71
 
 
72
  numeric_cols = df.select_dtypes(include=['number']).columns
73
  if len(numeric_cols) < 1:
74
  return None
@@ -76,16 +70,17 @@ def visualize_data(df):
76
  plt.figure(figsize=(6, 4))
77
  sns.set_theme(style="darkgrid")
78
 
79
- if len(numeric_cols) == 1:
 
80
  sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal")
81
  plt.title(f"Distribution of {numeric_cols[0]}")
82
- elif len(numeric_cols) == 2:
83
  sns.scatterplot(x=df[numeric_cols[0]], y=df[numeric_cols[1]], color="blue")
84
  plt.title(f"{numeric_cols[0]} vs {numeric_cols[1]}")
85
- elif df.shape[0] < 10:
86
  plt.pie(df[numeric_cols[0]], labels=df.iloc[:, 0], autopct='%1.1f%%', colors=sns.color_palette("pastel"))
87
  plt.title(f"Proportion of {numeric_cols[0]}")
88
- else:
89
  sns.barplot(x=df.iloc[:, 0], y=df[numeric_cols[0]], palette="coolwarm")
90
  plt.xticks(rotation=45)
91
  plt.title(f"{df.columns[0]} vs {numeric_cols[0]}")
@@ -103,8 +98,8 @@ def gradio_ui(query):
103
  return sql_query, results.to_string(index=False) if isinstance(results, pd.DataFrame) else results, visualization
104
 
105
  with gr.Blocks() as demo:
106
- gr.Markdown("## SQL Explorer: Text to SQL with a Simple Visualization")
107
- query_input = gr.Textbox(label="Enter your query", placeholder="Enter your query in English.")
108
  submit_btn = gr.Button("Convert & Execute")
109
  sql_output = gr.Textbox(label="Generated SQL Query")
110
  table_output = gr.Textbox(label="Query Results")
 
10
  OPENROUTER_API_KEY = "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"
11
  OPENROUTER_MODEL = "sophosympatheia/rogue-rose-103b-v0.2:free"
12
 
13
+ # Hugging Face Space path
14
+ DB_PATH = "ecommerce.db"
 
15
 
16
  # Ensure dataset exists
17
+ if not os.path.exists(DB_PATH):
18
+ os.system("wget https://your-dataset-link.com/ecommerce.db -O ecommerce.db") # Replace with actual dataset link
19
 
20
  # Initialize OpenAI client
21
  openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1")
22
 
23
+ # Few-shot examples for text-to-SQL
24
  few_shot_examples = [
25
+ {"input": "Show all customers from São Paulo.", "output": "SELECT * FROM customers WHERE customer_state = 'SP';"},
26
+ {"input": "Find the total sales per product.", "output": "SELECT product_id, SUM(price) FROM order_items GROUP BY product_id;"},
27
+ {"input": "List all orders placed in 2017.", "output": "SELECT * FROM orders WHERE order_purchase_timestamp LIKE '2017%';"}
 
 
 
 
 
28
  ]
29
 
30
+ # Function: Convert text to SQL
 
31
  def text_to_sql(query):
32
+ prompt = "Convert the following queries into SQL:\n\n"
33
  for example in few_shot_examples:
34
  prompt += f"Input: {example['input']}\nOutput: {example['output']}\n\n"
35
  prompt += f"Input: {query}\nOutput:"
 
37
  try:
38
  response = openai_client.chat.completions.create(
39
  model=OPENROUTER_MODEL,
40
+ messages=[{"role": "system", "content": "You are an SQL expert."}, {"role": "user", "content": prompt}]
 
41
  )
42
  sql_query = response.choices[0].message.content.strip()
43
+
44
+ # Ensure only one query is returned (remove extra text)
45
+ sql_query = sql_query.split("\n")[0].strip()
46
+ return sql_query
47
  except Exception as e:
48
  return f"Error: {e}"
49
 
50
+ # Function: Execute SQL on SQLite database
 
51
  def execute_sql(sql_query):
52
  try:
53
+ conn = sqlite3.connect(DB_PATH)
54
  df = pd.read_sql_query(sql_query, conn)
55
  conn.close()
56
  return df
57
  except Exception as e:
58
  return f"SQL Execution Error: {e}"
59
 
60
+ # Function: Generate Dynamic Visualization
 
61
  def visualize_data(df):
62
  if df.empty or df.shape[1] < 2:
63
  return None
64
 
65
+ # Detect numeric columns
66
  numeric_cols = df.select_dtypes(include=['number']).columns
67
  if len(numeric_cols) < 1:
68
  return None
 
70
  plt.figure(figsize=(6, 4))
71
  sns.set_theme(style="darkgrid")
72
 
73
+ # Choose visualization type dynamically
74
+ if len(numeric_cols) == 1: # Single numeric column, assume it's a count metric
75
  sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal")
76
  plt.title(f"Distribution of {numeric_cols[0]}")
77
+ elif len(numeric_cols) == 2: # Two numeric columns, assume X-Y plot
78
  sns.scatterplot(x=df[numeric_cols[0]], y=df[numeric_cols[1]], color="blue")
79
  plt.title(f"{numeric_cols[0]} vs {numeric_cols[1]}")
80
+ elif df.shape[0] < 10: # If rows are few, prefer pie chart
81
  plt.pie(df[numeric_cols[0]], labels=df.iloc[:, 0], autopct='%1.1f%%', colors=sns.color_palette("pastel"))
82
  plt.title(f"Proportion of {numeric_cols[0]}")
83
+ else: # Default: Bar chart for categories + values
84
  sns.barplot(x=df.iloc[:, 0], y=df[numeric_cols[0]], palette="coolwarm")
85
  plt.xticks(rotation=45)
86
  plt.title(f"{df.columns[0]} vs {numeric_cols[0]}")
 
98
  return sql_query, results.to_string(index=False) if isinstance(results, pd.DataFrame) else results, visualization
99
 
100
  with gr.Blocks() as demo:
101
+ gr.Markdown("## SQL Explorer: Text-to-SQL with Real Execution & Visualization")
102
+ query_input = gr.Textbox(label="Enter your query", placeholder="e.g., Show all products sold in 2018.")
103
  submit_btn = gr.Button("Convert & Execute")
104
  sql_output = gr.Textbox(label="Generated SQL Query")
105
  table_output = gr.Textbox(label="Query Results")