File size: 4,581 Bytes
e8ca0bf 945af81 e8ca0bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
class Gramformer:
def __init__(self, models=1, use_gpu=False):
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from lm_scorer.models.auto import AutoLMScorer as LMScorer
import errant
import en_core_web_sm
nlp = en_core_web_sm.load()
self.annotator = errant.load('en', nlp)
if use_gpu:
device= "cuda:0"
else:
device = "cpu"
batch_size = 1
self.scorer = LMScorer.from_pretrained("gpt2", device=device, batch_size=batch_size)
self.device = device
correction_model_tag = "prithivida/grammar_error_correcter_v1"
self.model_loaded = False
if models == 1:
self.correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_tag)
self.correction_model = AutoModelForSeq2SeqLM.from_pretrained(correction_model_tag)
self.correction_model = self.correction_model.to(device)
self.model_loaded = True
print("[Gramformer] Grammar error correct/highlight model loaded..")
elif models == 2:
# TODO
print("TO BE IMPLEMENTED!!!")
def correct(self, input_sentence, max_candidates=1):
if self.model_loaded:
correction_prefix = "gec: "
input_sentence = correction_prefix + input_sentence
input_ids = self.correction_tokenizer.encode(input_sentence, return_tensors='pt')
input_ids = input_ids.to(self.device)
preds = self.correction_model.generate(
input_ids,
do_sample=True,
max_length=128,
top_k=50,
top_p=0.95,
early_stopping=True,
num_return_sequences=max_candidates)
corrected = set()
for pred in preds:
corrected.add(self.correction_tokenizer.decode(pred, skip_special_tokens=True).strip())
corrected = list(corrected)
scores = self.scorer.sentence_score(corrected, log=True)
ranked_corrected = [(c,s) for c, s in zip(corrected, scores)]
ranked_corrected.sort(key = lambda x:x[1], reverse=True)
return ranked_corrected
else:
print("Model is not loaded")
return None
def highlight(self, orig, cor):
edits = self._get_edits(orig, cor)
orig_tokens = orig.split()
ignore_indexes = []
for edit in edits:
edit_type = edit[0]
edit_str_start = edit[1]
edit_spos = edit[2]
edit_epos = edit[3]
edit_str_end = edit[4]
# if no_of_tokens(edit_str_start) > 1 ==> excluding the first token, mark all other tokens for deletion
for i in range(edit_spos+1, edit_epos):
ignore_indexes.append(i)
if edit_str_start == "":
if edit_spos - 1 >= 0:
new_edit_str = orig_tokens[edit_spos - 1]
edit_spos -= 1
else:
new_edit_str = orig_tokens[edit_spos + 1]
edit_spos += 1
if edit_type == "PUNCT":
st = "<a type='" + edit_type + "' edit='" + \
edit_str_end + "'>" + new_edit_str + "</a>"
else:
st = "<a type='" + edit_type + "' edit='" + new_edit_str + \
" " + edit_str_end + "'>" + new_edit_str + "</a>"
orig_tokens[edit_spos] = st
elif edit_str_end == "":
st = "<d type='" + edit_type + "' edit=''>" + edit_str_start + "</d>"
orig_tokens[edit_spos] = st
else:
st = "<c type='" + edit_type + "' edit='" + \
edit_str_end + "'>" + edit_str_start + "</c>"
orig_tokens[edit_spos] = st
for i in sorted(ignore_indexes, reverse=True):
del(orig_tokens[i])
return(" ".join(orig_tokens))
def detect(self, input_sentence):
# TO BE IMPLEMENTED
pass
def _get_edits(self, orig, cor):
orig = self.annotator.parse(orig)
cor = self.annotator.parse(cor)
alignment = self.annotator.align(orig, cor)
edits = self.annotator.merge(alignment)
if len(edits) == 0:
return []
edit_annotations = []
for e in edits:
e = self.annotator.classify(e)
edit_annotations.append((e.type[2:], e.o_str, e.o_start, e.o_end, e.c_str, e.c_start, e.c_end))
if len(edit_annotations) > 0:
return edit_annotations
else:
return []
def get_edits(self, orig, cor):
return self._get_edits(orig, cor)
|