Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import re | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| MODEL = "jinesh90/qwen2.5-coder-sql-generator" | |
| print("Loading model...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL, | |
| torch_dtype = torch.float16, | |
| device_map = "auto", | |
| low_cpu_mem_usage = True, | |
| ) | |
| model.eval() | |
| print("Ready!") | |
| def clean_sql(text): | |
| text = text.strip() | |
| clean = re.sub(r'[^\x00-\x7F].*', '', text).strip() | |
| for stop in ["###", "assistant", "\n\n"]: | |
| if stop in clean: | |
| clean = clean.split(stop)[0].strip() | |
| return clean | |
| def build_prompt(question, schema): | |
| return f"""You are a SQL expert. Generate the simplest and most direct SQL query. | |
| Use JOINs only when multiple tables are needed. | |
| ### Schema: | |
| {schema} | |
| ### Question: | |
| {question} | |
| ### SQL:""" | |
| def generate(question, schema): | |
| if not question or not schema: | |
| return "Please provide both a question and schema!" | |
| messages = [{"role": "user", "content": build_prompt(question, schema)}] | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize = False, | |
| add_generation_prompt = True | |
| ) | |
| inputs = tokenizer( | |
| text, | |
| return_tensors = "pt", | |
| truncation = True, | |
| max_length = 1024 | |
| ).to(model.device) | |
| stop_tokens = [ | |
| tokenizer.eos_token_id, | |
| tokenizer.convert_tokens_to_ids("<|im_end|>"), | |
| ] | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens = 200, | |
| do_sample = False, | |
| temperature = 0, | |
| repetition_penalty = 1.3, | |
| eos_token_id = stop_tokens, | |
| pad_token_id = tokenizer.eos_token_id, | |
| ) | |
| input_len = inputs["input_ids"].shape[1] | |
| raw = tokenizer.decode(outputs[0, input_len:], skip_special_tokens=True) | |
| return clean_sql(raw) | |
| # Example schemas for demo | |
| example_schema = """CREATE TABLE employees ( | |
| id INTEGER, | |
| name VARCHAR, | |
| salary REAL, | |
| department VARCHAR, | |
| age INTEGER | |
| );""" | |
| with gr.Blocks(title="SQL Query Generator") as demo: | |
| gr.Markdown("# ποΈ SQL Query Generator") | |
| gr.Markdown("Fine-tuned Qwen2.5-Coder 7B on Spider dataset | 42% execution accuracy") | |
| with gr.Row(): | |
| with gr.Column(): | |
| schema = gr.Textbox( | |
| label = "Database Schema (CREATE TABLE statements)", | |
| value = example_schema, | |
| lines = 10 | |
| ) | |
| question = gr.Textbox( | |
| label = "Question", | |
| placeholder = "How many employees have salary > 50000?", | |
| lines = 2 | |
| ) | |
| btn = gr.Button("π Generate SQL", variant="primary") | |
| with gr.Column(): | |
| output = gr.Code( | |
| label = "Generated SQL", | |
| language = "sql" | |
| ) | |
| gr.Markdown(""" | |
| ### π Model Stats | |
| - **Base model**: Qwen2.5-Coder-7B | |
| - **Training data**: Spider dataset (7.9k samples) | |
| - **Simple queries**: 64.2% accuracy | |
| - **Complex queries**: 17.0% accuracy | |
| - **Overall**: 42% execution accuracy | |
| """) | |
| btn.click(fn=generate, inputs=[question, schema], outputs=output) | |
| gr.Examples( | |
| examples=[ | |
| ["How many employees are there?", example_schema], | |
| ["Find all employees with salary greater than 50000", example_schema], | |
| ["What is the average salary by department?", example_schema], | |
| ], | |
| inputs=[question, schema] | |
| ) | |
| demo.launch() | |