jpatel
adding sql generator app
22654ec
import gradio as gr
import torch
import re
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL = "jinesh90/qwen2.5-coder-sql-generator"
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(
MODEL,
torch_dtype = torch.float16,
device_map = "auto",
low_cpu_mem_usage = True,
)
model.eval()
print("Ready!")
def clean_sql(text):
text = text.strip()
clean = re.sub(r'[^\x00-\x7F].*', '', text).strip()
for stop in ["###", "assistant", "\n\n"]:
if stop in clean:
clean = clean.split(stop)[0].strip()
return clean
def build_prompt(question, schema):
return f"""You are a SQL expert. Generate the simplest and most direct SQL query.
Use JOINs only when multiple tables are needed.
### Schema:
{schema}
### Question:
{question}
### SQL:"""
def generate(question, schema):
if not question or not schema:
return "Please provide both a question and schema!"
messages = [{"role": "user", "content": build_prompt(question, schema)}]
text = tokenizer.apply_chat_template(
messages,
tokenize = False,
add_generation_prompt = True
)
inputs = tokenizer(
text,
return_tensors = "pt",
truncation = True,
max_length = 1024
).to(model.device)
stop_tokens = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|im_end|>"),
]
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens = 200,
do_sample = False,
temperature = 0,
repetition_penalty = 1.3,
eos_token_id = stop_tokens,
pad_token_id = tokenizer.eos_token_id,
)
input_len = inputs["input_ids"].shape[1]
raw = tokenizer.decode(outputs[0, input_len:], skip_special_tokens=True)
return clean_sql(raw)
# Example schemas for demo
example_schema = """CREATE TABLE employees (
id INTEGER,
name VARCHAR,
salary REAL,
department VARCHAR,
age INTEGER
);"""
with gr.Blocks(title="SQL Query Generator") as demo:
gr.Markdown("# πŸ—„οΈ SQL Query Generator")
gr.Markdown("Fine-tuned Qwen2.5-Coder 7B on Spider dataset | 42% execution accuracy")
with gr.Row():
with gr.Column():
schema = gr.Textbox(
label = "Database Schema (CREATE TABLE statements)",
value = example_schema,
lines = 10
)
question = gr.Textbox(
label = "Question",
placeholder = "How many employees have salary > 50000?",
lines = 2
)
btn = gr.Button("πŸš€ Generate SQL", variant="primary")
with gr.Column():
output = gr.Code(
label = "Generated SQL",
language = "sql"
)
gr.Markdown("""
### πŸ“Š Model Stats
- **Base model**: Qwen2.5-Coder-7B
- **Training data**: Spider dataset (7.9k samples)
- **Simple queries**: 64.2% accuracy
- **Complex queries**: 17.0% accuracy
- **Overall**: 42% execution accuracy
""")
btn.click(fn=generate, inputs=[question, schema], outputs=output)
gr.Examples(
examples=[
["How many employees are there?", example_schema],
["Find all employees with salary greater than 50000", example_schema],
["What is the average salary by department?", example_schema],
],
inputs=[question, schema]
)
demo.launch()