#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse from pathlib import Path import huggingface_hub import sherpa from project_settings import project_path def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--repo_id", default="csukuangfj/wenet-chinese-model", # default="csukuangfj/wenet-english-model", type=str ) parser.add_argument("--model_filename", default="final.zip", type=str) parser.add_argument("--tokens_filename", default="units.txt", type=str) parser.add_argument( "--pretrained_model_dir", default=(project_path / "pretrained_models").as_posix(), type=str ) args = parser.parse_args() return args def main(): args = get_args() pretrained_model_dir = Path(args.pretrained_model_dir) pretrained_model_dir.mkdir(exist_ok=True) model_dir = pretrained_model_dir / "huggingface" / args.repo_id model_dir.mkdir(exist_ok=True) print("download model") model_filename = huggingface_hub.hf_hub_download( repo_id=args.repo_id, filename=args.model_filename, subfolder=".", local_dir=model_dir.as_posix(), ) print(model_filename) print("download tokens") token_filename = huggingface_hub.hf_hub_download( repo_id=args.repo_id, filename=args.tokens_filename, subfolder=".", local_dir=model_dir.as_posix(), ) print(token_filename) feat_config = sherpa.FeatureConfig(normalize_samples=False) feat_config.fbank_opts.frame_opts.samp_freq = sample_rate feat_config.fbank_opts.mel_opts.num_bins = 80 feat_config.fbank_opts.frame_opts.dither = 0 config = sherpa.OfflineRecognizerConfig( nn_model=nn_model, tokens=tokens, use_gpu=False, feat_config=feat_config, decoding_method=decoding_method, num_active_paths=num_active_paths, ) recognizer = sherpa.OfflineRecognizer(config) return if __name__ == "__main__": main()