rpunct removed
Browse files- myrpunct/__init__.py +0 -2
- myrpunct/__pycache__/__init__.cpython-310.pyc +0 -0
- myrpunct/__pycache__/__init__.cpython-39.pyc +0 -0
- myrpunct/__pycache__/punctuate.cpython-310.pyc +0 -0
- myrpunct/__pycache__/punctuate.cpython-39.pyc +0 -0
- myrpunct/punctuate.py +0 -174
- myrpunct/utils.py +0 -34
- repunct.py +4 -3
myrpunct/__init__.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
from .punctuate import RestorePuncts
|
2 |
-
print("init executed ...")
|
|
|
|
|
|
myrpunct/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (231 Bytes)
|
|
myrpunct/__pycache__/__init__.cpython-39.pyc
DELETED
Binary file (227 Bytes)
|
|
myrpunct/__pycache__/punctuate.cpython-310.pyc
DELETED
Binary file (5.71 kB)
|
|
myrpunct/__pycache__/punctuate.cpython-39.pyc
DELETED
Binary file (5.69 kB)
|
|
myrpunct/punctuate.py
DELETED
@@ -1,174 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# ๐พโ๏ธ๐ฎ
|
3 |
-
|
4 |
-
__author__ = "Daulet N."
|
5 |
-
__email__ = "daulet.nurmanbetov@gmail.com"
|
6 |
-
|
7 |
-
import logging
|
8 |
-
from langdetect import detect
|
9 |
-
from simpletransformers.ner import NERModel, NERArgs
|
10 |
-
|
11 |
-
|
12 |
-
class RestorePuncts:
|
13 |
-
def __init__(self, wrds_per_pred=250, use_cuda=False):
|
14 |
-
self.wrds_per_pred = wrds_per_pred
|
15 |
-
self.overlap_wrds = 30
|
16 |
-
self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U']
|
17 |
-
self.model_hf = "wldmr/felflare-bert-restore-punctuation"
|
18 |
-
self.model_args = NERArgs()
|
19 |
-
self.model_args.silent = True
|
20 |
-
self.model_args.max_seq_length = 512
|
21 |
-
#self.model_args.use_multiprocessing = False
|
22 |
-
self.model = NERModel("bert", self.model_hf, labels=self.valid_labels, use_cuda=use_cuda, args=self.model_args)
|
23 |
-
#self.model = NERModel("bert", self.model_hf, labels=self.valid_labels, use_cuda=use_cuda, args={"silent": True, "max_seq_length": 512, "use_multiprocessing": False})
|
24 |
-
print("class init ...")
|
25 |
-
print("use_multiprocessing: ",self.model_args.use_multiprocessing)
|
26 |
-
|
27 |
-
def status(self):
|
28 |
-
print("function called")
|
29 |
-
|
30 |
-
def punctuate(self, text: str, lang:str=''):
|
31 |
-
"""
|
32 |
-
Performs punctuation restoration on arbitrarily large text.
|
33 |
-
Detects if input is not English, if non-English was detected terminates predictions.
|
34 |
-
Overrride by supplying `lang='en'`
|
35 |
-
|
36 |
-
Args:
|
37 |
-
- text (str): Text to punctuate, can be few words to as large as you want.
|
38 |
-
- lang (str): Explicit language of input text.
|
39 |
-
"""
|
40 |
-
if not lang and len(text) > 10:
|
41 |
-
lang = detect(text)
|
42 |
-
if lang != 'en':
|
43 |
-
raise Exception(F"""Non English text detected. Restore Punctuation works only for English.
|
44 |
-
If you are certain the input is English, pass argument lang='en' to this function.
|
45 |
-
Punctuate received: {text}""")
|
46 |
-
|
47 |
-
# plit up large text into bert digestable chunks
|
48 |
-
splits = self.split_on_toks(text, self.wrds_per_pred, self.overlap_wrds)
|
49 |
-
# predict slices
|
50 |
-
# full_preds_lst contains tuple of labels and logits
|
51 |
-
full_preds_lst = [self.predict(i['text']) for i in splits]
|
52 |
-
# extract predictions, and discard logits
|
53 |
-
preds_lst = [i[0][0] for i in full_preds_lst]
|
54 |
-
# join text slices
|
55 |
-
combined_preds = self.combine_results(text, preds_lst)
|
56 |
-
# create punctuated prediction
|
57 |
-
punct_text = self.punctuate_texts(combined_preds)
|
58 |
-
return punct_text
|
59 |
-
|
60 |
-
def predict(self, input_slice):
|
61 |
-
"""
|
62 |
-
Passes the unpunctuated text to the model for punctuation.
|
63 |
-
"""
|
64 |
-
predictions, raw_outputs = self.model.predict([input_slice])
|
65 |
-
return predictions, raw_outputs
|
66 |
-
|
67 |
-
@staticmethod
|
68 |
-
def split_on_toks(text, length, overlap):
|
69 |
-
"""
|
70 |
-
Splits text into predefined slices of overlapping text with indexes (offsets)
|
71 |
-
that tie-back to original text.
|
72 |
-
This is done to bypass 512 token limit on transformer models by sequentially
|
73 |
-
feeding chunks of < 512 toks.
|
74 |
-
Example output:
|
75 |
-
[{...}, {"text": "...", 'start_idx': 31354, 'end_idx': 32648}, {...}]
|
76 |
-
"""
|
77 |
-
wrds = text.replace('\n', ' ').split(" ")
|
78 |
-
resp = []
|
79 |
-
lst_chunk_idx = 0
|
80 |
-
i = 0
|
81 |
-
|
82 |
-
while True:
|
83 |
-
# words in the chunk and the overlapping portion
|
84 |
-
wrds_len = wrds[(length * i):(length * (i + 1))]
|
85 |
-
wrds_ovlp = wrds[(length * (i + 1)):((length * (i + 1)) + overlap)]
|
86 |
-
wrds_split = wrds_len + wrds_ovlp
|
87 |
-
|
88 |
-
# Break loop if no more words
|
89 |
-
if not wrds_split:
|
90 |
-
break
|
91 |
-
|
92 |
-
wrds_str = " ".join(wrds_split)
|
93 |
-
nxt_chunk_start_idx = len(" ".join(wrds_len))
|
94 |
-
lst_char_idx = len(" ".join(wrds_split))
|
95 |
-
|
96 |
-
resp_obj = {
|
97 |
-
"text": wrds_str,
|
98 |
-
"start_idx": lst_chunk_idx,
|
99 |
-
"end_idx": lst_char_idx + lst_chunk_idx,
|
100 |
-
}
|
101 |
-
|
102 |
-
resp.append(resp_obj)
|
103 |
-
lst_chunk_idx += nxt_chunk_start_idx + 1
|
104 |
-
i += 1
|
105 |
-
logging.info(f"Sliced transcript into {len(resp)} slices.")
|
106 |
-
return resp
|
107 |
-
|
108 |
-
@staticmethod
|
109 |
-
def combine_results(full_text: str, text_slices):
|
110 |
-
"""
|
111 |
-
Given a full text and predictions of each slice combines predictions into a single text again.
|
112 |
-
Performs validataion wether text was combined correctly
|
113 |
-
"""
|
114 |
-
split_full_text = full_text.replace('\n', ' ').split(" ")
|
115 |
-
split_full_text = [i for i in split_full_text if i]
|
116 |
-
split_full_text_len = len(split_full_text)
|
117 |
-
output_text = []
|
118 |
-
index = 0
|
119 |
-
|
120 |
-
if len(text_slices[-1]) <= 3 and len(text_slices) > 1:
|
121 |
-
text_slices = text_slices[:-1]
|
122 |
-
|
123 |
-
for _slice in text_slices:
|
124 |
-
slice_wrds = len(_slice)
|
125 |
-
for ix, wrd in enumerate(_slice):
|
126 |
-
# print(index, "|", str(list(wrd.keys())[0]), "|", split_full_text[index])
|
127 |
-
if index == split_full_text_len:
|
128 |
-
break
|
129 |
-
|
130 |
-
if split_full_text[index] == str(list(wrd.keys())[0]) and \
|
131 |
-
ix <= slice_wrds - 3 and text_slices[-1] != _slice:
|
132 |
-
index += 1
|
133 |
-
pred_item_tuple = list(wrd.items())[0]
|
134 |
-
output_text.append(pred_item_tuple)
|
135 |
-
elif split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == _slice:
|
136 |
-
index += 1
|
137 |
-
pred_item_tuple = list(wrd.items())[0]
|
138 |
-
output_text.append(pred_item_tuple)
|
139 |
-
assert [i[0] for i in output_text] == split_full_text
|
140 |
-
return output_text
|
141 |
-
|
142 |
-
@staticmethod
|
143 |
-
def punctuate_texts(full_pred: list):
|
144 |
-
"""
|
145 |
-
Given a list of Predictions from the model, applies the predictions to text,
|
146 |
-
thus punctuating it.
|
147 |
-
"""
|
148 |
-
punct_resp = ""
|
149 |
-
for i in full_pred:
|
150 |
-
word, label = i
|
151 |
-
if label[-1] == "U":
|
152 |
-
punct_wrd = word.capitalize()
|
153 |
-
else:
|
154 |
-
punct_wrd = word
|
155 |
-
|
156 |
-
if label[0] != "O":
|
157 |
-
punct_wrd += label[0]
|
158 |
-
|
159 |
-
punct_resp += punct_wrd + " "
|
160 |
-
punct_resp = punct_resp.strip()
|
161 |
-
# Append trailing period if doesnt exist.
|
162 |
-
if punct_resp[-1].isalnum():
|
163 |
-
punct_resp += "."
|
164 |
-
return punct_resp
|
165 |
-
|
166 |
-
|
167 |
-
if __name__ == "__main__":
|
168 |
-
punct_model = RestorePuncts()
|
169 |
-
# read test file
|
170 |
-
with open('../tests/sample_text.txt', 'r') as fp:
|
171 |
-
test_sample = fp.read()
|
172 |
-
# predict text and print
|
173 |
-
punctuated = punct_model.punctuate(test_sample)
|
174 |
-
print(punctuated)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
myrpunct/utils.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# ๐พโ๏ธ๐ฎ
|
3 |
-
|
4 |
-
__author__ = "Daulet N."
|
5 |
-
__email__ = "daulet.nurmanbetov@gmail.com"
|
6 |
-
|
7 |
-
def prepare_unpunct_text(text):
|
8 |
-
"""
|
9 |
-
Given a text, normalizes it to subsequently restore punctuation
|
10 |
-
"""
|
11 |
-
formatted_txt = text.replace('\n', '').strip()
|
12 |
-
formatted_txt = formatted_txt.lower()
|
13 |
-
formatted_txt_lst = formatted_txt.split(" ")
|
14 |
-
punct_strp_txt = [strip_punct(i) for i in formatted_txt_lst]
|
15 |
-
normalized_txt = " ".join([i for i in punct_strp_txt if i])
|
16 |
-
return normalized_txt
|
17 |
-
|
18 |
-
def strip_punct(wrd):
|
19 |
-
"""
|
20 |
-
Given a word, strips non aphanumeric characters that precede and follow it
|
21 |
-
"""
|
22 |
-
if not wrd:
|
23 |
-
return wrd
|
24 |
-
|
25 |
-
while not wrd[-1:].isalnum():
|
26 |
-
if not wrd:
|
27 |
-
break
|
28 |
-
wrd = wrd[:-1]
|
29 |
-
|
30 |
-
while not wrd[:1].isalnum():
|
31 |
-
if not wrd:
|
32 |
-
break
|
33 |
-
wrd = wrd[1:]
|
34 |
-
return wrd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repunct.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
-
from myrpunct import RestorePuncts
|
2 |
|
3 |
def predict(input_text):
|
4 |
-
rpunct = RestorePuncts()
|
5 |
-
output_text = rpunct.punctuate(input_text)
|
|
|
6 |
print("Punctuation finished...")
|
7 |
return output_text
|
|
|
1 |
+
#from myrpunct import RestorePuncts
|
2 |
|
3 |
def predict(input_text):
|
4 |
+
#rpunct = RestorePuncts()
|
5 |
+
#output_text = rpunct.punctuate(input_text)
|
6 |
+
output_text = "Error: Module not available"
|
7 |
print("Punctuation finished...")
|
8 |
return output_text
|