csukuangfj commited on
Commit
ee0d936
1 Parent(s): 08d2e6b

small fixes

Browse files
Files changed (4) hide show
  1. app.py +12 -19
  2. giga-tokens.txt +500 -0
  3. model.py +205 -88
  4. offline_asr.py +0 -427
app.py CHANGED
@@ -37,7 +37,7 @@ def convert_to_wav(in_filename: str) -> str:
37
  """Convert the input audio file to a wave file"""
38
  out_filename = in_filename + ".wav"
39
  logging.info(f"Converting '{in_filename}' to '{out_filename}'")
40
- _ = os.system(f"ffmpeg -hide_banner -i '{in_filename}' '{out_filename}'")
41
  return out_filename
42
 
43
 
@@ -128,31 +128,24 @@ def process(
128
  logging.info(f"Started at {date_time}")
129
 
130
  start = time.time()
131
- wave, wave_sample_rate = torchaudio.load(filename)
132
 
133
- if wave_sample_rate != sample_rate:
134
- logging.info(
135
- f"Expected sample rate: {sample_rate}. Given: {wave_sample_rate}. "
136
- f"Resampling to {sample_rate}."
137
- )
138
-
139
- wave = torchaudio.functional.resample(
140
- wave,
141
- orig_freq=wave_sample_rate,
142
- new_freq=sample_rate,
143
- )
144
- wave = wave[0] # use only the first channel.
145
-
146
- hyp = get_pretrained_model(repo_id).decode_waves(
147
- [wave],
148
  decoding_method=decoding_method,
149
  num_active_paths=num_active_paths,
150
- )[0]
 
 
 
 
 
 
151
 
152
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
153
  end = time.time()
154
 
155
- duration = wave.shape[0] / sample_rate
 
156
  rtf = (end - start) / duration
157
 
158
  logging.info(f"Finished at {date_time} s. Elapsed: {end - start: .3f} s")
37
  """Convert the input audio file to a wave file"""
38
  out_filename = in_filename + ".wav"
39
  logging.info(f"Converting '{in_filename}' to '{out_filename}'")
40
+ _ = os.system(f"ffmpeg -hide_banner -i '{in_filename}' -ar 16000 '{out_filename}'")
41
  return out_filename
42
 
43
 
128
  logging.info(f"Started at {date_time}")
129
 
130
  start = time.time()
 
131
 
132
+ recognizer = get_pretrained_model(
133
+ repo_id,
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  decoding_method=decoding_method,
135
  num_active_paths=num_active_paths,
136
+ )
137
+ s = recognizer.create_stream()
138
+
139
+ s.accept_wave_file(filename)
140
+ recognizer.decode_stream(s)
141
+
142
+ logging.info(s.text)
143
 
144
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
145
  end = time.time()
146
 
147
+ metadata = torchaudio.info(filename)
148
+ duration = wave.num_frames / sample_rate
149
  rtf = (end - start) / duration
150
 
151
  logging.info(f"Finished at {date_time} s. Elapsed: {end - start: .3f} s")
giga-tokens.txt ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <blk> 0
2
+ <sos/eos> 1
3
+ <unk> 2
4
+ S 3
5
+ T 4
6
+ ▁THE 5
7
+ ▁A 6
8
+ E 7
9
+ ▁AND 8
10
+ ▁TO 9
11
+ N 10
12
+ D 11
13
+ ▁OF 12
14
+ ' 13
15
+ ING 14
16
+ ▁I 15
17
+ Y 16
18
+ ▁IN 17
19
+ ED 18
20
+ ▁THAT 19
21
+ ▁ 20
22
+ P 21
23
+ R 22
24
+ ▁YOU 23
25
+ M 24
26
+ RE 25
27
+ ER 26
28
+ C 27
29
+ O 28
30
+ ▁IT 29
31
+ L 30
32
+ A 31
33
+ U 32
34
+ G 33
35
+ ▁WE 34
36
+ ▁IS 35
37
+ ▁SO 36
38
+ AL 37
39
+ I 38
40
+ ▁S 39
41
+ ▁RE 40
42
+ AR 41
43
+ B 42
44
+ ▁FOR 43
45
+ ▁C 44
46
+ ▁BE 45
47
+ LE 46
48
+ F 47
49
+ W 48
50
+ ▁E 49
51
+ ▁HE 50
52
+ LL 51
53
+ ▁WAS 52
54
+ LY 53
55
+ OR 54
56
+ IN 55
57
+ ▁F 56
58
+ VE 57
59
+ ▁THIS 58
60
+ TH 59
61
+ K 60
62
+ ▁ON 61
63
+ IT 62
64
+ ▁B 63
65
+ ▁WITH 64
66
+ ▁BUT 65
67
+ EN 66
68
+ CE 67
69
+ RI 68
70
+ ▁DO 69
71
+ UR 70
72
+ ▁HAVE 71
73
+ ▁DE 72
74
+ ▁ME 73
75
+ ▁T 74
76
+ ENT 75
77
+ CH 76
78
+ ▁THEY 77
79
+ ▁NOT 78
80
+ ES 79
81
+ V 80
82
+ ▁AS 81
83
+ RA 82
84
+ ▁P 83
85
+ ON 84
86
+ TER 85
87
+ ▁ARE 86
88
+ ▁WHAT 87
89
+ IC 88
90
+ ▁ST 89
91
+ ▁LIKE 90
92
+ ATION 91
93
+ ▁OR 92
94
+ ▁CA 93
95
+ ▁AT 94
96
+ H 95
97
+ ▁KNOW 96
98
+ ▁G 97
99
+ AN 98
100
+ ▁CON 99
101
+ IL 100
102
+ ND 101
103
+ RO 102
104
+ ▁HIS 103
105
+ ▁CAN 104
106
+ ▁ALL 105
107
+ TE 106
108
+ ▁THERE 107
109
+ ▁SU 108
110
+ ▁MO 109
111
+ ▁MA 110
112
+ LI 111
113
+ ▁ONE 112
114
+ ▁ABOUT 113
115
+ LA 114
116
+ ▁CO 115
117
+ - 116
118
+ ▁MY 117
119
+ ▁HAD 118
120
+ CK 119
121
+ NG 120
122
+ ▁NO 121
123
+ MENT 122
124
+ AD 123
125
+ LO 124
126
+ ME 125
127
+ ▁AN 126
128
+ ▁FROM 127
129
+ NE 128
130
+ ▁IF 129
131
+ VER 130
132
+ ▁JUST 131
133
+ ▁PRO 132
134
+ ION 133
135
+ ▁PA 134
136
+ ▁WHO 135
137
+ ▁SE 136
138
+ EL 137
139
+ IR 138
140
+ ▁US 139
141
+ ▁UP 140
142
+ ▁YOUR 141
143
+ CI 142
144
+ RY 143
145
+ ▁GO 144
146
+ ▁SHE 145
147
+ ▁LE 146
148
+ ▁OUT 147
149
+ ▁PO 148
150
+ ▁HO 149
151
+ ATE 150
152
+ ▁BO 151
153
+ ▁BY 152
154
+ ▁FA 153
155
+ ▁MI 154
156
+ AS 155
157
+ MP 156
158
+ ▁HER 157
159
+ VI 158
160
+ ▁THINK 159
161
+ ▁SOME 160
162
+ ▁WHEN 161
163
+ ▁AH 162
164
+ ▁PEOPLE 163
165
+ IG 164
166
+ ▁WA 165
167
+ ▁TE 166
168
+ ▁LA 167
169
+ ▁WERE 168
170
+ ▁LI 169
171
+ ▁WOULD 170
172
+ ▁SEE 171
173
+ ▁WHICH 172
174
+ DE 173
175
+ GE 174
176
+ ▁K 175
177
+ IGHT 176
178
+ ▁HA 177
179
+ ▁OUR 178
180
+ UN 179
181
+ ▁HOW 180
182
+ ▁GET 181
183
+ IS 182
184
+ UT 183
185
+ Z 184
186
+ CO 185
187
+ ET 186
188
+ UL 187
189
+ IES 188
190
+ IVE 189
191
+ AT 190
192
+ ▁O 191
193
+ ▁DON 192
194
+ LU 193
195
+ ▁TIME 194
196
+ ▁WILL 195
197
+ ▁MORE 196
198
+ ▁SP 197
199
+ ▁NOW 198
200
+ RU 199
201
+ ▁THEIR 200
202
+ ▁UN 201
203
+ ITY 202
204
+ OL 203
205
+ X 204
206
+ TI 205
207
+ US 206
208
+ ▁VERY 207
209
+ TION 208
210
+ ▁FI 209
211
+ ▁SAY 210
212
+ ▁BECAUSE 211
213
+ ▁EX 212
214
+ ▁RO 213
215
+ ERS 214
216
+ IST 215
217
+ ▁DA 216
218
+ TING 217
219
+ ▁EN 218
220
+ OM 219
221
+ ▁BA 220
222
+ ▁BEEN 221
223
+ ▁LO 222
224
+ ▁UM 223
225
+ AGE 224
226
+ ABLE 225
227
+ ▁WO 226
228
+ ▁RA 227
229
+ ▁OTHER 228
230
+ ▁REALLY 229
231
+ ENCE 230
232
+ ▁GOING 231
233
+ ▁HIM 232
234
+ ▁HAS 233
235
+ ▁THEM 234
236
+ ▁DIS 235
237
+ ▁WANT 236
238
+ ID 237
239
+ TA 238
240
+ ▁LOOK 239
241
+ KE 240
242
+ ▁DID 241
243
+ ▁SA 242
244
+ ▁VI 243
245
+ ▁SAID 244
246
+ ▁RIGHT 245
247
+ ▁THESE 246
248
+ ▁WORK 247
249
+ ▁COM 248
250
+ ALLY 249
251
+ FF 250
252
+ QU 251
253
+ AC 252
254
+ ▁DR 253
255
+ ▁WAY 254
256
+ ▁INTO 255
257
+ MO 256
258
+ TED 257
259
+ EST 258
260
+ ▁HERE 259
261
+ OK 260
262
+ ▁COULD 261
263
+ ▁WELL 262
264
+ MA 263
265
+ ▁PRE 264
266
+ ▁DI 265
267
+ MAN 266
268
+ ▁COMP 267
269
+ ▁THEN 268
270
+ IM 269
271
+ ▁PER 270
272
+ ▁NA 271
273
+ ▁WHERE 272
274
+ ▁TWO 273
275
+ ▁WI 274
276
+ ▁FE 275
277
+ INE 276
278
+ ▁ANY 277
279
+ TURE 278
280
+ ▁OVER 279
281
+ BO 280
282
+ ACH 281
283
+ OW 282
284
+ ▁MAKE 283
285
+ ▁TRA 284
286
+ HE 285
287
+ UND 286
288
+ ▁EVEN 287
289
+ ANCE 288
290
+ ▁YEAR 289
291
+ HO 290
292
+ AM 291
293
+ ▁CHA 292
294
+ ▁BACK 293
295
+ VO 294
296
+ ANT 295
297
+ DI 296
298
+ ▁ALSO 297
299
+ ▁THOSE 298
300
+ ▁MAN 299
301
+ CTION 300
302
+ ICAL 301
303
+ ▁JO 302
304
+ ▁OP 303
305
+ ▁NEW 304
306
+ ▁MU 305
307
+ ▁HU 306
308
+ ▁KIND 307
309
+ ▁NE 308
310
+ CA 309
311
+ END 310
312
+ TIC 311
313
+ FUL 312
314
+ ▁YEAH 313
315
+ SH 314
316
+ ▁APP 315
317
+ ▁THINGS 316
318
+ SIDE 317
319
+ ▁GOOD 318
320
+ ONE 319
321
+ ▁TAKE 320
322
+ CU 321
323
+ ▁EVERY 322
324
+ ▁MEAN 323
325
+ ▁FIRST 324
326
+ OP 325
327
+ ▁TH 326
328
+ ▁MUCH 327
329
+ ▁PART 328
330
+ UGH 329
331
+ ▁COME 330
332
+ J 331
333
+ ▁THAN 332
334
+ ▁EXP 333
335
+ ▁AGAIN 334
336
+ ▁LITTLE 335
337
+ MB 336
338
+ ▁NEED 337
339
+ ▁TALK 338
340
+ IF 339
341
+ FOR 340
342
+ ▁SH 341
343
+ ISH 342
344
+ ▁STA 343
345
+ ATED 344
346
+ ▁GU 345
347
+ ▁LET 346
348
+ IA 347
349
+ ▁MAR 348
350
+ ▁DOWN 349
351
+ ▁DAY 350
352
+ ▁GA 351
353
+ ▁SOMETHING 352
354
+ ▁BU 353
355
+ DUC 354
356
+ HA 355
357
+ ▁LOT 356
358
+ ▁RU 357
359
+ ▁THOUGH 358
360
+ ▁GREAT 359
361
+ AIN 360
362
+ ▁THROUGH 361
363
+ ▁THING 362
364
+ OUS 363
365
+ ▁PRI 364
366
+ ▁GOT 365
367
+ ▁SHOULD 366
368
+ ▁AFTER 367
369
+ ▁HEAR 368
370
+ ▁TA 369
371
+ ▁ONLY 370
372
+ ▁CHI 371
373
+ IOUS 372
374
+ ▁SHA 373
375
+ ▁MOST 374
376
+ ▁ACTUALLY 375
377
+ ▁START 376
378
+ LIC 377
379
+ ▁VA 378
380
+ ▁RI 379
381
+ DAY 380
382
+ IAN 381
383
+ ▁DOES 382
384
+ ROW 383
385
+ ▁GRA 384
386
+ ITION 385
387
+ ▁MANY 386
388
+ ▁BEFORE 387
389
+ ▁GIVE 388
390
+ PORT 389
391
+ QUI 390
392
+ ▁LIFE 391
393
+ ▁WORLD 392
394
+ ▁PI 393
395
+ ▁LONG 394
396
+ ▁THREE 395
397
+ IZE 396
398
+ NESS 397
399
+ ▁SHOW 398
400
+ PH 399
401
+ ▁WHY 400
402
+ ▁QUESTION 401
403
+ WARD 402
404
+ ▁THANK 403
405
+ ▁PH 404
406
+ ▁DIFFERENT 405
407
+ ▁OWN 406
408
+ ▁FEEL 407
409
+ ▁MIGHT 408
410
+ ▁HAPPEN 409
411
+ ▁MADE 410
412
+ ▁BRO 411
413
+ IBLE 412
414
+ ▁HI 413
415
+ ▁STATE 414
416
+ ▁HAND 415
417
+ ▁NEVER 416
418
+ ▁PLACE 417
419
+ ▁LOVE 418
420
+ ▁DU 419
421
+ ▁POINT 420
422
+ ▁HELP 421
423
+ ▁COUNT 422
424
+ ▁STILL 423
425
+ ▁MR 424
426
+ ▁FIND 425
427
+ ▁PERSON 426
428
+ ▁CAME 427
429
+ ▁SAME 428
430
+ ▁LAST 429
431
+ ▁HIGH 430
432
+ ▁OLD 431
433
+ ▁UNDER 432
434
+ ▁FOUR 433
435
+ ▁AROUND 434
436
+ ▁SORT 435
437
+ ▁CHANGE 436
438
+ ▁YES 437
439
+ SHIP 438
440
+ ▁ANOTHER 439
441
+ ATIVE 440
442
+ ▁FOUND 441
443
+ ▁JA 442
444
+ ▁ALWAYS 443
445
+ ▁NEXT 444
446
+ ▁TURN 445
447
+ ▁JU 446
448
+ ▁SIX 447
449
+ ▁FACT 448
450
+ ▁INTEREST 449
451
+ ▁WORD 450
452
+ ▁THOUSAND 451
453
+ ▁HUNDRED 452
454
+ ▁NUMBER 453
455
+ ▁IDEA 454
456
+ ▁PLAN 455
457
+ ▁COURSE 456
458
+ ▁SCHOOL 457
459
+ ▁HOUSE 458
460
+ ▁TWENTY 459
461
+ ▁JE 460
462
+ ▁PLAY 461
463
+ ▁AWAY 462
464
+ ▁LEARN 463
465
+ ▁HARD 464
466
+ ▁WEEK 465
467
+ ▁BETTER 466
468
+ ▁WHILE 467
469
+ ▁FRIEND 468
470
+ ▁OKAY 469
471
+ ▁NINE 470
472
+ ▁UNDERSTAND 471
473
+ ▁KEEP 472
474
+ ▁GONNA 473
475
+ ▁SYSTEM 474
476
+ ▁AMERICA 475
477
+ ▁POWER 476
478
+ ▁IMPORTANT 477
479
+ ▁WITHOUT 478
480
+ ▁MAYBE 479
481
+ ▁SEVEN 480
482
+ ▁BETWEEN 481
483
+ ▁BUILD 482
484
+ ▁CERTAIN 483
485
+ ▁PROBLEM 484
486
+ ▁MONEY 485
487
+ ▁BELIEVE 486
488
+ ▁SECOND 487
489
+ ▁REASON 488
490
+ ▁TOGETHER 489
491
+ ▁PUBLIC 490
492
+ ▁ANYTHING 491
493
+ ▁SPEAK 492
494
+ ▁BUSINESS 493
495
+ ▁EVERYTHING 494
496
+ ▁CLOSE 495
497
+ ▁QUITE 496
498
+ ▁ANSWER 497
499
+ ▁ENOUGH 498
500
+ Q 499
model.py CHANGED
@@ -16,23 +16,35 @@
16
 
17
  from huggingface_hub import hf_hub_download
18
  from functools import lru_cache
 
19
 
20
 
21
- from offline_asr import OfflineAsr
22
-
23
  sample_rate = 16000
24
 
25
 
26
  @lru_cache(maxsize=30)
27
- def get_pretrained_model(repo_id: str) -> OfflineAsr:
 
 
 
 
28
  if repo_id in chinese_models:
29
- return chinese_models[repo_id](repo_id)
 
 
30
  elif repo_id in english_models:
31
- return english_models[repo_id](repo_id)
 
 
32
  elif repo_id in chinese_english_mixed_models:
33
- return chinese_english_mixed_models[repo_id](repo_id)
 
 
34
  elif repo_id in tibetan_models:
35
  return tibetan_models[repo_id](repo_id)
 
 
 
36
  else:
37
  raise ValueError(f"Unsupported repo_id: {repo_id}")
38
 
@@ -77,7 +89,11 @@ def _get_token_filename(
77
 
78
 
79
  @lru_cache(maxsize=10)
80
- def _get_aishell2_pretrained_model(repo_id: str) -> OfflineAsr:
 
 
 
 
81
  assert repo_id in [
82
  # context-size 1
83
  "yuekai/icefall-asr-aishell2-pruned-transducer-stateless5-A-2022-07-12", # noqa
@@ -85,44 +101,68 @@ def _get_aishell2_pretrained_model(repo_id: str) -> OfflineAsr:
85
  "yuekai/icefall-asr-aishell2-pruned-transducer-stateless5-B-2022-07-12", # noqa
86
  ], repo_id
87
 
88
- nn_model_filename = _get_nn_model_filename(
89
  repo_id=repo_id,
90
  filename="cpu_jit.pt",
91
  )
92
- token_filename = _get_token_filename(repo_id=repo_id)
93
-
94
- return OfflineAsr(
95
- nn_model_filename=nn_model_filename,
96
- bpe_model_filename=None,
97
- token_filename=token_filename,
98
- sample_rate=sample_rate,
99
- device="cpu",
 
 
 
 
 
 
100
  )
101
 
 
 
 
 
102
 
103
  @lru_cache(maxsize=10)
104
- def _get_gigaspeech_pre_trained_model(repo_id: str) -> OfflineAsr:
 
 
 
 
105
  assert repo_id in [
106
  "wgb14/icefall-asr-gigaspeech-pruned-transducer-stateless2",
107
  ], repo_id
108
 
109
- nn_model_filename = _get_nn_model_filename(
110
  repo_id=repo_id,
111
  filename="cpu_jit-iter-3488000-avg-20.pt",
112
  )
113
- bpe_model_filename = _get_bpe_model_filename(repo_id=repo_id)
114
-
115
- return OfflineAsr(
116
- nn_model_filename=nn_model_filename,
117
- bpe_model_filename=bpe_model_filename,
118
- token_filename=None,
119
- sample_rate=sample_rate,
120
- device="cpu",
 
 
 
 
 
 
121
  )
122
 
123
 
124
  @lru_cache(maxsize=10)
125
- def _get_librispeech_pre_trained_model(repo_id: str) -> OfflineAsr:
 
 
 
 
126
  assert repo_id in [
127
  "WeijiZhuang/icefall-asr-librispeech-pruned-transducer-stateless8-2022-12-02", # noqa
128
  "csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13", # noqa
@@ -143,107 +183,172 @@ def _get_librispeech_pre_trained_model(repo_id: str) -> OfflineAsr:
143
  ):
144
  filename = "cpu_jit-torch-1.10.pt"
145
 
146
- nn_model_filename = _get_nn_model_filename(
147
  repo_id=repo_id,
148
  filename=filename,
149
  )
150
- bpe_model_filename = _get_bpe_model_filename(repo_id=repo_id)
151
-
152
- return OfflineAsr(
153
- nn_model_filename=nn_model_filename,
154
- bpe_model_filename=bpe_model_filename,
155
- token_filename=None,
156
- sample_rate=sample_rate,
157
- device="cpu",
 
158
  )
159
 
 
 
 
 
160
 
161
  @lru_cache(maxsize=10)
162
- def _get_wenetspeech_pre_trained_model(repo_id: str):
 
 
 
 
163
  assert repo_id in [
164
  "luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2",
165
  ], repo_id
166
 
167
- nn_model_filename = _get_nn_model_filename(
168
  repo_id=repo_id,
169
  filename="cpu_jit_epoch_10_avg_2_torch_1.7.1.pt",
170
  )
171
- token_filename = _get_token_filename(repo_id=repo_id)
172
-
173
- return OfflineAsr(
174
- nn_model_filename=nn_model_filename,
175
- bpe_model_filename=None,
176
- token_filename=token_filename,
177
- sample_rate=sample_rate,
178
- device="cpu",
 
 
 
 
 
 
179
  )
180
 
 
 
 
 
181
 
182
  @lru_cache(maxsize=10)
183
- def _get_tal_csasr_pre_trained_model(repo_id: str):
 
 
 
 
184
  assert repo_id in [
185
  "luomingshuang/icefall_asr_tal-csasr_pruned_transducer_stateless5",
186
  ], repo_id
187
 
188
- nn_model_filename = _get_nn_model_filename(
189
  repo_id=repo_id,
190
  filename="cpu_jit.pt",
191
  )
192
- token_filename = _get_token_filename(repo_id=repo_id)
193
-
194
- return OfflineAsr(
195
- nn_model_filename=nn_model_filename,
196
- bpe_model_filename=None,
197
- token_filename=token_filename,
198
- sample_rate=sample_rate,
199
- device="cpu",
 
 
 
 
 
 
200
  )
201
 
 
 
 
 
202
 
203
  @lru_cache(maxsize=10)
204
- def _get_alimeeting_pre_trained_model(repo_id: str):
 
 
 
 
205
  assert repo_id in [
206
  "luomingshuang/icefall_asr_alimeeting_pruned_transducer_stateless2",
207
  ], repo_id
208
 
209
- nn_model_filename = _get_nn_model_filename(
210
  repo_id=repo_id,
211
  filename="cpu_jit_torch_1.7.1.pt",
212
  )
213
- token_filename = _get_token_filename(repo_id=repo_id)
214
-
215
- return OfflineAsr(
216
- nn_model_filename=nn_model_filename,
217
- bpe_model_filename=None,
218
- token_filename=token_filename,
219
- sample_rate=sample_rate,
220
- device="cpu",
 
 
 
 
 
 
221
  )
222
 
 
 
 
 
223
 
224
  @lru_cache(maxsize=10)
225
- def _get_aidatatang_200zh_pretrained_mode(repo_id: str):
 
 
 
 
226
  assert repo_id in [
227
  "luomingshuang/icefall_asr_aidatatang-200zh_pruned_transducer_stateless2",
228
  ], repo_id
229
 
230
- nn_model_filename = _get_nn_model_filename(
231
  repo_id=repo_id,
232
  filename="cpu_jit_torch.1.7.1.pt",
233
  )
234
- token_filename = _get_token_filename(repo_id=repo_id)
235
-
236
- return OfflineAsr(
237
- nn_model_filename=nn_model_filename,
238
- bpe_model_filename=None,
239
- token_filename=token_filename,
240
- sample_rate=sample_rate,
241
- device="cpu",
 
 
 
 
 
 
242
  )
243
 
 
 
 
 
244
 
245
  @lru_cache(maxsize=10)
246
- def _get_tibetan_pre_trained_model(repo_id: str):
 
 
 
 
247
  assert repo_id in [
248
  "syzym/icefall-asr-xbmu-amdo31-pruned-transducer-stateless7-2022-12-02",
249
  "syzym/icefall-asr-xbmu-amdo31-pruned-transducer-stateless5-2022-11-29",
@@ -254,21 +359,33 @@ def _get_tibetan_pre_trained_model(repo_id: str):
254
  repo_id
255
  == "syzym/icefall-asr-xbmu-amdo31-pruned-transducer-stateless5-2022-11-29"
256
  ):
257
- nn_model_filename = _get_nn_model_filename(
258
- repo_id=repo_id,
259
- filename="cpu_jit-epoch-28-avg-23-torch-1.10.0.pt",
260
- )
 
 
 
 
261
 
262
- bpe_model_filename = _get_bpe_model_filename(repo_id=repo_id)
 
 
 
263
 
264
- return OfflineAsr(
265
- nn_model_filename=nn_model_filename,
266
- bpe_model_filename=bpe_model_filename,
267
- token_filename=None,
268
- sample_rate=sample_rate,
269
- device="cpu",
 
270
  )
271
 
 
 
 
 
272
 
273
  chinese_models = {
274
  "luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2": _get_wenetspeech_pre_trained_model, # noqa
16
 
17
  from huggingface_hub import hf_hub_download
18
  from functools import lru_cache
19
+ import sherpa
20
 
21
 
 
 
22
  sample_rate = 16000
23
 
24
 
25
  @lru_cache(maxsize=30)
26
+ def get_pretrained_model(
27
+ repo_id: str,
28
+ decoding_method: str,
29
+ num_active_paths: int,
30
+ ) -> sherpa.OfflineRecognizer:
31
  if repo_id in chinese_models:
32
+ return chinese_models[repo_id](
33
+ repo_id, decoding_method=decoding_method, num_active_paths=num_active_paths
34
+ )
35
  elif repo_id in english_models:
36
+ return english_models[repo_id](
37
+ repo_id, decoding_method=decoding_method, num_active_paths=num_active_paths
38
+ )
39
  elif repo_id in chinese_english_mixed_models:
40
+ return chinese_english_mixed_models[repo_id](
41
+ repo_id, decoding_method=decoding_method, num_active_paths=num_active_paths
42
+ )
43
  elif repo_id in tibetan_models:
44
  return tibetan_models[repo_id](repo_id)
45
+ return tibetan_models[repo_id](
46
+ repo_id, decoding_method=decoding_method, num_active_paths=num_active_paths
47
+ )
48
  else:
49
  raise ValueError(f"Unsupported repo_id: {repo_id}")
50
 
89
 
90
 
91
  @lru_cache(maxsize=10)
92
+ def _get_aishell2_pretrained_model(
93
+ repo_id: str,
94
+ decoding_method: str,
95
+ num_active_paths: int,
96
+ ) -> sherpa.OfflineRecognizer:
97
  assert repo_id in [
98
  # context-size 1
99
  "yuekai/icefall-asr-aishell2-pruned-transducer-stateless5-A-2022-07-12", # noqa
101
  "yuekai/icefall-asr-aishell2-pruned-transducer-stateless5-B-2022-07-12", # noqa
102
  ], repo_id
103
 
104
+ nn_model = _get_nn_model_filename(
105
  repo_id=repo_id,
106
  filename="cpu_jit.pt",
107
  )
108
+ tokens = _get_token_filename(repo_id=repo_id)
109
+
110
+ feat_config = sherpa.FeatureConfig()
111
+ feat_config.fbank_opts.frame_opts.samp_freq = sample_rate
112
+ feat_config.fbank_opts.mel_opts.num_bins = 80
113
+ feat_config.fbank_opts.frame_opts.dither = 0
114
+
115
+ config = sherpa.OfflineRecognizerConfig(
116
+ nn_model=nn_model,
117
+ tokens=tokens,
118
+ use_gpu=False,
119
+ feat_config=feat_config,
120
+ decoding_method=decoding_method,
121
+ num_active_paths=num_active_paths,
122
  )
123
 
124
+ recognizer = sherpa.OfflineRecognizer(config)
125
+
126
+ return recognizer
127
+
128
 
129
  @lru_cache(maxsize=10)
130
+ def _get_gigaspeech_pre_trained_model(
131
+ repo_id: str,
132
+ decoding_method: str,
133
+ num_active_paths: int,
134
+ ) -> sherpa.OfflineRecognizer:
135
  assert repo_id in [
136
  "wgb14/icefall-asr-gigaspeech-pruned-transducer-stateless2",
137
  ], repo_id
138
 
139
+ nn_model = _get_nn_model_filename(
140
  repo_id=repo_id,
141
  filename="cpu_jit-iter-3488000-avg-20.pt",
142
  )
143
+ tokens = "./giga-tokens.txt"
144
+
145
+ feat_config = sherpa.FeatureConfig()
146
+ feat_config.fbank_opts.frame_opts.samp_freq = sample_rate
147
+ feat_config.fbank_opts.mel_opts.num_bins = 80
148
+ feat_config.fbank_opts.frame_opts.dither = 0
149
+
150
+ config = sherpa.OfflineRecognizerConfig(
151
+ nn_model=nn_model,
152
+ tokens=tokens,
153
+ use_gpu=False,
154
+ feat_config=feat_config,
155
+ decoding_method=decoding_method,
156
+ num_active_paths=num_active_paths,
157
  )
158
 
159
 
160
  @lru_cache(maxsize=10)
161
+ def _get_librispeech_pre_trained_model(
162
+ repo_id: str,
163
+ decoding_method: str,
164
+ num_active_paths: int,
165
+ ) -> sherpa.OfflineRecognizer:
166
  assert repo_id in [
167
  "WeijiZhuang/icefall-asr-librispeech-pruned-transducer-stateless8-2022-12-02", # noqa
168
  "csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13", # noqa
183
  ):
184
  filename = "cpu_jit-torch-1.10.pt"
185
 
186
+ nn_model = _get_nn_model_filename(
187
  repo_id=repo_id,
188
  filename=filename,
189
  )
190
+ tokens = _get_token_filename(repo_id=repo_id, subfolder="data/lang_bpe_500")
191
+
192
+ config = sherpa.OfflineRecognizerConfig(
193
+ nn_model=nn_model,
194
+ tokens=tokens,
195
+ use_gpu=False,
196
+ feat_config=feat_config,
197
+ decoding_method=decoding_method,
198
+ num_active_paths=num_active_paths,
199
  )
200
 
201
+ recognizer = sherpa.OfflineRecognizer(config)
202
+
203
+ return recognizer
204
+
205
 
206
  @lru_cache(maxsize=10)
207
+ def _get_wenetspeech_pre_trained_model(
208
+ repo_id: str,
209
+ decoding_method: str,
210
+ num_active_paths: int,
211
+ ):
212
  assert repo_id in [
213
  "luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2",
214
  ], repo_id
215
 
216
+ nn_model = _get_nn_model_filename(
217
  repo_id=repo_id,
218
  filename="cpu_jit_epoch_10_avg_2_torch_1.7.1.pt",
219
  )
220
+ tokens = _get_token_filename(repo_id=repo_id)
221
+
222
+ feat_config = sherpa.FeatureConfig()
223
+ feat_config.fbank_opts.frame_opts.samp_freq = sample_rate
224
+ feat_config.fbank_opts.mel_opts.num_bins = 80
225
+ feat_config.fbank_opts.frame_opts.dither = 0
226
+
227
+ config = sherpa.OfflineRecognizerConfig(
228
+ nn_model=nn_model,
229
+ tokens=tokens,
230
+ use_gpu=False,
231
+ feat_config=feat_config,
232
+ decoding_method=decoding_method,
233
+ num_active_paths=num_active_paths,
234
  )
235
 
236
+ recognizer = sherpa.OfflineRecognizer(config)
237
+
238
+ return recognizer
239
+
240
 
241
  @lru_cache(maxsize=10)
242
+ def _get_tal_csasr_pre_trained_model(
243
+ repo_id: str,
244
+ decoding_method: str,
245
+ num_active_paths: int,
246
+ ):
247
  assert repo_id in [
248
  "luomingshuang/icefall_asr_tal-csasr_pruned_transducer_stateless5",
249
  ], repo_id
250
 
251
+ nn_model = _get_nn_model_filename(
252
  repo_id=repo_id,
253
  filename="cpu_jit.pt",
254
  )
255
+ tokens = _get_token_filename(repo_id=repo_id)
256
+
257
+ feat_config = sherpa.FeatureConfig()
258
+ feat_config.fbank_opts.frame_opts.samp_freq = sample_rate
259
+ feat_config.fbank_opts.mel_opts.num_bins = 80
260
+ feat_config.fbank_opts.frame_opts.dither = 0
261
+
262
+ config = sherpa.OfflineRecognizerConfig(
263
+ nn_model=nn_model,
264
+ tokens=tokens,
265
+ use_gpu=False,
266
+ feat_config=feat_config,
267
+ decoding_method=decoding_method,
268
+ num_active_paths=num_active_paths,
269
  )
270
 
271
+ recognizer = sherpa.OfflineRecognizer(config)
272
+
273
+ return recognizer
274
+
275
 
276
  @lru_cache(maxsize=10)
277
+ def _get_alimeeting_pre_trained_model(
278
+ repo_id: str,
279
+ decoding_method: str,
280
+ num_active_paths: int,
281
+ ):
282
  assert repo_id in [
283
  "luomingshuang/icefall_asr_alimeeting_pruned_transducer_stateless2",
284
  ], repo_id
285
 
286
+ nn_model = _get_nn_model_filename(
287
  repo_id=repo_id,
288
  filename="cpu_jit_torch_1.7.1.pt",
289
  )
290
+ tokens = _get_token_filename(repo_id=repo_id)
291
+
292
+ feat_config = sherpa.FeatureConfig()
293
+ feat_config.fbank_opts.frame_opts.samp_freq = sample_rate
294
+ feat_config.fbank_opts.mel_opts.num_bins = 80
295
+ feat_config.fbank_opts.frame_opts.dither = 0
296
+
297
+ config = sherpa.OfflineRecognizerConfig(
298
+ nn_model=nn_model,
299
+ tokens=tokens,
300
+ use_gpu=False,
301
+ feat_config=feat_config,
302
+ decoding_method=decoding_method,
303
+ num_active_paths=num_active_paths,
304
  )
305
 
306
+ recognizer = sherpa.OfflineRecognizer(config)
307
+
308
+ return recognizer
309
+
310
 
311
  @lru_cache(maxsize=10)
312
+ def _get_aidatatang_200zh_pretrained_mode(
313
+ repo_id: str,
314
+ decoding_method: str,
315
+ num_active_paths: int,
316
+ ):
317
  assert repo_id in [
318
  "luomingshuang/icefall_asr_aidatatang-200zh_pruned_transducer_stateless2",
319
  ], repo_id
320
 
321
+ nn_model = _get_nn_model_filename(
322
  repo_id=repo_id,
323
  filename="cpu_jit_torch.1.7.1.pt",
324
  )
325
+ tokens = _get_token_filename(repo_id=repo_id)
326
+
327
+ feat_config = sherpa.FeatureConfig()
328
+ feat_config.fbank_opts.frame_opts.samp_freq = sample_rate
329
+ feat_config.fbank_opts.mel_opts.num_bins = 80
330
+ feat_config.fbank_opts.frame_opts.dither = 0
331
+
332
+ config = sherpa.OfflineRecognizerConfig(
333
+ nn_model=nn_model,
334
+ tokens=tokens,
335
+ use_gpu=False,
336
+ feat_config=feat_config,
337
+ decoding_method=decoding_method,
338
+ num_active_paths=num_active_paths,
339
  )
340
 
341
+ recognizer = sherpa.OfflineRecognizer(config)
342
+
343
+ return recognizer
344
+
345
 
346
  @lru_cache(maxsize=10)
347
+ def _get_tibetan_pre_trained_model(
348
+ repo_id: str,
349
+ decoding_method: str,
350
+ num_active_paths: int,
351
+ ):
352
  assert repo_id in [
353
  "syzym/icefall-asr-xbmu-amdo31-pruned-transducer-stateless7-2022-12-02",
354
  "syzym/icefall-asr-xbmu-amdo31-pruned-transducer-stateless5-2022-11-29",
359
  repo_id
360
  == "syzym/icefall-asr-xbmu-amdo31-pruned-transducer-stateless5-2022-11-29"
361
  ):
362
+ filename = ("cpu_jit-epoch-28-avg-23-torch-1.10.0.pt",)
363
+
364
+ nn_model = _get_nn_model_filename(
365
+ repo_id=repo_id,
366
+ filename=filename,
367
+ )
368
+
369
+ tokens = _get_token_filename(repo_id=repo_id, subfolder="data/lang_bpe_500")
370
 
371
+ feat_config = sherpa.FeatureConfig()
372
+ feat_config.fbank_opts.frame_opts.samp_freq = sample_rate
373
+ feat_config.fbank_opts.mel_opts.num_bins = 80
374
+ feat_config.fbank_opts.frame_opts.dither = 0
375
 
376
+ config = sherpa.OfflineRecognizerConfig(
377
+ nn_model=nn_model,
378
+ tokens=tokens,
379
+ use_gpu=False,
380
+ feat_config=feat_config,
381
+ decoding_method=decoding_method,
382
+ num_active_paths=num_active_paths,
383
  )
384
 
385
+ recognizer = sherpa.OfflineRecognizer(config)
386
+
387
+ return recognizer
388
+
389
 
390
  chinese_models = {
391
  "luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2": _get_wenetspeech_pre_trained_model, # noqa
offline_asr.py DELETED
@@ -1,427 +0,0 @@
1
- #!/usr/bin/env python3
2
- # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
3
- #
4
- # Copied from https://github.com/k2-fsa/sherpa/blob/master/sherpa/bin/conformer_rnnt/offline_asr.py
5
- #
6
- # See LICENSE for clarification regarding multiple authors
7
- #
8
- # Licensed under the Apache License, Version 2.0 (the "License");
9
- # you may not use this file except in compliance with the License.
10
- # You may obtain a copy of the License at
11
- #
12
- # http://www.apache.org/licenses/LICENSE-2.0
13
- #
14
- # Unless required by applicable law or agreed to in writing, software
15
- # distributed under the License is distributed on an "AS IS" BASIS,
16
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
- # See the License for the specific language governing permissions and
18
- # limitations under the License.
19
- """
20
- A standalone script for offline ASR recognition.
21
-
22
- It loads a torchscript model, decodes the given wav files, and exits.
23
-
24
- Usage:
25
- ./offline_asr.py --help
26
-
27
- For BPE based models (e.g., LibriSpeech):
28
-
29
- ./offline_asr.py \
30
- --nn-model-filename /path/to/cpu_jit.pt \
31
- --bpe-model-filename /path/to/bpe.model \
32
- --decoding-method greedy_search \
33
- ./foo.wav \
34
- ./bar.wav \
35
- ./foobar.wav
36
-
37
- For character based models (e.g., aishell):
38
-
39
- ./offline.py \
40
- --nn-model-filename /path/to/cpu_jit.pt \
41
- --token-filename /path/to/lang_char/tokens.txt \
42
- --decoding-method greedy_search \
43
- ./foo.wav \
44
- ./bar.wav \
45
- ./foobar.wav
46
-
47
- Note: We provide pre-trained models for testing.
48
-
49
- (1) Pre-trained model with the LibriSpeech dataset
50
-
51
- sudo apt-get install git-lfs
52
- git lfs install
53
- git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
54
-
55
- nn_model_filename=./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/cpu_jit-torch-1.6.0.pt
56
- bpe_model=./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/bpe.model
57
-
58
- wav1=./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
59
- wav2=./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav
60
- wav3=./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav
61
-
62
- sherpa/bin/conformer_rnnt/offline_asr.py \
63
- --nn-model-filename $nn_model_filename \
64
- --bpe-model $bpe_model \
65
- $wav1 \
66
- $wav2 \
67
- $wav3
68
-
69
- (2) Pre-trained model with the aishell dataset
70
-
71
- sudo apt-get install git-lfs
72
- git lfs install
73
- git clone https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20
74
-
75
- nn_model_filename=./icefall-aishell-pruned-transducer-stateless3-2022-06-20/exp/cpu_jit-epoch-29-avg-5-torch-1.6.0.pt
76
- token_filename=./icefall-aishell-pruned-transducer-stateless3-2022-06-20/data/lang_char/tokens.txt
77
-
78
- wav1=./icefall-aishell-pruned-transducer-stateless3-2022-06-20/test_wavs/BAC009S0764W0121.wav
79
- wav2=./icefall-aishell-pruned-transducer-stateless3-2022-06-20/test_wavs/BAC009S0764W0122.wav
80
- wav3=./icefall-aishell-pruned-transducer-stateless3-2022-06-20/test_wavs/BAC009S0764W0123.wav
81
-
82
- sherpa/bin/conformer_rnnt/offline_asr.py \
83
- --nn-model-filename $nn_model_filename \
84
- --token-filename $token_filename \
85
- $wav1 \
86
- $wav2 \
87
- $wav3
88
- """
89
- import argparse
90
- import functools
91
- import logging
92
- from typing import List, Optional, Union
93
-
94
- import k2
95
- import kaldifeat
96
- import sentencepiece as spm
97
- import torch
98
- import torchaudio
99
- from sherpa import RnntConformerModel
100
-
101
- from decode import run_model_and_do_greedy_search, run_model_and_do_modified_beam_search
102
-
103
-
104
- def get_args():
105
- parser = argparse.ArgumentParser(
106
- formatter_class=argparse.ArgumentDefaultsHelpFormatter
107
- )
108
-
109
- parser.add_argument(
110
- "--nn-model-filename",
111
- type=str,
112
- help="""The torchscript model. You can use
113
- icefall/egs/librispeech/ASR/pruned_transducer_statelessX/export.py \
114
- --jit=1
115
- to generate this model.
116
- """,
117
- )
118
-
119
- parser.add_argument(
120
- "--bpe-model-filename",
121
- type=str,
122
- help="""The BPE model
123
- You can find it in the directory egs/librispeech/ASR/data/lang_bpe_xxx
124
- from icefall,
125
- where xxx is the number of BPE tokens you used to train the model.
126
- Note: Use it only when your model is using BPE. You don't need to
127
- provide it if you provide `--token-filename`
128
- """,
129
- )
130
-
131
- parser.add_argument(
132
- "--token-filename",
133
- type=str,
134
- help="""Filename for tokens.txt
135
- You can find it in the directory
136
- egs/aishell/ASR/data/lang_char/tokens.txt from icefall.
137
- Note: You don't need to provide it if you provide `--bpe-model`
138
- """,
139
- )
140
-
141
- parser.add_argument(
142
- "--decoding-method",
143
- type=str,
144
- default="greedy_search",
145
- help="""Decoding method to use. Currently, only greedy_search and
146
- modified_beam_search are implemented.
147
- """,
148
- )
149
-
150
- parser.add_argument(
151
- "--num-active-paths",
152
- type=int,
153
- default=4,
154
- help="""Used only when decoding_method is modified_beam_search.
155
- It specifies number of active paths for each utterance. Due to
156
- merging paths with identical token sequences, the actual number
157
- may be less than "num_active_paths".
158
- """,
159
- )
160
-
161
- parser.add_argument(
162
- "--sample-rate",
163
- type=int,
164
- default=16000,
165
- help="The expected sample rate of the input sound files",
166
- )
167
-
168
- parser.add_argument(
169
- "sound_files",
170
- type=str,
171
- nargs="+",
172
- help="The input sound file(s) to transcribe. "
173
- "Supported formats are those supported by torchaudio.load(). "
174
- "For example, wav and flac are supported. "
175
- "The sample rate has to equal to `--sample-rate`.",
176
- )
177
-
178
- return parser.parse_args()
179
-
180
-
181
- def read_sound_files(
182
- filenames: List[str],
183
- expected_sample_rate: int,
184
- ) -> List[torch.Tensor]:
185
- """Read a list of sound files into a list 1-D float32 torch tensors.
186
- Args:
187
- filenames:
188
- A list of sound filenames.
189
- expected_sample_rate:
190
- The expected sample rate of the sound files.
191
- Returns:
192
- Return a list of 1-D float32 torch tensors.
193
- """
194
- ans = []
195
- for f in filenames:
196
- wave, sample_rate = torchaudio.load(f)
197
- assert sample_rate == expected_sample_rate, (
198
- f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
199
- )
200
- # We use only the first channel
201
- ans.append(wave[0])
202
- return ans
203
-
204
-
205
- class OfflineAsr(object):
206
- def __init__(
207
- self,
208
- nn_model_filename: str,
209
- bpe_model_filename: Optional[str] = None,
210
- token_filename: Optional[str] = None,
211
- decoding_method: str = "greedy_search",
212
- num_active_paths: int = 4,
213
- sample_rate: int = 16000,
214
- device: Union[str, torch.device] = "cpu",
215
- ):
216
- """
217
- Args:
218
- nn_model_filename:
219
- Path to the torch script model.
220
- bpe_model_filename:
221
- Path to the BPE model. If it is None, you have to provide
222
- `token_filename`.
223
- token_filename:
224
- Path to tokens.txt. If it is None, you have to provide
225
- `bpe_model_filename`.
226
- sample_rate:
227
- Expected sample rate of the feature extractor.
228
- device:
229
- The device to use for computation.
230
- """
231
- self.model = RnntConformerModel(
232
- filename=nn_model_filename,
233
- device=device,
234
- optimize_for_inference=False,
235
- )
236
-
237
- if bpe_model_filename:
238
- self.sp = spm.SentencePieceProcessor()
239
- self.sp.load(bpe_model_filename)
240
- else:
241
- assert token_filename is not None, token_filename
242
- self.token_table = k2.SymbolTable.from_file(token_filename)
243
-
244
- self.feature_extractor = self._build_feature_extractor(
245
- sample_rate=sample_rate,
246
- device=device,
247
- )
248
-
249
- self.device = device
250
-
251
- def _build_feature_extractor(
252
- self,
253
- sample_rate: int = 16000,
254
- device: Union[str, torch.device] = "cpu",
255
- ) -> kaldifeat.OfflineFeature:
256
- """Build a fbank feature extractor for extracting features.
257
-
258
- Args:
259
- sample_rate:
260
- Expected sample rate of the feature extractor.
261
- device:
262
- The device to use for computation.
263
- Returns:
264
- Return a fbank feature extractor.
265
- """
266
- opts = kaldifeat.FbankOptions()
267
- opts.device = device
268
- opts.frame_opts.dither = 0
269
- opts.frame_opts.snip_edges = False
270
- opts.frame_opts.samp_freq = sample_rate
271
- opts.mel_opts.num_bins = 80
272
-
273
- fbank = kaldifeat.Fbank(opts)
274
-
275
- return fbank
276
-
277
- def decode_waves(
278
- self,
279
- waves: List[torch.Tensor],
280
- decoding_method: str,
281
- num_active_paths: int,
282
- ) -> List[List[str]]:
283
- """
284
- Args:
285
- waves:
286
- A list of 1-D torch.float32 tensors containing audio samples.
287
- wavs[i] contains audio samples for the i-th utterance.
288
-
289
- Note:
290
- Whether it should be in the range [-32768, 32767] or be normalized
291
- to [-1, 1] depends on which range you used for your training data.
292
- For instance, if your training data used [-32768, 32767],
293
- then the given waves have to contain samples in this range.
294
-
295
- All models trained in icefall use the normalized range [-1, 1].
296
- decoding_method:
297
- The decoding method to use. Currently, only greedy_search and
298
- modified_beam_search are implemented.
299
- num_active_paths:
300
- Used only when decoding_method is modified_beam_search.
301
- It specifies number of active paths for each utterance. Due to
302
- merging paths with identical token sequences, the actual number
303
- may be less than "num_active_paths".
304
- Returns:
305
- Return a list of decoded results. `ans[i]` contains the decoded
306
- results for `wavs[i]`.
307
- """
308
- assert decoding_method in (
309
- "greedy_search",
310
- "modified_beam_search",
311
- ), decoding_method
312
-
313
- if decoding_method == "greedy_search":
314
- nn_and_decoding_func = run_model_and_do_greedy_search
315
- elif decoding_method == "modified_beam_search":
316
- nn_and_decoding_func = functools.partial(
317
- run_model_and_do_modified_beam_search,
318
- num_active_paths=num_active_paths,
319
- )
320
- else:
321
- raise ValueError(
322
- f"Unsupported decoding_method: {decoding_method} "
323
- "Please use greedy_search or modified_beam_search"
324
- )
325
-
326
- waves = [w.to(self.device) for w in waves]
327
- features = self.feature_extractor(waves)
328
-
329
- tokens = nn_and_decoding_func(self.model, features)
330
-
331
- if hasattr(self, "sp"):
332
- results = self.sp.decode(tokens)
333
- else:
334
- results = [[self.token_table[i] for i in hyp] for hyp in tokens]
335
- blank = chr(0x2581)
336
- results = ["".join(r) for r in results]
337
- results = [r.replace(blank, " ") for r in results]
338
-
339
- return results
340
-
341
-
342
- @torch.no_grad()
343
- def main():
344
- args = get_args()
345
- logging.info(vars(args))
346
-
347
- nn_model_filename = args.nn_model_filename
348
- bpe_model_filename = args.bpe_model_filename
349
- token_filename = args.token_filename
350
- decoding_method = args.decoding_method
351
- num_active_paths = args.num_active_paths
352
- sample_rate = args.sample_rate
353
- sound_files = args.sound_files
354
-
355
- assert decoding_method in ("greedy_search", "modified_beam_search"), decoding_method
356
-
357
- if decoding_method == "modified_beam_search":
358
- assert num_active_paths >= 1, num_active_paths
359
-
360
- if bpe_model_filename:
361
- assert token_filename is None
362
-
363
- if token_filename:
364
- assert bpe_model_filename is None
365
-
366
- device = torch.device("cpu")
367
- if torch.cuda.is_available():
368
- device = torch.device("cuda", 0)
369
-
370
- logging.info(f"device: {device}")
371
-
372
- offline_asr = OfflineAsr(
373
- nn_model_filename=nn_model_filename,
374
- bpe_model_filename=bpe_model_filename,
375
- token_filename=token_filename,
376
- decoding_method=decoding_method,
377
- num_active_paths=num_active_paths,
378
- sample_rate=sample_rate,
379
- device=device,
380
- )
381
-
382
- waves = read_sound_files(
383
- filenames=sound_files,
384
- expected_sample_rate=sample_rate,
385
- )
386
-
387
- logging.info("Decoding started.")
388
-
389
- hyps = offline_asr.decode_waves(waves)
390
-
391
- s = "\n"
392
- for filename, hyp in zip(sound_files, hyps):
393
- s += f"{filename}:\n{hyp}\n\n"
394
- logging.info(s)
395
-
396
- logging.info("Decoding done.")
397
-
398
-
399
- torch.set_num_threads(1)
400
- torch.set_num_interop_threads(1)
401
-
402
- # See https://github.com/pytorch/pytorch/issues/38342
403
- # and https://github.com/pytorch/pytorch/issues/33354
404
- #
405
- # If we don't do this, the delay increases whenever there is
406
- # a new request that changes the actual batch size.
407
- # If you use `py-spy dump --pid <server-pid> --native`, you will
408
- # see a lot of time is spent in re-compiling the torch script model.
409
- torch._C._jit_set_profiling_executor(False)
410
- torch._C._jit_set_profiling_mode(False)
411
- torch._C._set_graph_executor_optimize(False)
412
- """
413
- // Use the following in C++
414
- torch::jit::getExecutorMode() = false;
415
- torch::jit::getProfilingMode() = false;
416
- torch::jit::setGraphExecutorOptimize(false);
417
- """
418
-
419
- if __name__ == "__main__":
420
- torch.manual_seed(20220609)
421
-
422
- formatter = (
423
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa
424
- )
425
- logging.basicConfig(format=formatter, level=logging.INFO)
426
-
427
- main()