Spaces:
Sleeping
Sleeping
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
from enum import Enum | |
from functools import lru_cache | |
import os | |
import huggingface_hub | |
import sherpa | |
class EnumDecodingMethod(Enum): | |
greedy_search = "greedy_search" | |
modified_beam_search = "modified_beam_search" | |
class EnumRecognizerType(Enum): | |
sherpa_offline_recognizer = "sherpa.OfflineRecognizer" | |
sherpa_online_recognizer = "sherpa.OnlineRecognizer" | |
sherpa_onnx_offline_recognizer = "sherpa_onnx.OfflineRecognizer" | |
sherpa_onnx_online_recognizer = "sherpa_onnx.OnlineRecognizer" | |
model_map = { | |
"Chinese": [ | |
{ | |
"repo_id": "csukuangfj/wenet-chinese-model", | |
"nn_model_file": "final.zip", | |
"tokens_file": "units.txt", | |
"sub_folder": ".", | |
"recognizer_type": EnumRecognizerType.sherpa_offline_recognizer.value, | |
} | |
] | |
} | |
def download_model(repo_id: str, | |
nn_model_file: str, | |
tokens_file: str, | |
sub_folder: str, | |
local_model_dir: str, | |
): | |
nn_model_file = huggingface_hub.hf_hub_download( | |
repo_id=repo_id, | |
filename=nn_model_file, | |
subfolder=sub_folder, | |
local_dir=local_model_dir, | |
) | |
tokens_file = huggingface_hub.hf_hub_download( | |
repo_id=repo_id, | |
filename=tokens_file, | |
subfolder=sub_folder, | |
local_dir=local_model_dir, | |
) | |
return nn_model_file, tokens_file | |
def load_sherpa_offline_recognizer(nn_model_file: str, | |
tokens_file: str, | |
sample_rate: int = 16000, | |
num_active_paths: int = 2, | |
decoding_method: EnumDecodingMethod = EnumDecodingMethod.greedy_search, | |
num_mel_bins: int = 80, | |
frame_dither: int = 0, | |
): | |
feat_config = sherpa.FeatureConfig() | |
feat_config.fbank_opts.frame_opts.samp_freq = sample_rate | |
feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins | |
feat_config.fbank_opts.frame_opts.dither = frame_dither | |
config = sherpa.OfflineRecognizerConfig( | |
nn_model=nn_model_file, | |
tokens=tokens_file, | |
use_gpu=False, | |
feat_config=feat_config, | |
decoding_method=decoding_method, | |
num_active_paths=num_active_paths, | |
) | |
recognizer = sherpa.OfflineRecognizer(config) | |
return recognizer | |
def load_recognizer(repo_id: str, | |
nn_model_file: str, | |
tokens_file: str, | |
sub_folder: str, | |
local_model_dir: str, | |
recognizer_type: str, | |
decoding_method: EnumDecodingMethod = EnumDecodingMethod.greedy_search, | |
num_active_paths: int = 4, | |
): | |
if not os.path.exists(local_model_dir): | |
download_model( | |
repo_id=repo_id, | |
nn_model_file=nn_model_file, | |
tokens_file=tokens_file, | |
sub_folder=sub_folder, | |
local_model_dir=local_model_dir, | |
) | |
if recognizer_type == EnumRecognizerType.sherpa_offline_recognizer.value: | |
recognizer = load_sherpa_offline_recognizer( | |
nn_model_file=nn_model_file, | |
tokens_file=tokens_file, | |
decoding_method=decoding_method, | |
num_active_paths=num_active_paths, | |
) | |
else: | |
raise NotImplementedError("recognizer_type not support: {}".format(recognizer_type.value)) | |
return recognizer | |
if __name__ == "__main__": | |
pass | |