|
|
import os |
|
|
import pandas as pd |
|
|
import sqlite3 |
|
|
import numpy as np |
|
|
import json |
|
|
import re |
|
|
from typing import List, Dict, Tuple |
|
|
from groq import Groq |
|
|
import gradio as gr |
|
|
from sklearn.metrics import accuracy_score |
|
|
import warnings |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
|
|
|
|
|
if not GROQ_API_KEY: |
|
|
print("β οΈ WARNING: GROQ_API_KEY environment variable not set!") |
|
|
print("Please add your Groq API key to your Hugging Face Space secrets.") |
|
|
print("For demo purposes, the app will continue but API calls will fail.") |
|
|
GROQ_API_KEY = "dummy-key-for-demo" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EnhancedNL2SQLConverter: |
|
|
def __init__(self, model_name: str = "llama3-70b-8192"): |
|
|
self.model_name = model_name |
|
|
self.client = None |
|
|
|
|
|
try: |
|
|
|
|
|
if GROQ_API_KEY and GROQ_API_KEY != "dummy-key-for-demo": |
|
|
self.client = Groq(api_key=GROQ_API_KEY) |
|
|
print(f"β
Successfully initialized Groq client with model: {self.model_name}") |
|
|
else: |
|
|
print("β οΈ Groq client not initialized - API key missing") |
|
|
except Exception as e: |
|
|
print(f"β Error initializing Groq client: {str(e)}") |
|
|
self.client = None |
|
|
|
|
|
self.default_schema = """ |
|
|
Table: employees |
|
|
Columns: |
|
|
- id (INTEGER) PRIMARY KEY |
|
|
- name (TEXT) NOT NULL |
|
|
- department (TEXT) |
|
|
- salary (REAL) |
|
|
- hire_date (TEXT) |
|
|
- manager_id (INTEGER) |
|
|
""" |
|
|
|
|
|
def generate_sql(self, query: str, schema: str = None) -> str: |
|
|
try: |
|
|
|
|
|
if not self.client: |
|
|
return "ERROR: Groq API client not initialized. Please check your API key." |
|
|
|
|
|
schema_to_use = schema or self.default_schema |
|
|
|
|
|
system_prompt = """You are an expert SQL query generator. Convert natural language questions to SQL queries based on the provided database schema. |
|
|
|
|
|
Rules: |
|
|
1. Only return the SQL query, nothing else |
|
|
2. Use proper SQL syntax |
|
|
3. Be precise with column names and table names |
|
|
4. Use appropriate WHERE clauses, JOINs, and aggregations as needed |
|
|
5. For date comparisons, use proper date format |
|
|
6. Don't include explanations, just the SQL query""" |
|
|
|
|
|
user_prompt = f"""Database Schema: |
|
|
{schema_to_use} |
|
|
|
|
|
Natural Language Question: {query} |
|
|
|
|
|
Generate the SQL query:""" |
|
|
|
|
|
chat_completion = self.client.chat.completions.create( |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt} |
|
|
], |
|
|
model=self.model_name, |
|
|
temperature=0.1, |
|
|
max_tokens=200 |
|
|
) |
|
|
|
|
|
sql_query = chat_completion.choices[0].message.content.strip() |
|
|
return self._clean_sql(sql_query) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error generating SQL: {str(e)}") |
|
|
return f"ERROR: Could not generate SQL query - {str(e)}" |
|
|
|
|
|
def _clean_sql(self, sql: str) -> str: |
|
|
sql = sql.strip() |
|
|
sql = re.sub(r'```sql\n?', '', sql) |
|
|
sql = re.sub(r'```\n?', '', sql) |
|
|
sql = re.sub(r'^["\']|["\']$', '', sql) |
|
|
sql = sql.rstrip(';') |
|
|
|
|
|
sql_keywords = ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'DROP', 'ALTER'] |
|
|
if not any(sql.upper().startswith(keyword) for keyword in sql_keywords): |
|
|
for keyword in sql_keywords: |
|
|
if keyword in sql.upper(): |
|
|
sql = sql[sql.upper().find(keyword):] |
|
|
break |
|
|
return sql |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SQLEvaluator: |
|
|
def __init__(self): |
|
|
self.db_path = "test_database.db" |
|
|
self.setup_test_database() |
|
|
|
|
|
def setup_test_database(self): |
|
|
conn = sqlite3.connect(self.db_path) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute(''' |
|
|
CREATE TABLE IF NOT EXISTS employees ( |
|
|
id INTEGER PRIMARY KEY, |
|
|
name TEXT NOT NULL, |
|
|
department TEXT, |
|
|
salary REAL, |
|
|
hire_date TEXT, |
|
|
manager_id INTEGER |
|
|
)''') |
|
|
|
|
|
|
|
|
sample_data = [ |
|
|
(1, 'Alice Johnson', 'Engineering', 75000, '2022-01-15', None), |
|
|
(2, 'Bob Smith', 'Sales', 65000, '2021-06-20', None), |
|
|
(3, 'Charlie Brown', 'Engineering', 80000, '2020-03-10', 1), |
|
|
(4, 'Diana Prince', 'HR', 60000, '2023-02-28', None), |
|
|
(5, 'Eve Wilson', 'Sales', 70000, '2022-11-05', 2), |
|
|
(6, 'Frank Miller', 'Engineering', 85000, '2019-08-12', 1), |
|
|
(7, 'Grace Lee', 'Marketing', 55000, '2023-01-20', None), |
|
|
(8, 'Henry Davis', 'Engineering', 72000, '2022-07-30', 1) |
|
|
] |
|
|
|
|
|
cursor.executemany(''' |
|
|
INSERT OR REPLACE INTO employees (id, name, department, salary, hire_date, manager_id) |
|
|
VALUES (?, ?, ?, ?, ?, ?)''', sample_data) |
|
|
|
|
|
conn.commit() |
|
|
conn.close() |
|
|
print("β
Test database initialized successfully") |
|
|
|
|
|
def execute_sql(self, sql_query: str) -> Tuple[bool, any]: |
|
|
try: |
|
|
conn = sqlite3.connect(self.db_path) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(sql_query) |
|
|
|
|
|
if sql_query.strip().upper().startswith('SELECT'): |
|
|
results = cursor.fetchall() |
|
|
columns = [description[0] for description in cursor.description] |
|
|
conn.close() |
|
|
return True, {'columns': columns, 'data': results} |
|
|
else: |
|
|
conn.commit() |
|
|
conn.close() |
|
|
return True, "Query executed successfully" |
|
|
except Exception as e: |
|
|
return False, str(e) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
converter = EnhancedNL2SQLConverter() |
|
|
evaluator = SQLEvaluator() |
|
|
print("β
Application components initialized successfully") |
|
|
except Exception as e: |
|
|
print(f"β Error initializing components: {str(e)}") |
|
|
|
|
|
converter = None |
|
|
evaluator = SQLEvaluator() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_nl_query(nl_query: str) -> Tuple[str, str, str]: |
|
|
"""Process natural language query and return SQL + results""" |
|
|
if not nl_query.strip(): |
|
|
return "", "", "Please enter a natural language query." |
|
|
|
|
|
try: |
|
|
|
|
|
if not converter: |
|
|
return "", "", "β Error: SQL converter not initialized. Please check API configuration." |
|
|
|
|
|
|
|
|
generated_sql = converter.generate_sql(nl_query) |
|
|
|
|
|
if generated_sql.startswith("ERROR"): |
|
|
return generated_sql, "", "β Failed to generate SQL query. Please check your API key." |
|
|
|
|
|
|
|
|
success, result = evaluator.execute_sql(generated_sql) |
|
|
|
|
|
if success and isinstance(result, dict): |
|
|
|
|
|
df = pd.DataFrame(result['data'], columns=result['columns']) |
|
|
if len(df) == 0: |
|
|
formatted_output = "No results found." |
|
|
else: |
|
|
formatted_output = df.to_string(index=False) |
|
|
return generated_sql, formatted_output, "β
Query executed successfully!" |
|
|
elif success: |
|
|
return generated_sql, str(result), "β
Query executed successfully!" |
|
|
else: |
|
|
return generated_sql, "", f"β Error executing query: {result}" |
|
|
|
|
|
except Exception as e: |
|
|
return "", "", f"β Unexpected error: {str(e)}" |
|
|
|
|
|
def get_sample_queries(): |
|
|
"""Return sample queries for users to try""" |
|
|
return [ |
|
|
"Show all employees in the Engineering department", |
|
|
"Find employees with salary greater than 70000", |
|
|
"List all employees hired after 2022", |
|
|
"Count employees by department", |
|
|
"Show the highest paid employee in each department", |
|
|
"Find employees who don't have a manager", |
|
|
"Show average salary by department" |
|
|
] |
|
|
|
|
|
def load_sample_query(query): |
|
|
"""Load a sample query into the input""" |
|
|
return query |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
css = """ |
|
|
.gradio-container { |
|
|
max-width: 1200px !important; |
|
|
} |
|
|
.sample-queries { |
|
|
margin: 10px 0; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(css=css, title="NL2SQL with Groq AI", theme=gr.themes.Soft()) as iface: |
|
|
gr.Markdown(""" |
|
|
# π Natural Language to SQL Converter |
|
|
|
|
|
Convert your natural language questions into SQL queries using **Groq AI** and execute them on a sample employee database! |
|
|
|
|
|
### Sample Database Schema: |
|
|
**employees** table with columns: `id`, `name`, `department`, `salary`, `hire_date`, `manager_id` |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
nl_input = gr.Textbox( |
|
|
label="π¬ Enter Your Question", |
|
|
placeholder="e.g., Show all employees in Engineering department", |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
submit_btn = gr.Button("π Generate & Execute SQL", variant="primary") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### π Try These Sample Queries:") |
|
|
sample_queries = get_sample_queries() |
|
|
|
|
|
for i, query in enumerate(sample_queries): |
|
|
gr.Button( |
|
|
f"{query}", |
|
|
variant="secondary", |
|
|
size="sm" |
|
|
).click( |
|
|
lambda q=query: q, |
|
|
outputs=nl_input |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
sql_output = gr.Textbox( |
|
|
label="π§ Generated SQL Query", |
|
|
lines=3, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
status_output = gr.Textbox( |
|
|
label="π Status", |
|
|
lines=1, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
results_output = gr.Textbox( |
|
|
label="π Query Results", |
|
|
lines=10, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
fn=process_nl_query, |
|
|
inputs=[nl_input], |
|
|
outputs=[sql_output, results_output, status_output] |
|
|
) |
|
|
|
|
|
nl_input.submit( |
|
|
fn=process_nl_query, |
|
|
inputs=[nl_input], |
|
|
outputs=[sql_output, results_output, status_output] |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### π About This App: |
|
|
- **AI Model**: Groq's Llama3-70B for SQL generation |
|
|
- **Database**: SQLite with sample employee data |
|
|
- **Features**: Natural language processing, SQL execution, formatted results |
|
|
|
|
|
### π‘ Tips: |
|
|
- Be specific in your questions |
|
|
- Use clear, simple language |
|
|
- Try the sample queries to get started |
|
|
""") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |