import gc import logging import os import re import spaces import torch from cleantext import clean import gradio as gr from tqdm.auto import tqdm from transformers import pipeline from transformers import AutoModelForSequenceClassification, AutoTokenizer logging.basicConfig(level=logging.INFO) logging.info(f"torch version:\t{torch.__version__}") # Model names checker_model_name = "textattack/roberta-base-CoLA" corrector_model_name = "pszemraj/flan-t5-large-grammar-synthesis" checker = pipeline( "text-classification", checker_model_name, device_map="cuda", ) corrector = pipeline( "text2text-generation", corrector_model_name, device_map="cuda", ) def split_text(text: str) -> list: # Split the text into sentences using regex sentences = re.split(r"(?<=[^A-Z].[.?]) +(?=[A-Z])", text) # Initialize lists for batching sentence_batches = [] temp_batch = [] # Create batches of 2-3 sentences for sentence in sentences: temp_batch.append(sentence) if len(temp_batch) >= 2 and len(temp_batch) <= 3 or sentence == sentences[-1]: sentence_batches.append(temp_batch) temp_batch = [] return sentence_batches @spaces.GPU(duration=60) def correct_text(text: str, separator: str = " ") -> str: # Split the text into sentence batches sentence_batches = split_text(text) # Initialize a list to store the corrected text corrected_text = [] # Process each batch for batch in tqdm( sentence_batches, total=len(sentence_batches), desc="correcting text.." ): raw_text = " ".join(batch) # Check grammar quality results = checker(raw_text) # Correct text if needed if results[0]["label"] != "LABEL_1" or ( results[0]["label"] == "LABEL_1" and results[0]["score"] < 0.9 ): corrected_batch = corrector(raw_text) corrected_text.append(corrected_batch[0]["generated_text"]) else: corrected_text.append(raw_text) # Join the corrected text return separator.join(corrected_text) def update(text: str): # Clean and truncate input text text = clean(text[:4000], lower=False) return correct_text(text) # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown("#
Robust Grammar Correction with FLAN-T5
") gr.Markdown( "**Instructions:** Enter the text you want to correct in the textbox below (_text will be truncated to 4000 characters_). Click 'Process' to run." ) gr.Markdown( """Models: - `textattack/roberta-base-CoLA` for grammar quality detection - `pszemraj/flan-t5-large-grammar-synthesis` for grammar correction """ ) with gr.Row(): inp = gr.Textbox( label="input", placeholder="Enter text to check & correct", value="I wen to the store yesturday to bye some food. I needd milk, bread, and a few otter things. The store was really crowed and I had a hard time finding everyting I needed. I finaly made it to the check out line and payed for my stuff.", ) out = gr.Textbox(label="output", interactive=False) btn = gr.Button("Process") btn.click(fn=update, inputs=inp, outputs=out) gr.Markdown("---") gr.Markdown( "- See the [model card](https://huggingface.co/pszemraj/flan-t5-large-grammar-synthesis) for more info" ) # Launch the demo demo.launch(debug=True)