import os import json import shutil import uuid import tempfile import subprocess import re import time import traceback import gradio as gr import pytube as pt import nemo.collections.asr as nemo_asr import torch import speech_to_text_buffered_infer_ctc as buffered_ctc import speech_to_text_buffered_infer_rnnt as buffered_rnnt from nemo.utils import logging # Set NeMo cache dir as /tmp from nemo import constants os.environ[constants.NEMO_ENV_CACHE_DIR] = "/tmp/nemo/" SAMPLE_RATE = 16000 # Default sample rate for ASR BUFFERED_INFERENCE_DURATION_THRESHOLD = 60.0 # 60 second and above will require chunked inference. CHUNK_LEN_IN_SEC = 20.0 # Chunk size BUFFER_LEN_IN_SEC = 30.0 # Total buffer size TITLE = "NeMo ASR Inference on Hugging Face" DESCRIPTION = "Demo of all languages supported by NeMo ASR" DEFAULT_EN_MODEL = "nvidia/stt_en_conformer_transducer_xlarge" DEFAULT_BUFFERED_EN_MODEL = "nvidia/stt_en_conformer_transducer_large" # Pre-download and cache the model in disk space logging.setLevel(logging.ERROR) tmp_model = nemo_asr.models.ASRModel.from_pretrained(DEFAULT_BUFFERED_EN_MODEL, map_location='cpu') del tmp_model logging.setLevel(logging.INFO) MARKDOWN = f""" # {TITLE} ## {DESCRIPTION} """ CSS = """ p.big { font-size: 20px; } /* From https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition/blob/main/app.py */ .result {display:flex;flex-direction:column} .result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%;font-size:20px;} .result_item_success {background-color:mediumaquamarine;color:white;align-self:start} .result_item_error {background-color:#ff7070;color:white;align-self:start} """ ARTICLE = """

NeMo ASR | Github Repo

