#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse from collections import defaultdict from datetime import datetime import functools import logging import os from pathlib import Path import platform import time import tempfile # import uuid import shortuuid from project_settings import project_path, log_directory import log log.setup(log_directory=log_directory) import gradio as gr import torch import torchaudio from toolbox.k2_sherpa.examples import examples from toolbox.k2_sherpa import decode, nn_models from toolbox.k2_sherpa.utils import audio_convert main_logger = logging.getLogger("main") def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--pretrained_model_dir", default=(project_path / "pretrained_models").as_posix(), type=str ) args = parser.parse_args() return args def update_model_dropdown(language: str): if language not in nn_models.model_map.keys(): raise ValueError(f"Unsupported language: {language}") choices = nn_models.model_map[language] choices = [c["repo_id"] for c in choices] return gr.Dropdown( choices=choices, value=choices[0], interactive=True, ) def build_html_output(s: str, style: str = "result_item_success"): return f"""
{s}
""" @torch.no_grad() def process( language: str, repo_id: str, decoding_method: str, num_active_paths: int, add_punctuation: str, in_filename: str, pretrained_model_dir: Path, ): main_logger.info("language: {}".format(language)) main_logger.info("repo_id: {}".format(repo_id)) main_logger.info("decoding_method: {}".format(decoding_method)) main_logger.info("num_active_paths: {}".format(num_active_paths)) main_logger.info("in_filename: {}".format(in_filename)) # audio convert in_filename = Path(in_filename) out_filename = Path(tempfile.gettempdir()) / "asr" / in_filename.name out_filename.parent.mkdir(parents=True, exist_ok=True) audio_convert(in_filename=in_filename.as_posix(), out_filename=out_filename.as_posix(), ) # model settings m_list = nn_models.model_map.get(language) if m_list is None: raise AssertionError("language invalid: {}".format(language)) m_dict = None for m in m_list: if m["repo_id"] == repo_id: m_dict = m if m_dict is None: raise AssertionError("repo_id invalid: {}".format(repo_id)) # local_model_dir repo_id: Path = Path(repo_id) if len(repo_id.parts) == 1: repo_name = repo_id.parts[-1] if len(repo_name) > 40: # repo_name = str(uuid.uuid4()) repo_name = str(shortuuid.uuid()) # repo_name = repo_name[:40] folder = repo_name elif len(repo_id.parts) == 2: repo_supplier = repo_id.parts[-2] repo_name = repo_id.parts[-1] if len(repo_name) > 40: # repo_name = str(uuid.uuid4()) repo_name = str(shortuuid.uuid()) # repo_name = repo_name[:40] folder = "{}/{}".format(repo_supplier, repo_name) else: raise AssertionError("repo_id parts count invalid: {}".format(len(repo_id.parts))) local_model_dir = pretrained_model_dir / "huggingface" / folder # load recognizer recognizer = nn_models.load_recognizer( local_model_dir=local_model_dir, decoding_method=decoding_method, num_active_paths=num_active_paths, **m_dict ) # transcribe now = datetime.now() date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f") logging.info(f"Started at {date_time}") start = time.time() text = decode.decode_by_recognizer(recognizer=recognizer, filename=out_filename.as_posix(), ) date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f") end = time.time() # statistics metadata = torchaudio.info(out_filename.as_posix()) duration = metadata.num_frames / 16000 rtf = (end - start) / duration main_logger.info(f"Finished at {date_time} s. Elapsed: {end - start: .3f} s") info = f""" Wave duration : {duration: .3f} s
Processing time: {end - start: .3f} s
RTF: {end - start: .3f}/{duration: .3f} = {rtf:.3f}
""" main_logger.info(info) main_logger.info(f"\nrepo_id: {repo_id}\nhyp: {text}") return text, build_html_output(info) def process_uploaded_file(language: str, repo_id: str, decoding_method: str, num_active_paths: int, add_punctuation: str, in_filename: str, pretrained_model_dir: Path, ): if in_filename is None or in_filename == "": return "", build_html_output( "Please first upload a file and then click " 'the button "submit for recognition"', "result_item_error", ) main_logger.info(f"Processing uploaded file: {in_filename}") try: return process( in_filename=in_filename, language=language, repo_id=repo_id, decoding_method=decoding_method, num_active_paths=num_active_paths, add_punctuation=add_punctuation, pretrained_model_dir=pretrained_model_dir, ) except Exception as e: msg = "transcribe error: {}".format(str(e)) main_logger.info(msg) return "", build_html_output(msg, "result_item_error") # css style is copied from # https://huggingface.co/spaces/alphacep/asr/blob/main/app.py#L113 css = """ .result {display:flex;flex-direction:column} .result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%} .result_item_success {background-color:mediumaquamarine;color:white;align-self:start} .result_item_error {background-color:#ff7070;color:white;align-self:start} """ def main(): args = get_args() pretrained_model_dir = Path(args.pretrained_model_dir) pretrained_model_dir.mkdir(exist_ok=True) process_uploaded_file_ = functools.partial( process_uploaded_file, pretrained_model_dir=pretrained_model_dir, ) title = "# Automatic Speech Recognition with Next-gen Kaldi" language_choices = list(nn_models.model_map.keys()) language_to_models = defaultdict(list) for k, v in nn_models.model_map.items(): for m in v: repo_id = m["repo_id"] language_to_models[k].append(repo_id) # blocks with gr.Blocks(css=css) as blocks: gr.Markdown(value=title) with gr.Tabs(): with gr.TabItem("Upload from disk"): language_radio = gr.Radio( label="Language", choices=language_choices, value=language_choices[0], ) model_dropdown = gr.Dropdown( choices=language_to_models[language_choices[0]], label="Select a model", value=language_to_models[language_choices[0]][0], ) decoding_method_radio = gr.Radio( label="Decoding method", choices=["greedy_search", "modified_beam_search"], value="greedy_search", ) num_active_paths_slider = gr.Slider( minimum=1, value=4, step=1, label="Number of active paths for modified_beam_search", ) punct_radio = gr.Radio( label="Whether to add punctuation (Only for Chinese and English)", choices=["Yes", "No"], value="Yes", ) uploaded_file = gr.Audio( sources=["upload"], type="filepath", label="Upload from disk", ) upload_button = gr.Button("Submit for recognition") uploaded_output = gr.Textbox(label="Recognized speech from uploaded file") uploaded_html_info = gr.HTML(label="Info") gr.Examples( examples=examples, inputs=[ language_radio, model_dropdown, decoding_method_radio, num_active_paths_slider, punct_radio, uploaded_file, ], outputs=[uploaded_output, uploaded_html_info], fn=process_uploaded_file_, ) upload_button.click( process_uploaded_file_, inputs=[ language_radio, model_dropdown, decoding_method_radio, num_active_paths_slider, punct_radio, uploaded_file, ], outputs=[uploaded_output, uploaded_html_info], ) language_radio.change( update_model_dropdown, inputs=language_radio, outputs=model_dropdown, ) blocks.queue().launch( share=False if platform.system() == "Windows" else False, server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", server_port=7860 ) return if __name__ == "__main__": main()