pszemraj commited on
Commit
8505d54
1 Parent(s): e329fbe

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import gradio as gr
3
+ from tqdm.auto import tqdm
4
+ from transformers import pipeline
5
+
6
+ # pipelines
7
+ checker = pipeline("text-classification", "textattack/roberta-base-CoLA")
8
+ corrector = pipeline("text2text-generation", "pszemraj/flan-t5-large-grammar-synthesis")
9
+
10
+ def split_text(text: str) -> list:
11
+ # Split the text into sentences using regex
12
+ sentences = re.split(r"(?<=[^A-Z].[.?]) +(?=[A-Z])", text)
13
+
14
+ # Initialize a list to store the sentence batches
15
+ sentence_batches = []
16
+
17
+ # Initialize a temporary list to store the current batch of sentences
18
+ temp_batch = []
19
+
20
+ # Iterate through the sentences
21
+ for sentence in sentences:
22
+ # Add the sentence to the temporary batch
23
+ temp_batch.append(sentence)
24
+
25
+ # If the length of the temporary batch is between 2 and 3 sentences, or if it is the last batch, add it to the list of sentence batches
26
+ if len(temp_batch) >= 2 and len(temp_batch) <= 3 or sentence == sentences[-1]:
27
+ sentence_batches.append(temp_batch)
28
+ temp_batch = []
29
+
30
+ return sentence_batches
31
+
32
+ def correct_text(text: str, checker, corrector, separator: str = " ") -> str:
33
+ # Split the text into sentence batches
34
+ sentence_batches = split_text(text)
35
+
36
+ # Initialize a list to store the corrected text
37
+ corrected_text = []
38
+
39
+ # Iterate through the sentence batches
40
+ for batch in tqdm(
41
+ sentence_batches, total=len(sentence_batches), desc="correcting text.."
42
+ ):
43
+ # Join the sentences in the batch into a single string
44
+ raw_text = " ".join(batch)
45
+
46
+ # Check the grammar quality of the text using the text-classification pipeline
47
+ results = checker(raw_text)
48
+
49
+ # Only correct the text if the results of the text-classification are not LABEL_1 or are LABEL_1 with a score below 0.9
50
+ if results[0]["label"] != "LABEL_1" or (
51
+ results[0]["label"] == "LABEL_1" and results[0]["score"] < 0.9
52
+ ):
53
+ # Correct the text using the text-generation pipeline
54
+ corrected_batch = corrector(raw_text)
55
+ corrected_text.append(corrected_batch[0]["generated_text"])
56
+ else:
57
+ corrected_text.append(raw_text)
58
+
59
+ # Join the corrected text into a single string
60
+ corrected_text = separator.join(corrected_text)
61
+
62
+ return corrected_text
63
+
64
+ def update(text: str, checker, corrector):
65
+ text = text[:4000]
66
+ return correct_text(text, checker, corrector)
67
+
68
+ with gr.Blocks() as demo:
69
+ gr.Markdown("<center># Robust Grammar Correction with FLAN-T5</center>")
70
+ gr.Markdown("Enter the text you want to correct in the textbox below. The text will be truncated to 4000 characters. Then click Run to see the corrected text.")
71
+ with gr.Row():
72
+ inp = gr.Textbox(placeholder="I wen to the store yesturday to bye some food. I needd milk, bread, and a few otter things. The store was really crowed and I had a hard time finding everyting I needed. I finaly made it to the check out line and payed for my stuff.")
73
+ out = gr.Textbox()
74
+ btn = gr.Button("Process Text")
75
+ btn.click(fn=update, inputs=inp, outputs=out)
76
+
77
+ demo.launch()