""" SUPPORTED_LANGUAGES = set([]) SUPPORTED_MODEL_NAMES = set([]) # HF models, grouped by language identifier hf_filter = nemo_asr.models.ASRModel.get_hf_model_filter() hf_filter.task = "automatic-speech-recognition" hf_infos = nemo_asr.models.ASRModel.search_huggingface_models(model_filter=hf_filter) for info in hf_infos: lang_id = info.modelId.split("_")[1] # obtains lang id as str SUPPORTED_LANGUAGES.add(lang_id) SUPPORTED_MODEL_NAMES.add(info.modelId) SUPPORTED_MODEL_NAMES = sorted(list(SUPPORTED_MODEL_NAMES)) # DEBUG FILTER # SUPPORTED_MODEL_NAMES = list(filter(lambda x: "en" in x and "conformer_transducer_large" in x, SUPPORTED_MODEL_NAMES)) model_dict = {model_name: gr.Interface.load(f'models/{model_name}') for model_name in SUPPORTED_MODEL_NAMES} SUPPORTED_LANG_MODEL_DICT = {} for lang in SUPPORTED_LANGUAGES: for model_id in SUPPORTED_MODEL_NAMES: if ("_" + lang + "_") in model_id: # create new lang in dict if lang not in SUPPORTED_LANG_MODEL_DICT: SUPPORTED_LANG_MODEL_DICT[lang] = [model_id] else: SUPPORTED_LANG_MODEL_DICT[lang].append(model_id) # Sort model names for lang in SUPPORTED_LANG_MODEL_DICT.keys(): model_ids = SUPPORTED_LANG_MODEL_DICT[lang] model_ids = sorted(model_ids) SUPPORTED_LANG_MODEL_DICT[lang] = model_ids def get_device(): gpu_available = torch.cuda.is_available() if gpu_available: return torch.cuda.get_device_name() else: return "CPU" def parse_duration(audio_file): """ FFMPEG to calculate durations. Libraries can do it too, but filetypes cause different libraries to behave differently. """ process = subprocess.Popen(['ffmpeg', '-i', audio_file], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) stdout, stderr = process.communicate() matches = re.search( r"Duration:\s{1}(?P\d+?):(?P\d+?):(?P\d+\.\d+?),", stdout.decode(), re.DOTALL ).groupdict() duration = 0.0 duration += float(matches['hours']) * 60.0 * 60.0 duration += float(matches['minutes']) * 60.0 duration += float(matches['seconds']) * 1.0 return duration def resolve_model_type(model_name: str) -> str: """ Map model name to a class type, without loading the model. Has some hardcoded assumptions in semantics of model naming. """ # Loss specific maps if 'hybrid' in model_name or 'hybrid_ctc' in model_name or 'hybrid_transducer' in model_name: return 'hybrid' elif 'transducer' in model_name or 'rnnt' in model_id: return 'transducer' elif 'ctc' in model_name: return 'ctc' # Model specific maps if 'jasper' in model_name: return 'ctc' elif 'quartznet' in model_name: return 'ctc' elif 'citrinet' in model_name: return 'ctc' elif 'contextnet' in model_name: return 'transducer' return None def resolve_model_stride(model_name) -> int: """ Model specific pre-calc of stride levels. Dont laod model to get such info. """ if 'jasper' in model_name: return 2 if 'quartznet' in model_name: return 2 if 'conformer' in model_name: return 4 if 'squeezeformer' in model_name: return 4 if 'citrinet' in model_name: return 8 if 'contextnet' in model_name: return 8 return -1 def convert_audio(audio_filepath): """ Transcode all mp3 files to monochannel 16 kHz wav files. """ filedir = os.path.split(audio_filepath)[0] filename, ext = os.path.splitext(audio_filepath) if ext == 'wav': return audio_filepath out_filename = os.path.join(filedir, filename + '.wav') process = subprocess.Popen( ['ffmpeg', '-y', '-i', audio_filepath, '-ac', '1', '-ar', str(SAMPLE_RATE), out_filename], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True, ) stdout, stderr = process.communicate() if os.path.exists(out_filename): return out_filename else: return None def extract_result_from_manifest(filepath, model_name) -> (bool, str): """ Parse the written manifest which is result of the buffered inference process. """ data = [] with open(filepath, 'r', encoding='utf-8') as f: for line in f: try: line = json.loads(line) data.append(line['pred_text']) except Exception as e: pass if len(data) > 0: return True, data[0] else: return False, f"Could not perform inference on model with name : {model_name}" def build_html_output(s: str, style: str = "result_item_success"): return f"""
{s}
""" def infer_audio(model_name: str, audio_file: str) -> str: """ Main method that switches from HF inference for small audio files to Buffered CTC/RNNT mode for long audio files. Args: model_name: Str name of the model (potentially with / to denote HF models) audio_file: Path to an audio file (mp3 or wav) Returns: str which is the transcription if successful. str which is HTML output of logs. """ # Parse the duration of the audio file duration = parse_duration(audio_file) if duration > BUFFERED_INFERENCE_DURATION_THRESHOLD: # Longer than one minute; use buffered mode # Process audio to be of wav type (possible youtube audio) audio_file = convert_audio(audio_file) # If audio file transcoding failed, let user know if audio_file is None: return "Error:- Failed to convert audio file to wav." # Extract audio dir from resolved audio filepath audio_dir = os.path.split(audio_file)[0] # Next calculate the stride of each model model_stride = resolve_model_stride(model_name) if model_stride < 0: return f"Error:- Failed to compute the model stride for model with name : {model_name}" # Process model type (CTC/RNNT/Hybrid) model_type = resolve_model_type(model_name) if model_type is None: # Model type could not be infered. # Try all feasible options RESULT = None try: ctc_config = buffered_ctc.TranscriptionConfig( pretrained_name=model_name, audio_dir=audio_dir, output_filename="output.json", audio_type="wav", overwrite_transcripts=True, model_stride=model_stride, chunk_len_in_secs=20.0, total_buffer_in_secs=30.0, ) buffered_ctc.main(ctc_config) result = extract_result_from_manifest('output.json', model_name) if result[0]: RESULT = result[1] except Exception as e: pass try: rnnt_config = buffered_rnnt.TranscriptionConfig( pretrained_name=model_name, audio_dir=audio_dir, output_filename="output.json", audio_type="wav", overwrite_transcripts=True, model_stride=model_stride, chunk_len_in_secs=20.0, total_buffer_in_secs=30.0, ) buffered_rnnt.main(rnnt_config) result = extract_result_from_manifest('output.json', model_name)[-1] if result[0]: RESULT = result[1] except Exception as e: pass if RESULT is None: return f"Error:- Could not parse model type; failed to perform inference with model {model_name}!" elif model_type == 'ctc': # CTC Buffered Inference ctc_config = buffered_ctc.TranscriptionConfig( pretrained_name=model_name, audio_dir=audio_dir, output_filename="output.json", audio_type="wav", overwrite_transcripts=True, model_stride=model_stride, chunk_len_in_secs=20.0, total_buffer_in_secs=30.0, ) buffered_ctc.main(ctc_config) return extract_result_from_manifest('output.json', model_name)[-1] elif model_type == 'transducer': # RNNT Buffered Inference rnnt_config = buffered_rnnt.TranscriptionConfig( pretrained_name=model_name, audio_dir=audio_dir, output_filename="output.json", audio_type="wav", overwrite_transcripts=True, model_stride=model_stride, chunk_len_in_secs=20.0, total_buffer_in_secs=30.0, ) buffered_rnnt.main(rnnt_config) return extract_result_from_manifest('output.json', model_name)[-1] else: return f"Error:- Could not parse model type; failed to perform inference with model {model_name}!" else: # Obtain Gradio Model function from cache of models if model_name in model_dict: model = model_dict[model_name] else: model = None if model is not None: # Use HF API for transcription transcriptions = model(audio_file) return transcriptions else: error = ( f"Error:- Could not find model {model_name} in list of available models : " f"{list([k for k in model_dict.keys()])}" ) return error def transcribe(microphone, audio_file, model_name): audio_data = None warn_output = "" if (microphone is not None) and (audio_file is not None): warn_output = ( "WARNING: You've uploaded an audio file and used the microphone. " "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n" ) audio_data = microphone elif (microphone is None) and (audio_file is None): warn_output = "ERROR: You have to either use the microphone or upload an audio file" elif microphone is not None: audio_data = microphone else: audio_data = audio_file if audio_data is not None: audio_duration = parse_duration(audio_data) else: audio_duration = None time_diff = None try: with tempfile.TemporaryDirectory() as tempdir: filename = os.path.split(audio_data)[-1] new_audio_data = os.path.join(tempdir, filename) shutil.copy2(audio_data, new_audio_data) if os.path.exists(audio_data): os.remove(audio_data) audio_data = new_audio_data # Use HF API for transcription start = time.time() transcriptions = infer_audio(model_name, audio_data) end = time.time() time_diff = end - start except Exception as e: print(traceback.print_exc()) transcriptions = "" warn_output = warn_output if warn_output != "": warn_output += "

" warn_output += ( f"The model `{model_name}` is currently loading and cannot be used " f"for transcription.
" f"Please try another model or wait a few minutes." ) # Built HTML output if warn_output != "": html_output = build_html_output(warn_output, style="result_item_error") else: if transcriptions.startswith("Error:-"): html_output = build_html_output(transcriptions, style="result_item_error") else: output = f"Successfully transcribed on {get_device()} !
" f"Transcription Time : {time_diff: 0.3f} s" if audio_duration > BUFFERED_INFERENCE_DURATION_THRESHOLD: output += f"""

