Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
from transformers import T5ForConditionalGeneration, T5Tokenizer | |
import torch | |
import re | |
MODEL_NAME = "michaperki/SQLToText" # Replace with your actual model repository | |
# Load the tokenizer and model from Hugging Face Hub | |
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME) | |
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
model.eval() | |
def postprocess_output(text): | |
""" | |
Postprocesses the generated text to handle SQL aggregate functions like AVG, COUNT, etc., | |
and SQL-like conditions such as 'price > 100'. It ensures proper punctuation and capitalization. | |
""" | |
# Handle SQL aggregate functions (case-insensitive) | |
text = re.sub(r'\bAVG\((.*?)\)', r'average \1', text, flags=re.IGNORECASE) | |
text = re.sub(r'\bCOUNT\(\*\)', 'number of records', text, flags=re.IGNORECASE) | |
text = re.sub(r'\bSUM\((.*?)\)', r'sum of \1', text, flags=re.IGNORECASE) | |
text = re.sub(r'\bMAX\((.*?)\)', r'maximum \1', text, flags=re.IGNORECASE) | |
text = re.sub(r'\bMIN\((.*?)\)', r'minimum \1', text, flags=re.IGNORECASE) | |
# Handle SQL-like conditions (e.g., 'price > 100') | |
text = re.sub(r'(\b\w+\b) > (\d+)', r'\1 greater than \2', text, flags=re.IGNORECASE) | |
text = re.sub(r'(\b\w+\b) < (\d+)', r'\1 less than \2', text, flags=re.IGNORECASE) | |
text = re.sub(r'(\b\w+\b) >= (\d+)', r'\1 greater than or equal to \2', text, flags=re.IGNORECASE) | |
text = re.sub(r'(\b\w+\b) <= (\d+)', r'\1 less than or equal to \2', text, flags=re.IGNORECASE) | |
text = re.sub(r'(\b\w+\b) = (\d+)', r'\1 equal to \2', text, flags=re.IGNORECASE) | |
# Ensure proper punctuation | |
if text and text[-1] not in {'.', '?'}: | |
# If the text ends with a question, keep it; otherwise, add a period | |
if '?' in text: | |
text += '?' | |
else: | |
text += '.' | |
# Capitalize the first letter of the sentence | |
text = text.capitalize() | |
return text | |
def translate_sql_to_text(sql_query): | |
""" | |
Translates a given SQL query into a natural language question. | |
""" | |
try: | |
if not sql_query.strip(): | |
return "Please enter a valid SQL query." | |
# Prepare the input text for the model | |
input_text = f"Convert the following SQL query to a natural language question: {sql_query}" | |
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device) | |
# Generate the translated question | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
input_ids=input_ids, | |
max_length=128, | |
num_beams=5, | |
early_stopping=True | |
) | |
# Decode the generated tokens to a string | |
translated_question = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
# Apply postprocessing | |
translated_question = postprocess_output(translated_question) | |
return translated_question | |
except Exception as e: | |
return f"Error during translation: {str(e)}" | |
# Define the Gradio interface | |
iface = gr.Interface( | |
fn=translate_sql_to_text, | |
inputs=gr.Textbox( | |
lines=5, | |
placeholder="Enter your SQL query here...", | |
label="SQL Query" | |
), | |
outputs=gr.Textbox(label="Translated Question"), | |
title="SQL to Text Translator", | |
description="Enter an SQL query, and the model will translate it into a natural language question.", | |
examples=[ | |
["SELECT name FROM employees WHERE department = 'HR';"], | |
["SELECT COUNT(*) FROM orders WHERE status = 'shipped';"], | |
["SELECT AVG(salary) FROM employees WHERE department = 'Engineering';"], | |
["SELECT product_name, price FROM products WHERE price > 100 ORDER BY price DESC LIMIT 5;"], | |
["SELECT COUNT(*) FROM orders WHERE status = 'shipped';"] | |
], | |
theme="default" | |
) | |
if __name__ == "__main__": | |
iface.launch() | |