csukuangfj commited on
Commit
39b3b3e
1 Parent(s): 991cd55

small fixes

Browse files
Files changed (3) hide show
  1. app.py +27 -18
  2. model.py +159 -22
  3. offline_asr.py +5 -4
app.py CHANGED
@@ -26,16 +26,9 @@ from datetime import datetime
26
  import gradio as gr
27
  import torchaudio
28
 
29
- from model import (
30
- get_gigaspeech_pre_trained_model,
31
- sample_rate,
32
- get_wenetspeech_pre_trained_model,
33
- )
34
 
35
- models = {
36
- "Chinese": get_wenetspeech_pre_trained_model(),
37
- "English": get_gigaspeech_pre_trained_model(),
38
- }
39
 
40
 
41
  def convert_to_wav(in_filename: str) -> str:
@@ -46,12 +39,10 @@ def convert_to_wav(in_filename: str) -> str:
46
  return out_filename
47
 
48
 
49
- demo = gr.Blocks()
50
-
51
-
52
- def process(in_filename: str, language: str) -> str:
53
  print("in_filename", in_filename)
54
  print("language", language)
 
55
  filename = convert_to_wav(in_filename)
56
 
57
  now = datetime.now()
@@ -74,7 +65,7 @@ def process(in_filename: str, language: str) -> str:
74
  )
75
  wave = wave[0] # use only the first channel.
76
 
77
- hyp = models[language].decode_waves([wave])[0]
78
 
79
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
80
  end = time.time()
@@ -103,14 +94,32 @@ See more information by visiting the following links:
103
  - <https://github.com/lhotse-speech/lhotse>
104
  """
105
 
 
 
 
 
 
 
 
 
 
 
 
106
  with demo:
107
  gr.Markdown(title)
108
  gr.Markdown(description)
109
- language_choices = list(models.keys())
110
- language = gr.inputs.Radio(
 
111
  label="Language",
112
  choices=language_choices,
113
  )
 
 
 
 
 
 
114
 
115
  with gr.Tabs():
116
  with gr.TabItem("Upload from disk"):
@@ -140,12 +149,12 @@ with demo:
140
 
141
  upload_button.click(
142
  process,
143
- inputs=[uploaded_file, language],
144
  outputs=uploaded_output,
145
  )
146
  record_button.click(
147
  process,
148
- inputs=[microphone, language],
149
  outputs=recorded_output,
150
  )
151
 
 
26
  import gradio as gr
27
  import torchaudio
28
 
29
+ from model import get_pretrained_model, language_to_models, sample_rate
 
 
 
 
30
 
31
+ languages = sorted(language_to_models.keys())
 
 
 
32
 
33
 
34
  def convert_to_wav(in_filename: str) -> str:
 
39
  return out_filename
40
 
41
 
42
+ def process(in_filename: str, language: str, repo_id: str) -> str:
 
 
 
43
  print("in_filename", in_filename)
44
  print("language", language)
45
+ print("repo_id", repo_id)
46
  filename = convert_to_wav(in_filename)
47
 
48
  now = datetime.now()
 
65
  )
66
  wave = wave[0] # use only the first channel.
67
 
68
+ hyp = get_pretrained_model(repo_id).decode_waves([wave])[0]
69
 
70
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
71
  end = time.time()
 
94
  - <https://github.com/lhotse-speech/lhotse>
95
  """
96
 
97
+
98
+ def update_model_dropdown(language: str):
99
+ if language in language_to_models:
100
+ choices = language_to_models[language]
101
+ return gr.Dropdown.update(choices=choices, value=choices[0])
102
+
103
+ raise ValueError(f"Unsupported language: {language}")
104
+
105
+
106
+ demo = gr.Blocks()
107
+
108
  with demo:
109
  gr.Markdown(title)
110
  gr.Markdown(description)
111
+ language_choices = list(language_to_models.keys())
112
+
113
+ language_radio = gr.Radio(
114
  label="Language",
115
  choices=language_choices,
116
  )
