| | import gc |
| | import logging |
| | import os |
| | import re |
| | import torch |
| | from cleantext import clean |
| | import gradio as gr |
| | from tqdm.auto import tqdm |
| | from transformers import pipeline |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logging.info(f"torch version:\t{torch.__version__}") |
| |
|
| | |
| | checker_model_name = "textattack/roberta-base-CoLA" |
| | corrector_model_name = "pszemraj/flan-t5-large-grammar-synthesis" |
| |
|
| | |
| | device = 0 if torch.cuda.is_available() else -1 |
| | logging.info(f"Using device: {'cuda' if device == 0 else 'cpu'}") |
| |
|
| | |
| | checker = pipeline( |
| | "text-classification", |
| | model=checker_model_name, |
| | device=device, |
| | ) |
| | corrector = pipeline( |
| | "text2text-generation", |
| | model=corrector_model_name, |
| | device=device, |
| | ) |
| |
|
| | |
| | def split_text(text: str) -> list: |
| | sentences = re.split(r"(?<=[^A-Z].[.?]) +(?=[A-Z])", text) |
| | sentence_batches = [] |
| | temp_batch = [] |
| | for sentence in sentences: |
| | temp_batch.append(sentence) |
| | if (len(temp_batch) >= 2 and len(temp_batch) <= 3) or sentence == sentences[-1]: |
| | sentence_batches.append(temp_batch) |
| | temp_batch = [] |
| | return sentence_batches |
| |
|
| | def correct_text(text: str, separator: str = " ") -> str: |
| | sentence_batches = split_text(text) |
| | corrected_text = [] |
| | for batch in tqdm(sentence_batches, desc="correcting text.."): |
| | raw_text = " ".join(batch) |
| | results = checker(raw_text) |
| | |
| | |
| | if results[0]["label"] != "LABEL_1" or ( |
| | results[0]["label"] == "LABEL_1" and results[0]["score"] < 0.9 |
| | ): |
| | corrected_batch = corrector(raw_text) |
| | corrected_text.append(corrected_batch[0]["generated_text"]) |
| | else: |
| | corrected_text.append(raw_text) |
| | return separator.join(corrected_text) |
| |
|
| | def update(text: str): |
| | text = clean(text[:4000], lower=False) |
| | return correct_text(text) |
| |
|
| | |
| | with gr.Blocks() as demo: |
| | gr.Markdown("# <center>Robust Grammar Correction</center>") |
| | with gr.Row(): |
| | inp = gr.Textbox(label="Input", placeholder="Enter text here...") |
| | out = gr.Textbox(label="Output", interactive=False) |
| | btn = gr.Button("Process") |
| | btn.click(fn=update, inputs=inp, outputs=out) |
| |
|
| | demo.launch() |