from myrpunct import RestorePuncts
from youtube_transcript_api import YouTubeTranscriptApi
import gradio as gr
import re
def get_srt(input_link):
if "v=" in input_link:
video_id = input_link.split("v=")[1]
else:
return "Error: Invalid Link, it does not have the pattern 'v=' in it."
print("video_id: ",video_id)
transcript_raw = YouTubeTranscriptApi.get_transcript(video_id)
transcript_text= '\n'.join([i['text'] for i in transcript_raw])
return transcript_text
def predict(input_text, input_file, input_link, input_checkbox):
if input_checkbox=="File" and input_file is not None:
print("Input File ...")
with open(input_file.name) as file:
input_file_read = file.read()
return run_predict(input_file_read)
elif input_checkbox=="Text" and len(input_text) >0:
print("Input Text ...")
return run_predict(input_text)
elif input_checkbox=="Link" and len(input_link)>0:
print("Input Link ...", input_link)
input_link_text = get_srt(input_link)
if "Error" in input_link_text:
return input_link_text
else:
return run_predict(input_link_text)
else:
return "Error: Please provide either an input text or file and select an option accordingly."
def run_predict(input_text):
rpunct = RestorePuncts()
output_text = rpunct.punctuate(input_text)
print("Punctuation finished...")
# restore the carrige returns
srt_file = input_text
punctuated = output_text
srt_file_strip=srt_file.strip()
srt_file_sub=re.sub('\s*\n\s*','# ',srt_file_strip)
srt_file_array=srt_file_sub.split(' ')
pcnt_file_array=punctuated.split(' ')
# goal: restore the break points i.e. the same number of lines as the srt file
# this is necessary, because each line in the srt file corresponds to a frame from the video
if len(srt_file_array)!=len(pcnt_file_array):
return "AssertError: The length of the transcript and the punctuated file should be the same: ",len(srt_file_array),len(pcnt_file_array)
pcnt_file_array_hash = []
for idx, item in enumerate(srt_file_array):
if item.endswith('#'):
pcnt_file_array_hash.append(pcnt_file_array[idx]+'#')
else:
pcnt_file_array_hash.append(pcnt_file_array[idx])
# assemble the array back to a string
pcnt_file_cr=' '.join(pcnt_file_array_hash).replace('#','\n')
return pcnt_file_cr
if __name__ == "__main__":
title = "Rpunct Gradio App"
description = """
Description:
Model restores punctuation and case i.e. of the following punctuations -- [! ? . , - : ; ' ] and also the upper-casing of words.
Usage:
There are three input types any text, a file that can be uploaded or a YouTube video.
Because all three options can be provided by the user (that is you) at the same time
the user has to decisde which input type has to be processed.
"""
article = "Model by [felflare](https://huggingface.co/felflare/bert-restore-punctuation)"
sample_link = "https://www.youtube.com/watch?v=6MI0f6YjJIk"
examples = [["my name is clara and i live in berkeley california", "sample.srt", sample_link, "Text"]]
interface = gr.Interface(fn = predict,
inputs = ["text", "file", "text", gr.Radio(["Text", "File", "Link"], type="value", label='Input Type')],
outputs = ["text"],
title = title,
description = description,
article = article,
examples=examples,
allow_flagging="never")
interface.launch()
# save flagging to a hf dataset
# https://github.com/gradio-app/gradio/issues/914
# the best option here is to use a Hugging Face dataset as the storage for flagged data. And to do that, please check out the HuggingFaceDatasetSaver() flagging handler, which allows you to do that easily.
#Here is an example Space that uses this: https://huggingface.co/spaces/abidlabs/crowd-speech