thechaiexperiment commited on
Commit
4aa996b
·
verified ·
1 Parent(s): 350e55d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -25
app.py CHANGED
@@ -10,26 +10,33 @@ import os
10
  OPENROUTER_API_KEY = "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"
11
  OPENROUTER_MODEL = "sophosympatheia/rogue-rose-103b-v0.2:free"
12
 
13
- # Hugging Face Space path
14
- DB_PATH = "ecommerce.db"
 
15
 
16
  # Ensure dataset exists
17
- if not os.path.exists(DB_PATH):
18
- os.system("wget https://your-dataset-link.com/ecommerce.db -O ecommerce.db") # Replace with actual dataset link
19
 
20
  # Initialize OpenAI client
21
  openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1")
22
 
23
- # Few-shot examples for text-to-SQL
24
  few_shot_examples = [
25
- {"input": "Show all customers from São Paulo.", "output": "SELECT * FROM customers WHERE customer_state = 'SP';"},
26
- {"input": "Find the total sales per product.", "output": "SELECT product_id, SUM(price) FROM order_items GROUP BY product_id;"},
27
- {"input": "List all orders placed in 2017.", "output": "SELECT * FROM orders WHERE order_purchase_timestamp LIKE '2017%';"}
 
 
 
 
 
28
  ]
29
 
30
- # Function: Convert text to SQL
 
31
  def text_to_sql(query):
32
- prompt = "Convert the following queries into SQL:\n\n"
33
  for example in few_shot_examples:
34
  prompt += f"Input: {example['input']}\nOutput: {example['output']}\n\n"
35
  prompt += f"Input: {query}\nOutput:"
@@ -37,33 +44,31 @@ def text_to_sql(query):
37
  try:
38
  response = openai_client.chat.completions.create(
39
  model=OPENROUTER_MODEL,
40
- messages=[{"role": "system", "content": "You are an SQL expert."}, {"role": "user", "content": prompt}]
 
41
  )
42
  sql_query = response.choices[0].message.content.strip()
43
- sql_query = sql_query.split("\n")[0] # Take only the first line if multiple lines exist
44
- sql_query = sql_query.replace("mathchar", "").rstrip(";") # Remove unwanted text
45
- return sql_query
46
  except Exception as e:
47
  return f"Error: {e}"
48
 
49
- # Function: Execute SQL on SQLite database
 
50
  def execute_sql(sql_query):
51
  try:
52
- sql_query = sql_query.strip().rstrip(";") # Remove trailing semicolons
53
- sql_query = sql_query.replace("mathchar", "") # Remove any bad tokens
54
- conn = sqlite3.connect(DB_PATH)
55
  df = pd.read_sql_query(sql_query, conn)
56
  conn.close()
57
  return df
58
  except Exception as e:
59
  return f"SQL Execution Error: {e}"
60
 
61
- # Function: Generate Dynamic Visualization
 
62
  def visualize_data(df):
63
  if df.empty or df.shape[1] < 2:
64
  return None
65
 
66
- # Detect numeric columns
67
  numeric_cols = df.select_dtypes(include=['number']).columns
68
  if len(numeric_cols) < 1:
69
  return None
@@ -71,17 +76,16 @@ def visualize_data(df):
71
  plt.figure(figsize=(6, 4))
72
  sns.set_theme(style="darkgrid")
73
 
74
- # Choose visualization type dynamically
75
- if len(numeric_cols) == 1: # Single numeric column, assume it's a count metric
76
  sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal")
77
  plt.title(f"Distribution of {numeric_cols[0]}")
78
- elif len(numeric_cols) == 2: # Two numeric columns, assume X-Y plot
79
  sns.scatterplot(x=df[numeric_cols[0]], y=df[numeric_cols[1]], color="blue")
80
  plt.title(f"{numeric_cols[0]} vs {numeric_cols[1]}")
81
- elif df.shape[0] < 10: # If rows are few, prefer pie chart
82
  plt.pie(df[numeric_cols[0]], labels=df.iloc[:, 0], autopct='%1.1f%%', colors=sns.color_palette("pastel"))
83
  plt.title(f"Proportion of {numeric_cols[0]}")
