File size: 15,919 Bytes
74b8229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d868fb
 
74b8229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b6c220
0d868fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74b8229
6b6c220
74b8229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
"""
grammar_improve.py - this .py script contains functions to improve the grammar of a user's input or the models output.

"""

from datetime import datetime
import os
import pprint as pp
from neuspell import BertChecker, SclstmChecker
import neuspell
import math
from cleantext import clean
import time
import re
import sys
from symspellpy.symspellpy import SymSpell
import transformers
from transformers import pipeline
from utils import suppress_stdout


def detect_propers(text: str):
    """
    detect_propers - detect if a string contains proper nouns

    Args:
        text (str): [string to be checked]

    Returns:
        [bool]: [True if string contains proper nouns]
    """
    pat = re.compile(r"(?:\w+['’])?\w+(?:-(?:\w+['’])?\w+)*")
    return bool(pat.search(text))


def fix_punct_spaces(string):
    """
    fix_punct_spaces - replace spaces around punctuation with punctuation. For example, "hello , there" -> "hello, there"

    Parameters
    ----------
    string : str, required, input string to be corrected

    Returns
    -------
    str, corrected string
    """

    fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,]+)*)\s*")
    string = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), string)
    return string.strip()


def split_sentences(text: str):
    """
    split_sentences - split a string into a list of sentences that keep their ending punctuation. powered by regex witchcraft

    Args:
        text (str): [string to be split]

    Returns:
        [list]: [list of strings]
    """
    return re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", text)


def remove_repeated_words(bot_response):
    """
    remove_repeated_words - remove repeated words from a string, returning only the first instance of each word

    Parameters
    ----------
    bot_response : str
        string to remove repeated words from

    Returns
    -------
    str
        string containing the first instance of each word
    """
    words = bot_response.split()
    unique_words = []
    for word in words:
        if word not in unique_words:
            unique_words.append(word)
    return " ".join(unique_words)


def remove_trailing_punctuation(text: str, fuLL_strip=False):
    """
    remove_trailing_punctuation - remove trailing punctuation from a string. Purpose is to seem more natural to end users

    Args:
        text (str): [string to be cleaned]

    Returns:
        [str]: [cleaned string]
    """
    if fuLL_strip:
        return text.strip("?!.,;:")
    else:
        return text.strip(".,;:")


def fix_punct_spacing(text: str):
    fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,]+)*)\s*")
    spc_text = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), text)
    cln_text = re.sub(r"(\W)(?=\1)", "", spc_text)

    return cln_text


def synthesize_grammar(
    corrector: transformers.pipeline,
    message: str,
    num_beams=4,
    length_penalty=0.9,
    repetition_penalty=1.5,
    no_repeat_ngram_size=4,
    verbose=False,
):
    """
    synthesize_grammar - use a SyntaxSynthesizer model to generate a string from a message

    Parameters
    ----------
    corrector : transformers.pipeline, required, which is the SyntaxSynthesizer model already loaded
    message : str, required, which is the message to be corrected
    num_beams : int, optional, by default 4, which is the number of beams to use for the model
    length_penalty : float, optional, by default 0.9, which is the length penalty to use for the model
    repetition_penalty : float, optional, by default 1.5, which is the repetition penalty to use for the model
    no_repeat_ngram_size : int, optional, by default 4, which is the n-gram size to use for the model
    verbose : bool, optional, by default False, which is whether to print the runtime of the model

    Returns
    -------
    """
    st = time.perf_counter()
    input_text = clean(message, lower=False)
    results = corrector(
        input_text,
        max_length=int(1.1 * len(input_text)),
        min_length=2 if len(input_text) < 64 else int(0.2 * len(input_text)),
        num_beams=num_beams,
        repetition_penalty=repetition_penalty,
        length_penalty=length_penalty,
        no_repeat_ngram_size=no_repeat_ngram_size,
        early_stopping=True,
        do_sample=False,
        clean_up_tokenization_spaces=True,
    )
    corrected_text = results[0]["generated_text"]
    if verbose:
        rt = round(time.perf_counter() - st, 2)
        print(f"synthesizing took {rt} seconds")
    return corrected_text.strip()


"""
start of SymSpell code
"""


