#! /usr/bin/env python # coding=utf-8 # Copyright 2022 Bofeng Huang import datetime import logging import os import re import warnings import gradio as gr import pandas as pd import psutil import pytube as pt import torch import whisper from huggingface_hub import hf_hub_download, model_info from transformers.utils.logging import disable_progress_bar import nltk nltk.download("punkt") from nltk.tokenize import sent_tokenize warnings.filterwarnings("ignore") disable_progress_bar() DEFAULT_MODEL_NAME = "bofenghuang/whisper-large-v2-cv11-german" CHECKPOINT_FILENAME = "checkpoint_openai.pt" GEN_KWARGS = { "task": "transcribe", "language": "de", # "without_timestamps": True, # decode options # "beam_size": 5, # "patience": 2, # disable fallback # "compression_ratio_threshold": None, # "logprob_threshold": None, # vad threshold # "no_speech_threshold": None, } logging.basicConfig( format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s", datefmt="%Y-%m-%dT%H:%M:%SZ", ) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) # device = 0 if torch.cuda.is_available() else "cpu" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") logger.info(f"Model will be loaded on device `{device}`") cached_models = {} def format_timestamp(seconds): return str(datetime.timedelta(seconds=round(seconds))) def _return_yt_html_embed(yt_url): video_id = yt_url.split("?v=")[-1] HTML_str = ( f'
' "
" ) return HTML_str def download_audio_from_youtube(yt_url, downloaded_filename="audio.wav"): yt = pt.YouTube(yt_url) stream = yt.streams.filter(only_audio=True)[0] # stream.download(filename="audio.mp3") stream.download(filename=downloaded_filename) return downloaded_filename def download_video_from_youtube(yt_url, downloaded_filename="video.mp4"): yt = pt.YouTube(yt_url) stream = yt.streams.filter(progressive=True, file_extension="mp4").order_by("resolution").desc().first() stream.download(filename=downloaded_filename) logger.info(f"Download YouTube video from {yt_url}") return downloaded_filename def _print_memory_info(): memory = psutil.virtual_memory() logger.info( f"Memory info - Free: {memory.available / (1024 ** 3):.2f} Gb, used: {memory.percent}%, total: {memory.total / (1024 ** 3):.2f} Gb" ) def _print_cuda_memory_info(): used_mem, tot_mem = torch.cuda.mem_get_info() logger.info( f"CUDA memory info - Free: {used_mem / 1024 ** 3:.2f} Gb, used: {(tot_mem - used_mem) / 1024 ** 3:.2f} Gb, total: {tot_mem / 1024 ** 3:.2f} Gb" ) def print_memory_info(): _print_memory_info() _print_cuda_memory_info() def maybe_load_cached_pipeline(model_name): model = cached_models.get(model_name) if model is None: downloaded_model_path = hf_hub_download(repo_id=model_name, filename=CHECKPOINT_FILENAME) model = whisper.load_model(downloaded_model_path, device=device) logger.info(f"`{model_name}` has been loaded on device `{device}`") print_memory_info() cached_models[model_name] = model return model def infer(model, filename, with_timestamps, return_df=False): if with_timestamps: model_outputs = model.transcribe(filename, **GEN_KWARGS) if return_df: model_outputs_df = pd.DataFrame(model_outputs["segments"]) # print(model_outputs) # print(model_outputs_df) # print(model_outputs_df.info(verbose=True)) model_outputs_df = model_outputs_df[["start", "end", "text"]] model_outputs_df["start"] = model_outputs_df["start"].map(format_timestamp) model_outputs_df["end"] = model_outputs_df["end"].map(format_timestamp) model_outputs_df["text"] = model_outputs_df["text"].str.strip() return model_outputs_df else: return "\n\n".join( [ f'Segment {segment["id"]+1} from {segment["start"]:.2f}s to {segment["end"]:.2f}s:\n{segment["text"].strip()}' for segment in model_outputs["segments"] ] ) else: text = model.transcribe(filename, without_timestamps=True, **GEN_KWARGS)["text"] if return_df: return pd.DataFrame({"text": sent_tokenize(text)}) else: return text def transcribe(microphone, file_upload, with_timestamps, model_name=DEFAULT_MODEL_NAME): warn_output = "" if (microphone is not None) and (file_upload is not None): warn_output = ( "WARNING: You've uploaded an audio file and used the microphone. " "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n" ) elif (microphone is None) and (file_upload is None): return "ERROR: You have to either use the microphone or upload an audio file" file = microphone if microphone is not None else file_upload model = maybe_load_cached_pipeline(model_name) # text = model.transcribe(file, **GEN_KWARGS)["text"] # text = infer(model, file, with_timestamps) text = infer(model, file, with_timestamps, return_df=True) logger.info(f'Transcription by `{model_name}`:\n{text.to_json(orient="index", force_ascii=False, indent=2)}\n') # return warn_output + text return text def yt_transcribe(yt_url, with_timestamps, model_name=DEFAULT_MODEL_NAME): # html_embed_str = _return_yt_html_embed(yt_url) audio_file_path = download_audio_from_youtube(yt_url) model = maybe_load_cached_pipeline(model_name) # text = model.transcribe("audio.mp3", **GEN_KWARGS)["text"] # text = infer(model, audio_file_path, with_timestamps) text = infer(model, audio_file_path, with_timestamps, return_df=True) logger.info(f'Transcription by `{model_name}` of "{yt_url}":\n{text.to_json(orient="index", force_ascii=False, indent=2)}\n') # return html_embed_str, text return text def video_transcribe(video_file_path, with_timestamps, model_name=DEFAULT_MODEL_NAME): if video_file_path is None: raise ValueError("Failed to transcribe video as no video_file_path has been defined") audio_file_path = re.sub(r"\.mp4$", ".wav", video_file_path) os.system(f'ffmpeg -i "{video_file_path}" -ar 16000 -ac 1 -c:a pcm_s16le "{audio_file_path}"') model = maybe_load_cached_pipeline(model_name) # text = model.transcribe("audio.mp3", **GEN_KWARGS)["text"] text = infer(model, audio_file_path, with_timestamps, return_df=True) logger.info(f'Transcription by `{model_name}`:\n{text.to_json(orient="index", force_ascii=False, indent=2)}\n') return text # load default model maybe_load_cached_pipeline(DEFAULT_MODEL_NAME) # default_text_output_df = pd.DataFrame(columns=["start", "end", "text"]) default_text_output_df = pd.DataFrame(columns=["text"]) with gr.Blocks() as demo: with gr.Tab("Transcribe Audio"): gr.Markdown( f"""

