Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| from functools import lru_cache | |
| import json | |
| import mysql.connector | |
| from mysql.connector import Error | |
| import os | |
| import sys | |
| from datetime import datetime | |
| import time | |
| import logging | |
| import threading | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| ) | |
| # Enable GPU if available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Database configuration | |
| DB_CONFIG = { | |
| 'host': 'sql12.freemysqlhosting.net', | |
| 'database': 'sql12740625', | |
| 'user': 'sql12740625', | |
| 'password': 'QGG9kdrE4g', | |
| 'port': 3306, | |
| 'pool_size': 5, | |
| 'pool_reset_session': True | |
| } | |
| # Global variables for model and tokenizer | |
| GLOBAL_MODEL = None | |
| GLOBAL_TOKENIZER = None | |
| db_connection_status = False | |
| def initialize_model(): | |
| """Initialize model and tokenizer globally""" | |
| global GLOBAL_MODEL, GLOBAL_TOKENIZER | |
| logging.info("Initializing model and tokenizer...") | |
| st.write("Initializing model and tokenizer...") | |
| start_time = time.time() | |
| model_name_sql = "premai-io/prem-1B-SQL" | |
| GLOBAL_TOKENIZER = AutoTokenizer.from_pretrained(model_name_sql) | |
| GLOBAL_MODEL = AutoModelForCausalLM.from_pretrained( | |
| model_name_sql, | |
| torch_dtype=torch.float32, # Use float32 for CPU | |
| ).to(device) | |
| # Set model to evaluation mode | |
| GLOBAL_MODEL.eval() | |
| logging.info(f"Model initialization took {time.time() - start_time:.2f} seconds") | |
| def test_db_connection(): | |
| """Test database connection with timeout""" | |
| global db_connection_status | |
| try: | |
| logging.info("Testing database connection...") | |
| connection = mysql.connector.connect( | |
| **DB_CONFIG, | |
| connect_timeout=10 | |
| ) | |
| if connection.is_connected(): | |
| db_info = connection.get_server_info() | |
| cursor = connection.cursor() | |
| cursor.execute("SELECT DATABASE();") | |
| db_name = cursor.fetchone()[0] | |
| cursor.close() | |
| connection.close() | |
| db_connection_status = True | |
| logging.info(f"Successfully connected to MySQL Server version {db_info} - Database: {db_name}") | |
| return True, f"Successfully connected to MySQL Server version {db_info}\nDatabase: {db_name}" | |
| except Error as e: | |
| db_connection_status = False | |
| logging.error(f"Error connecting to MySQL database: {e}") | |
| return False, f"Error connecting to MySQL database: {e}" | |
| return False, "Unable to establish database connection" | |
| def get_db_connection(): | |
| """Get database connection from pool""" | |
| logging.info("Getting database connection from pool...") | |
| return mysql.connector.connect(**DB_CONFIG) | |
| def execute_query(query): | |
| """Execute SQL query with timeout and connection pooling""" | |
| logging.info(f"Executing query: {query}") | |
| connection = None | |
| try: | |
| connection = get_db_connection() | |
| cursor = connection.cursor(dictionary=True, buffered=True) | |
| cursor.execute(query) | |
| results = cursor.fetchall() | |
| logging.info(f"Query executed successfully, retrieved {len(results)} records.") | |
| return results | |
| except Error as e: | |
| logging.error(f"Error executing query: {e}") | |
| return f"Error executing query: {e}" | |
| finally: | |
| if connection and connection.is_connected(): | |
| cursor.close() | |
| connection.close() | |
| logging.info("Database connection closed.") | |
| def generate_sql(natural_language_query): | |
| """Generate SQL query with performance optimizations""" | |
| logging.info(f"Generating SQL for query: {natural_language_query}") | |
| try: | |
| start_time = time.time() | |
| schema_info = """ | |
| CREATE TABLE sales ( | |
| pizza_id DECIMAL(8,2) PRIMARY KEY, | |
| order_id DECIMAL(8,2), | |
| pizza_name_id VARCHAR(14), | |
| quantity DECIMAL(4,2), | |
| order_date DATE, | |
| order_time VARCHAR(8), | |
| unit_price DECIMAL(5,2), | |
| total_price DECIMAL(5,2), | |
| pizza_size VARCHAR(3), | |
| pizza_category VARCHAR(7), | |
| pizza_ingredients VARCHAR(97), | |
| pizza_name VARCHAR(42) | |
| ); | |
| """ | |
| prompt = f"""### Task: Generate a SQL query to answer the following question. | |
| ### Database Schema: | |
| {schema_info} | |
| ### Question: {natural_language_query} | |
| ### SQL Query:""" | |
| inputs = GLOBAL_TOKENIZER( | |
| prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_attention_mask=True | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = GLOBAL_MODEL.generate( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| max_length=256, | |
| temperature=0.1, | |
| do_sample=True, | |
| top_p=0.95, | |
| num_return_sequences=1, | |
| pad_token_id=GLOBAL_TOKENIZER.eos_token_id, | |
| ) | |
| generated_query = GLOBAL_TOKENIZER.decode(outputs[0], skip_special_tokens=True) | |
| sql_query = generated_query.split("### SQL Query:")[-1].strip() | |
| logging.info(f"SQL generation took {time.time() - start_time:.2f} seconds") | |
| return sql_query | |
| except Exception as e: | |
| logging.error(f"Error generating SQL query: {str(e)}") | |
| return f"Error generating SQL query: {str(e)}" | |
| def format_result(query_result): | |
| """Format query results efficiently""" | |
| if isinstance(query_result, str) and "Error" in query_result: | |
| logging.warning(f"Query result contains an error: {query_result}") | |
| return query_result | |
| if not query_result: | |
| logging.info("No results found.") | |
| return "No results found." | |
| # Use list comprehension for better performance | |
| if len(query_result) == 1: | |
| return "\n".join(f"{k}: {v}" for k, v in query_result[0].items()) | |
| results = [f"Found {len(query_result)} results:\n"] | |
| for i, row in enumerate(query_result[:5], 1): | |
| results.append(f"Result {i}:") | |
| results.extend(f"{k}: {v}" for k, v in row.items()) | |
| results.append("") | |
| if len(query_result) > 5: | |
| results.append(f"(Showing first 5 of {len(query_result)} results)") | |
| return "\n".join(results) | |
| def check_live_connection(): | |
| """Check the database connection status periodically.""" | |
| while True: | |
| test_db_connection() | |
| time.sleep(10) # Check every 10 seconds | |
| def main(): | |
| """Main function with Streamlit UI components""" | |
| st.title("Natural Language to SQL Query") | |
| st.write("Ask questions about pizza sales data in plain English.") | |
| # Start the live connection check in a separate thread | |
| threading.Thread(target=check_live_connection, daemon=True).start() | |
| # Test and display database connection status | |
| if db_connection_status: | |
| st.success("Database connection is live.") | |
| else: | |
| st.error("Database connection is not live.") | |
| # Initialize model | |
| initialize_model() | |
| # Input field for natural language query | |
| natural_language_query = st.text_input("Enter your question", placeholder="e.g., What were the total sales for each pizza category?") | |
| if st.button("Generate and Execute Query"): | |
| if natural_language_query: | |
| # Generate SQL query | |
| sql_query = generate_sql(natural_language_query) | |
| st.write("Generated SQL Query:", sql_query) | |
| # Execute the generated query | |
| query_result = execute_query(sql_query) | |
| formatted_result = format_result(query_result) | |
| st.write("Query Result:") | |
| st.code(json.dumps(query_result, indent=2)) | |
| st.write("Human-Readable Response:") | |
| st.text(formatted_result) | |
| else: | |
| logging.warning("User did not enter a query.") | |
| st.write("Please enter a query.") | |
| if __name__ == "__main__": | |
| main() | |