File size: 3,492 Bytes
57a7aa0
5a43bd7
 
 
151824c
e741efb
 
 
5a43bd7
151824c
8505d54
 
 
57a7aa0
8505d54
5a43bd7
 
84b8f8b
e741efb
84b8f8b
 
 
5a43bd7
e741efb
 
 
 
 
84b8f8b
e741efb
 
 
 
 
8505d54
 
 
 
 
e741efb
8505d54
 
 
e741efb
8505d54
 
 
 
 
 
 
 
84b8f8b
e741efb
 
 
8505d54
 
 
 
 
 
e741efb
8505d54
 
 
 
 
e741efb
8505d54
 
e741efb
8505d54
 
 
 
 
 
 
 
e741efb
 
8505d54
84b8f8b
9de8217
e741efb
151824c
e741efb
8505d54
84b8f8b
e741efb
8505d54
5b88edd
84b8f8b
 
 
 
 
5b88edd
 
84b8f8b
 
8505d54
84b8f8b
 
e741efb
84b8f8b
 
5b88edd
 
8505d54
5b88edd
84b8f8b
e741efb
84b8f8b
e741efb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gc
import logging
import os
import re


import spaces

import torch
from cleantext import clean
import gradio as gr
from tqdm.auto import tqdm
from transformers import pipeline
from transformers import AutoModelForSequenceClassification, AutoTokenizer

logging.basicConfig(level=logging.INFO)
logging.info(f"torch version:\t{torch.__version__}")

# Model names
checker_model_name = "textattack/roberta-base-CoLA"
corrector_model_name = "pszemraj/flan-t5-large-grammar-synthesis"


checker = pipeline(
    "text-classification",
    checker_model_name,
    device_map="cuda",
)

corrector = pipeline(
    "text2text-generation",
    corrector_model_name,
    device_map="cuda",
)

def split_text(text: str) -> list:
    # Split the text into sentences using regex
    sentences = re.split(r"(?<=[^A-Z].[.?]) +(?=[A-Z])", text)

    # Initialize lists for batching
    sentence_batches = []
    temp_batch = []

    # Create batches of 2-3 sentences
    for sentence in sentences:
        temp_batch.append(sentence)
        if len(temp_batch) >= 2 and len(temp_batch) <= 3 or sentence == sentences[-1]:
            sentence_batches.append(temp_batch)
            temp_batch = []

    return sentence_batches


@spaces.GPU(duration=60)
def correct_text(text: str, separator: str = " ") -> str:

    # Split the text into sentence batches
    sentence_batches = split_text(text)

    # Initialize a list to store the corrected text
    corrected_text = []

    # Process each batch
    for batch in tqdm(
        sentence_batches, total=len(sentence_batches), desc="correcting text.."
    ):
        raw_text = " ".join(batch)

        # Check grammar quality
        results = checker(raw_text)

        # Correct text if needed
        if results[0]["label"] != "LABEL_1" or (
            results[0]["label"] == "LABEL_1" and results[0]["score"] < 0.9
        ):
            corrected_batch = corrector(raw_text)
            corrected_text.append(corrected_batch[0]["generated_text"])
        else:
            corrected_text.append(raw_text)

    # Join the corrected text
    return separator.join(corrected_text)


def update(text: str):
    # Clean and truncate input text
    text = clean(text[:4000], lower=False)
    return correct_text(text)


# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# <center>Robust Grammar Correction with FLAN-T5</center>")
    gr.Markdown(
        "**Instructions:** Enter the text you want to correct in the textbox below (_text will be truncated to 4000 characters_). Click 'Process' to run."
    )
    gr.Markdown(
        """Models:
    - `textattack/roberta-base-CoLA` for grammar quality detection
    - `pszemraj/flan-t5-large-grammar-synthesis` for grammar correction
    """
    )
    with gr.Row():
        inp = gr.Textbox(
            label="input",
            placeholder="Enter text to check & correct",
            value="I wen to the store yesturday to bye some food. I needd milk, bread, and a few otter things. The store was really crowed and I had a hard time finding everyting I needed. I finaly made it to the check out line and payed for my stuff.",
        )
        out = gr.Textbox(label="output", interactive=False)
    btn = gr.Button("Process")
    btn.click(fn=update, inputs=inp, outputs=out)
    gr.Markdown("---")
    gr.Markdown(
        "- See the [model card](https://huggingface.co/pszemraj/flan-t5-large-grammar-synthesis) for more info"
    )

# Launch the demo
demo.launch(debug=True)