import gradio as gr
import mysql.connector
from mysql.connector import Error
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the model and tokenizer
model_name = "premai-io/prem-1B-SQL"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

def generate_sql(natural_language_query):
    """Generate SQL query from natural language."""
    # Define your schema information
    schema_info = """
    CREATE TABLE sales (
      pizza_id DECIMAL(8,2) PRIMARY KEY,
      order_id DECIMAL(8,2),
      pizza_name_id VARCHAR(14),
      quantity DECIMAL(4,2),
      order_date DATE,
      order_time VARCHAR(8),
      unit_price DECIMAL(5,2),
      total_price DECIMAL(5,2),
      pizza_size VARCHAR(3),
      pizza_category VARCHAR(7),
      pizza_ingredients VARCHAR(97),
      pizza_name VARCHAR(42)
    );
    """

    # Construct the prompt
    prompt = f"""### Task: Generate a SQL query to answer the following question.

### Database Schema:
{schema_info}

### Question: {natural_language_query}

### SQL Query:"""

    # Tokenize and generate
    inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
    outputs = model.generate(
        inputs["input_ids"],
        max_length=512,
        temperature=0.1,
        do_sample=True,
        top_p=0.95,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id
    )

    # Decode and clean up the response
    generated_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
    sql_query = generated_query.split("### SQL Query:")[-1].strip()
    
    return sql_query

def main():
    # Gradio interface setup
    iface = gr.Interface(
        fn=generate_sql,
        inputs="text",
        outputs="text",
        title="Natural Language to SQL Query Generator",
        description="Enter a natural language query to generate the corresponding SQL query."
    )
    iface.launch()

if __name__ == "__main__":
    main()