wldmr commited on
Commit
47ae719
0 Parent(s):

Duplicate from wldmr/punct-tube-gr

Browse files
Files changed (7) hide show
  1. .gitattributes +34 -0
  2. README.md +14 -0
  3. app.py +53 -0
  4. myrpunct/__init__.py +2 -0
  5. myrpunct/punctuate.py +174 -0
  6. myrpunct/utils.py +34 -0
  7. requirements.txt +4 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Punct Tube Gr
3
+ emoji: 💻
4
+ colorFrom: gray
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 3.12.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: wldmr/punct-tube-gr
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from myrpunct import RestorePuncts
2
+ import gradio as gr
3
+ import re
4
+
5
+ def predict(input_text):
6
+ rpunct = RestorePuncts()
7
+ output_text = rpunct.punctuate(input_text)
8
+ print("Punctuation finished...")
9
+
10
+ # restore the carrige returns
11
+ srt_file = input_text
12
+ punctuated = output_text
13
+
14
+ srt_file_strip=srt_file.strip()
15
+ srt_file_sub=re.sub('\s*\n\s*','# ',srt_file_strip)
16
+ srt_file_array=srt_file_sub.split(' ')
17
+ pcnt_file_array=punctuated.split(' ')
18
+
19
+ # goal: restore the break points i.e. the same number of lines as the srt file
20
+ # this is necessary, because each line in the srt file corresponds to a frame from the video
21
+ if len(srt_file_array)!=len(pcnt_file_array):
22
+ return "AssertError: The length of the transcript and the punctuated file should be the same: ",len(srt_file_array),len(pcnt_file_array)
23
+ pcnt_file_array_hash = []
24
+ for idx, item in enumerate(srt_file_array):
25
+ if item.endswith('#'):
26
+ pcnt_file_array_hash.append(pcnt_file_array[idx]+'#')
27
+ else:
28
+ pcnt_file_array_hash.append(pcnt_file_array[idx])
29
+
30
+ # assemble the array back to a string
31
+ pcnt_file_cr=' '.join(pcnt_file_array_hash).replace('#','\n')
32
+
33
+ return pcnt_file_cr
34
+
35
+ if __name__ == "__main__":
36
+
37
+ title = "Rpunct App"
38
+ description = """
39
+ <b>Description</b>: <br>
40
+ Model restores punctuation and case i.e. of the following punctuations -- [! ? . , - : ; ' ] and also the upper-casing of words. <br>
41
+ """
42
+ examples = ["my name is clara and i live in berkeley california"]
43
+
44
+ interface = gr.Interface(fn = predict,
45
+ inputs = ["text"],
46
+ outputs = ["text"],
47
+ title = title,
48
+ description = description,
49
+ examples=examples,
50
+ allow_flagging="never")
51
+
52
+ interface.launch()
53
+
myrpunct/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .punctuate import RestorePuncts
2
+ print("init executed ...")
myrpunct/punctuate.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # 💾⚙️🔮
3
+
4
+ __author__ = "Daulet N."
5
+ __email__ = "daulet.nurmanbetov@gmail.com"
6
+
7
+ import logging
8
+ from langdetect import detect
9
+ from simpletransformers.ner import NERModel, NERArgs
10
+
11
+
12
+ class RestorePuncts:
13
+ def __init__(self, wrds_per_pred=250, use_cuda=False):
14
+ self.wrds_per_pred = wrds_per_pred
15
+ self.overlap_wrds = 30
16
+ self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U']
17
+ self.model_hf = "wldmr/felflare-bert-restore-punctuation"
18
+ self.model_args = NERArgs()
19
+ self.model_args.silent = True
20
+ self.model_args.max_seq_length = 512
21
+ #self.model_args.use_multiprocessing = False
22
+ self.model = NERModel("bert", self.model_hf, labels=self.valid_labels, use_cuda=use_cuda, args=self.model_args)
23
+ #self.model = NERModel("bert", self.model_hf, labels=self.valid_labels, use_cuda=use_cuda, args={"silent": True, "max_seq_length": 512, "use_multiprocessing": False})
24
+ print("class init ...")
25
+ print("use_multiprocessing: ",self.model_args.use_multiprocessing)
26
+
27
+ def status(self):
28
+ print("function called")
29
+
30
+ def punctuate(self, text: str, lang:str=''):
31
+ """
32
+ Performs punctuation restoration on arbitrarily large text.
33
+ Detects if input is not English, if non-English was detected terminates predictions.
34
+ Overrride by supplying `lang='en'`
35
+
36
+ Args:
37
+ - text (str): Text to punctuate, can be few words to as large as you want.
38
+ - lang (str): Explicit language of input text.
39
+ """
40
+ if not lang and len(text) > 10:
41
+ lang = detect(text)
42
+ if lang != 'en':
43
+ raise Exception(F"""Non English text detected. Restore Punctuation works only for English.
44
+ If you are certain the input is English, pass argument lang='en' to this function.
45
+ Punctuate received: {text}""")
46
+
47
+ # plit up large text into bert digestable chunks
48
+ splits = self.split_on_toks(text, self.wrds_per_pred, self.overlap_wrds)
49
+ # predict slices
50
+ # full_preds_lst contains tuple of labels and logits
51
+ full_preds_lst = [self.predict(i['text']) for i in splits]
52
+ # extract predictions, and discard logits
53
+ preds_lst = [i[0][0] for i in full_preds_lst]
54
+ # join text slices
55
+ combined_preds = self.combine_results(text, preds_lst)
56
+ # create punctuated prediction
57
+ punct_text = self.punctuate_texts(combined_preds)
58
+ return punct_text
59
+
60
+ def predict(self, input_slice):
61
+ """
62
+ Passes the unpunctuated text to the model for punctuation.
63
+ """
64
+ predictions, raw_outputs = self.model.predict([input_slice])
65
+ return predictions, raw_outputs
66
+
67
+ @staticmethod
68
+ def split_on_toks(text, length, overlap):
69
+ """
70
+ Splits text into predefined slices of overlapping text with indexes (offsets)
71
+ that tie-back to original text.
72
+ This is done to bypass 512 token limit on transformer models by sequentially
73
+ feeding chunks of < 512 toks.
74
+ Example output:
75
+ [{...}, {"text": "...", 'start_idx': 31354, 'end_idx': 32648}, {...}]
76
+ """
77
+ wrds = text.replace('\n', ' ').split(" ")
78
+ resp = []
79
+ lst_chunk_idx = 0
80
+ i = 0
81
+
82
+ while True:
83
+ # words in the chunk and the overlapping portion
84
+ wrds_len = wrds[(length * i):(length * (i + 1))]
85
+ wrds_ovlp = wrds[(length * (i + 1)):((length * (i + 1)) + overlap)]
86
+ wrds_split = wrds_len + wrds_ovlp
87
+
88
+ # Break loop if no more words
89
+ if not wrds_split:
90
+ break
91
+
92
+ wrds_str = " ".join(wrds_split)
93
+ nxt_chunk_start_idx = len(" ".join(wrds_len))
94
+ lst_char_idx = len(" ".join(wrds_split))
95
+
96
+ resp_obj = {
97
+ "text": wrds_str,
98
+ "start_idx": lst_chunk_idx,
99
+ "end_idx": lst_char_idx + lst_chunk_idx,
100
+ }
101
+
102
+ resp.append(resp_obj)
103
+ lst_chunk_idx += nxt_chunk_start_idx + 1
104
+ i += 1
105
+ logging.info(f"Sliced transcript into {len(resp)} slices.")
106
+ return resp
107
+
108
+ @staticmethod
109
+ def combine_results(full_text: str, text_slices):
110
+ """
111
+ Given a full text and predictions of each slice combines predictions into a single text again.
112
+ Performs validataion wether text was combined correctly
113
+ """
114
+ split_full_text = full_text.replace('\n', ' ').split(" ")
115
+ split_full_text = [i for i in split_full_text if i]
116
+ split_full_text_len = len(split_full_text)
117
+ output_text = []
118
+ index = 0
119
+
120
+ if len(text_slices[-1]) <= 3 and len(text_slices) > 1:
121
+ text_slices = text_slices[:-1]
122
+
123
+ for _slice in text_slices:
124
+ slice_wrds = len(_slice)
125
+ for ix, wrd in enumerate(_slice):
126
+ # print(index, "|", str(list(wrd.keys())[0]), "|", split_full_text[index])
127
+ if index == split_full_text_len:
128
+ break
129
+
130
+ if split_full_text[index] == str(list(wrd.keys())[0]) and \
131
+ ix <= slice_wrds - 3 and text_slices[-1] != _slice:
132
+ index += 1
133
+ pred_item_tuple = list(wrd.items())[0]
134
+ output_text.append(pred_item_tuple)
135
+ elif split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == _slice:
136
+ index += 1
137
+ pred_item_tuple = list(wrd.items())[0]
138
+ output_text.append(pred_item_tuple)
139
+ assert [i[0] for i in output_text] == split_full_text
140
+ return output_text
141
+
142
+ @staticmethod
143
+ def punctuate_texts(full_pred: list):
144
+ """
145
+ Given a list of Predictions from the model, applies the predictions to text,
146
+ thus punctuating it.
147
+ """
148
+ punct_resp = ""
149
+ for i in full_pred:
150
+ word, label = i
151
+ if label[-1] == "U":
152
+ punct_wrd = word.capitalize()
153
+ else:
154
+ punct_wrd = word
155
+
156
+ if label[0] != "O":
157
+ punct_wrd += label[0]
158
+
159
+ punct_resp += punct_wrd + " "
160
+ punct_resp = punct_resp.strip()
161
+ # Append trailing period if doesnt exist.
162
+ if punct_resp[-1].isalnum():
163
+ punct_resp += "."
164
+ return punct_resp
165
+
166
+
167
+ if __name__ == "__main__":
168
+ punct_model = RestorePuncts()
169
+ # read test file
170
+ with open('../tests/sample_text.txt', 'r') as fp:
171
+ test_sample = fp.read()
172
+ # predict text and print
173
+ punctuated = punct_model.punctuate(test_sample)
174
+ print(punctuated)
myrpunct/utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # 💾⚙️🔮
3
+
4
+ __author__ = "Daulet N."
5
+ __email__ = "daulet.nurmanbetov@gmail.com"
6
+
7
+ def prepare_unpunct_text(text):
8
+ """
9
+ Given a text, normalizes it to subsequently restore punctuation
10
+ """
11
+ formatted_txt = text.replace('\n', '').strip()
12
+ formatted_txt = formatted_txt.lower()
13
+ formatted_txt_lst = formatted_txt.split(" ")
14
+ punct_strp_txt = [strip_punct(i) for i in formatted_txt_lst]
15
+ normalized_txt = " ".join([i for i in punct_strp_txt if i])
16
+ return normalized_txt
17
+
18
+ def strip_punct(wrd):
19
+ """
20
+ Given a word, strips non aphanumeric characters that precede and follow it
21
+ """
22
+ if not wrd:
23
+ return wrd
24
+
25
+ while not wrd[-1:].isalnum():
26
+ if not wrd:
27
+ break
28
+ wrd = wrd[:-1]
29
+
30
+ while not wrd[:1].isalnum():
31
+ if not wrd:
32
+ break
33
+ wrd = wrd[1:]
34
+ return wrd
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ langdetect
4
+ simpletransformers