csukuangfj commited on
Commit
b0eec9a
1 Parent(s): 5decad3

Add Japanese models

Browse files
app.py CHANGED
@@ -30,7 +30,7 @@ import torch
30
  import torchaudio
31
 
32
  from examples import examples
33
- from model import get_pretrained_model, language_to_models, sample_rate
34
 
35
  languages = list(language_to_models.keys())
36
 
@@ -146,12 +146,8 @@ def process(
146
  decoding_method=decoding_method,
147
  num_active_paths=num_active_paths,
148
  )
149
- s = recognizer.create_stream()
150
 
151
- s.accept_wave_file(filename)
152
- recognizer.decode_stream(s)
153
-
154
- text = s.result.text.strip()
155
 
156
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
157
  end = time.time()
 
30
  import torchaudio
31
 
32
  from examples import examples
33
+ from model import decode, get_pretrained_model, language_to_models, sample_rate
34
 
35
  languages = list(language_to_models.keys())
36
 
 
146
  decoding_method=decoding_method,
147
  num_active_paths=num_active_paths,
148
  )
 
149
 
150
+ text = decode(recognizer, filename)
 
 
 
151
 
152
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
153
  end = time.time()
model.py CHANGED
@@ -17,6 +17,7 @@
17
  from huggingface_hub import hf_hub_download
18
  from functools import lru_cache
19
  import os
 
20
 
