wldmr commited on
Commit
a01d1cf
โ€ข
1 Parent(s): fa50d64

rpunct removed

Browse files
myrpunct/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .punctuate import RestorePuncts
2
- print("init executed ...")
 
 
 
myrpunct/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (231 Bytes)
 
myrpunct/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (227 Bytes)
 
myrpunct/__pycache__/punctuate.cpython-310.pyc DELETED
Binary file (5.71 kB)
 
myrpunct/__pycache__/punctuate.cpython-39.pyc DELETED
Binary file (5.69 kB)
 
myrpunct/punctuate.py DELETED
@@ -1,174 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # ๐Ÿ’พโš™๏ธ๐Ÿ”ฎ
3
-
4
- __author__ = "Daulet N."
5
- __email__ = "daulet.nurmanbetov@gmail.com"
6
-
7
- import logging
8
- from langdetect import detect
9
- from simpletransformers.ner import NERModel, NERArgs
10
-
11
-
12
- class RestorePuncts:
13
- def __init__(self, wrds_per_pred=250, use_cuda=False):
14
- self.wrds_per_pred = wrds_per_pred
15
- self.overlap_wrds = 30
16
- self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U']
17
- self.model_hf = "wldmr/felflare-bert-restore-punctuation"
18
- self.model_args = NERArgs()
19
- self.model_args.silent = True
20
- self.model_args.max_seq_length = 512
21
- #self.model_args.use_multiprocessing = False
22
- self.model = NERModel("bert", self.model_hf, labels=self.valid_labels, use_cuda=use_cuda, args=self.model_args)
23
- #self.model = NERModel("bert", self.model_hf, labels=self.valid_labels, use_cuda=use_cuda, args={"silent": True, "max_seq_length": 512, "use_multiprocessing": False})
24
- print("class init ...")
25
- print("use_multiprocessing: ",self.model_args.use_multiprocessing)
26
-
27
- def status(self):
28
- print("function called")
29
-
30
- def punctuate(self, text: str, lang:str=''):
31
- """
32
- Performs punctuation restoration on arbitrarily large text.
33
- Detects if input is not English, if non-English was detected terminates predictions.
34
- Overrride by supplying `lang='en'`
35
-
36
- Args:
37
- - text (str): Text to punctuate, can be few words to as large as you want.
38
- - lang (str): Explicit language of input text.
39
- """
40
- if not lang and len(text) > 10:
41
- lang = detect(text)
42
- if lang != 'en':
43
- raise Exception(F"""Non English text detected. Restore Punctuation works only for English.
44
- If you are certain the input is English, pass argument lang='en' to this function.
45
- Punctuate received: {text}""")
46
-
47
- # plit up large text into bert digestable chunks
48
- splits = self.split_on_toks(text, self.wrds_per_pred, self.overlap_wrds)
49
- # predict slices
50
- # full_preds_lst contains tuple of labels and logits
51
- full_preds_lst = [self.predict(i['text']) for i in splits]
52
- # extract predictions, and discard logits
53
- preds_lst = [i[0][0] for i in full_preds_lst]
54
- # join text slices
55
- combined_preds = self.combine_results(text, preds_lst)
56
- # create punctuated prediction
57
- punct_text = self.punctuate_texts(combined_preds)
58
- return punct_text
59
-
60
- def predict(self, input_slice):
61
- """
62
- Passes the unpunctuated text to the model for punctuation.
63
- """
64
- predictions, raw_outputs = self.model.predict([input_slice])
65
- return predictions, raw_outputs
66
-
67
- @staticmethod
68
- def split_on_toks(text, length, overlap):
69
- """
70
- Splits text into predefined slices of overlapping text with indexes (offsets)
71
- that tie-back to original text.
72
- This is done to bypass 512 token limit on transformer models by sequentially
73
- feeding chunks of < 512 toks.
74
- Example output:
75
- [{...}, {"text": "...", 'start_idx': 31354, 'end_idx': 32648}, {...}]
76
- """
77
- wrds = text.replace('\n', ' ').split(" ")
78
- resp = []
79
- lst_chunk_idx = 0
80
- i = 0
81
-
82
- while True:
83
- # words in the chunk and the overlapping portion
84
- wrds_len = wrds[(length * i):(length * (i + 1))]
85
- wrds_ovlp = wrds[(length * (i + 1)):((length * (i + 1)) + overlap)]
86
- wrds_split = wrds_len + wrds_ovlp
87
-
88
- # Break loop if no more words
89
- if not wrds_split:
90
- break
91
-
92
- wrds_str = " ".join(wrds_split)
93
- nxt_chunk_start_idx = len(" ".join(wrds_len))
94
- lst_char_idx = len(" ".join(wrds_split))
95
-
96
- resp_obj = {
97
- "text": wrds_str,
98
- "start_idx": lst_chunk_idx,
99
- "end_idx": lst_char_idx + lst_chunk_idx,
100
- }
101
-
102
- resp.append(resp_obj)
103
- lst_chunk_idx += nxt_chunk_start_idx + 1
104
- i += 1
105
- logging.info(f"Sliced transcript into {len(resp)} slices.")
106
- return resp
107
-
108
- @staticmethod
109
- def combine_results(full_text: str, text_slices):
110
- """
111
- Given a full text and predictions of each slice combines predictions into a single text again.
112
- Performs validataion wether text was combined correctly
113
- """
114
- split_full_text = full_text.replace('\n', ' ').split(" ")
115
- split_full_text = [i for i in split_full_text if i]
116
- split_full_text_len = len(split_full_text)
117
- output_text = []
118
- index = 0
119
-
120
- if len(text_slices[-1]) <= 3 and len(text_slices) > 1:
121
- text_slices = text_slices[:-1]
122
-
123
- for _slice in text_slices:
124
- slice_wrds = len(_slice)
125
- for ix, wrd in enumerate(_slice):
126
- # print(index, "|", str(list(wrd.keys())[0]), "|", split_full_text[index])
127
- if index == split_full_text_len:
128
- break
129
-
130
- if split_full_text[index] == str(list(wrd.keys())[0]) and \
131
- ix <= slice_wrds - 3 and text_slices[-1] != _slice:
132
- index += 1
133
- pred_item_tuple = list(wrd.items())[0]
134
- output_text.append(pred_item_tuple)
135
- elif split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == _slice:
136
- index += 1
137
- pred_item_tuple = list(wrd.items())[0]
138
- output_text.append(pred_item_tuple)
139
- assert [i[0] for i in output_text] == split_full_text
140
- return output_text
141
-
142
- @staticmethod
143
- def punctuate_texts(full_pred: list):
144
- """
145
- Given a list of Predictions from the model, applies the predictions to text,
146
- thus punctuating it.
147
- """
148
- punct_resp = ""
149
- for i in full_pred:
150
- word, label = i
151
- if label[-1] == "U":
152
- punct_wrd = word.capitalize()
153
- else:
154
- punct_wrd = word
155
-
156
- if label[0] != "O":
157
- punct_wrd += label[0]
158
-
159
- punct_resp += punct_wrd + " "
160
- punct_resp = punct_resp.strip()
161
- # Append trailing period if doesnt exist.
162
- if punct_resp[-1].isalnum():
163
- punct_resp += "."
164
- return punct_resp
165
-
166
-
167
- if __name__ == "__main__":
168
- punct_model = RestorePuncts()
169
- # read test file
170
- with open('../tests/sample_text.txt', 'r') as fp:
171
- test_sample = fp.read()
172
- # predict text and print
173
- punctuated = punct_model.punctuate(test_sample)
174
- print(punctuated)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
myrpunct/utils.py DELETED
@@ -1,34 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # ๐Ÿ’พโš™๏ธ๐Ÿ”ฎ
3
-
4
- __author__ = "Daulet N."
5
- __email__ = "daulet.nurmanbetov@gmail.com"
6
-
7
- def prepare_unpunct_text(text):
8
- """
9
- Given a text, normalizes it to subsequently restore punctuation
10
- """
11
- formatted_txt = text.replace('\n', '').strip()
12
- formatted_txt = formatted_txt.lower()
13
- formatted_txt_lst = formatted_txt.split(" ")
14
- punct_strp_txt = [strip_punct(i) for i in formatted_txt_lst]
15
- normalized_txt = " ".join([i for i in punct_strp_txt if i])
16
- return normalized_txt
17
-
18
- def strip_punct(wrd):
19
- """
20
- Given a word, strips non aphanumeric characters that precede and follow it
21
- """
22
- if not wrd:
23
- return wrd
24
-
25
- while not wrd[-1:].isalnum():
26
- if not wrd:
27
- break
28
- wrd = wrd[:-1]
29
-
30
- while not wrd[:1].isalnum():
31
- if not wrd:
32
- break
33
- wrd = wrd[1:]
34
- return wrd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
repunct.py CHANGED
@@ -1,7 +1,8 @@
1
- from myrpunct import RestorePuncts
2
 
3
  def predict(input_text):
4
- rpunct = RestorePuncts()
5
- output_text = rpunct.punctuate(input_text)
 
6
  print("Punctuation finished...")
7
  return output_text
 
1
+ #from myrpunct import RestorePuncts
2
 
3
  def predict(input_text):
4
+ #rpunct = RestorePuncts()
5
+ #output_text = rpunct.punctuate(input_text)
6
+ output_text = "Error: Module not available"
7
  print("Punctuation finished...")
8
  return output_text