Note: Audio duration was {audio_duration: 0.3f} s, so model had to be downloaded, initialized, and then buffered inference was used.
""" html_output = build_html_output(output) return transcriptions, html_output def _return_yt_html_embed(yt_url): """ Obtained from https://huggingface.co/spaces/whisper-event/whisper-demo """ video_id = yt_url.split("?v=")[-1] HTML_str = ( f'
' "
" ) return HTML_str def yt_transcribe(yt_url, model_name): """ Modified from https://huggingface.co/spaces/whisper-event/whisper-demo """ if yt_url == "": text = "" html_embed_str = "" html_output = build_html_output(f""" Error:- No YouTube URL was provide ! """, style='result_item_error') return text, html_embed_str, html_output yt = pt.YouTube(yt_url) html_embed_str = _return_yt_html_embed(yt_url) with tempfile.TemporaryDirectory() as tempdir: file_uuid = str(uuid.uuid4().hex) file_uuid = f"{tempdir}/{file_uuid}.mp3" # Download YT Audio temporarily download_time_start = time.time() stream = yt.streams.filter(only_audio=True)[0] stream.download(filename=file_uuid) download_time_end = time.time() # Get audio duration audio_duration = parse_duration(file_uuid) # Perform transcription infer_time_start = time.time() text = infer_audio(model_name, file_uuid) infer_time_end = time.time() if text.startswith("Error:-"): html_output = build_html_output(text, style='result_item_error') else: html_output = f""" Successfully transcribed on {get_device()} !
Audio Download Time : {download_time_end - download_time_start: 0.3f} s
Transcription Time : {infer_time_end - infer_time_start: 0.3f} s
""" if audio_duration > BUFFERED_INFERENCE_DURATION_THRESHOLD: html_output += f"""
Note: Audio duration was {audio_duration: 0.3f} s, so model had to be downloaded, initialized, and then buffered inference was used.
""" html_output = build_html_output(html_output) return text, html_embed_str, html_output def create_lang_selector_component(default_en_model=DEFAULT_EN_MODEL): """ Utility function to select a langauge from a dropdown menu, and simultanously update another dropdown containing the corresponding model checkpoints for that language. Args: default_en_model: str name of a default english model that should be the set default. Returns: Gradio components for lang_selector (Dropdown menu) and models_in_lang (Dropdown menu) """ lang_selector = gr.components.Dropdown( choices=sorted(list(SUPPORTED_LANGUAGES)), value="en", type="value", label="Languages", interactive=True, ) models_in_lang = gr.components.Dropdown( choices=sorted(list(SUPPORTED_LANG_MODEL_DICT["en"])), value=default_en_model, label="Models", interactive=True, ) def update_models_with_lang(lang): models_names = sorted(list(SUPPORTED_LANG_MODEL_DICT[lang])) default = models_names[0] if lang == 'en': default = default_en_model return models_in_lang.update(choices=models_names, value=default) lang_selector.change(update_models_with_lang, inputs=[lang_selector], outputs=[models_in_lang]) return lang_selector, models_in_lang """ Define the GUI """ demo = gr.Blocks(title=TITLE, css=CSS) with demo: header = gr.Markdown(MARKDOWN) with gr.Tab("Transcribe Audio"): with gr.Row() as row: file_upload = gr.components.Audio(source="upload", type='filepath', label='Upload File') microphone = gr.components.Audio(source="microphone", type='filepath', label='Microphone') lang_selector, models_in_lang = create_lang_selector_component() run = gr.components.Button('Transcribe') transcript = gr.components.Label(label='Transcript') audio_html_output = gr.components.HTML() run.click( transcribe, inputs=[microphone, file_upload, models_in_lang], outputs=[transcript, audio_html_output] ) with gr.Tab("Transcribe Youtube"): yt_url = gr.components.Textbox( lines=1, label="Youtube URL", placeholder="Paste the URL to a YouTube video here" ) lang_selector_yt, models_in_lang_yt = create_lang_selector_component( default_en_model=DEFAULT_BUFFERED_EN_MODEL ) with gr.Row(): run = gr.components.Button('Transcribe YouTube') embedded_video = gr.components.HTML() transcript = gr.components.Label(label='Transcript') yt_html_output = gr.components.HTML() run.click( yt_transcribe, inputs=[yt_url, models_in_lang_yt], outputs=[transcript, embedded_video, yt_html_output] ) gr.components.HTML(ARTICLE) demo.queue(concurrency_count=1) demo.launch(enable_queue=True)