asr / toolbox /k2_sherpa /nn_models.py
HoneyTian's picture
update
3e60665
raw
history blame
3.61 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from enum import Enum
from functools import lru_cache
import os
import huggingface_hub
import sherpa
class EnumDecodingMethod(Enum):
greedy_search = "greedy_search"
modified_beam_search = "modified_beam_search"
class EnumRecognizerType(Enum):
sherpa_offline_recognizer = "sherpa.OfflineRecognizer"
sherpa_online_recognizer = "sherpa.OnlineRecognizer"
sherpa_onnx_offline_recognizer = "sherpa_onnx.OfflineRecognizer"
sherpa_onnx_online_recognizer = "sherpa_onnx.OnlineRecognizer"
model_map = {
"Chinese": [
{
"repo_id": "csukuangfj/wenet-chinese-model",
"nn_model_file": "final.zip",
"tokens_file": "units.txt",
"sub_folder": ".",
"recognizer_type": EnumRecognizerType.sherpa_offline_recognizer.value,
}
]
}
def download_model(repo_id: str,
nn_model_file: str,
tokens_file: str,
sub_folder: str,
local_model_dir: str,
):
nn_model_file = huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename=nn_model_file,
subfolder=sub_folder,
local_dir=local_model_dir,
)
tokens_file = huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename=tokens_file,
subfolder=sub_folder,
local_dir=local_model_dir,
)
return nn_model_file, tokens_file
def load_sherpa_offline_recognizer(nn_model_file: str,
tokens_file: str,
sample_rate: int = 16000,
num_active_paths: int = 2,
decoding_method: str = "greedy_search",
num_mel_bins: int = 80,
frame_dither: int = 0,
):
feat_config = sherpa.FeatureConfig(normalize_samples=False)
feat_config.fbank_opts.frame_opts.samp_freq = sample_rate
feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins
feat_config.fbank_opts.frame_opts.dither = frame_dither
config = sherpa.OfflineRecognizerConfig(
nn_model=nn_model_file,
tokens=tokens_file,
use_gpu=False,
feat_config=feat_config,
decoding_method=decoding_method,
num_active_paths=num_active_paths,
)
recognizer = sherpa.OfflineRecognizer(config)
return recognizer
def load_recognizer(repo_id: str,
nn_model_file: str,
tokens_file: str,
sub_folder: str,
local_model_dir: str,
recognizer_type: str,
decoding_method: str = "greedy_search",
num_active_paths: int = 4,
):
if not os.path.exists(local_model_dir):
download_model(
repo_id=repo_id,
nn_model_file=nn_model_file,
tokens_file=tokens_file,
sub_folder=sub_folder,
local_model_dir=local_model_dir,
)
if recognizer_type == EnumRecognizerType.sherpa_offline_recognizer.value:
recognizer = load_sherpa_offline_recognizer(
nn_model_file=nn_model_file,
tokens_file=tokens_file,
decoding_method=decoding_method,
num_active_paths=num_active_paths,
)
else:
raise NotImplementedError("recognizer_type not support: {}".format(recognizer_type))
return recognizer
if __name__ == "__main__":
pass