117
+ model_dropdown = gr.Dropdown(choices=[], label="Select a model")
118
+ language_radio.change(
119
+ update_model_dropdown,
120
+ inputs=language_radio,
121
+ outputs=model_dropdown,
122
+ )
123
 
124
  with gr.Tabs():
125
  with gr.TabItem("Upload from disk"):
 
149
 
150
  upload_button.click(
151
  process,
152
+ inputs=[uploaded_file, language_radio, model_dropdown],
153
  outputs=uploaded_output,
154
  )
155
  record_button.click(
156
  process,
157
+ inputs=[microphone, language_radio, model_dropdown],
158
  outputs=recorded_output,
159
  )
160
 
model.py CHANGED
@@ -23,52 +23,189 @@ from offline_asr import OfflineAsr
23
  sample_rate = 16000
24
 
25
 
26
- @lru_cache(maxsize=1)
27
- def get_gigaspeech_pre_trained_model():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  nn_model_filename = hf_hub_download(
29
- # It is converted from https://huggingface.co/wgb14/icefall-asr-gigaspeech-pruned-transducer-stateless2
30
- repo_id="csukuangfj/icefall-asr-gigaspeech-pruned-transducer-stateless2",
31
- filename="cpu_jit-epoch-29-avg-11-torch-1.10.0.pt",
32
- subfolder="exp",
33
  )
 
 
34
 
 
 
 
 
 
35
  bpe_model_filename = hf_hub_download(
36
- repo_id="wgb14/icefall-asr-gigaspeech-pruned-transducer-stateless2",
37
- filename="bpe.model",
38
- subfolder="data/lang_bpe_500",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  )
 
40
 
41
  return OfflineAsr(
42
  nn_model_filename=nn_model_filename,
43
  bpe_model_filename=bpe_model_filename,
44
  token_filename=None,
45
- decoding_method="greedy_search",
46
- num_active_paths=4,
47
  sample_rate=sample_rate,
48
  device="cpu",
49
  )
50
 
51
 
52
- @lru_cache(maxsize=1)
53
- def get_wenetspeech_pre_trained_model():
54
- nn_model_filename = hf_hub_download(
55
- repo_id="luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  filename="cpu_jit_epoch_10_avg_2_torch_1.7.1.pt",
57
- subfolder="exp",
58
  )
 
59
 
60
- token_filename = hf_hub_download(
61
- repo_id="luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2",
62
- filename="tokens.txt",
63
- subfolder="data/lang_char",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  )
 
65
 