21
  os.system(
22
  "cp -v /home/user/.local/lib/python3.8/site-packages/k2/lib/*.so /home/user/.local/lib/python3.8/site-packages/sherpa/lib/"
@@ -29,6 +30,56 @@ import sherpa
29
  sample_rate = 16000
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  @lru_cache(maxsize=30)
33
  def get_pretrained_model(
34
  repo_id: str,
@@ -547,6 +598,55 @@ def _get_german_pre_trained_model(
547
  return recognizer
548
 
549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
  chinese_models = {
551
  "luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2": _get_wenetspeech_pre_trained_model, # noqa
552
  "desh2608/icefall-asr-alimeeting-pruned-transducer-stateless7": _get_alimeeting_pre_trained_model,
@@ -555,6 +655,7 @@ chinese_models = {
555
  "luomingshuang/icefall_asr_aidatatang-200zh_pruned_transducer_stateless2": _get_aidatatang_200zh_pretrained_mode, # noqa
556
  "luomingshuang/icefall_asr_alimeeting_pruned_transducer_stateless2": _get_alimeeting_pre_trained_model, # noqa
557
  "csukuangfj/wenet-chinese-model": _get_wenet_model,
 
558
  }
559
 
560
  english_models = {
@@ -587,10 +688,16 @@ german_models = {
587
  "csukuangfj/wav2vec2.0-torchaudio": _get_german_pre_trained_model,
588
  }
589
 
 
 
 
 
 
590
  all_models = {
591
  **chinese_models,
592
  **english_models,
593
  **chinese_english_mixed_models,
 
594
  **tibetan_models,
595
  **arabic_models,
596
  **german_models,
@@ -600,6 +707,7 @@ language_to_models = {
600
  "Chinese": list(chinese_models.keys()),
601
  "English": list(english_models.keys()),
602
  "Chinese+English": list(chinese_english_mixed_models.keys()),
 
603
  "Tibetan": list(tibetan_models.keys()),
604
  "Arabic": list(arabic_models.keys()),
605
  "German": list(german_models.keys()),
 
17
  from huggingface_hub import hf_hub_download
18
  from functools import lru_cache
19
  import os
20
+ import torchaudio
21
 
22
  os.system(
23
  "cp -v /home/user/.local/lib/python3.8/site-packages/k2/lib/*.so /home/user/.local/lib/python3.8/site-packages/sherpa/lib/"
 
30
  sample_rate = 16000
31
 
32
 
33
+ def decode_offline_recognizer(
34
+ recognizer: Union[sherpa.OfflineRecognizer, sherpa.OnlineRecognizer],
35
+ filename: str,
36
+ ) -> str:
37
+ s = recognizer.create_stream()
38
+
39
+ s.accept_wave_file(filename)
40
+ recognizer.decode_stream(s)
41
+
42
+ text = s.result.text.strip()
43
+ return text.lower()
44
+
45
+
46
+ def decode_online_recognizer(
47
+ recognizer: Union[sherpa.OfflineRecognizer, sherpa.OnlineRecognizer],
48
+ filename: str,
49
+ ) -> str:
50
+ samples, actual_sample_rate = torchaudio.load(filename)
51
+ assert sample_rate == actual_sample_rate, (
52
+ sample_rate,
53
+ actual_sample_rate,
54
+ )
55
+ samples = samples[0].contiguous()
56
+
57
+ s = recognizer.create_stream()
58
+
59
+ tail_padding = torch.zeros(int(sample_rate * 0.3), dtype=torch.float32)
60
+ s.accept_waveform(sample_rate, samples)
61
+ s.accept_waveform(sample_rate, tail_padding)
62
+ s.input_finished()
63
+
64
+ while recognizer.is_ready(s):
65
+ recognizer.decode_stream(s)
66
+
67
+ text = recognizer.get_result(s).text
68
+ return text.strip().lower()
69
+
70
+
71
+ def decode(
72
+ recognizer: Union[sherpa.OfflineRecognizer, sherpa.OnlineRecognizer],
73
+ filename: str,
74
+ ) -> str:
75
+ if isinstance(recognizer, sherpa.OfflineRecognizer):
76
+ return decode_offline_recognizer(recognizer, filename)
77
+ elif isinstance(recognizer, sherpa.OnlineRecognizer):
78
+ return decode_online_recognizer(recognizer, filename)
79
+ else:
80
+ raise ValueError(f"Unknown recongizer type {type(recognizer)}")
81
+
82
+
83
  @lru_cache(maxsize=30)
84
  def get_pretrained_model(
85
  repo_id: str,
 
598
  return recognizer
599
 
600
 
601
+ @lru_cache(maxsize=10)
602
+ def _get_japanese_pre_trained_model(
603
+ repo_id: str,
604
+ decoding_method: str,
605
+ num_active_paths: int,
606
+ ):
607
+ repo_id, kind = repo_id.rsplit("-", maxsplit=1)
608
+
609
+ assert repo_id in [
610
+ "TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208"
611
+ ], repo_id
612
+ assert kind in ("fluent", "disfluent"), kind
613
+
614
+ encoder_model = _get_nn_model_filename(
615
+ repo_id=repo_id, filename="encoder_jit_trace.pt", subfolder=f"exp_{kind}"
616
+ )
617
+
618
+ decoder_model = _get_nn_model_filename(
619
+ repo_id=repo_id, filename="decoder_jit_trace.pt", subfolder=f"exp_{kind}"
620
+ )
621
+
622
+ joiner_model = _get_nn_model_filename(
623
+ repo_id=repo_id, filename="joiner_jit_trace.pt", subfolder=f"exp_{kind}"
624
+ )
625
+
626
+ tokens = _get_token_filename(repo_id=repo_id)
627
+
628
+ feat_config = sherpa.FeatureConfig()
629
+ feat_config.fbank_opts.frame_opts.samp_freq = sample_rate
630
+ feat_config.fbank_opts.mel_opts.num_bins = 80
631
+ feat_config.fbank_opts.frame_opts.dither = 0
632
+
633
+ config = sherpa.OnlineRecognizerConfig(
634
+ nn_model="",
635
+ encoder_model=encoder_model,
636
+ decoder_model=decoder_model,
637
+ joiner_model=joiner_model,
638
+ tokens=tokens,
639
+ use_gpu=False,
640
+ feat_config=feat_config,
641
+ decoding_method="greedy_search",
642
+ chunk_size=32,
643
+ )
644
+
645
+ recognizer = sherpa.OnlineRecognizer(config)
646
+
647
+ return recognizer
648
+
649
+
650
  chinese_models = {
651
  "luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2": _get_wenetspeech_pre_trained_model, # noqa
652
  "desh2608/icefall-asr-alimeeting-pruned-transducer-stateless7": _get_alimeeting_pre_trained_model,
 
655
  "luomingshuang/icefall_asr_aidatatang-200zh_pruned_transducer_stateless2": _get_aidatatang_200zh_pretrained_mode, # noqa
656
  "luomingshuang/icefall_asr_alimeeting_pruned_transducer_stateless2": _get_alimeeting_pre_trained_model, # noqa
657
  "csukuangfj/wenet-chinese-model": _get_wenet_model,
658
+ "csukuangfj/icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14": _get_lstm_transducer_model,
659
  }
660
 
661
  english_models = {
 
688
  "csukuangfj/wav2vec2.0-torchaudio": _get_german_pre_trained_model,
689
  }
690
 
691
+ japanese_models = {
692
+ "TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208-fluent": _get_japanese_pre_trained_model,
693
+ "TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208-disfluent": _get_japanese_pre_trained_model,
694
+ }
695
+
696
  all_models = {
697
  **chinese_models,
698
  **english_models,
699
  **chinese_english_mixed_models,
700
+ **japanese_models,
701
  **tibetan_models,
702
  **arabic_models,
703
  **german_models,
 
707
  "Chinese": list(chinese_models.keys()),
708
  "English": list(english_models.keys()),
709
  "Chinese+English": list(chinese_english_mixed_models.keys()),
710
+ "Japanese": list(japanese_models.keys()),
711
  "Tibetan": list(tibetan_models.keys()),
712
  "Arabic": list(arabic_models.keys()),
713
  "German": list(german_models.keys()),
test_wavs/alimeeting/165.wav ADDED
Binary file (263 kB). View file
 
test_wavs/alimeeting/209.wav ADDED
Binary file (155 kB). View file
 
test_wavs/alimeeting/74.wav ADDED
Binary file (120 kB). View file