84
- else: # Default: Bar chart for categories + values
85
  sns.barplot(x=df.iloc[:, 0], y=df[numeric_cols[0]], palette="coolwarm")
86
  plt.xticks(rotation=45)
87
  plt.title(f"{df.columns[0]} vs {numeric_cols[0]}")
 
10
  OPENROUTER_API_KEY = "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"
11
  OPENROUTER_MODEL = "sophosympatheia/rogue-rose-103b-v0.2:free"
12
 
13
+
14
+ # Database Path
15
+ db_path = "ecommerce.db"
16
 
17
  # Ensure dataset exists
18
+ if not os.path.exists(db_path):
19
+ print("Database file not found! Please upload ecommerce.db.")
20
 
21
  # Initialize OpenAI client
22
  openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1")
23
 
24
+ # Updated Few-Shot Examples with SQLite-Compatible Queries
25
  few_shot_examples = [
26
+ {"input": "Find the busiest months for orders.",
27
+ "output": "SELECT strftime('%m', order_purchase_timestamp) AS month, COUNT(*) AS order_count FROM orders GROUP BY month ORDER BY order_count DESC;"},
28
+ {"input": "Show all customers from São Paulo.",
29
+ "output": "SELECT * FROM customers WHERE customer_state = 'SP';"},
30
+ {"input": "Find the total sales per product.",
31
+ "output": "SELECT product_id, SUM(price) FROM order_items GROUP BY product_id;"},
32
+ {"input": "List all orders placed in 2017.",
33
+ "output": "SELECT * FROM orders WHERE order_purchase_timestamp LIKE '2017%';"}
34
  ]
35
 
36
+ # Function: Convert Text to SQL
37
+
38
  def text_to_sql(query):
39
+ prompt = "Convert the following queries into SQLite-compatible SQL:\n\n"
40
  for example in few_shot_examples:
41
  prompt += f"Input: {example['input']}\nOutput: {example['output']}\n\n"
42
  prompt += f"Input: {query}\nOutput:"
 
44
  try:
45
  response = openai_client.chat.completions.create(
46
  model=OPENROUTER_MODEL,
47
+ messages=[{"role": "system", "content": "You are an SQLite expert."},
48
+ {"role": "user", "content": prompt}]
49
  )
50
  sql_query = response.choices[0].message.content.strip()
51
+ return sql_query if sql_query.lower().startswith("select") else f"Error: Invalid SQL generated - {sql_query}"
 
 
52
  except Exception as e:
53
  return f"Error: {e}"
54
 
55
+ # Function: Execute SQL on SQLite Database
56
+
57
  def execute_sql(sql_query):
58
  try:
59
+ conn = sqlite3.connect(db_path)
 
 
60
  df = pd.read_sql_query(sql_query, conn)
61
  conn.close()
62
  return df
63
  except Exception as e:
64
  return f"SQL Execution Error: {e}"
65
 
66
+ # Function: Generate Data Visualization
67
+
68
  def visualize_data(df):
69
  if df.empty or df.shape[1] < 2:
70
  return None
71
 
 
72
  numeric_cols = df.select_dtypes(include=['number']).columns
73
  if len(numeric_cols) < 1:
74
  return None
 
76
  plt.figure(figsize=(6, 4))
77
  sns.set_theme(style="darkgrid")
78
 
79
+ if len(numeric_cols) == 1:
 
80
  sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal")
81
  plt.title(f"Distribution of {numeric_cols[0]}")
82
+ elif len(numeric_cols) == 2:
83
  sns.scatterplot(x=df[numeric_cols[0]], y=df[numeric_cols[1]], color="blue")
84
  plt.title(f"{numeric_cols[0]} vs {numeric_cols[1]}")
85
+ elif df.shape[0] < 10:
86
  plt.pie(df[numeric_cols[0]], labels=df.iloc[:, 0], autopct='%1.1f%%', colors=sns.color_palette("pastel"))
87
  plt.title(f"Proportion of {numeric_cols[0]}")
88
+ else:
89
  sns.barplot(x=df.iloc[:, 0], y=df[numeric_cols[0]], palette="coolwarm")
90
  plt.xticks(rotation=45)
91
  plt.title(f"{df.columns[0]} vs {numeric_cols[0]}")