Spaces:
Sleeping
Sleeping
update
Browse files
examples/wenet/toolbox_download.py
ADDED
File without changes
|
toolbox/k2_sherpa/nn_models.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
from enum import Enum
|
4 |
from functools import lru_cache
|
|
|
5 |
import os
|
6 |
import platform
|
7 |
from pathlib import Path
|
@@ -10,6 +11,8 @@ import huggingface_hub
|
|
10 |
import sherpa
|
11 |
import sherpa_onnx
|
12 |
|
|
|
|
|
13 |
|
14 |
class EnumDecodingMethod(Enum):
|
15 |
greedy_search = "greedy_search"
|
@@ -104,6 +107,7 @@ def download_model(local_model_dir: str,
|
|
104 |
repo_id = kwargs["repo_id"]
|
105 |
|
106 |
if "nn_model_file" in kwargs.keys():
|
|
|
107 |
_ = huggingface_hub.hf_hub_download(
|
108 |
repo_id=repo_id,
|
109 |
filename=kwargs["nn_model_file"],
|
@@ -112,6 +116,7 @@ def download_model(local_model_dir: str,
|
|
112 |
)
|
113 |
|
114 |
if "encoder_model_file" in kwargs.keys():
|
|
|
115 |
_ = huggingface_hub.hf_hub_download(
|
116 |
repo_id=repo_id,
|
117 |
filename=kwargs["encoder_model_file"],
|
@@ -120,6 +125,7 @@ def download_model(local_model_dir: str,
|
|
120 |
)
|
121 |
|
122 |
if "decoder_model_file" in kwargs.keys():
|
|
|
123 |
_ = huggingface_hub.hf_hub_download(
|
124 |
repo_id=repo_id,
|
125 |
filename=kwargs["decoder_model_file"],
|
@@ -128,6 +134,7 @@ def download_model(local_model_dir: str,
|
|
128 |
)
|
129 |
|
130 |
if "joiner_model_file" in kwargs.keys():
|
|
|
131 |
_ = huggingface_hub.hf_hub_download(
|
132 |
repo_id=repo_id,
|
133 |
filename=kwargs["joiner_model_file"],
|
@@ -136,6 +143,7 @@ def download_model(local_model_dir: str,
|
|
136 |
)
|
137 |
|
138 |
if "tokens_file" in kwargs.keys():
|
|
|
139 |
_ = huggingface_hub.hf_hub_download(
|
140 |
repo_id=repo_id,
|
141 |
filename=kwargs["tokens_file"],
|
@@ -158,6 +166,9 @@ def load_sherpa_offline_recognizer(nn_model_file: str,
|
|
158 |
feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins
|
159 |
feat_config.fbank_opts.frame_opts.dither = frame_dither
|
160 |
|
|
|
|
|
|
|
161 |
config = sherpa.OfflineRecognizerConfig(
|
162 |
nn_model=nn_model_file,
|
163 |
tokens=tokens_file,
|
@@ -220,7 +231,7 @@ def load_recognizer(local_model_dir: Path,
|
|
220 |
num_active_paths: int = 4,
|
221 |
**kwargs
|
222 |
):
|
223 |
-
if not
|
224 |
download_model(
|
225 |
local_model_dir=local_model_dir.as_posix(),
|
226 |
**kwargs,
|
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
from enum import Enum
|
4 |
from functools import lru_cache
|
5 |
+
import logging
|
6 |
import os
|
7 |
import platform
|
8 |
from pathlib import Path
|
|
|
11 |
import sherpa
|
12 |
import sherpa_onnx
|
13 |
|
14 |
+
main_logger = logging.getLogger("main")
|
15 |
+
|
16 |
|
17 |
class EnumDecodingMethod(Enum):
|
18 |
greedy_search = "greedy_search"
|
|
|
107 |
repo_id = kwargs["repo_id"]
|
108 |
|
109 |
if "nn_model_file" in kwargs.keys():
|
110 |
+
main_logger.info("download nn_model_file. filename: {}, subfolder: {}".format(kwargs["nn_model_file"], kwargs["nn_model_file_sub_folder"]))
|
111 |
_ = huggingface_hub.hf_hub_download(
|
112 |
repo_id=repo_id,
|
113 |
filename=kwargs["nn_model_file"],
|
|
|
116 |
)
|
117 |
|
118 |
if "encoder_model_file" in kwargs.keys():
|
119 |
+
main_logger.info("download encoder_model_file. filename: {}, subfolder: {}".format(kwargs["encoder_model_file"], kwargs["encoder_model_file_sub_folder"]))
|
120 |
_ = huggingface_hub.hf_hub_download(
|
121 |
repo_id=repo_id,
|
122 |
filename=kwargs["encoder_model_file"],
|
|
|
125 |
)
|
126 |
|
127 |
if "decoder_model_file" in kwargs.keys():
|
128 |
+
main_logger.info("download decoder_model_file. filename: {}, subfolder: {}".format(kwargs["decoder_model_file"], kwargs["decoder_model_file_sub_folder"]))
|
129 |
_ = huggingface_hub.hf_hub_download(
|
130 |
repo_id=repo_id,
|
131 |
filename=kwargs["decoder_model_file"],
|
|
|
134 |
)
|
135 |
|
136 |
if "joiner_model_file" in kwargs.keys():
|
137 |
+
main_logger.info("download joiner_model_file. filename: {}, subfolder: {}".format(kwargs["joiner_model_file"], kwargs["joiner_model_file_sub_folder"]))
|
138 |
_ = huggingface_hub.hf_hub_download(
|
139 |
repo_id=repo_id,
|
140 |
filename=kwargs["joiner_model_file"],
|
|
|
143 |
)
|
144 |
|
145 |
if "tokens_file" in kwargs.keys():
|
146 |
+
main_logger.info("download tokens_file. filename: {}, subfolder: {}".format(kwargs["tokens_file"], kwargs["tokens_file_sub_folder"]))
|
147 |
_ = huggingface_hub.hf_hub_download(
|
148 |
repo_id=repo_id,
|
149 |
filename=kwargs["tokens_file"],
|
|
|
166 |
feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins
|
167 |
feat_config.fbank_opts.frame_opts.dither = frame_dither
|
168 |
|
169 |
+
if not os.path.exists(nn_model_file):
|
170 |
+
raise AssertionError("nn_model_file not found. ")
|
171 |
+
|
172 |
config = sherpa.OfflineRecognizerConfig(
|
173 |
nn_model=nn_model_file,
|
174 |
tokens=tokens_file,
|
|
|
231 |
num_active_paths: int = 4,
|
232 |
**kwargs
|
233 |
):
|
234 |
+
if not local_model_dir.exists():
|
235 |
download_model(
|
236 |
local_model_dir=local_model_dir.as_posix(),
|
237 |
**kwargs,
|