def symspeller(
    my_string: str,
    sym_checker=None,
    max_dist: int = 2,
    prefix_length: int = 7,
    ignore_non_words=True,
    dictionary_path: str = None,
    bigram_path: str = None,
    verbose=False,
):
    """
    symspeller - a wrapper for the SymSpell class from symspellpy

    Parameters
    ----------
        my_string : str, required, default=None, the string to be checked
        sym_checker : SymSpell, optional, default=None, the SymSpell object to use
        max_dist : int, optional, default=3, the maximum distance to look for replacements
        prefix_length : int, optional, default=7, the length of the prefixes to use
        ignore_non_words : bool, optional, default=True, whether to ignore non-words
        dictionary_path : str, optional, default=None, the path to the dictionary file
        bigram_path : str, optional, default=None, the path to the bigram dictionary file
        verbose : bool, optional, default=False, whether to print the results

    Returns
    -------
        list,

    """

    assert len(my_string) > 0, "entered string for correction is empty"

    if sym_checker is None:
        # need to create a new class object. user can specify their own dictionary and bigram files
        if verbose:
            print("creating new SymSpell object")
        sym_checker = build_symspell_obj(
            edit_dist=max_dist,
            prefix_length=prefix_length,
            dictionary_path=dictionary_path,
            bigram_path=bigram_path,
        )
    else:
        if verbose:
            print("using existing SymSpell object")
    # max edit distance per lookup (per single word, not per whole input string)
    suggestions = sym_checker.lookup_compound(
        my_string,
        max_edit_distance=max_dist,
        ignore_non_words=ignore_non_words,
        ignore_term_with_digits=True,
        transfer_casing=True,
    )

    if verbose:
        print(f"{len(suggestions)} suggestions found")
        print(f"the original string is:\n\t{my_string}")
        sug_list = [sug.term for sug in suggestions]
        print(f"suggestions:\n\t{sug_list}\n")

    if len(suggestions) < 1:
        return clean(my_string)  # no correction because no suggestions
    else:
        first_result = suggestions[0]  # first result is the most likely
        return first_result._term


def build_symspell_obj(
    edit_dist=2,
    prefix_length=7,
    dictionary_path=None,
    bigram_path=None,
):
    """
    build_symspell_obj [build a SymSpell object]

    Args:
        verbose (bool, optional): Defaults to False.

    Returns:
        SymSpell: a SymSpell object
    """
    dictionary_path = (
        r"symspell_rsc/frequency_dictionary_en_82_765.txt"
        if dictionary_path is None
        else dictionary_path
    )
    bigram_path = (
        r"symspell_rsc/frequency_bigramdictionary_en_243_342.txt"
        if bigram_path is None
        else bigram_path
    )
    sym_checker = SymSpell(
        max_dictionary_edit_distance=edit_dist + 2, prefix_length=prefix_length
    )
    # term_index is the column of the term and count_index is the
    # column of the term frequency
    sym_checker.load_dictionary(dictionary_path, term_index=0, count_index=1)
    sym_checker.load_bigram_dictionary(bigram_path, term_index=0, count_index=2)

    return sym_checker


"""
# if using t5b_correction to check for spelling errors, use this code to initialize the objects

import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration

model_name = 'deep-learning-analytics/GrammarCorrector'
# torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch_device = 'cpu'
gc_tokenizer = T5Tokenizer.from_pretrained(model_name)
gc_model = T5ForConditionalGeneration.from_pretrained(model_name).to(torch_device)

"""


def t5b_correction(prompt: str, korrektor, verbose=False, beams=4):
    """
    t5b_correction - correct a string using a text2textgen pipeline model from transformers

    Parameters
    ----------
    prompt : str, required, input prompt to be corrected
    korrektor : transformers.pipeline, required, pipeline object
    verbose : bool, optional, whether to print the corrected prompt. Defaults to False.
    beams : int, optional, number of beams to use for the correction. Defaults to 4.

    Returns
    -------
    str, corrected prompt
    """

    p_min_len = int(math.ceil(0.9 * len(prompt)))
    p_max_len = int(math.ceil(1.1 * len(prompt)))
    if verbose:
        print(f"setting min to {p_min_len} and max to {p_max_len}\n")
    gcorr_result = korrektor(
        f"grammar: {prompt}",
        return_text=True,
        clean_up_tokenization_spaces=True,
        num_beams=beams,
        max_length=p_max_len,
        repetition_penalty=1.3,
        length_penalty=0.2,
        no_repeat_ngram_size=2,
    )
    if verbose:
        print(f"grammar correction result: \n\t{gcorr_result}\n")
    return gcorr_result


def all_neuspell_chkrs():
    """
    disp_neuspell_chkrs - display the neuspell checkers available

    Parameters
    ----------
    None

    Returns
    -------
    checker_opts - list of checkers available
    """

    checker_opts = dir(neuspell)
    print(f"\navailable checkers:")

    pp.pprint(checker_opts, indent=4, compact=True)

    return checker_opts


def load_ns_checker(customckr=None, fast=False):
    """
    load_ns_checker - helper function, load / "set up" a neuspell checker from huggingface transformers

    Args:
        customckr (neuspell.NeuSpell): [neuspell checker object], optional, if not provided, will load the default checker

    Returns:
        [neuspell.NeuSpell]: [neuspell checker object]
    """
    st = time.perf_counter()
    # stop all printing to the console
    with suppress_stdout():
        if customckr is None and not fast:

            checker = BertChecker(
                pretrained=True
            )  # load the default checker, has the best balance
        elif customckr is None and fast:
            checker = SclstmChecker(
                pretrained=True
            )  # this one is faster but not as accurate
        else:
            checker = customckr(pretrained=True)
    rt_min = (time.perf_counter() - st) / 60
    # return to standard logging level
    print(f"\n\nloaded checker in {rt_min} minutes")

    return checker


