kinensake commited on
Commit
e8ca0bf
1 Parent(s): 7e8d880

Modify: requirements.txt

Browse files
gramformer_backup/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from gramformer.gramformer import Gramformer
gramformer_backup/demo.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gramformer import Gramformer
2
+ import torch
3
+
4
+ def set_seed(seed):
5
+ torch.manual_seed(seed)
6
+ if torch.cuda.is_available():
7
+ torch.cuda.manual_seed_all(seed)
8
+
9
+ set_seed(1212)
10
+
11
+
12
+ gf = Gramformer(models = 1, use_gpu=False) # 1=corrector, 2=detector
13
+
14
+ influent_sentences = [
15
+ "Matt like fish",
16
+ "the collection of letters was original used by the ancient Romans",
17
+ "We enjoys horror movies",
18
+ "Anna and Mike is going skiing",
19
+ "I walk to the store and I bought milk",
20
+ "We all eat the fish and then made dessert",
21
+ "I will eat fish for dinner and drank milk",
22
+ "what be the reason for everyone leave the company",
23
+ ]
24
+
25
+ for influent_sentence in influent_sentences:
26
+ corrected_sentences = gf.correct(influent_sentence, max_candidates=1)
27
+ print("[Input] ", influent_sentence)
28
+ for corrected_sentence in corrected_sentences:
29
+ print("[Correction] ",corrected_sentence)
30
+ print("-" *100)
gramformer_backup/gramformer.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Gramformer:
2
+
3
+ def __init__(self, models=1, use_gpu=False):
4
+ from transformers import AutoTokenizer
5
+ from transformers import AutoModelForSeq2SeqLM
6
+ from lm_scorer.models.auto import AutoLMScorer as LMScorer
7
+ import errant
8
+ self.annotator = errant.load('en')
9
+
10
+ if use_gpu:
11
+ device= "cuda:0"
12
+ else:
13
+ device = "cpu"
14
+ batch_size = 1
15
+ self.scorer = LMScorer.from_pretrained("gpt2", device=device, batch_size=batch_size)
16
+ self.device = device
17
+ correction_model_tag = "prithivida/grammar_error_correcter_v1"
18
+ self.model_loaded = False
19
+
20
+ if models == 1:
21
+ self.correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_tag)
22
+ self.correction_model = AutoModelForSeq2SeqLM.from_pretrained(correction_model_tag)
23
+ self.correction_model = self.correction_model.to(device)
24
+ self.model_loaded = True
25
+ print("[Gramformer] Grammar error correct/highlight model loaded..")
26
+ elif models == 2:
27
+ # TODO
28
+ print("TO BE IMPLEMENTED!!!")
29
+
30
+ def correct(self, input_sentence, max_candidates=1):
31
+ if self.model_loaded:
32
+ correction_prefix = "gec: "
33
+ input_sentence = correction_prefix + input_sentence
34
+ input_ids = self.correction_tokenizer.encode(input_sentence, return_tensors='pt')
35
+ input_ids = input_ids.to(self.device)
36
+
37
+ preds = self.correction_model.generate(
38
+ input_ids,
39
+ do_sample=True,
40
+ max_length=128,
41
+ top_k=50,
42
+ top_p=0.95,
43
+ early_stopping=True,
44
+ num_return_sequences=max_candidates)
45
+
46
+ corrected = set()
47
+ for pred in preds:
48
+ corrected.add(self.correction_tokenizer.decode(pred, skip_special_tokens=True).strip())
49
+
50
+ corrected = list(corrected)
51
+ scores = self.scorer.sentence_score(corrected, log=True)
52
+ ranked_corrected = [(c,s) for c, s in zip(corrected, scores)]
53
+ ranked_corrected.sort(key = lambda x:x[1], reverse=True)
54
+ return ranked_corrected
55
+ else:
56
+ print("Model is not loaded")
57
+ return None
58
+
59
+ def highlight(self, orig, cor):
60
+ edits = self._get_edits(orig, cor)
61
+ orig_tokens = orig.split()
62
+
63
+ ignore_indexes = []
64
+
65
+ for edit in edits:
66
+ edit_type = edit[0]
67
+ edit_str_start = edit[1]
68
+ edit_spos = edit[2]
69
+ edit_epos = edit[3]
70
+ edit_str_end = edit[4]
71
+
72
+ # if no_of_tokens(edit_str_start) > 1 ==> excluding the first token, mark all other tokens for deletion
73
+ for i in range(edit_spos+1, edit_epos):
74
+ ignore_indexes.append(i)
75
+
76
+ if edit_str_start == "":
77
+ if edit_spos - 1 >= 0:
78
+ new_edit_str = orig_tokens[edit_spos - 1]
79
+ edit_spos -= 1
80
+ else:
81
+ new_edit_str = orig_tokens[edit_spos + 1]
82
+ edit_spos += 1
83
+ if edit_type == "PUNCT":
84
+ st = "<a type='" + edit_type + "' edit='" + \
85
+ edit_str_end + "'>" + new_edit_str + "</a>"
86
+ else:
87
+ st = "<a type='" + edit_type + "' edit='" + new_edit_str + \
88
+ " " + edit_str_end + "'>" + new_edit_str + "</a>"
89
+ orig_tokens[edit_spos] = st
90
+ elif edit_str_end == "":
91
+ st = "<d type='" + edit_type + "' edit=''>" + edit_str_start + "</d>"
92
+ orig_tokens[edit_spos] = st
93
+ else:
94
+ st = "<c type='" + edit_type + "' edit='" + \
95
+ edit_str_end + "'>" + edit_str_start + "</c>"
96
+ orig_tokens[edit_spos] = st
97
+
98
+ for i in sorted(ignore_indexes, reverse=True):
99
+ del(orig_tokens[i])
100
+
101
+ return(" ".join(orig_tokens))
102
+
103
+ def detect(self, input_sentence):
104
+ # TO BE IMPLEMENTED
105
+ pass
106
+
107
+ def _get_edits(self, orig, cor):
108
+ orig = self.annotator.parse(orig)
109
+ cor = self.annotator.parse(cor)
110
+ alignment = self.annotator.align(orig, cor)
111
+ edits = self.annotator.merge(alignment)
112
+
113
+ if len(edits) == 0:
114
+ return []
115
+
116
+ edit_annotations = []
117
+ for e in edits:
118
+ e = self.annotator.classify(e)
119
+ edit_annotations.append((e.type[2:], e.o_str, e.o_start, e.o_end, e.c_str, e.c_start, e.c_end))
120
+
121
+ if len(edit_annotations) > 0:
122
+ return edit_annotations
123
+ else:
124
+ return []
125
+
126
+ def get_edits(self, orig, cor):
127
+ return self._get_edits(orig, cor)
requirements.txt CHANGED
@@ -1,4 +1,10 @@
1
- pip==20.1.1
2
  st-annotated-text
