Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,492 Bytes
57a7aa0 5a43bd7 151824c e741efb 5a43bd7 151824c 8505d54 57a7aa0 8505d54 5a43bd7 84b8f8b e741efb 84b8f8b 5a43bd7 e741efb 84b8f8b e741efb 8505d54 e741efb 8505d54 e741efb 8505d54 84b8f8b e741efb 8505d54 e741efb 8505d54 e741efb 8505d54 e741efb 8505d54 e741efb 8505d54 84b8f8b 9de8217 e741efb 151824c e741efb 8505d54 84b8f8b e741efb 8505d54 5b88edd 84b8f8b 5b88edd 84b8f8b 8505d54 84b8f8b e741efb 84b8f8b 5b88edd 8505d54 5b88edd 84b8f8b e741efb 84b8f8b e741efb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import gc
import logging
import os
import re
import spaces
import torch
from cleantext import clean
import gradio as gr
from tqdm.auto import tqdm
from transformers import pipeline
from transformers import AutoModelForSequenceClassification, AutoTokenizer
logging.basicConfig(level=logging.INFO)
logging.info(f"torch version:\t{torch.__version__}")
# Model names
checker_model_name = "textattack/roberta-base-CoLA"
corrector_model_name = "pszemraj/flan-t5-large-grammar-synthesis"
checker = pipeline(
"text-classification",
checker_model_name,
device_map="cuda",
)
corrector = pipeline(
"text2text-generation",
corrector_model_name,
device_map="cuda",
)
def split_text(text: str) -> list:
# Split the text into sentences using regex
sentences = re.split(r"(?<=[^A-Z].[.?]) +(?=[A-Z])", text)
# Initialize lists for batching
sentence_batches = []
temp_batch = []
# Create batches of 2-3 sentences
for sentence in sentences:
temp_batch.append(sentence)
if len(temp_batch) >= 2 and len(temp_batch) <= 3 or sentence == sentences[-1]:
sentence_batches.append(temp_batch)
temp_batch = []
return sentence_batches
@spaces.GPU(duration=60)
def correct_text(text: str, separator: str = " ") -> str:
# Split the text into sentence batches
sentence_batches = split_text(text)
# Initialize a list to store the corrected text
corrected_text = []
# Process each batch
for batch in tqdm(
sentence_batches, total=len(sentence_batches), desc="correcting text.."
):
raw_text = " ".join(batch)
# Check grammar quality
results = checker(raw_text)
# Correct text if needed
if results[0]["label"] != "LABEL_1" or (
results[0]["label"] == "LABEL_1" and results[0]["score"] < 0.9
):
corrected_batch = corrector(raw_text)
corrected_text.append(corrected_batch[0]["generated_text"])
else:
corrected_text.append(raw_text)
# Join the corrected text
return separator.join(corrected_text)
def update(text: str):
# Clean and truncate input text
text = clean(text[:4000], lower=False)
return correct_text(text)
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# <center>Robust Grammar Correction with FLAN-T5</center>")
gr.Markdown(
"**Instructions:** Enter the text you want to correct in the textbox below (_text will be truncated to 4000 characters_). Click 'Process' to run."
)
gr.Markdown(
"""Models:
- `textattack/roberta-base-CoLA` for grammar quality detection
- `pszemraj/flan-t5-large-grammar-synthesis` for grammar correction
"""
)
with gr.Row():
inp = gr.Textbox(
label="input",
placeholder="Enter text to check & correct",
value="I wen to the store yesturday to bye some food. I needd milk, bread, and a few otter things. The store was really crowed and I had a hard time finding everyting I needed. I finaly made it to the check out line and payed for my stuff.",
)
out = gr.Textbox(label="output", interactive=False)
btn = gr.Button("Process")
btn.click(fn=update, inputs=inp, outputs=out)
gr.Markdown("---")
gr.Markdown(
"- See the [model card](https://huggingface.co/pszemraj/flan-t5-large-grammar-synthesis) for more info"
)
# Launch the demo
demo.launch(debug=True) |