csukuangfj commited on
Commit
f5b2e32
1 Parent(s): 39b3b3e

Support selecting decoding method.

Browse files
Files changed (2) hide show
  1. app.py +41 -5
  2. offline_asr.py +33 -28
app.py CHANGED
@@ -39,7 +39,13 @@ def convert_to_wav(in_filename: str) -> str:
39
  return out_filename
40
 
41
 
42
- def process(in_filename: str, language: str, repo_id: str) -> str:
 
 
 
 
 
 
43
  print("in_filename", in_filename)
44
  print("language", language)
45
  print("repo_id", repo_id)
@@ -65,7 +71,11 @@ def process(in_filename: str, language: str, repo_id: str) -> str:
65
  )
66
  wave = wave[0] # use only the first channel.
67
 
68
- hyp = get_pretrained_model(repo_id).decode_waves([wave])[0]
 
 
 
 
69
 
70
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
71
  end = time.time()
@@ -107,12 +117,12 @@ demo = gr.Blocks()
107
 
108
  with demo:
109
  gr.Markdown(title)
110
- gr.Markdown(description)
111
  language_choices = list(language_to_models.keys())
112
 
113
  language_radio = gr.Radio(
114
  label="Language",
115
  choices=language_choices,
 
116
  )
117
  model_dropdown = gr.Dropdown(choices=[], label="Select a model")
118
  language_radio.change(
@@ -121,6 +131,19 @@ with demo:
121
  outputs=model_dropdown,
122
  )
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  with gr.Tabs():
125
  with gr.TabItem("Upload from disk"):
126
  uploaded_file = gr.inputs.Audio(
@@ -149,14 +172,27 @@ with demo:
149
 
150
  upload_button.click(
151
  process,
152
- inputs=[uploaded_file, language_radio, model_dropdown],
 
 
 
 
 
 
153
  outputs=uploaded_output,
154
  )
155
  record_button.click(
156
  process,
157
- inputs=[microphone, language_radio, model_dropdown],
 
 
 
 
 
 
158
  outputs=recorded_output,
159
  )
 
160
 
161
  if __name__ == "__main__":
162
  demo.launch()
 
39
  return out_filename
40
 
41
 
42
+ def process(
43
+ in_filename: str,
44
+ language: str,
45
+ repo_id: str,
46
+ decoding_method: str,
47
+ num_active_paths: int,
48
+ ) -> str:
49
  print("in_filename", in_filename)
50
  print("language", language)
51
  print("repo_id", repo_id)
 
71
  )
72
  wave = wave[0] # use only the first channel.
73
 
74
+ hyp = get_pretrained_model(repo_id).decode_waves(
75
+ [wave],
76
+ decoding_method=decoding_method,
77
+ num_active_paths=num_active_paths,
78
+ )[0]
79
 
80
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
81
  end = time.time()
 
117
 
118
  with demo:
119
  gr.Markdown(title)
 
120
  language_choices = list(language_to_models.keys())
121
 
122
  language_radio = gr.Radio(
123
  label="Language",
124
  choices=language_choices,
125
+ value=language_choices[0],
126
  )
127
  model_dropdown = gr.Dropdown(choices=[], label="Select a model")
128
  language_radio.change(
 
131
  outputs=model_dropdown,
132
  )
133
 
134
+ decoding_method_radio = gr.Radio(
135
+ label="Decoding method",
136
+ choices=["greedy_search", "modified_beam_search"],
137
+ value="greedy_search",
138
+ )
139
+
140
+ num_active_paths_slider = gr.Slider(
141
+ minimum=1,
142
+ value=4,
143
+ step=1,
144
+ label="Number of active paths for modified_beam_search",
145
+ )
146
+
147
  with gr.Tabs():
148
  with gr.TabItem("Upload from disk"):
149
  uploaded_file = gr.inputs.Audio(
 
172
 
173
  upload_button.click(
174
  process,
175
+ inputs=[
176
+ uploaded_file,
177
+ language_radio,
178
+ model_dropdown,
179
+ decoding_method_radio,
180
+ num_active_paths_slider,
181
+ ],
182
  outputs=uploaded_output,
183
  )
184
  record_button.click(
185
  process,
186
+ inputs=[
187
+ microphone,
188
+ language_radio,
189
+ model_dropdown,
190
+ decoding_method_radio,
191
+ num_active_paths_slider,
192
+ ],
193
  outputs=recorded_output,
194
  )
195
+ gr.Markdown(description)
196
 
197
  if __name__ == "__main__":
198
  demo.launch()
offline_asr.py CHANGED
@@ -223,14 +223,6 @@ class OfflineAsr(object):
223
  token_filename:
224
  Path to tokens.txt. If it is None, you have to provide
225
  `bpe_model_filename`.
226
- decoding_method:
227
- The decoding method to use. Currently, only greedy_search and
228
- modified_beam_search are implemented.
229
- num_active_paths:
230
- Used only when decoding_method is modified_beam_search.
231
- It specifies number of active paths for each utterance. Due to
232
- merging paths with identical token sequences, the actual number
233
- may be less than "num_active_paths".
234
  sample_rate:
235
  Expected sample rate of the feature extractor.
236
  device:
@@ -254,24 +246,6 @@ class OfflineAsr(object):
254
  device=device,
255
  )
