nileshhanotia commited on
Commit
39e6004
1 Parent(s): 17ee535

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +208 -0
app.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ from functools import lru_cache
5
+ import json
6
+ import mysql.connector
7
+ from mysql.connector import Error
8
+ import os
9
+ import sys
10
+ from datetime import datetime
11
+ import time
12
+
13
+ # Enable GPU if available
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ # Database configuration
17
+ DB_CONFIG = {
18
+ 'host': 'sql12.freemysqlhosting.net',
19
+ 'database': 'sql12740625',
20
+ 'user': 'sql12740625',
21
+ 'password': 'QGG9kdrE4g',
22
+ 'port': 3306,
23
+ 'pool_size': 5,
24
+ 'pool_reset_session': True
25
+ }
26
+
27
+ # Global variables for model and tokenizer
28
+ GLOBAL_MODEL = None
29
+ GLOBAL_TOKENIZER = None
30
+
31
+ def initialize_model():
32
+ """Initialize model and tokenizer globally"""
33
+ global GLOBAL_MODEL, GLOBAL_TOKENIZER
34
+ st.write("Initializing model and tokenizer...")
35
+ start_time = time.time()
36
+
37
+ model_name_sql = "premai-io/prem-1B-SQL"
38
+ GLOBAL_TOKENIZER = AutoTokenizer.from_pretrained(model_name_sql)
39
+ GLOBAL_MODEL = AutoModelForCausalLM.from_pretrained(
40
+ model_name_sql,
41
+ torch_dtype=torch.float32, # Use float32 for CPU
42
+ ).to(device)
43
+
44
+ # Set model to evaluation mode
45
+ GLOBAL_MODEL.eval()
46
+
47
+ st.write(f"Model initialization took {time.time() - start_time:.2f} seconds")
48
+
49
+ def test_db_connection():
50
+ """Test database connection with timeout"""
51
+ try:
52
+ connection = mysql.connector.connect(
53
+ **DB_CONFIG,
54
+ connect_timeout=10
55
+ )
56
+ if connection.is_connected():
57
+ db_info = connection.get_server_info()
58
+ cursor = connection.cursor()
59
+ cursor.execute("SELECT DATABASE();")
60
+ db_name = cursor.fetchone()[0]
61
+ cursor.close()
62
+ connection.close()
63
+ return True, f"Successfully connected to MySQL Server version {db_info}\nDatabase: {db_name}"
64
+ except Error as e:
65
+ return False, f"Error connecting to MySQL database: {e}"
66
+ return False, "Unable to establish database connection"
67
+
68
+ def get_db_connection():
69
+ """Get database connection from pool"""
70
+ return mysql.connector.connect(**DB_CONFIG)
71
+
72
+ def execute_query(query):
73
+ """Execute SQL query with timeout and connection pooling"""
74
+ connection = None
75
+ try:
76
+ connection = get_db_connection()
77
+ cursor = connection.cursor(dictionary=True, buffered=True)
78
+ cursor.execute(query)
79
+ results = cursor.fetchall()
80
+ return results
81
+ except Error as e:
82
+ return f"Error executing query: {e}"
83
+ finally:
84
+ if connection and connection.is_connected():
85
+ cursor.close()
86
+ connection.close()
87
+
88
+ def generate_sql(natural_language_query):
89
+ """Generate SQL query with performance optimizations"""
90
+ try:
91
+ start_time = time.time()
92
+
93
+ schema_info = """
94
+ CREATE TABLE sales (
95
+ pizza_id DECIMAL(8,2) PRIMARY KEY,
96
+ order_id DECIMAL(8,2),
97
+ pizza_name_id VARCHAR(14),
98
+ quantity DECIMAL(4,2),
99
+ order_date DATE,
100
+ order_time VARCHAR(8),
101
+ unit_price DECIMAL(5,2),
102
+ total_price DECIMAL(5,2),
103
+ pizza_size VARCHAR(3),
104
+ pizza_category VARCHAR(7),
105
+ pizza_ingredients VARCHAR(97),
106
+ pizza_name VARCHAR(42)
107
+ );
108
+ """
109
+
110
+ prompt = f"""### Task: Generate a SQL query to answer the following question.
111
+ ### Database Schema:
112
+ {schema_info}
113
+ ### Question: {natural_language_query}
114
+ ### SQL Query:"""
115
+
116
+ inputs = GLOBAL_TOKENIZER(
117
+ prompt,
118
+ return_tensors="pt",
119
+ padding=True,
120
+ truncation=True,
121
+ max_length=512,
122
+ return_attention_mask=True
123
+ )
124
+ inputs = {k: v.to(device) for k, v in inputs.items()}
125
+
126
+ with torch.no_grad():
127
+ outputs = GLOBAL_MODEL.generate(
128
+ input_ids=inputs["input_ids"],
129
+ attention_mask=inputs["attention_mask"],
130
+ max_length=256,
131
+ temperature=0.1,
132
+ do_sample=True,
133
+ top_p=0.95,
134
+ num_return_sequences=1,
135
+ pad_token_id=GLOBAL_TOKENIZER.eos_token_id,
136
+ )
137
+
138
+ generated_query = GLOBAL_TOKENIZER.decode(outputs[0], skip_special_tokens=True)
139
+ sql_query = generated_query.split("### SQL Query:")[-1].strip()
140
+
141
+ st.write(f"SQL generation took {time.time() - start_time:.2f} seconds")
142
+ return sql_query
143
+
144
+ except Exception as e:
145
+ return f"Error generating SQL query: {str(e)}"
146
+
147
+ def format_result(query_result):
148
+ """Format query results efficiently"""
149
+ if isinstance(query_result, str) and "Error" in query_result:
150
+ return query_result
151
+
152
+ if not query_result:
153
+ return "No results found."
154
+
155
+ # Use list comprehension for better performance
156
+ if len(query_result) == 1:
157
+ return "\n".join(f"{k}: {v}" for k, v in query_result[0].items())
158
+
159
+ results = [f"Found {len(query_result)} results:\n"]
160
+ for i, row in enumerate(query_result[:5], 1):
161
+ results.append(f"Result {i}:")
162
+ results.extend(f"{k}: {v}" for k, v in row.items())
163
+ results.append("")
164
+
165
+ if len(query_result) > 5:
166
+ results.append(f"(Showing first 5 of {len(query_result)} results)")
167
+
168
+ return "\n".join(results)
169
+
170
+ def main():
171
+ """Main function with Streamlit UI components"""
172
+ st.title("Natural Language to SQL Query")
173
+ st.write("Ask questions about pizza sales data in plain English.")
174
+
175
+ # Test and display database connection status
176
+ db_success, db_message = test_db_connection()
177
+ st.write(db_message)
178
+
179
+ if not db_success:
180
+ st.write("Could not connect to the database. Exiting.")
181
+ return
182
+
183
+ # Initialize model
184
+ initialize_model()
185
+
186
+ # Input field for natural language query
187
+ natural_language_query = st.text_input("Enter your question", placeholder="e.g., What were the total sales for each pizza category?")
188
+
189
+ if st.button("Generate and Execute Query"):
190
+ if natural_language_query:
191
+ # Generate SQL query
192
+ sql_query = generate_sql(natural_language_query)
193
+ st.write("Generated SQL Query:", sql_query)
194
+
195
+ # Execute the generated query
196
+ query_result = execute_query(sql_query)
197
+ formatted_result = format_result(query_result)
198
+
199
+ st.write("Query Result:")
200
+ st.code(json.dumps(query_result, indent=2))
201
+
202
+ st.write("Human-Readable Response:")
203
+ st.text(formatted_result)
204
+ else:
205
+ st.write("Please enter a query.")
206
+
207
+ if __name__ == "__main__":
208
+ main()