thechaiexperiment commited on
Commit
3ddc773
·
verified ·
1 Parent(s): 5566670

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -52
app.py CHANGED
@@ -1,76 +1,107 @@
1
  import gradio as gr
2
  import openai
 
 
 
 
3
  import os
4
 
5
- # OpenRouter API Key
6
  OPENROUTER_API_KEY = "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"
7
  OPENROUTER_MODEL = "sophosympatheia/rogue-rose-103b-v0.2:free"
8
 
9
- # Initialize OpenAI client with OpenRouter base URL
10
- print(f"Using API Key: {OPENROUTER_API_KEY}")
11
- openai_client = openai.OpenAI(
12
- api_key=OPENROUTER_API_KEY,
13
- base_url="https://openrouter.ai/api/v1" # OpenRouter API endpoint
14
- )
15
 
16
- # Few-shot examples for text-to-SQL conversion
 
 
 
 
 
 
 
17
  few_shot_examples = [
18
- {
19
- "input": "Show all customers from the USA.",
20
- "output": "SELECT * FROM customers WHERE country = 'USA';"
21
- },
22
- {
23
- "input": "Find the total sales for each product category.",
24
- "output": "SELECT product_category, SUM(sales) AS total_sales FROM sales GROUP BY product_category;"
25
- },
26
- {
27
- "input": "List all orders placed in 2023.",
28
- "output": "SELECT * FROM orders WHERE YEAR(order_date) = 2023;"
29
- }
30
  ]
31
 
 
32
  def text_to_sql(query):
33
- # Construct the prompt with few-shot examples
34
- prompt = "Convert the following natural language queries to SQL:\n\n"
35
  for example in few_shot_examples:
36
  prompt += f"Input: {example['input']}\nOutput: {example['output']}\n\n"
37
  prompt += f"Input: {query}\nOutput:"
38
 
39
- print("Sending query to OpenRouter API...")
40
  try:
41
  response = openai_client.chat.completions.create(
42
  model=OPENROUTER_MODEL,
43
- messages=[
44
- {
45
- "role": "system",
46
- "content": "You are a helpful assistant. Your task is to convert natural language queries into SQL queries. "
47
- "Use the provided examples as a guide. If the query cannot be converted into SQL, say 'I cannot convert this query into SQL.'"
48
- },
49
- {
50
- "role": "user",
51
- "content": prompt
52
- }
53
- ]
54
  )
55
- print("Received response from OpenRouter API.")
56
- return response.choices[0].message.content
57
  except Exception as e:
58
- print(f"Error calling OpenRouter API: {e}")
59
  return f"Error: {e}"
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # Gradio UI
62
- def gradio_ui():
63
- with gr.Blocks() as demo:
64
- gr.Markdown("## Text-to-SQL Converter. Enter a natural language query and get the corresponding SQL query!")
65
- query_input = gr.Textbox(label="Enter your query")
66
- submit_btn = gr.Button("Convert to SQL")
67
- output = gr.Textbox(label="SQL Query")
68
-
69
- submit_btn.click(text_to_sql, inputs=[query_input], outputs=[output])
70
-
71
- return demo
72
-
73
- demo = gradio_ui()
74
-
75
- print("Launching Gradio UI...")
76
- demo.launch()
 
 
 
 
 
1
  import gradio as gr
2
  import openai
3
+ import sqlite3
4
+ import pandas as pd
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
  import os
8
 
9
+ # OpenRouter API Key (Replace with yours)
10
  OPENROUTER_API_KEY = "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"
11
  OPENROUTER_MODEL = "sophosympatheia/rogue-rose-103b-v0.2:free"
12
 
13
+ # Hugging Face Space path
14
+ DB_PATH = "ecommerce.db"
 
 
 
 
15
 
16
+ # Ensure dataset exists
17
+ if not os.path.exists(DB_PATH):
18
+ os.system("wget https://your-dataset-link.com/ecommerce.db -O ecommerce.db") # Replace with actual dataset link
19
+
20
+ # Initialize OpenAI client
21
+ openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1")
22
+
23
+ # Few-shot examples for text-to-SQL
24
  few_shot_examples = [
25
+ {"input": "Show all customers from São Paulo.", "output": "SELECT * FROM customers WHERE customer_state = 'SP';"},
26
+ {"input": "Find the total sales per product.", "output": "SELECT product_id, SUM(price) FROM order_items GROUP BY product_id;"},
27
+ {"input": "List all orders placed in 2017.", "output": "SELECT * FROM orders WHERE order_purchase_timestamp LIKE '2017%';"}
 
 
 
 
 
 
 
 
 
28
  ]