def neuspell_correct(input_text: str, checker=None, verbose=False):
    """
    neuspell_correct - correct a string using neuspell.
                        note that modificaitons to the checker are needed if doing list-based corrections

    Parameters
    ----------
    input_text : str, required, input string to be corrected
    checker : neuspell.NeuSpell, optional, neuspell checker object. Defaults to None.
    verbose : bool, optional, whether to print the corrected string. Defaults to False.

    Returns
    -------
    str, corrected string
    """
    if isinstance(input_text, str) and len(input_text) < 4:
        print(f"input text of {input_text} is too short to be corrected")
        return input_text

    if checker is None:
        print("NOTE - no checker provided, loading default checker")
        checker = SclstmChecker(pretrained=True)

    corrected = checker.correct(input_text)
    cleaned_txt = fix_punct_spaces(corrected)

    if verbose:
        print(f"neuspell correction result: \n\t{cleaned_txt}\n")
    return cleaned_txt


def grammarpipe(corrector, qphrase: str):
    """
    gramformer_correct - THE ORIGINAL ONE USED IN PROJECT AND NEEDS TO BE CHANGED.
                            Idea is to correct a string using a text2textgen pipeline model from transformers
    Args:
        corrector (transformers.pipeline): [transformers pipeline object, already created w/ relevant model]
        qphrase (str): [text to be corrected]
    Returns:
        [str]: [corrected text]
    """
    if isinstance(qphrase, str) and len(qphrase) < 4:
        print(f"input text of {qphrase} is too short to be corrected")
        return qphrase
    try:
        corrected = corrector(
            clean(qphrase), return_text=True, clean_up_tokenization_spaces=True
        )
        return corrected[0]["generated_text"]
    except Exception as e:
        print(f"NOTE - failed to correct with grammarpipe:\n {e}")
        return clean(qphrase)


def DLA_correct(qphrase: str):
    """
    DLA_correct - an "overhead" function to call correct_grammar() on a string, allowing for each newline to be corrected individually

    Args:
        qphrase (str): [string to be corrected]

    Returns:
        str, the list of the corrected strings joined under " "
    """
    if isinstance(qphrase, str) and len(qphrase) < 4:
        print(f"input text of {qphrase} is too short to be corrected")
        return qphrase

    sentences = split_sentences(qphrase)
    if len(sentences) == 1:
        corrected = correct_grammar(sentences[0])
        return corrected
    else:
        full_cor = []
        for sen in sentences:
            corr_sen = correct_grammar(clean(sen))
            full_cor.append(corr_sen)
        return " ".join(full_cor)


def correct_grammar(
    input_text: str,
    tokenizer,
    model,
    n_results: int = 1,
    beams: int = 8,
    temp=1,
    uniq_ngrams=2,
    rep_penalty=1.5,
    device="cpu",
):
    """
    correct_grammar - correct a string using a text2textgen pipeline model from transformers.
                        This function is an alternative to the t5b_correction function.

    Parameters
    ----------
    input_text : str, required, input string to be corrected
    tokenizer : transformers.T5Tokenizer, required, tokenizer object, already created w/ relevant model
    model : transformers.T5ForConditionalGeneration, required, model object, already created w/ relevant model
    n_results : int, optional, number of results to return. Defaults to 1.
    beams : int, optional, number of beams to use for the correction. Defaults to 8.
    temp : int, optional, temperature to use for the correction. Defaults to 1.
    uniq_ngrams : int, optional, number of ngrams to use for the correction. Defaults to 2.
    rep_penalty : float, optional, penalty to use for the correction. Defaults to 1.5.
    device : str, optional, device to use for the correction. Defaults to 'cpu'.

    Returns
    -------
    str, corrected string (or list of strings if n_results > 1)
    """
    st = time.perf_counter()

    if len(input_text) < 5:
        return input_text
    max_length = min(int(math.ceil(len(input_text) * 1.2)), 128)
    batch = tokenizer(
        [input_text],
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt",
    ).to(device)
    translated = model.generate(
        **batch,
        max_length=max_length,
        min_length=min(10, len(input_text)),
        no_repeat_ngram_size=uniq_ngrams,
        repetition_penalty=rep_penalty,
        num_beams=beams,
        num_return_sequences=n_results,
        temperature=temp,
    )

    tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
    rt_min = (time.perf_counter() - st) / 60
    print(f"\n\ncorrected in {rt_min} minutes")

    if isinstance(tgt_text, list):
        return tgt_text[0]
    else:
        return tgt_text