text2sql / app.py
Balaprime's picture
Update app.py
cfb9d4e verified
from dotenv import load_dotenv
import os
from sentence_transformers import SentenceTransformer
import gradio as gr
from sklearn.metrics.pairwise import cosine_similarity
from groq import Groq
import sqlite3
import pandas as pd
load_dotenv()
api = os.getenv("groq_api_key")
# 🔹 STEP 1: Create a sample in-memory SQLite database with mock data
def setup_database():
conn = sqlite3.connect("college.db")
cursor = conn.cursor()
# Drop existing tables
cursor.execute("DROP TABLE IF EXISTS student;")
cursor.execute("DROP TABLE IF EXISTS employee;")
cursor.execute("DROP TABLE IF EXISTS course_info;")
# Student table
cursor.execute("""
CREATE TABLE student (
student_id INTEGER,
first_name TEXT,
last_name TEXT,
date_of_birth TEXT,
email TEXT,
phone_number TEXT,
major TEXT,
year_of_enrollment INTEGER
);
""")
cursor.execute("INSERT INTO student VALUES (1, 'Alice', 'Smith', '2000-05-01', 'alice@example.com', '1234567890', 'Computer Science', 2019);")
# Employee table
cursor.execute("""
CREATE TABLE employee (
employee_id INTEGER,
first_name TEXT,
last_name TEXT,
email TEXT,
department TEXT,
position TEXT,
salary REAL,
date_of_joining TEXT
);
""")
cursor.execute("INSERT INTO employee VALUES (101, 'John', 'Doe', 'john@college.edu', 'CSE', 'Professor', 80000, '2015-08-20');")
# Course table
cursor.execute("""
CREATE TABLE course_info (
course_id INTEGER,
course_name TEXT,
course_code TEXT,
instructor_id INTEGER,
department TEXT,
credits INTEGER,
semester TEXT
);
""")
cursor.execute("INSERT INTO course_info VALUES (501, 'AI Basics', 'CS501', 101, 'CSE', 4, 'Fall');")
conn.commit()
conn.close()
# Call it once to setup
setup_database()
# 🔹 STEP 2: Embedding & LLM logic (unchanged mostly)
def create_metadata_embeddings():
student = """Table: student...""" # (same as your original metadata)
employee = """Table: employee..."""
course = """Table: course_info..."""
metadata_list = [student, employee, course]
model = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = model.encode(metadata_list)
return embeddings, model, student, employee, course
def find_best_fit(embeddings, model, user_query, student, employee, course):
query_embedding = model.encode([user_query])
similarities = cosine_similarity(query_embedding, embeddings)
best_match_table = similarities.argmax()
return [student, employee, course][best_match_table]
def create_prompt(user_query, table_metadata):
system_prompt = """You are a SQL query generator specialized in generating SQL queries for a single table at a time. Your task is to accurately convert natural language queries into SQL statements based on the user's intent and the provided table metadata.
Rules:
- Multi-Table Queries Allowed: You can generate queries involving multiple tables using appropriate SQL JOIN operations, based on the provided metadata.
- Join Logic: Use INNER JOIN, LEFT JOIN, or other appropriate joins based on logical relationships (e.g., foreign keys like `student_id`, `instructor_id`, etc.) inferred from the metadata.
- Metadata-Based Validation: Always ensure the generated query matches the table names, columns, and data types as described in the metadata.
- User Intent: Accurately capture the user's requirements such as filters, sorting, aggregations, and selections across one or more tables.
- SQL Syntax: Use standard SQL syntax that is compatible with most relational database systems.
- Output Format: Provide only the SQL query in a single line. Do not include explanations or any extra text.
Input Format:
User Query: The user's natural language request.
Table Metadata: The structure of the relevant table, including the table name, column names, and data types.
Output Format:
SQL Query: A valid SQL query formatted for readability.
Do not output anything else except the SQL query.Not even a single word extra.Ouput the whole query in a single line only.
You are ready to generate SQL queries based on the user input and table metadata."""
user_prompt = f"User Query: {user_query}\nTable Metadata: {table_metadata}"
return system_prompt, user_prompt
def generate_sql(system_prompt, user_prompt):
client = Groq(api_key=api)
chat_completion = client.chat.completions.create(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
model="llama3-70b-8192",
)
res = chat_completion.choices[0].message.content.strip()
if res.lower().startswith("select"):
return res
else:
return None
# 🔹 STEP 3: Execute SQL and return results
def execute_sql(sql_query):
try:
conn = sqlite3.connect("college.db")
df = pd.read_sql_query(sql_query, conn)
conn.close()
return df
except Exception as e:
return str(e)
# 🔹 STEP 4: Final combined response
def response(user_query):
embeddings, model, student, employee, course = create_metadata_embeddings()
table_metadata = find_best_fit(embeddings, model, user_query, student, employee, course)
system_prompt, user_prompt = create_prompt(user_query, table_metadata)
sql_query = generate_output(system_prompt, user_prompt)
# Try running the query against the SQLite database
try:
conn = sqlite3.connect("college.db") # Make sure college.db is present in your repo
cursor = conn.cursor()
cursor.execute(sql_query)
result = cursor.fetchall()
conn.close()
return f"SQL Query:\n{sql_query}\n\nQuery Result:\n{result}"
except Exception as e:
return f"SQL Query:\n{sql_query}\n\nQuery Result:\nError: {str(e)}"
# 🔹 Gradio UI
desc = """Ask a natural language question about students, employees, or courses. I'll generate and run a SQL query for you."""
demo = gr.Interface(
fn=response,
inputs=gr.Textbox(label="Your Question"),
outputs=gr.Textbox(label="SQL + Result"),
title="Natural Language to SQL + Result",
description="Ask a natural language question about students, employees, or courses. I'll generate and run a SQL query for you."
)
demo.launch(share=True)