thechaiexperiment commited on
Commit
79ceb52
·
verified ·
1 Parent(s): 08e4afd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -31
app.py CHANGED
@@ -5,6 +5,7 @@ import pandas as pd
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
7
  import os
 
8
 
9
  # OpenRouter API Key (Replace with yours)
10
  OPENROUTER_API_KEY = "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"
@@ -20,56 +21,65 @@ if not os.path.exists(DB_PATH):
20
  # Initialize OpenAI client
21
  openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1")
22
 
23
- # Few-shot examples for text-to-SQL
24
- few_shot_examples = [
25
- {"input": "Show all customers from São Paulo.", "output": "SELECT * FROM customers WHERE customer_state = 'SP';"},
26
- {"input": "Find the total sales per product.", "output": "SELECT product_id, SUM(price) FROM order_items GROUP BY product_id;"},
27
- {"input": "List all orders placed in 2017.", "output": "SELECT * FROM orders WHERE order_purchase_timestamp LIKE '2017%';"}
28
- ]
 
 
 
 
 
 
 
 
 
 
29
 
30
  # Function: Convert text to SQL
31
- def text_to_sql(query):
32
- prompt = "Convert the following queries into SQL:\n\n"
33
- for example in few_shot_examples:
34
- prompt += f"Input: {example['input']}\nOutput: {example['output']}\n\n"
35
- prompt += f"Input: {query}\nOutput:"
36
-
 
 
37
  try:
38
  response = openai_client.chat.completions.create(
39
  model=OPENROUTER_MODEL,
40
  messages=[{"role": "system", "content": "You are an SQL expert."}, {"role": "user", "content": prompt}]
41
  )
42
  sql_query = response.choices[0].message.content.strip()
43
-
44
- # Ensure only one query is returned (remove extra text)
45
- sql_query = sql_query.split("\n")[0].strip()
46
  return sql_query
47
  except Exception as e:
48
  return f"Error: {e}"
49
 
50
  # Function: Execute SQL on SQLite database
51
- def execute_sql(sql_query):
52
  try:
53
  conn = sqlite3.connect(DB_PATH)
54
  df = pd.read_sql_query(sql_query, conn)
55
  conn.close()
56
- return df
57
  except Exception as e:
58
- return f"SQL Execution Error: {e}"
59
 
60
  # Function: Generate Dynamic Visualization
61
- def visualize_data(df):
62
  if df.empty or df.shape[1] < 2:
63
  return None
64
 
 
 
 
65
  # Detect numeric columns
66
  numeric_cols = df.select_dtypes(include=['number']).columns
67
  if len(numeric_cols) < 1:
68
  return None
69
 
70
- plt.figure(figsize=(6, 4))
71
- sns.set_theme(style="darkgrid")
72
-
73
  # Choose visualization type dynamically
74
  if len(numeric_cols) == 1: # Single numeric column, assume it's a count metric
75
  sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal")
@@ -90,13 +100,16 @@ def visualize_data(df):
90
  return "chart.png"
91
 
92
  # Gradio UI
93
- def gradio_ui(query):
94
- sql_query = text_to_sql(query)
95
- results = execute_sql(sql_query)
96
- visualization = visualize_data(results) if isinstance(results, pd.DataFrame) else None
97
-
98
- return sql_query, results.to_string(index=False) if isinstance(results, pd.DataFrame) else results, visualization
99
-
 
 
 
100
  with gr.Blocks() as demo:
101
  gr.Markdown("## SQL Explorer: Text-to-SQL with Real Execution & Visualization")
102
  query_input = gr.Textbox(label="Enter your query", placeholder="e.g., Show all products sold in 2018.")
@@ -107,5 +120,4 @@ with gr.Blocks() as demo:
107
 
108
  submit_btn.click(gradio_ui, inputs=[query_input], outputs=[sql_output, table_output, chart_output])
109
 
110
- # Launch
111
- demo.launch()
 
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
7
  import os
8
+ from typing import Optional, Tuple
9
 
10
  # OpenRouter API Key (Replace with yours)
11
  OPENROUTER_API_KEY = "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"
 
21
  # Initialize OpenAI client
22
  openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1")
23
 
24
+ # Function: Fetch database schema
25
+ def fetch_schema(db_path: str) -> str:
26
+ conn = sqlite3.connect(db_path)
27
+ cursor = conn.cursor()
28
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
29
+ tables = cursor.fetchall()
30
+ schema = ""
31
+ for table in tables:
32
+ table_name = table[0]
33
+ cursor.execute(f"PRAGMA table_info({table_name});")
34
+ columns = cursor.fetchall()
35
+ schema += f"Table: {table_name}\n"
36
+ for column in columns:
37
+ schema += f" Column: {column[1]}, Type: {column[2]}\n"
38
+ conn.close()
39
+ return schema
40
 
41
  # Function: Convert text to SQL
42
+ def text_to_sql(query: str, schema: str) -> str:
43
+ prompt = (
44
+ "You are an SQL expert. Given the following database schema:\n\n"
45
+ f"{schema}\n\n"
46
+ "Convert the following query into SQL:\n\n"
47
+ f"Query: {query}\n"
48
+ "SQL:"
49
+ )
50
  try:
51
  response = openai_client.chat.completions.create(
52
  model=OPENROUTER_MODEL,
53
  messages=[{"role": "system", "content": "You are an SQL expert."}, {"role": "user", "content": prompt}]
54
  )
55
  sql_query = response.choices[0].message.content.strip()
 
 
 
56
  return sql_query
57
  except Exception as e:
58
  return f"Error: {e}"
59
 
60
  # Function: Execute SQL on SQLite database
61
+ def execute_sql(sql_query: str) -> Tuple[Optional[pd.DataFrame], Optional[str]]:
62
  try:
63
  conn = sqlite3.connect(DB_PATH)
64
  df = pd.read_sql_query(sql_query, conn)
65
  conn.close()
66
+ return df, None
67
  except Exception as e:
68
+ return None, f"SQL Execution Error: {e}"
69
 
70
  # Function: Generate Dynamic Visualization
71
+ def visualize_data(df: pd.DataFrame) -> Optional[str]:
72
  if df.empty or df.shape[1] < 2:
73
  return None
74
 
75
+ plt.figure(figsize=(6, 4))
76
+ sns.set_theme(style="darkgrid")
77
+
78
  # Detect numeric columns
79
  numeric_cols = df.select_dtypes(include=['number']).columns
80
  if len(numeric_cols) < 1:
81
  return None
82
 
 
 
 
83
  # Choose visualization type dynamically
84
  if len(numeric_cols) == 1: # Single numeric column, assume it's a count metric
85
  sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal")
 
100
  return "chart.png"
101
 
102
  # Gradio UI
103
+ def gradio_ui(query: str) -> Tuple[str, str, Optional[str]]:
104
+ schema = fetch_schema(DB_PATH)
105
+ sql_query = text_to_sql(query, schema)
106
+ df, error = execute_sql(sql_query)
107
+ if error:
108
+ return sql_query, error, None
109
+ visualization = visualize_data(df) if df is not None else None
110
+ return sql_query, df.to_string(index=False), visualization
111
+
112
+ # Launch Gradio App
113
  with gr.Blocks() as demo:
114
  gr.Markdown("## SQL Explorer: Text-to-SQL with Real Execution & Visualization")
115
  query_input = gr.Textbox(label="Enter your query", placeholder="e.g., Show all products sold in 2018.")
 
120
 
121
  submit_btn.click(gradio_ui, inputs=[query_input], outputs=[sql_output, table_output, chart_output])
122
 
123
+ demo.launch()