29
 
30
+ # Function: Convert text to SQL
31
  def text_to_sql(query):
32
+ prompt = "Convert the following queries into SQL:\n\n"
 
33
  for example in few_shot_examples:
34
  prompt += f"Input: {example['input']}\nOutput: {example['output']}\n\n"
35
  prompt += f"Input: {query}\nOutput:"
36
 
 
37
  try:
38
  response = openai_client.chat.completions.create(
39
  model=OPENROUTER_MODEL,
40
+ messages=[{"role": "system", "content": "You are an SQL expert."}, {"role": "user", "content": prompt}]
 
 
 
 
 
 
 
 
 
 
41
  )
42
+ return response.choices[0].message.content.strip()
 
43
  except Exception as e:
 
44
  return f"Error: {e}"
45
 
46
+ # Function: Execute SQL on SQLite database
47
+ def execute_sql(sql_query):
48
+ try:
49
+ conn = sqlite3.connect(DB_PATH)
50
+ df = pd.read_sql_query(sql_query, conn)
51
+ conn.close()
52
+ return df
53
+ except Exception as e:
54
+ return f"SQL Execution Error: {e}"
55
+
56
+ # Function: Generate Dynamic Visualization
57
+ def visualize_data(df):
58
+ if df.empty or df.shape[1] < 2:
59
+ return None
60
+
61
+ # Detect numeric columns
62
+ numeric_cols = df.select_dtypes(include=['number']).columns
63
+ if len(numeric_cols) < 1:
64
+ return None
65
+
66
+ plt.figure(figsize=(6, 4))
67
+ sns.set_theme(style="darkgrid")
68
+
69
+ # Choose visualization type dynamically
70
+ if len(numeric_cols) == 1: # Single numeric column, assume it's a count metric
71
+ sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal")
72
+ plt.title(f"Distribution of {numeric_cols[0]}")
73
+ elif len(numeric_cols) == 2: # Two numeric columns, assume X-Y plot
74
+ sns.scatterplot(x=df[numeric_cols[0]], y=df[numeric_cols[1]], color="blue")
75
+ plt.title(f"{numeric_cols[0]} vs {numeric_cols[1]}")
76
+ elif df.shape[0] < 10: # If rows are few, prefer pie chart
77
+ plt.pie(df[numeric_cols[0]], labels=df.iloc[:, 0], autopct='%1.1f%%', colors=sns.color_palette("pastel"))
78
+ plt.title(f"Proportion of {numeric_cols[0]}")
79
+ else: # Default: Bar chart for categories + values
80
+ sns.barplot(x=df.iloc[:, 0], y=df[numeric_cols[0]], palette="coolwarm")
81
+ plt.xticks(rotation=45)
82
+ plt.title(f"{df.columns[0]} vs {numeric_cols[0]}")
83
+
84
+ plt.tight_layout()
85
+ plt.savefig("chart.png")
86
+ return "chart.png"
87
+
88
  # Gradio UI
89
+ def gradio_ui(query):
90
+ sql_query = text_to_sql(query)
91
+ results = execute_sql(sql_query)
92
+ visualization = visualize_data(results) if isinstance(results, pd.DataFrame) else None
93
+
94
+ return sql_query, results.to_string(index=False) if isinstance(results, pd.DataFrame) else results, visualization
95
+
96
+ with gr.Blocks() as demo:
97
+ gr.Markdown("## SQL Explorer: Text-to-SQL with Real Execution & Visualization")
98
+ query_input = gr.Textbox(label="Enter your query", placeholder="e.g., Show all products sold in 2018.")
99
+ submit_btn = gr.Button("Convert & Execute")
100
+ sql_output = gr.Textbox(label="Generated SQL Query")
101
+ table_output = gr.Textbox(label="Query Results")
102
+ chart_output = gr.Image(label="Data Visualization")
103
+
104
+ submit_btn.click(gradio_ui, inputs=[query_input], outputs=[sql_output, table_output, chart_output])
105
+
106
+ # Launch
107
+ demo.launch()