Spaces:
Build error
Build error
new
Browse files- app.py +94 -4
- myrpunct/__init__.py +2 -0
- myrpunct/punctuate.py +174 -0
- myrpunct/utils.py +34 -0
- requirements.txt +5 -0
- sample.srt +20 -0
app.py
CHANGED
@@ -1,7 +1,97 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
2 |
|
3 |
-
def
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from myrpunct import RestorePuncts
|
2 |
+
from youtube_transcript_api import YouTubeTranscriptApi
|
3 |
import gradio as gr
|
4 |
+
import re
|
5 |
|
6 |
+
def get_srt(input_link):
|
7 |
+
if "v=" in input_link:
|
8 |
+
video_id = input_link.split("v=")[1]
|
9 |
+
else:
|
10 |
+
return "Error: Invalid Link, it does not have the pattern 'v=' in it."
|
11 |
+
print("video_id: ",video_id)
|
12 |
+
transcript_raw = YouTubeTranscriptApi.get_transcript(video_id)
|
13 |
+
transcript_text= '\n'.join([i['text'] for i in transcript_raw])
|
14 |
+
return transcript_text
|
15 |
|
16 |
+
def predict(input_text, input_file, input_link, input_checkbox):
|
17 |
+
|
18 |
+
if input_checkbox=="File" and input_file is not None:
|
19 |
+
print("Input File ...")
|
20 |
+
with open(input_file.name) as file:
|
21 |
+
input_file_read = file.read()
|
22 |
+
return run_predict(input_file_read)
|
23 |
+
elif input_checkbox=="Text" and len(input_text) >0:
|
24 |
+
print("Input Text ...")
|
25 |
+
return run_predict(input_text)
|
26 |
+
elif input_checkbox=="Link" and len(input_link)>0:
|
27 |
+
print("Input Link ...", input_link)
|
28 |
+
input_link_text = get_srt(input_link)
|
29 |
+
if "Error" in input_link_text:
|
30 |
+
return input_link_text
|
31 |
+
else:
|
32 |
+
return run_predict(input_link_text)
|
33 |
+
else:
|
34 |
+
return "Error: Please provide either an input text or file and select an option accordingly."
|
35 |
+
|
36 |
+
def run_predict(input_text):
|
37 |
+
rpunct = RestorePuncts()
|
38 |
+
output_text = rpunct.punctuate(input_text)
|
39 |
+
print("Punctuation finished...")
|
40 |
+
|
41 |
+
# restore the carrige returns
|
42 |
+
srt_file = input_text
|
43 |
+
punctuated = output_text
|
44 |
+
|
45 |
+
srt_file_strip=srt_file.strip()
|
46 |
+
srt_file_sub=re.sub('\s*\n\s*','# ',srt_file_strip)
|
47 |
+
srt_file_array=srt_file_sub.split(' ')
|
48 |
+
pcnt_file_array=punctuated.split(' ')
|
49 |
+
|
50 |
+
# goal: restore the break points i.e. the same number of lines as the srt file
|
51 |
+
# this is necessary, because each line in the srt file corresponds to a frame from the video
|
52 |
+
if len(srt_file_array)!=len(pcnt_file_array):
|
53 |
+
return "AssertError: The length of the transcript and the punctuated file should be the same: ",len(srt_file_array),len(pcnt_file_array)
|
54 |
+
pcnt_file_array_hash = []
|
55 |
+
for idx, item in enumerate(srt_file_array):
|
56 |
+
if item.endswith('#'):
|
57 |
+
pcnt_file_array_hash.append(pcnt_file_array[idx]+'#')
|
58 |
+
else:
|
59 |
+
pcnt_file_array_hash.append(pcnt_file_array[idx])
|
60 |
+
|
61 |
+
# assemble the array back to a string
|
62 |
+
pcnt_file_cr=' '.join(pcnt_file_array_hash).replace('#','\n')
|
63 |
+
|
64 |
+
return pcnt_file_cr
|
65 |
+
|
66 |
+
if __name__ == "__main__":
|
67 |
+
|
68 |
+
title = "Rpunct Gradio App"
|
69 |
+
description = """
|
70 |
+
<b>Description</b>: <br>
|
71 |
+
Model restores punctuation and case i.e. of the following punctuations -- [! ? . , - : ; ' ] and also the upper-casing of words. <br>
|
72 |
+
<b>Usage</b>: <br>
|
73 |
+
There are three input types any text, a file that can be uploaded or a YouTube video. <br>
|
74 |
+
Because all three options can be provided by the user (that is you) at the same time <br>
|
75 |
+
the user has to decisde which input type has to be processed.
|
76 |
+
"""
|
77 |
+
article = "Model by [felflare](https://huggingface.co/felflare/bert-restore-punctuation)"
|
78 |
+
|
79 |
+
sample_link = "https://www.youtube.com/watch?v=6MI0f6YjJIk"
|
80 |
+
|
81 |
+
examples = [["my name is clara and i live in berkeley california", "sample.srt", sample_link, "Text"]]
|
82 |
+
|
83 |
+
interface = gr.Interface(fn = predict,
|
84 |
+
inputs = ["text", "file", "text", gr.Radio(["Text", "File", "Link"], type="value", label='Input Type')],
|
85 |
+
outputs = ["text"],
|
86 |
+
title = title,
|
87 |
+
description = description,
|
88 |
+
article = article,
|
89 |
+
examples=examples,
|
90 |
+
allow_flagging="never")
|
91 |
+
|
92 |
+
interface.launch()
|
93 |
+
|
94 |
+
# save flagging to a hf dataset
|
95 |
+
# https://github.com/gradio-app/gradio/issues/914
|
96 |
+
# the best option here is to use a Hugging Face dataset as the storage for flagged data. And to do that, please check out the HuggingFaceDatasetSaver() flagging handler, which allows you to do that easily.
|
97 |
+
#Here is an example Space that uses this: https://huggingface.co/spaces/abidlabs/crowd-speech
|
myrpunct/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .punctuate import RestorePuncts
|
2 |
+
print("init executed ...")
|
myrpunct/punctuate.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# 💾⚙️🔮
|
3 |
+
|
4 |
+
__author__ = "Daulet N."
|
5 |
+
__email__ = "daulet.nurmanbetov@gmail.com"
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from langdetect import detect
|
9 |
+
from simpletransformers.ner import NERModel, NERArgs
|
10 |
+
|
11 |
+
|
12 |
+
class RestorePuncts:
|
13 |
+
def __init__(self, wrds_per_pred=250, use_cuda=False):
|
14 |
+
self.wrds_per_pred = wrds_per_pred
|
15 |
+
self.overlap_wrds = 30
|
16 |
+
self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U']
|
17 |
+
self.model_hf = "wldmr/felflare-bert-restore-punctuation"
|
18 |
+
self.model_args = NERArgs()
|
19 |
+
self.model_args.silent = True
|
20 |
+
self.model_args.max_seq_length = 512
|
21 |
+
#self.model_args.use_multiprocessing = False
|
22 |
+
self.model = NERModel("bert", self.model_hf, labels=self.valid_labels, use_cuda=use_cuda, args=self.model_args)
|
23 |
+
#self.model = NERModel("bert", self.model_hf, labels=self.valid_labels, use_cuda=use_cuda, args={"silent": True, "max_seq_length": 512, "use_multiprocessing": False})
|
24 |
+
print("class init ...")
|
25 |
+
print("use_multiprocessing: ",self.model_args.use_multiprocessing)
|
26 |
+
|
27 |
+
def status(self):
|
28 |
+
print("function called")
|
29 |
+
|
30 |
+
def punctuate(self, text: str, lang:str=''):
|
31 |
+
"""
|
32 |
+
Performs punctuation restoration on arbitrarily large text.
|
33 |
+
Detects if input is not English, if non-English was detected terminates predictions.
|
34 |
+
Overrride by supplying `lang='en'`
|
35 |
+
|
36 |
+
Args:
|
37 |
+
- text (str): Text to punctuate, can be few words to as large as you want.
|
38 |
+
- lang (str): Explicit language of input text.
|
39 |
+
"""
|
40 |
+
if not lang and len(text) > 10:
|
41 |
+
lang = detect(text)
|
42 |
+
if lang != 'en':
|
43 |
+
raise Exception(F"""Non English text detected. Restore Punctuation works only for English.
|
44 |
+
If you are certain the input is English, pass argument lang='en' to this function.
|
45 |
+
Punctuate received: {text}""")
|
46 |
+
|
47 |
+
# plit up large text into bert digestable chunks
|
48 |
+
splits = self.split_on_toks(text, self.wrds_per_pred, self.overlap_wrds)
|
49 |
+
# predict slices
|
50 |
+
# full_preds_lst contains tuple of labels and logits
|
51 |
+
full_preds_lst = [self.predict(i['text']) for i in splits]
|
52 |
+
# extract predictions, and discard logits
|
53 |
+
preds_lst = [i[0][0] for i in full_preds_lst]
|
54 |
+
# join text slices
|
55 |
+
combined_preds = self.combine_results(text, preds_lst)
|
56 |
+
# create punctuated prediction
|
57 |
+
punct_text = self.punctuate_texts(combined_preds)
|
58 |
+
return punct_text
|
59 |
+
|
60 |
+
def predict(self, input_slice):
|
61 |
+
"""
|
62 |
+
Passes the unpunctuated text to the model for punctuation.
|
63 |
+
"""
|
64 |
+
predictions, raw_outputs = self.model.predict([input_slice])
|
65 |
+
return predictions, raw_outputs
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def split_on_toks(text, length, overlap):
|
69 |
+
"""
|
70 |
+
Splits text into predefined slices of overlapping text with indexes (offsets)
|
71 |
+
that tie-back to original text.
|
72 |
+
This is done to bypass 512 token limit on transformer models by sequentially
|
73 |
+
feeding chunks of < 512 toks.
|
74 |
+
Example output:
|
75 |
+
[{...}, {"text": "...", 'start_idx': 31354, 'end_idx': 32648}, {...}]
|
76 |
+
"""
|
77 |
+
wrds = text.replace('\n', ' ').split(" ")
|
78 |
+
resp = []
|
79 |
+
lst_chunk_idx = 0
|
80 |
+
i = 0
|
81 |
+
|
82 |
+
while True:
|
83 |
+
# words in the chunk and the overlapping portion
|
84 |
+
wrds_len = wrds[(length * i):(length * (i + 1))]
|
85 |
+
wrds_ovlp = wrds[(length * (i + 1)):((length * (i + 1)) + overlap)]
|
86 |
+
wrds_split = wrds_len + wrds_ovlp
|
87 |
+
|
88 |
+
# Break loop if no more words
|
89 |
+
if not wrds_split:
|
90 |
+
break
|
91 |
+
|
92 |
+
wrds_str = " ".join(wrds_split)
|
93 |
+
nxt_chunk_start_idx = len(" ".join(wrds_len))
|
94 |
+
lst_char_idx = len(" ".join(wrds_split))
|
95 |
+
|
96 |
+
resp_obj = {
|
97 |
+
"text": wrds_str,
|
98 |
+
"start_idx": lst_chunk_idx,
|
99 |
+
"end_idx": lst_char_idx + lst_chunk_idx,
|
100 |
+
}
|
101 |
+
|
102 |
+
resp.append(resp_obj)
|
103 |
+
lst_chunk_idx += nxt_chunk_start_idx + 1
|
104 |
+
i += 1
|
105 |
+
logging.info(f"Sliced transcript into {len(resp)} slices.")
|
106 |
+
return resp
|
107 |
+
|
108 |
+
@staticmethod
|
109 |
+
def combine_results(full_text: str, text_slices):
|
110 |
+
"""
|
111 |
+
Given a full text and predictions of each slice combines predictions into a single text again.
|
112 |
+
Performs validataion wether text was combined correctly
|
113 |
+
"""
|
114 |
+
split_full_text = full_text.replace('\n', ' ').split(" ")
|
115 |
+
split_full_text = [i for i in split_full_text if i]
|
116 |
+
split_full_text_len = len(split_full_text)
|
117 |
+
output_text = []
|
118 |
+
index = 0
|
119 |
+
|
120 |
+
if len(text_slices[-1]) <= 3 and len(text_slices) > 1:
|
121 |
+
text_slices = text_slices[:-1]
|
122 |
+
|
123 |
+
for _slice in text_slices:
|
124 |
+
slice_wrds = len(_slice)
|
125 |
+
for ix, wrd in enumerate(_slice):
|
126 |
+
# print(index, "|", str(list(wrd.keys())[0]), "|", split_full_text[index])
|
127 |
+
if index == split_full_text_len:
|
128 |
+
break
|
129 |
+
|
130 |
+
if split_full_text[index] == str(list(wrd.keys())[0]) and \
|
131 |
+
ix <= slice_wrds - 3 and text_slices[-1] != _slice:
|
132 |
+
index += 1
|
133 |
+
pred_item_tuple = list(wrd.items())[0]
|
134 |
+
output_text.append(pred_item_tuple)
|
135 |
+
elif split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == _slice:
|
136 |
+
index += 1
|
137 |
+
pred_item_tuple = list(wrd.items())[0]
|
138 |
+
output_text.append(pred_item_tuple)
|
139 |
+
assert [i[0] for i in output_text] == split_full_text
|
140 |
+
return output_text
|
141 |
+
|
142 |
+
@staticmethod
|
143 |
+
def punctuate_texts(full_pred: list):
|
144 |
+
"""
|
145 |
+
Given a list of Predictions from the model, applies the predictions to text,
|
146 |
+
thus punctuating it.
|
147 |
+
"""
|
148 |
+
punct_resp = ""
|
149 |
+
for i in full_pred:
|
150 |
+
word, label = i
|
151 |
+
if label[-1] == "U":
|
152 |
+
punct_wrd = word.capitalize()
|
153 |
+
else:
|
154 |
+
punct_wrd = word
|
155 |
+
|
156 |
+
if label[0] != "O":
|
157 |
+
punct_wrd += label[0]
|
158 |
+
|
159 |
+
punct_resp += punct_wrd + " "
|
160 |
+
punct_resp = punct_resp.strip()
|
161 |
+
# Append trailing period if doesnt exist.
|
162 |
+
if punct_resp[-1].isalnum():
|
163 |
+
punct_resp += "."
|
164 |
+
return punct_resp
|
165 |
+
|
166 |
+
|
167 |
+
if __name__ == "__main__":
|
168 |
+
punct_model = RestorePuncts()
|
169 |
+
# read test file
|
170 |
+
with open('../tests/sample_text.txt', 'r') as fp:
|
171 |
+
test_sample = fp.read()
|
172 |
+
# predict text and print
|
173 |
+
punctuated = punct_model.punctuate(test_sample)
|
174 |
+
print(punctuated)
|
myrpunct/utils.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# 💾⚙️🔮
|
3 |
+
|
4 |
+
__author__ = "Daulet N."
|
5 |
+
__email__ = "daulet.nurmanbetov@gmail.com"
|
6 |
+
|
7 |
+
def prepare_unpunct_text(text):
|
8 |
+
"""
|
9 |
+
Given a text, normalizes it to subsequently restore punctuation
|
10 |
+
"""
|
11 |
+
formatted_txt = text.replace('\n', '').strip()
|
12 |
+
formatted_txt = formatted_txt.lower()
|
13 |
+
formatted_txt_lst = formatted_txt.split(" ")
|
14 |
+
punct_strp_txt = [strip_punct(i) for i in formatted_txt_lst]
|
15 |
+
normalized_txt = " ".join([i for i in punct_strp_txt if i])
|
16 |
+
return normalized_txt
|
17 |
+
|
18 |
+
def strip_punct(wrd):
|
19 |
+
"""
|
20 |
+
Given a word, strips non aphanumeric characters that precede and follow it
|
21 |
+
"""
|
22 |
+
if not wrd:
|
23 |
+
return wrd
|
24 |
+
|
25 |
+
while not wrd[-1:].isalnum():
|
26 |
+
if not wrd:
|
27 |
+
break
|
28 |
+
wrd = wrd[:-1]
|
29 |
+
|
30 |
+
while not wrd[:1].isalnum():
|
31 |
+
if not wrd:
|
32 |
+
break
|
33 |
+
wrd = wrd[1:]
|
34 |
+
return wrd
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
torch
|
3 |
+
langdetect
|
4 |
+
simpletransformers
|
5 |
+
youtube_transcript_api
|
sample.srt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
in 2018 cornell researchers built a
|
2 |
+
high-powered detector that in combination
|
3 |
+
with an algorithm-driven process called
|
4 |
+
ptychography set a world record by tripling
|
5 |
+
the resolution of a state-of-the-art electron
|
6 |
+
microscope as successful as it was that approach
|
7 |
+
had a weakness it only worked with ultrathin
|
8 |
+
samples that were a few atoms thick anything
|
9 |
+
thicker would cause the electrons to scatter
|
10 |
+
in ways that could not be disentangled now a
|
11 |
+
team again led by
|
12 |
+
david muller
|
13 |
+
the samuel beckert professor of engineering
|
14 |
+
has bested its own
|
15 |
+
record by a factor of two with an electron
|
16 |
+
microscope pixel array detector empad that
|
17 |
+
incorporates even more sophisticated 3d
|
18 |
+
reconstruction algorithms the resolution is so
|
19 |
+
fine-tuned the only blurring that remains is
|
20 |
+
the thermal jiggling of the atoms themselves
|