66
  return OfflineAsr(
67
  nn_model_filename=nn_model_filename,
68
  bpe_model_filename=None,
69
  token_filename=token_filename,
70
- decoding_method="greedy_search",
71
- num_active_paths=4,
72
  sample_rate=sample_rate,
73
  device="cpu",
74
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  sample_rate = 16000
24
 
25
 
26
+ @lru_cache(maxsize=30)
27
+ def get_pretrained_model(repo_id: str) -> OfflineAsr:
28
+ if repo_id in chinese_models:
29
+ return chinese_models[repo_id](repo_id)
30
+ elif repo_id in english_models:
31
+ return english_models[repo_id](repo_id)
32
+ elif repo_id in chinese_english_mixed_models:
33
+ chinese_english_mixed_models[repo_id](repo_id)
34
+ else:
35
+ raise ValueError(f"Unsupported repo_id: {repo_id}")
36
+
37
+
38
+ def _get_nn_model_filename(
39
+ repo_id: str,
40
+ filename: str,
41
+ subfolder: str = "exp",
42
+ ) -> str:
43
  nn_model_filename = hf_hub_download(
44
+ repo_id=repo_id,
45
+ filename=filename,
46
+ subfolder=subfolder,
 
47
  )
48
+ return nn_model_filename
49
+
50
 
51
+ def _get_bpe_model_filename(
52
+ repo_id: str,
53
+ filename: str = "bpe.model",
54
+ subfolder: str = "data/lang_bpe_500",
55
+ ) -> str:
56
  bpe_model_filename = hf_hub_download(
57
+ repo_id=repo_id,
58
+ filename=filename,
59
+ subfolder=subfolder,
60
+ )
61
+ return bpe_model_filename
62
+
63
+
64
+ def _get_token_filename(
65
+ repo_id: str,
66
+ filename: str = "tokens.txt",
67
+ subfolder: str = "data/lang_char",
68
+ ) -> str:
69
+ token_filename = hf_hub_download(
70
+ repo_id=repo_id,
71
+ filename=filename,
72
+ subfolder=subfolder,
73
+ )
74
+ return token_filename
75
+
76
+
77
+ @lru_cache(maxsize=10)
78
+ def _get_aishell2_pretrained_model(repo_id: str) -> OfflineAsr:
79
+ assert repo_id in [
80
+ # context-size 1
81
+ "yuekai/icefall-asr-aishell2-pruned-transducer-stateless5-A-2022-07-12", # noqa
82
+ # context-size 2
83
+ "yuekai/icefall-asr-aishell2-pruned-transducer-stateless5-B-2022-07-12", # noqa
84
+ ]
85
+
86
+ nn_model_filename = _get_nn_model_filename(
87
+ repo_id=repo_id,
88
+ filename="cpu_jit.pt",
89
+ )
90
+ token_filename = _get_token_filename(repo_id=repo_id)
91
+
92
+ return OfflineAsr(
93
+ nn_model_filename=nn_model_filename,
94
+ bpe_model_filename=None,
95
+ token_filename=token_filename,
96
+ sample_rate=sample_rate,
97
+ device="cpu",
98
+ )
99
+
100
+
101
+ @lru_cache(maxsize=10)
102
+ def _get_gigaspeech_pre_trained_model(repo_id: str) -> OfflineAsr:
103
+ assert repo_id in [
104
+ "wgb14/icefall-asr-gigaspeech-pruned-transducer-stateless2",
105
+ ]
106
+
107
+ nn_model_filename = _get_nn_model_filename(
108
+ # It is converted from https://huggingface.co/wgb14/icefall-asr-gigaspeech-pruned-transducer-stateless2 # noqa
109
+ repo_id="csukuangfj/icefall-asr-gigaspeech-pruned-transducer-stateless2", # noqa
110
+ filename="cpu_jit-epoch-29-avg-11-torch-1.10.0.pt",
111
  )
112
+ bpe_model_filename = _get_bpe_model_filename(repo_id=repo_id)
113
 
114
  return OfflineAsr(
115
  nn_model_filename=nn_model_filename,
116
  bpe_model_filename=bpe_model_filename,
117
  token_filename=None,
 
 
118
  sample_rate=sample_rate,
119
  device="cpu",
120
  )
121
 
122
 
123
+ @lru_cache(maxsize=10)
124
+ def _get_librispeech_pre_trained_model(repo_id: str) -> OfflineAsr:
125
+ assert repo_id in [
126
+ "csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13", # noqa
127
+ ]
128
+
129
+ nn_model_filename = _get_nn_model_filename(
130
+ repo_id=repo_id,
131
+ filename="cpu_jit.pt",
132
+ )
133
+ bpe_model_filename = _get_bpe_model_filename(repo_id=repo_id)
134
+
135
+ return OfflineAsr(
136
+ nn_model_filename=nn_model_filename,
137
+ bpe_model_filename=bpe_model_filename,
138
+ token_filename=None,
139
+ sample_rate=sample_rate,
140
+ device="cpu",
141
+ )
142
+
143
+
144
+ @lru_cache(maxsize=10)
145
+ def _get_wenetspeech_pre_trained_model(repo_id: str):
146
+ assert repo_id in [
147
+ "luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2",
148
+ ]
149
+
150
+ nn_model_filename = _get_nn_model_filename(
151
+ repo_id=repo_id,
152
  filename="cpu_jit_epoch_10_avg_2_torch_1.7.1.pt",
 
153
  )
154
+ token_filename = _get_token_filename(repo_id=repo_id)
155
 
156
+ return OfflineAsr(
157
+ nn_model_filename=nn_model_filename,
158
+ bpe_model_filename=None,
159
+ token_filename=token_filename,
160
+ sample_rate=sample_rate,
161
+ device="cpu",
162
+ )
163
+
164
+
165
+ @lru_cache(maxsize=10)
166
+ def _get_tal_csasr_pre_trained_model(repo_id: str):
167
+ assert repo_id in [
168
+ "luomingshuang/icefall_asr_tal-csasr_pruned_transducer_stateless5",
169
+ ]
170
+
171
+ nn_model_filename = _get_nn_model_filename(
172
+ repo_id=repo_id,
173
+ filename="cpu_jit.pt",
174
  )
175
+ token_filename = _get_token_filename(repo_id=repo_id)
176
 
177
  return OfflineAsr(
178
  nn_model_filename=nn_model_filename,
179
  bpe_model_filename=None,
180
  token_filename=token_filename,
 
 
181
  sample_rate=sample_rate,
182
  device="cpu",
183
  )
184
+
185
+
186
+ chinese_models = {
187
+ "yuekai/icefall-asr-aishell2-pruned-transducer-stateless5-A-2022-07-12": _get_aishell2_pretrained_model, # noqa
188
+ "yuekai/icefall-asr-aishell2-pruned-transducer-stateless5-B-2022-07-12": _get_aishell2_pretrained_model, # noqa
189
+ "luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2": _get_wenetspeech_pre_trained_model, # noqa
190
+ }
191
+
192
+ english_models = {
193
+ "wgb14/icefall-asr-gigaspeech-pruned-transducer-stateless2": _get_gigaspeech_pre_trained_model, # noqa
194
+ "csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13": _get_librispeech_pre_trained_model, # noqa
195
+ }
196
+
197
+ chinese_english_mixed_models = {
198
+ "luomingshuang/icefall_asr_tal-csasr_pruned_transducer_stateless5": _get_tal_csasr_pre_trained_model, # noqa
199
+ }
200
+
201
+ all_models = {
202
+ **chinese_models,
203
+ **english_models,
204
+ **chinese_english_mixed_models,
205
+ }
206
+
207
+ language_to_models = {
208
+ "Chinese": sorted(chinese_models.keys()),
209
+ "English": sorted(english_models.keys()),
210
+ "Chinese+English": sorted(chinese_english_mixed_models.keys()),
211
+ }
offline_asr.py CHANGED
@@ -206,10 +206,10 @@ class OfflineAsr(object):
206
  def __init__(
207
  self,
208
  nn_model_filename: str,
209
- bpe_model_filename: Optional[str],
210
- token_filename: Optional[str],
211
- decoding_method: str,
212
- num_active_paths: int,
213
  sample_rate: int = 16000,
214
  device: Union[str, torch.device] = "cpu",
215
  ):
@@ -246,6 +246,7 @@ class OfflineAsr(object):
246
  self.sp = spm.SentencePieceProcessor()
247
  self.sp.load(bpe_model_filename)
248
  else:
 
249
  self.token_table = k2.SymbolTable.from_file(token_filename)
250
 
251
  self.feature_extractor = self._build_feature_extractor(
 
206
  def __init__(
207
  self,
208
  nn_model_filename: str,
209
+ bpe_model_filename: Optional[str] = None,
210
+ token_filename: Optional[str] = None,
211
+ decoding_method: str = "greedy_search",
212
+ num_active_paths: int = 4,
213
  sample_rate: int = 16000,
214
  device: Union[str, torch.device] = "cpu",
215
  ):
 
246
  self.sp = spm.SentencePieceProcessor()
247
  self.sp.load(bpe_model_filename)
248
  else:
249
+ assert token_filename is not None, token_filename
250
  self.token_table = k2.SymbolTable.from_file(token_filename)
251
 
252
  self.feature_extractor = self._build_feature_extractor(