3
  beautifulsoup4
4
- git+https://github.com/PrithivirajDamodaran/Gramformer.git
 
 
 
 
 
 
 
 
1
  st-annotated-text
2
  beautifulsoup4
3
+ transformers
4
+ sentencepiece==0.1.95
5
+ python-Levenshtein==0.12.2
6
+ fuzzywuzzy==0.18.0
7
+ tokenizers==0.10.2
8
+ fsspec==2021.5.0
9
+ lm-scorer==0.4.2
10
+ errant
setup.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import setuptools
2
+
3
+ setuptools.setup(
4
+ name="gramformer",
5
+ version="1.0",
6
+ author="prithiviraj damodaran",
7
+ author_email="",
8
+ description="Gramformer",
9
+ long_description="A framework for detecting, highlighting and correcting grammatical errors on natural language text",
10
+ url="https://github.com/PrithivirajDamodaran/Gramformer.git",
11
+ packages=setuptools.find_packages(),
12
+ install_requires=['transformers', 'sentencepiece==0.1.95', 'python-Levenshtein==0.12.2', 'fuzzywuzzy==0.18.0', 'tokenizers==0.10.2', 'fsspec==2021.5.0', 'lm-scorer==0.4.2', 'errant', 'st-annotated-text'],
13
+ classifiers=[
14
+ "Programming Language :: Python :: 3.7",
15
+ "License :: Apache 2.0",
16
+ "Operating System :: OS Independent",
17
+ ],
18
+ )
19
+