from myrpunct import RestorePuncts from youtube_transcript_api import YouTubeTranscriptApi import gradio as gr import re def get_srt(input_link): if "v=" in input_link: video_id = input_link.split("v=")[1] else: return "Error: Invalid Link, it does not have the pattern 'v=' in it." print("video_id: ",video_id) transcript_raw = YouTubeTranscriptApi.get_transcript(video_id) transcript_text= '\n'.join([i['text'] for i in transcript_raw]) return transcript_text def predict(input_text, input_file, input_link, input_checkbox): if input_checkbox=="File" and input_file is not None: print("Input File ...") with open(input_file.name) as file: input_file_read = file.read() return run_predict(input_file_read) elif input_checkbox=="Text" and len(input_text) >0: print("Input Text ...") return run_predict(input_text) elif input_checkbox=="Link" and len(input_link)>0: print("Input Link ...", input_link) input_link_text = get_srt(input_link) if "Error" in input_link_text: return input_link_text else: return run_predict(input_link_text) else: return "Error: Please provide either an input text or file and select an option accordingly." def run_predict(input_text): rpunct = RestorePuncts() output_text = rpunct.punctuate(input_text) print("Punctuation finished...") # restore the carrige returns srt_file = input_text punctuated = output_text srt_file_strip=srt_file.strip() srt_file_sub=re.sub('\s*\n\s*','# ',srt_file_strip) srt_file_array=srt_file_sub.split(' ') pcnt_file_array=punctuated.split(' ') # goal: restore the break points i.e. the same number of lines as the srt file # this is necessary, because each line in the srt file corresponds to a frame from the video if len(srt_file_array)!=len(pcnt_file_array): return "AssertError: The length of the transcript and the punctuated file should be the same: ",len(srt_file_array),len(pcnt_file_array) pcnt_file_array_hash = [] for idx, item in enumerate(srt_file_array): if item.endswith('#'): pcnt_file_array_hash.append(pcnt_file_array[idx]+'#') else: pcnt_file_array_hash.append(pcnt_file_array[idx]) # assemble the array back to a string pcnt_file_cr=' '.join(pcnt_file_array_hash).replace('#','\n') return pcnt_file_cr if __name__ == "__main__": title = "Rpunct Gradio App" description = """ Description:
Model restores punctuation and case i.e. of the following punctuations -- [! ? . , - : ; ' ] and also the upper-casing of words.
Usage:
There are three input types any text, a file that can be uploaded or a YouTube video.
Because all three options can be provided by the user (that is you) at the same time
the user has to decisde which input type has to be processed. """ article = "Model by [felflare](https://huggingface.co/felflare/bert-restore-punctuation)" sample_link = "https://www.youtube.com/watch?v=6MI0f6YjJIk" examples = [["my name is clara and i live in berkeley california", "sample.srt", sample_link, "Text"]] interface = gr.Interface(fn = predict, inputs = ["text", "file", "text", gr.Radio(["Text", "File", "Link"], type="value", label='Input Type')], outputs = ["text"], title = title, description = description, article = article, examples=examples, allow_flagging="never") interface.launch() # save flagging to a hf dataset # https://github.com/gradio-app/gradio/issues/914 # the best option here is to use a Hugging Face dataset as the storage for flagged data. And to do that, please check out the HuggingFaceDatasetSaver() flagging handler, which allows you to do that easily. #Here is an example Space that uses this: https://huggingface.co/spaces/abidlabs/crowd-speech