import os import tempfile import gradio as gr import torch from peft import PeftConfig, PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer from schema_extractor import SQLiteSchemaExtractor # Load model and tokenizer def load_model(): config = PeftConfig.from_pretrained("Rajan/training_run") tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-350M") base_model = AutoModelForCausalLM.from_pretrained("NumbersStation/nsql-350M") model = PeftModel.from_pretrained(base_model, "Rajan/training_run") return model, tokenizer # Extract and correct SQL from generated text def extract_and_correct_sql(text, correct=False): lines = text.splitlines() start_index = 0 for i, line in enumerate(lines): if line.strip().upper().startswith("SELECT"): start_index = i break generated_sql = "\n".join(lines[start_index:]) if correct: if not generated_sql.strip().endswith(";"): generated_sql = generated_sql.strip() + ";" return generated_sql # Function to handle file upload and schema extraction def upload_and_extract_schema(db_file): if db_file is None: return "Please upload a database file", None try: # Get the file path directly from Gradio temp_db_path = db_file.name extractor = SQLiteSchemaExtractor(temp_db_path) schema = extractor.get_schema() return schema, temp_db_path except Exception as e: return f"Error extracting schema: {str(e)}", None # Function to handle chat interaction with streaming effect def generate_sql(question, schema, db_path, chat_history): if db_path is None or not schema: return ( chat_history + [ {"role": "user", "content": question}, {"role": "assistant", "content": "Please upload a database file first"}, ], None, ) try: # Load model model, tokenizer = load_model() # Format prompt prompt_format = """ {} -- Using valid SQLite, answer the following questions for the tables provided above. {} SELECT""" # Format the prompt with schema and question prompt = prompt_format.format(schema, question) # Generate SQL input_ids = tokenizer(prompt, return_tensors="pt").input_ids generated_ids = model.generate(input_ids, max_length=500) generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) # Extract SQL sql_query = extract_and_correct_sql(generated_text, correct=True) # Update history using dictionary format new_history = chat_history + [ {"role": "user", "content": question}, {"role": "assistant", "content": sql_query}, ] return new_history, sql_query except Exception as e: error_message = f"Error: {str(e)}" return ( chat_history + [ {"role": "user", "content": question}, {"role": "assistant", "content": error_message}, ], None, ) # Function for streaming SQL generation effect def stream_sql(question, schema, db_path, chat_history): # First add the user message to chat yield chat_history + [{"role": "user", "content": question}], "" if db_path is None or not schema: yield chat_history + [ {"role": "user", "content": question}, {"role": "assistant", "content": "Please upload a database file first"}, ], "Please upload a database file first" return try: # Load model model, tokenizer = load_model() # Format prompt prompt_format = """ {} -- Using valid SQLite, answer the following questions for the tables provided above. {} SELECT""" # Format the prompt with schema and question prompt = prompt_format.format(schema, question) # Generate SQL input_ids = tokenizer(prompt, return_tensors="pt").input_ids generated_ids = model.generate(input_ids, max_length=500) generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) # Extract SQL sql_query = extract_and_correct_sql(generated_text, correct=True) # Fixed medium speed (0.03 seconds delay) import time delay = 0.03 # 30ms - normal typing speed # Stream the SQL query character by character for effect partial_sql = "" for char in sql_query: partial_sql += char # Update chat history and SQL display with each character yield chat_history + [ {"role": "user", "content": question}, {"role": "assistant", "content": partial_sql}, ], partial_sql time.sleep(delay) # Medium speed typing effect except Exception as e: error_message = f"Error: {str(e)}" yield chat_history + [ {"role": "user", "content": question}, {"role": "assistant", "content": error_message}, ], error_message # Main application def create_app(): with gr.Blocks(title="SQL Query Generator", theme=gr.themes.Soft()) as app: gr.Markdown("# SQL Query Generator") gr.Markdown( "Upload a SQLite database, ask questions, and get SQL queries automatically generated" ) # Store database path db_path = gr.State(value=None) with gr.Row(): with gr.Column(scale=1): # File upload section file_input = gr.File(label="Upload SQLite Database (.db file)") extract_btn = gr.Button("Extract Schema", variant="primary") # Schema display schema_output = gr.Textbox( label="Database Schema", lines=10, interactive=False ) with gr.Column(scale=2): # Chat interface chatbot = gr.Chatbot( label="Query Conversation", height=400, type="messages" ) with gr.Row(): question_input = gr.Textbox( label="Ask a question about your data", placeholder="e.g., Show me the top 10 most sold items", ) submit_btn = gr.Button("Generate SQL", variant="primary") # SQL output display sql_output = gr.Code( language="sql", label="Generated SQL Query", interactive=False ) # Event handlers extract_btn.click( fn=upload_and_extract_schema, inputs=[file_input], outputs=[schema_output, db_path], ) submit_btn.click( fn=stream_sql, inputs=[question_input, schema_output, db_path, chatbot], outputs=[chatbot, sql_output], api_name="generate", queue=True, ) # Also trigger on enter key question_input.submit( fn=stream_sql, inputs=[question_input, schema_output, db_path, chatbot], outputs=[chatbot, sql_output], api_name="generate_on_submit", queue=True, ) return app # Launch the app if __name__ == "__main__": app = create_app() app.launch(share=True)