# app.py import gradio as gr from transformers import T5ForConditionalGeneration, T5Tokenizer import torch import re MODEL_NAME = "michaperki/SQLToText" # Replace with your actual model repository # Load the tokenizer and model from Hugging Face Hub tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME) model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() def postprocess_output(text): """ Postprocesses the generated text to handle SQL aggregate functions like AVG, COUNT, etc., and SQL-like conditions such as 'price > 100'. It ensures proper punctuation and capitalization. """ # Handle SQL aggregate functions (case-insensitive) text = re.sub(r'\bAVG\((.*?)\)', r'average \1', text, flags=re.IGNORECASE) text = re.sub(r'\bCOUNT\(\*\)', 'number of records', text, flags=re.IGNORECASE) text = re.sub(r'\bSUM\((.*?)\)', r'sum of \1', text, flags=re.IGNORECASE) text = re.sub(r'\bMAX\((.*?)\)', r'maximum \1', text, flags=re.IGNORECASE) text = re.sub(r'\bMIN\((.*?)\)', r'minimum \1', text, flags=re.IGNORECASE) # Handle SQL-like conditions (e.g., 'price > 100') text = re.sub(r'(\b\w+\b) > (\d+)', r'\1 greater than \2', text, flags=re.IGNORECASE) text = re.sub(r'(\b\w+\b) < (\d+)', r'\1 less than \2', text, flags=re.IGNORECASE) text = re.sub(r'(\b\w+\b) >= (\d+)', r'\1 greater than or equal to \2', text, flags=re.IGNORECASE) text = re.sub(r'(\b\w+\b) <= (\d+)', r'\1 less than or equal to \2', text, flags=re.IGNORECASE) text = re.sub(r'(\b\w+\b) = (\d+)', r'\1 equal to \2', text, flags=re.IGNORECASE) # Ensure proper punctuation if text and text[-1] not in {'.', '?'}: # If the text ends with a question, keep it; otherwise, add a period if '?' in text: text += '?' else: text += '.' # Capitalize the first letter of the sentence text = text.capitalize() return text def translate_sql_to_text(sql_query): """ Translates a given SQL query into a natural language question. """ try: if not sql_query.strip(): return "Please enter a valid SQL query." # Prepare the input text for the model input_text = f"Convert the following SQL query to a natural language question: {sql_query}" input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device) # Generate the translated question with torch.no_grad(): generated_ids = model.generate( input_ids=input_ids, max_length=128, num_beams=5, early_stopping=True ) # Decode the generated tokens to a string translated_question = tokenizer.decode(generated_ids[0], skip_special_tokens=True) # Apply postprocessing translated_question = postprocess_output(translated_question) return translated_question except Exception as e: return f"Error during translation: {str(e)}" # Define the Gradio interface iface = gr.Interface( fn=translate_sql_to_text, inputs=gr.Textbox( lines=5, placeholder="Enter your SQL query here...", label="SQL Query" ), outputs=gr.Textbox(label="Translated Question"), title="SQL to Text Translator", description="Enter an SQL query, and the model will translate it into a natural language question.", examples=[ ["SELECT name FROM employees WHERE department = 'HR';"], ["SELECT COUNT(*) FROM orders WHERE status = 'shipped';"], ["SELECT AVG(salary) FROM employees WHERE department = 'Engineering';"], ["SELECT product_name, price FROM products WHERE price > 100 ORDER BY price DESC LIMIT 5;"], ["SELECT COUNT(*) FROM orders WHERE status = 'shipped';"] ], theme="default" ) if __name__ == "__main__": iface.launch()