Whisper German Demo 🇩🇪 : Transcribe Audio

Transcribe long-form microphone or audio inputs! Demo uses the fine-tuned checkpoint: {DEFAULT_MODEL_NAME} to transcribe audio files of arbitrary length. """ ) microphone_input = gr.inputs.Audio(source="microphone", type="filepath", label="Record", optional=True) upload_input = gr.inputs.Audio(source="upload", type="filepath", label="Upload File", optional=True) with_timestamps_input = gr.Checkbox(label="With timestamps?") microphone_transcribe_btn = gr.Button("Transcribe Audio") # gr.Markdown(''' # Here you will get generated transcrit. # ''') # microphone_text_output = gr.outputs.Textbox(label="Transcription") text_output_df2 = gr.DataFrame( value=default_text_output_df, label="Transcription", row_count=(0, "dynamic"), max_rows=10, wrap=True, overflow_row_behaviour="paginate", ) microphone_transcribe_btn.click( transcribe, inputs=[microphone_input, upload_input, with_timestamps_input], outputs=text_output_df2 ) # with gr.Tab("Transcribe YouTube"): # gr.Markdown( # f""" #
#

Whisper German Demo 🇩🇪 : Transcribe YouTube

#
# Transcribe long-form YouTube videos! # Demo uses the fine-tuned checkpoint: {DEFAULT_MODEL_NAME} to transcribe video files of arbitrary length. # """ # ) # yt_link_input2 = gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL") # with_timestamps_input2 = gr.Checkbox(label="With timestamps?", value=True) # yt_transcribe_btn = gr.Button("Transcribe YouTube") # # yt_text_output = gr.outputs.Textbox(label="Transcription") # text_output_df3 = gr.DataFrame( # value=default_text_output_df, # label="Transcription", # row_count=(0, "dynamic"), # max_rows=10, # wrap=True, # overflow_row_behaviour="paginate", # ) # # yt_html_output = gr.outputs.HTML(label="YouTube Page") # yt_transcribe_btn.click(yt_transcribe, inputs=[yt_link_input2, with_timestamps_input2], outputs=[text_output_df3]) with gr.Tab("Transcribe Video"): gr.Markdown( f"""

Whisper German Demo 🇩🇪 : Transcribe Video

Transcribe long-form YouTube videos or uploaded video inputs! Demo uses the fine-tuned checkpoint: {DEFAULT_MODEL_NAME} to transcribe video files of arbitrary length. """ ) yt_link_input = gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL") download_youtube_btn = gr.Button("Download Youtube video") downloaded_video_output = gr.Video(label="Video file", mirror_webcam=False) download_youtube_btn.click(download_video_from_youtube, inputs=[yt_link_input], outputs=[downloaded_video_output]) with_timestamps_input3 = gr.Checkbox(label="With timestamps?", value=True) video_transcribe_btn = gr.Button("Transcribe video") text_output_df = gr.DataFrame( value=default_text_output_df, label="Transcription", row_count=(0, "dynamic"), max_rows=10, wrap=True, overflow_row_behaviour="paginate", ) video_transcribe_btn.click(video_transcribe, inputs=[downloaded_video_output, with_timestamps_input3], outputs=[text_output_df]) # demo.launch(server_name="0.0.0.0", debug=True) # demo.launch(server_name="0.0.0.0", debug=True, share=True) demo.launch(enable_queue=True)