File size: 6,742 Bytes
837fdb6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
# -*- coding: utf-8 -*-
# 💾⚙️🔮
__author__ = "Daulet N."
__email__ = "daulet.nurmanbetov@gmail.com"
import logging
from langdetect import detect
from simpletransformers.ner import NERModel, NERArgs
class RestorePuncts:
def __init__(self, wrds_per_pred=250, use_cuda=False):
self.wrds_per_pred = wrds_per_pred
self.overlap_wrds = 30
self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U']
self.model_hf = "wldmr/felflare-bert-restore-punctuation"
self.model_args = NERArgs()
self.model_args.silent = True
self.model_args.max_seq_length = 512
#self.model_args.use_multiprocessing = False
self.model = NERModel("bert", self.model_hf, labels=self.valid_labels, use_cuda=use_cuda, args=self.model_args)
#self.model = NERModel("bert", self.model_hf, labels=self.valid_labels, use_cuda=use_cuda, args={"silent": True, "max_seq_length": 512, "use_multiprocessing": False})
print("class init ...")
print("use_multiprocessing: ",self.model_args.use_multiprocessing)
def status(self):
print("function called")
def punctuate(self, text: str, lang:str=''):
"""
Performs punctuation restoration on arbitrarily large text.
Detects if input is not English, if non-English was detected terminates predictions.
Overrride by supplying `lang='en'`
Args:
- text (str): Text to punctuate, can be few words to as large as you want.
- lang (str): Explicit language of input text.
"""
if not lang and len(text) > 10:
lang = detect(text)
if lang != 'en':
raise Exception(F"""Non English text detected. Restore Punctuation works only for English.
If you are certain the input is English, pass argument lang='en' to this function.
Punctuate received: {text}""")
# plit up large text into bert digestable chunks
splits = self.split_on_toks(text, self.wrds_per_pred, self.overlap_wrds)
# predict slices
# full_preds_lst contains tuple of labels and logits
full_preds_lst = [self.predict(i['text']) for i in splits]
# extract predictions, and discard logits
preds_lst = [i[0][0] for i in full_preds_lst]
# join text slices
combined_preds = self.combine_results(text, preds_lst)
# create punctuated prediction
punct_text = self.punctuate_texts(combined_preds)
return punct_text
def predict(self, input_slice):
"""
Passes the unpunctuated text to the model for punctuation.
"""
predictions, raw_outputs = self.model.predict([input_slice])
return predictions, raw_outputs
@staticmethod
def split_on_toks(text, length, overlap):
"""
Splits text into predefined slices of overlapping text with indexes (offsets)
that tie-back to original text.
This is done to bypass 512 token limit on transformer models by sequentially
feeding chunks of < 512 toks.
Example output:
[{...}, {"text": "...", 'start_idx': 31354, 'end_idx': 32648}, {...}]
"""
wrds = text.replace('\n', ' ').split(" ")
resp = []
lst_chunk_idx = 0
i = 0
while True:
# words in the chunk and the overlapping portion
wrds_len = wrds[(length * i):(length * (i + 1))]
wrds_ovlp = wrds[(length * (i + 1)):((length * (i + 1)) + overlap)]
wrds_split = wrds_len + wrds_ovlp
# Break loop if no more words
if not wrds_split:
break
wrds_str = " ".join(wrds_split)
nxt_chunk_start_idx = len(" ".join(wrds_len))
lst_char_idx = len(" ".join(wrds_split))
resp_obj = {
"text": wrds_str,
"start_idx": lst_chunk_idx,
"end_idx": lst_char_idx + lst_chunk_idx,
}
resp.append(resp_obj)
lst_chunk_idx += nxt_chunk_start_idx + 1
i += 1
logging.info(f"Sliced transcript into {len(resp)} slices.")
return resp
@staticmethod
def combine_results(full_text: str, text_slices):
"""
Given a full text and predictions of each slice combines predictions into a single text again.
Performs validataion wether text was combined correctly
"""
split_full_text = full_text.replace('\n', ' ').split(" ")
split_full_text = [i for i in split_full_text if i]
split_full_text_len = len(split_full_text)
output_text = []
index = 0
if len(text_slices[-1]) <= 3 and len(text_slices) > 1:
text_slices = text_slices[:-1]
for _slice in text_slices:
slice_wrds = len(_slice)
for ix, wrd in enumerate(_slice):
# print(index, "|", str(list(wrd.keys())[0]), "|", split_full_text[index])
if index == split_full_text_len:
break
if split_full_text[index] == str(list(wrd.keys())[0]) and \
ix <= slice_wrds - 3 and text_slices[-1] != _slice:
index += 1
pred_item_tuple = list(wrd.items())[0]
output_text.append(pred_item_tuple)
elif split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == _slice:
index += 1
pred_item_tuple = list(wrd.items())[0]
output_text.append(pred_item_tuple)
assert [i[0] for i in output_text] == split_full_text
return output_text
@staticmethod
def punctuate_texts(full_pred: list):
"""
Given a list of Predictions from the model, applies the predictions to text,
thus punctuating it.
"""
punct_resp = ""
for i in full_pred:
word, label = i
if label[-1] == "U":
punct_wrd = word.capitalize()
else:
punct_wrd = word
if label[0] != "O":
punct_wrd += label[0]
punct_resp += punct_wrd + " "
punct_resp = punct_resp.strip()
# Append trailing period if doesnt exist.
if punct_resp[-1].isalnum():
punct_resp += "."
return punct_resp
if __name__ == "__main__":
punct_model = RestorePuncts()
# read test file
with open('../tests/sample_text.txt', 'r') as fp:
test_sample = fp.read()
# predict text and print
punctuated = punct_model.punctuate(test_sample)
print(punctuated)
|