File size: 4,447 Bytes
882d84f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
class Gramformer:
def __init__(self, models=1, use_gpu=False):
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
#from lm_scorer.models.auto import AutoLMScorer as LMScorer
import errant
import spacy
import en_core_web_sm
nlp = en_core_web_sm.load()
self.annotator = errant.load('en', nlp)
if use_gpu:
device= "cuda:0"
else:
device = "cpu"
batch_size = 1
#self.scorer = LMScorer.from_pretrained("gpt2", device=device, batch_size=batch_size)
self.device = device
correction_model_tag = "prithivida/grammar_error_correcter_v1"
self.model_loaded = False
if models == 1:
self.correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_tag, use_auth_token=False)
self.correction_model = AutoModelForSeq2SeqLM.from_pretrained(correction_model_tag, use_auth_token=False)
self.correction_model = self.correction_model.to(device)
self.model_loaded = True
print("[Gramformer] Grammar error correct/highlight model loaded..")
elif models == 2:
# TODO
print("TO BE IMPLEMENTED!!!")
def correct(self, input_sentence, max_candidates=1):
if self.model_loaded:
correction_prefix = "gec: "
input_sentence = correction_prefix + input_sentence
input_ids = self.correction_tokenizer.encode(input_sentence, return_tensors='pt')
input_ids = input_ids.to(self.device)
preds = self.correction_model.generate(
input_ids,
do_sample=True,
max_length=128,
# top_k=50,
# top_p=0.95,
num_beams=7,
early_stopping=True,
num_return_sequences=max_candidates)
corrected = set()
for pred in preds:
corrected.add(self.correction_tokenizer.decode(pred, skip_special_tokens=True).strip())
#corrected = list(corrected)
#scores = self.scorer.sentence_score(corrected, log=True)
#ranked_corrected = [(c,s) for c, s in zip(corrected, scores)]
#ranked_corrected.sort(key = lambda x:x[1], reverse=True)
return corrected
else:
print("Model is not loaded")
return None
def highlight(self, orig, cor):
edits = self._get_edits(orig, cor)
orig_tokens = orig.split()
ignore_indexes = []
for edit in edits:
edit_type = edit[0]
edit_str_start = edit[1]
edit_spos = edit[2]
edit_epos = edit[3]
edit_str_end = edit[4]
# if no_of_tokens(edit_str_start) > 1 ==> excluding the first token, mark all other tokens for deletion
for i in range(edit_spos+1, edit_epos):
ignore_indexes.append(i)
if edit_str_start == "":
if edit_spos - 1 >= 0:
new_edit_str = orig_tokens[edit_spos - 1]
edit_spos -= 1
else:
new_edit_str = orig_tokens[edit_spos + 1]
edit_spos += 1
if edit_type == "PUNCT":
st = "<s>" + edit_str_end + "</s> " + new_edit_str
else:
st = "<s>" + edit_str_end + "</s> " + new_edit_str
orig_tokens[edit_spos] = st
elif edit_str_end == "":
st = "<s>""</s> " + edit_str_start
orig_tokens[edit_spos] = st
else:
st = "<s>" + edit_str_end + "</s> " + edit_str_start
orig_tokens[edit_spos] = st
for i in sorted(ignore_indexes, reverse=True):
del(orig_tokens[i])
return(" ".join(orig_tokens))
def detect(self, input_sentence):
# TO BE IMPLEMENTED
pass
def _get_edits(self, orig, cor):
orig = self.annotator.parse(orig)
cor = self.annotator.parse(cor)
alignment = self.annotator.align(orig, cor)
edits = self.annotator.merge(alignment)
if len(edits) == 0:
return []
edit_annotations = []
for e in edits:
e = self.annotator.classify(e)
edit_annotations.append((e.type[2:], e.o_str, e.o_start, e.o_end, e.c_str, e.c_start, e.c_end))
if len(edit_annotations) > 0:
return edit_annotations
else:
return []
def get_edits(self, orig, cor):
return self._get_edits(orig, cor)
|