Michael Perkins
tune postprocessing_
8bb74c2
# app.py
import gradio as gr
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
import re
MODEL_NAME = "michaperki/SQLToText" # Replace with your actual model repository
# Load the tokenizer and model from Hugging Face Hub
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
def postprocess_output(text):
"""
Postprocesses the generated text to handle SQL aggregate functions like AVG, COUNT, etc.,
and SQL-like conditions such as 'price > 100'. It ensures proper punctuation and capitalization.
"""
# Handle SQL aggregate functions (case-insensitive)
text = re.sub(r'\bAVG\((.*?)\)', r'average \1', text, flags=re.IGNORECASE)
text = re.sub(r'\bCOUNT\(\*\)', 'number of records', text, flags=re.IGNORECASE)
text = re.sub(r'\bSUM\((.*?)\)', r'sum of \1', text, flags=re.IGNORECASE)
text = re.sub(r'\bMAX\((.*?)\)', r'maximum \1', text, flags=re.IGNORECASE)
text = re.sub(r'\bMIN\((.*?)\)', r'minimum \1', text, flags=re.IGNORECASE)
# Handle SQL-like conditions (e.g., 'price > 100')
text = re.sub(r'(\b\w+\b) > (\d+)', r'\1 greater than \2', text, flags=re.IGNORECASE)
text = re.sub(r'(\b\w+\b) < (\d+)', r'\1 less than \2', text, flags=re.IGNORECASE)
text = re.sub(r'(\b\w+\b) >= (\d+)', r'\1 greater than or equal to \2', text, flags=re.IGNORECASE)
text = re.sub(r'(\b\w+\b) <= (\d+)', r'\1 less than or equal to \2', text, flags=re.IGNORECASE)
text = re.sub(r'(\b\w+\b) = (\d+)', r'\1 equal to \2', text, flags=re.IGNORECASE)
# Ensure proper punctuation
if text and text[-1] not in {'.', '?'}:
# If the text ends with a question, keep it; otherwise, add a period
if '?' in text:
text += '?'
else:
text += '.'
# Capitalize the first letter of the sentence
text = text.capitalize()
return text
def translate_sql_to_text(sql_query):
"""
Translates a given SQL query into a natural language question.
"""
try:
if not sql_query.strip():
return "Please enter a valid SQL query."
# Prepare the input text for the model
input_text = f"Convert the following SQL query to a natural language question: {sql_query}"
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
# Generate the translated question
with torch.no_grad():
generated_ids = model.generate(
input_ids=input_ids,
max_length=128,
num_beams=5,
early_stopping=True
)
# Decode the generated tokens to a string
translated_question = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# Apply postprocessing
translated_question = postprocess_output(translated_question)
return translated_question
except Exception as e:
return f"Error during translation: {str(e)}"
# Define the Gradio interface
iface = gr.Interface(
fn=translate_sql_to_text,
inputs=gr.Textbox(
lines=5,
placeholder="Enter your SQL query here...",
label="SQL Query"
),
outputs=gr.Textbox(label="Translated Question"),
title="SQL to Text Translator",
description="Enter an SQL query, and the model will translate it into a natural language question.",
examples=[
["SELECT name FROM employees WHERE department = 'HR';"],
["SELECT COUNT(*) FROM orders WHERE status = 'shipped';"],
["SELECT AVG(salary) FROM employees WHERE department = 'Engineering';"],
["SELECT product_name, price FROM products WHERE price > 100 ORDER BY price DESC LIMIT 5;"],
["SELECT COUNT(*) FROM orders WHERE status = 'shipped';"]
],
theme="default"
)
if __name__ == "__main__":
iface.launch()