256
 
257
- assert decoding_method in (
258
- "greedy_search",
259
- "modified_beam_search",
260
- ), decoding_method
261
- if decoding_method == "greedy_search":
262
- nn_and_decoding_func = run_model_and_do_greedy_search
263
- elif decoding_method == "modified_beam_search":
264
- nn_and_decoding_func = functools.partial(
265
- run_model_and_do_modified_beam_search,
266
- num_active_paths=num_active_paths,
267
- )
268
- else:
269
- raise ValueError(
270
- f"Unsupported decoding_method: {decoding_method} "
271
- "Please use greedy_search or modified_beam_search"
272
- )
273
-
274
- self.nn_and_decoding_func = nn_and_decoding_func
275
  self.device = device
276
 
277
  def _build_feature_extractor(
@@ -300,7 +274,12 @@ class OfflineAsr(object):
300
 
301
  return fbank
302
 
303
- def decode_waves(self, waves: List[torch.Tensor]) -> List[List[str]]:
 
 
 
 
 
304
  """
305
  Args:
306
  waves:
@@ -314,14 +293,40 @@ class OfflineAsr(object):
314
  then the given waves have to contain samples in this range.
315
 
316
  All models trained in icefall use the normalized range [-1, 1].
 
 
 
 
 
 
 
 
317
  Returns:
318
  Return a list of decoded results. `ans[i]` contains the decoded
319
  results for `wavs[i]`.
320
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  waves = [w.to(self.device) for w in waves]
322
  features = self.feature_extractor(waves)
323
 
324
- tokens = self.nn_and_decoding_func(self.model, features)
325
 
326
  if hasattr(self, "sp"):
327
  results = self.sp.decode(tokens)
 
223
  token_filename:
224
  Path to tokens.txt. If it is None, you have to provide
225
  `bpe_model_filename`.
 
 
 
 
 
 
 
 
226
  sample_rate:
227
  Expected sample rate of the feature extractor.
228
  device:
 
246
  device=device,
247
  )
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  self.device = device
250
 
251
  def _build_feature_extractor(
 
274
 
275
  return fbank
276
 
277
+ def decode_waves(
278
+ self,
279
+ waves: List[torch.Tensor],
280
+ decoding_method: str,
281
+ num_active_paths: int,
282
+ ) -> List[List[str]]:
283
  """
284
  Args:
285
  waves:
 
293
  then the given waves have to contain samples in this range.
294
 
295
  All models trained in icefall use the normalized range [-1, 1].
296
+ decoding_method:
297
+ The decoding method to use. Currently, only greedy_search and
298
+ modified_beam_search are implemented.
299
+ num_active_paths:
300
+ Used only when decoding_method is modified_beam_search.
301
+ It specifies number of active paths for each utterance. Due to
302
+ merging paths with identical token sequences, the actual number
303
+ may be less than "num_active_paths".
304
  Returns:
305
  Return a list of decoded results. `ans[i]` contains the decoded
306
  results for `wavs[i]`.
307
  """
308
+ assert decoding_method in (
309
+ "greedy_search",
310
+ "modified_beam_search",
311
+ ), decoding_method
312
+
313
+ if decoding_method == "greedy_search":
314
+ nn_and_decoding_func = run_model_and_do_greedy_search
315
+ elif decoding_method == "modified_beam_search":
316
+ nn_and_decoding_func = functools.partial(
317
+ run_model_and_do_modified_beam_search,
318
+ num_active_paths=num_active_paths,
319
+ )
320
+ else:
321
+ raise ValueError(
322
+ f"Unsupported decoding_method: {decoding_method} "
323
+ "Please use greedy_search or modified_beam_search"
324
+ )
325
+
326
  waves = [w.to(self.device) for w in waves]
327
  features = self.feature_extractor(waves)
328
 
329
+ tokens = nn_and_decoding_func(self.model, features)
330
 
331
  if hasattr(self, "sp"):
332
  results = self.sp.decode(tokens)