HoneyTian commited on
Commit
8f39105
1 Parent(s): 74df484
examples/wenet/toolbox_download.py ADDED
File without changes
toolbox/k2_sherpa/nn_models.py CHANGED
@@ -2,6 +2,7 @@
2
  # -*- coding: utf-8 -*-
3
  from enum import Enum
4
  from functools import lru_cache
 
5
  import os
6
  import platform
7
  from pathlib import Path
@@ -10,6 +11,8 @@ import huggingface_hub
10
  import sherpa
11
  import sherpa_onnx
12
 
 
 
13
 
14
  class EnumDecodingMethod(Enum):
15
  greedy_search = "greedy_search"
@@ -104,6 +107,7 @@ def download_model(local_model_dir: str,
104
  repo_id = kwargs["repo_id"]
105
 
106
  if "nn_model_file" in kwargs.keys():
 
107
  _ = huggingface_hub.hf_hub_download(
108
  repo_id=repo_id,
109
  filename=kwargs["nn_model_file"],
@@ -112,6 +116,7 @@ def download_model(local_model_dir: str,
112
  )
113
 
114
  if "encoder_model_file" in kwargs.keys():
 
115
  _ = huggingface_hub.hf_hub_download(
116
  repo_id=repo_id,
117
  filename=kwargs["encoder_model_file"],
@@ -120,6 +125,7 @@ def download_model(local_model_dir: str,
120
  )
121
 
122
  if "decoder_model_file" in kwargs.keys():
 
123
  _ = huggingface_hub.hf_hub_download(
124
  repo_id=repo_id,
125
  filename=kwargs["decoder_model_file"],
@@ -128,6 +134,7 @@ def download_model(local_model_dir: str,
128
  )
129
 
130
  if "joiner_model_file" in kwargs.keys():
 
131
  _ = huggingface_hub.hf_hub_download(
132
  repo_id=repo_id,
133
  filename=kwargs["joiner_model_file"],
@@ -136,6 +143,7 @@ def download_model(local_model_dir: str,
136
  )
137
 
138
  if "tokens_file" in kwargs.keys():
 
139
  _ = huggingface_hub.hf_hub_download(
140
  repo_id=repo_id,
141
  filename=kwargs["tokens_file"],
@@ -158,6 +166,9 @@ def load_sherpa_offline_recognizer(nn_model_file: str,
158
  feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins
159
  feat_config.fbank_opts.frame_opts.dither = frame_dither
160
 
 
 
 
161
  config = sherpa.OfflineRecognizerConfig(
162
  nn_model=nn_model_file,
163
  tokens=tokens_file,
@@ -220,7 +231,7 @@ def load_recognizer(local_model_dir: Path,
220
  num_active_paths: int = 4,
221
  **kwargs
222
  ):
223
- if not os.path.exists(local_model_dir):
224
  download_model(
225
  local_model_dir=local_model_dir.as_posix(),
226
  **kwargs,
 
2
  # -*- coding: utf-8 -*-
3
  from enum import Enum
4
  from functools import lru_cache
5
+ import logging
6
  import os
7
  import platform
8
  from pathlib import Path
 
11
  import sherpa
12
  import sherpa_onnx
13
 
14
+ main_logger = logging.getLogger("main")
15
+
16
 
17
  class EnumDecodingMethod(Enum):
18
  greedy_search = "greedy_search"
 
107
  repo_id = kwargs["repo_id"]
108
 
109
  if "nn_model_file" in kwargs.keys():
110
+ main_logger.info("download nn_model_file. filename: {}, subfolder: {}".format(kwargs["nn_model_file"], kwargs["nn_model_file_sub_folder"]))
111
  _ = huggingface_hub.hf_hub_download(
112
  repo_id=repo_id,
113
  filename=kwargs["nn_model_file"],
 
116
  )
117
 
118
  if "encoder_model_file" in kwargs.keys():
119
+ main_logger.info("download encoder_model_file. filename: {}, subfolder: {}".format(kwargs["encoder_model_file"], kwargs["encoder_model_file_sub_folder"]))
120
  _ = huggingface_hub.hf_hub_download(
121
  repo_id=repo_id,
122
  filename=kwargs["encoder_model_file"],
 
125
  )
126
 
127
  if "decoder_model_file" in kwargs.keys():
128
+ main_logger.info("download decoder_model_file. filename: {}, subfolder: {}".format(kwargs["decoder_model_file"], kwargs["decoder_model_file_sub_folder"]))
129
  _ = huggingface_hub.hf_hub_download(
130
  repo_id=repo_id,
131
  filename=kwargs["decoder_model_file"],
 
134
  )
135
 
136
  if "joiner_model_file" in kwargs.keys():
137
+ main_logger.info("download joiner_model_file. filename: {}, subfolder: {}".format(kwargs["joiner_model_file"], kwargs["joiner_model_file_sub_folder"]))
138
  _ = huggingface_hub.hf_hub_download(
139
  repo_id=repo_id,
140
  filename=kwargs["joiner_model_file"],
 
143
  )
144
 
145
  if "tokens_file" in kwargs.keys():
146
+ main_logger.info("download tokens_file. filename: {}, subfolder: {}".format(kwargs["tokens_file"], kwargs["tokens_file_sub_folder"]))
147
  _ = huggingface_hub.hf_hub_download(
148
  repo_id=repo_id,
149
  filename=kwargs["tokens_file"],
 
166
  feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins
167
  feat_config.fbank_opts.frame_opts.dither = frame_dither
168
 
169
+ if not os.path.exists(nn_model_file):
170
+ raise AssertionError("nn_model_file not found. ")
171
+
172
  config = sherpa.OfflineRecognizerConfig(
173
  nn_model=nn_model_file,
174
  tokens=tokens_file,
 
231
  num_active_paths: int = 4,
232
  **kwargs
233
  ):
234
+ if not local_model_dir.exists():
235
  download_model(
236
  local_model_dir=local_model_dir.as_posix(),
237
  **kwargs,