|
import torch |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
import gradio as gr |
|
|
|
tokenizer = T5Tokenizer.from_pretrained('t5-small') |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = T5ForConditionalGeneration.from_pretrained('cssupport/t5-small-awesome-text-to-sql') |
|
model = model.to(device) |
|
model.eval() |
|
|
|
def generate_sql(input_prompt): |
|
inputs = tokenizer(input_prompt, padding=True, truncation=True, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = model.generate(**inputs, max_length=512) |
|
generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return generated_sql |
|
|
|
def gradio_interface(tables, query): |
|
input_prompt = f"tables:\n{tables}\nquery for:{query}" |
|
return generate_sql(input_prompt) |
|
|
|
iface = gr.Interface( |
|
fn=gradio_interface, |
|
inputs=[ |
|
gr.Textbox(lines=5, label="Context Tables", placeholder="Enter table definitions here..."), |
|
gr.Textbox(lines=2, label="Query Description", placeholder="Enter your SQL query here...") |
|
], |
|
outputs=gr.Textbox(label="Generated SQL Query", placeholder=""), |
|
title="Text to SQL Generator", |
|
examples=[ |
|
["CREATE TABLE student_course_attendance (student_id VARCHAR); CREATE TABLE students (student_id VARCHAR);", "List the id of students who never attends courses?"] |
|
] |
|
) |
|
|
|
iface.launch() |
|
|