|
|
|
|
|
|
|
import gradio as gr |
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
import torch |
|
import re |
|
|
|
MODEL_NAME = "michaperki/SQLToText" |
|
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME) |
|
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model.to(device) |
|
model.eval() |
|
|
|
def postprocess_output(text): |
|
""" |
|
Postprocesses the generated text to handle SQL aggregate functions like AVG, COUNT, etc., |
|
and SQL-like conditions such as 'price > 100'. It ensures proper punctuation and capitalization. |
|
""" |
|
|
|
text = re.sub(r'\bAVG\((.*?)\)', r'average \1', text, flags=re.IGNORECASE) |
|
text = re.sub(r'\bCOUNT\(\*\)', 'number of records', text, flags=re.IGNORECASE) |
|
text = re.sub(r'\bSUM\((.*?)\)', r'sum of \1', text, flags=re.IGNORECASE) |
|
text = re.sub(r'\bMAX\((.*?)\)', r'maximum \1', text, flags=re.IGNORECASE) |
|
text = re.sub(r'\bMIN\((.*?)\)', r'minimum \1', text, flags=re.IGNORECASE) |
|
|
|
|
|
text = re.sub(r'(\b\w+\b) > (\d+)', r'\1 greater than \2', text, flags=re.IGNORECASE) |
|
text = re.sub(r'(\b\w+\b) < (\d+)', r'\1 less than \2', text, flags=re.IGNORECASE) |
|
text = re.sub(r'(\b\w+\b) >= (\d+)', r'\1 greater than or equal to \2', text, flags=re.IGNORECASE) |
|
text = re.sub(r'(\b\w+\b) <= (\d+)', r'\1 less than or equal to \2', text, flags=re.IGNORECASE) |
|
text = re.sub(r'(\b\w+\b) = (\d+)', r'\1 equal to \2', text, flags=re.IGNORECASE) |
|
|
|
|
|
if text and text[-1] not in {'.', '?'}: |
|
|
|
if '?' in text: |
|
text += '?' |
|
else: |
|
text += '.' |
|
|
|
|
|
text = text.capitalize() |
|
|
|
return text |
|
|
|
def translate_sql_to_text(sql_query): |
|
""" |
|
Translates a given SQL query into a natural language question. |
|
""" |
|
try: |
|
if not sql_query.strip(): |
|
return "Please enter a valid SQL query." |
|
|
|
|
|
input_text = f"Convert the following SQL query to a natural language question: {sql_query}" |
|
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
generated_ids = model.generate( |
|
input_ids=input_ids, |
|
max_length=128, |
|
num_beams=5, |
|
early_stopping=True |
|
) |
|
|
|
|
|
translated_question = tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
|
|
|
translated_question = postprocess_output(translated_question) |
|
|
|
return translated_question |
|
except Exception as e: |
|
return f"Error during translation: {str(e)}" |
|
|
|
|
|
iface = gr.Interface( |
|
fn=translate_sql_to_text, |
|
inputs=gr.Textbox( |
|
lines=5, |
|
placeholder="Enter your SQL query here...", |
|
label="SQL Query" |
|
), |
|
outputs=gr.Textbox(label="Translated Question"), |
|
title="SQL to Text Translator", |
|
description="Enter an SQL query, and the model will translate it into a natural language question.", |
|
examples=[ |
|
["SELECT name FROM employees WHERE department = 'HR';"], |
|
["SELECT COUNT(*) FROM orders WHERE status = 'shipped';"], |
|
["SELECT AVG(salary) FROM employees WHERE department = 'Engineering';"], |
|
["SELECT product_name, price FROM products WHERE price > 100 ORDER BY price DESC LIMIT 5;"], |
|
["SELECT COUNT(*) FROM orders WHERE status = 'shipped';"] |
|
], |
|
theme="default" |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|