SalimaM commited on
Commit
51bc847
·
verified ·
1 Parent(s): 4514493

Upload 12 files

Browse files
CKPT.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # yamllint disable
2
+ COER: 35.85329341317365
3
+ end-of-epoch: true
4
+ unixtime: 1701399679.8773978
SLU2.py ADDED
@@ -0,0 +1,1345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Specifies the inference interfaces for Automatic speech Recognition (ASR) modules.
2
+
3
+ Authors:
4
+ * Aku Rouhe 2021
5
+ * Peter Plantinga 2021
6
+ * Loren Lugosch 2020
7
+ * Mirco Ravanelli 2020
8
+ * Titouan Parcollet 2021
9
+ * Abdel Heba 2021
10
+ * Andreas Nautsch 2022, 2023
11
+ * Pooneh Mousavi 2023
12
+ * Sylvain de Langen 2023, 2024
13
+ * Adel Moumen 2023, 2024
14
+ * Pradnya Kandarkar 2023
15
+ """
16
+
17
+ import functools
18
+ import itertools
19
+ from dataclasses import dataclass
20
+ from typing import Any, List, Optional, Tuple
21
+
22
+ import sentencepiece
23
+ import torch
24
+ import torchaudio
25
+ from tqdm import tqdm
26
+
27
+ import speechbrain
28
+ from speechbrain.inference.interfaces import Pretrained
29
+ from speechbrain.utils.data_utils import split_path
30
+ from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
31
+ from speechbrain.utils.fetching import fetch
32
+ from speechbrain.utils.streaming import split_fixed_chunks
33
+
34
+
35
+ class EncoderDecoderASR(Pretrained):
36
+ """A ready-to-use Encoder-Decoder ASR model
37
+
38
+ The class can be used either to run only the encoder (encode()) to extract
39
+ features or to run the entire encoder-decoder model
40
+ (transcribe()) to transcribe speech. The given YAML must contain the fields
41
+ specified in the *_NEEDED[] lists.
42
+
43
+ Arguments
44
+ ---------
45
+ *args : tuple
46
+ **kwargs : dict
47
+ Arguments are forwarded to ``Pretrained`` parent class.
48
+
49
+ Example
50
+ -------
51
+ >>> from speechbrain.inference.ASR import EncoderDecoderASR
52
+ >>> tmpdir = getfixture("tmpdir")
53
+ >>> asr_model = EncoderDecoderASR.from_hparams(
54
+ ... source="speechbrain/asr-crdnn-rnnlm-librispeech",
55
+ ... savedir=tmpdir,
56
+ ... ) # doctest: +SKIP
57
+ >>> asr_model.transcribe_file("tests/samples/single-mic/example2.flac") # doctest: +SKIP
58
+ "MY FATHER HAS REVEALED THE CULPRIT'S NAME"
59
+ """
60
+
61
+ HPARAMS_NEEDED = ["tokenizer"]
62
+ MODULES_NEEDED = ["encoder", "decoder"]
63
+
64
+ def __init__(self, *args, **kwargs):
65
+ super().__init__(*args, **kwargs)
66
+ self.tokenizer = self.hparams.tokenizer
67
+ self.transducer_beam_search = False
68
+ self.transformer_beam_search = False
69
+ if hasattr(self.hparams, "transducer_beam_search"):
70
+ self.transducer_beam_search = self.hparams.transducer_beam_search
71
+ if hasattr(self.hparams, "transformer_beam_search"):
72
+ self.transformer_beam_search = self.hparams.transformer_beam_search
73
+
74
+ def transcribe_file(self, path, **kwargs):
75
+ """Transcribes the given audiofile into a sequence of words.
76
+
77
+ Arguments
78
+ ---------
79
+ path : str
80
+ Path to audio file which to transcribe.
81
+ **kwargs : dict
82
+ Arguments forwarded to ``load_audio``.
83
+
84
+ Returns
85
+ -------
86
+ str
87
+ The audiofile transcription produced by this ASR system.
88
+ """
89
+ waveform = self.load_audio(path, **kwargs)
90
+ # Fake a batch:
91
+ batch = waveform.unsqueeze(0)
92
+ rel_length = torch.tensor([1.0])
93
+ predicted_words, predicted_tokens = self.transcribe_batch(
94
+ batch, rel_length
95
+ )
96
+ return predicted_words[0]
97
+
98
+ def encode_batch(self, wavs, wav_lens):
99
+ """Encodes the input audio into a sequence of hidden states
100
+
101
+ The waveforms should already be in the model's desired format.
102
+ You can call:
103
+ ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
104
+ to get a correctly converted signal in most cases.
105
+
106
+ Arguments
107
+ ---------
108
+ wavs : torch.Tensor
109
+ Batch of waveforms [batch, time, channels] or [batch, time]
110
+ depending on the model.
111
+ wav_lens : torch.Tensor
112
+ Lengths of the waveforms relative to the longest one in the
113
+ batch, tensor of shape [batch]. The longest one should have
114
+ relative length 1.0 and others len(waveform) / max_length.
115
+ Used for ignoring padding.
116
+
117
+ Returns
118
+ -------
119
+ torch.Tensor
120
+ The encoded batch
121
+ """
122
+ wavs = wavs.float()
123
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
124
+ encoder_out = self.mods.encoder(wavs, wav_lens)
125
+ if self.transformer_beam_search:
126
+ encoder_out = self.mods.transformer.encode(encoder_out, wav_lens)
127
+ return encoder_out
128
+
129
+ def transcribe_batch(self, wavs, wav_lens):
130
+ """Transcribes the input audio into a sequence of words
131
+
132
+ The waveforms should already be in the model's desired format.
133
+ You can call:
134
+ ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
135
+ to get a correctly converted signal in most cases.
136
+
137
+ Arguments
138
+ ---------
139
+ wavs : torch.Tensor
140
+ Batch of waveforms [batch, time, channels] or [batch, time]
141
+ depending on the model.
142
+ wav_lens : torch.Tensor
143
+ Lengths of the waveforms relative to the longest one in the
144
+ batch, tensor of shape [batch]. The longest one should have
145
+ relative length 1.0 and others len(waveform) / max_length.
146
+ Used for ignoring padding.
147
+
148
+ Returns
149
+ -------
150
+ list
151
+ Each waveform in the batch transcribed.
152
+ tensor
153
+ Each predicted token id.
154
+ """
155
+ with torch.no_grad():
156
+ wav_lens = wav_lens.to(self.device)
157
+ encoder_out = self.encode_batch(wavs, wav_lens)
158
+ if self.transducer_beam_search:
159
+ inputs = [encoder_out]
160
+ else:
161
+ inputs = [encoder_out, wav_lens]
162
+ predicted_tokens, _, _, _ = self.mods.decoder(*inputs)
163
+ predicted_words = [
164
+ self.tokenizer.decode_ids(token_seq)
165
+ for token_seq in predicted_tokens
166
+ ]
167
+ return predicted_words, predicted_tokens
168
+
169
+ def forward(self, wavs, wav_lens):
170
+ """Runs full transcription - note: no gradients through decoding"""
171
+ return self.transcribe_batch(wavs, wav_lens)
172
+
173
+
174
+ class EncoderASR(Pretrained):
175
+ """A ready-to-use Encoder ASR model
176
+
177
+ The class can be used either to run only the encoder (encode()) to extract
178
+ features or to run the entire encoder + decoder function model
179
+ (transcribe()) to transcribe speech. The given YAML must contain the fields
180
+ specified in the *_NEEDED[] lists.
181
+
182
+ Arguments
183
+ ---------
184
+ *args : tuple
185
+ **kwargs : dict
186
+ Arguments are forwarded to ``Pretrained`` parent class.
187
+
188
+ Example
189
+ -------
190
+ >>> from speechbrain.inference.ASR import EncoderASR
191
+ >>> tmpdir = getfixture("tmpdir")
192
+ >>> asr_model = EncoderASR.from_hparams(
193
+ ... source="speechbrain/asr-wav2vec2-commonvoice-fr",
194
+ ... savedir=tmpdir,
195
+ ... ) # doctest: +SKIP
196
+ >>> asr_model.transcribe_file("samples/audio_samples/example_fr.wav") # doctest: +SKIP
197
+ """
198
+
199
+ HPARAMS_NEEDED = ["tokenizer", "decoding_function"]
200
+ MODULES_NEEDED = ["encoder"]
201
+
202
+ def __init__(self, *args, **kwargs):
203
+ super().__init__(*args, **kwargs)
204
+
205
+ self.tokenizer = self.hparams.tokenizer
206
+ self.set_decoding_function()
207
+
208
+ def set_decoding_function(self):
209
+ """Set the decoding function based on the parameters defined in the hyperparameter file.
210
+
211
+ The decoding function is determined by the `decoding_function` specified in the hyperparameter file.
212
+ It can be either a functools.partial object representing a decoding function or an instance of
213
+ `speechbrain.decoders.ctc.CTCBaseSearcher` for beam search decoding.
214
+
215
+ Raises:
216
+ ValueError: If the decoding function is neither a functools.partial nor an instance of
217
+ speechbrain.decoders.ctc.CTCBaseSearcher.
218
+
219
+ Note:
220
+ - For greedy decoding (functools.partial), the provided `decoding_function` is assigned directly.
221
+ - For CTCBeamSearcher decoding, an instance of the specified `decoding_function` is created, and
222
+ additional parameters are added based on the tokenizer type.
223
+ """
224
+ # Greedy Decoding case
225
+ if isinstance(self.hparams.decoding_function, functools.partial):
226
+ self.decoding_function = self.hparams.decoding_function
227
+ # CTCBeamSearcher case
228
+ else:
229
+ # 1. check if the decoding function is an instance of speechbrain.decoders.CTCBaseSearcher
230
+ if issubclass(
231
+ self.hparams.decoding_function,
232
+ speechbrain.decoders.ctc.CTCBaseSearcher,
233
+ ):
234
+ # If so, we need to retrieve the vocab list from the tokenizer.
235
+ # We also need to check if the tokenizer is a sentencepiece or a CTCTextEncoder.
236
+ if isinstance(
237
+ self.tokenizer, speechbrain.dataio.encoder.CTCTextEncoder
238
+ ):
239
+ ind2lab = self.tokenizer.ind2lab
240
+ vocab_list = [ind2lab[x] for x in range(len(ind2lab))]
241
+ elif isinstance(
242
+ self.tokenizer, sentencepiece.SentencePieceProcessor
243
+ ):
244
+ vocab_list = [
245
+ self.tokenizer.id_to_piece(i)
246
+ for i in range(self.tokenizer.vocab_size())
247
+ ]
248
+ else:
249
+ raise ValueError(
250
+ "The tokenizer must be sentencepiece or CTCTextEncoder"
251
+ )
252
+
253
+ # We can now instantiate the decoding class and add all the parameters
254
+ if hasattr(self.hparams, "test_beam_search"):
255
+ opt_beam_search_params = self.hparams.test_beam_search
256
+ # check if the kenlm_model_path is provided and fetch it if necessary
257
+ if "kenlm_model_path" in opt_beam_search_params:
258
+ source, fl = split_path(
259
+ opt_beam_search_params["kenlm_model_path"]
260
+ )
261
+ kenlm_model_path = str(
262
+ fetch(
263
+ fl, source=source, savedir=self.hparams.savedir
264
+ )
265
+ )
266
+ # we need to update the kenlm_model_path in the opt_beam_search_params
267
+ opt_beam_search_params["kenlm_model_path"] = (
268
+ kenlm_model_path
269
+ )
270
+ else:
271
+ opt_beam_search_params = {}
272
+ self.decoding_function = self.hparams.decoding_function(
273
+ **opt_beam_search_params, vocab_list=vocab_list
274
+ )
275
+ else:
276
+ raise ValueError(
277
+ "The decoding function must be an instance of speechbrain.decoders.CTCBaseSearcher"
278
+ )
279
+
280
+ def transcribe_file(self, path, **kwargs):
281
+ """Transcribes the given audiofile into a sequence of words.
282
+
283
+ Arguments
284
+ ---------
285
+ path : str
286
+ Path to audio file which to transcribe.
287
+ **kwargs : dict
288
+ Arguments forwarded to ``load_audio``.
289
+
290
+ Returns
291
+ -------
292
+ str
293
+ The audiofile transcription produced by this ASR system.
294
+ """
295
+ waveform = self.load_audio(path, **kwargs)
296
+ # Fake a batch:
297
+ batch = waveform.unsqueeze(0)
298
+ rel_length = torch.tensor([1.0])
299
+ predicted_words, predicted_tokens = self.transcribe_batch(
300
+ batch, rel_length
301
+ )
302
+ return str(predicted_words[0])
303
+
304
+ def encode_batch(self, wavs, wav_lens):
305
+ """Encodes the input audio into a sequence of hidden states
306
+
307
+ The waveforms should already be in the model's desired format.
308
+ You can call:
309
+ ``normalized = EncoderASR.normalizer(signal, sample_rate)``
310
+ to get a correctly converted signal in most cases.
311
+
312
+ Arguments
313
+ ---------
314
+ wavs : torch.Tensor
315
+ Batch of waveforms [batch, time, channels] or [batch, time]
316
+ depending on the model.
317
+ wav_lens : torch.Tensor
318
+ Lengths of the waveforms relative to the longest one in the
319
+ batch, tensor of shape [batch]. The longest one should have
320
+ relative length 1.0 and others len(waveform) / max_length.
321
+ Used for ignoring padding.
322
+
323
+ Returns
324
+ -------
325
+ torch.Tensor
326
+ The encoded batch
327
+ """
328
+ wavs = wavs.float()
329
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
330
+ encoder_out = self.mods.wav2vec(wavs, wav_lens)
331
+ x = self.mods.dec(encoder_out)
332
+ logits = self.mods.output_lin(x)
333
+ p_ctc = self.hparams.softmax(logits)
334
+ return p_ctc
335
+
336
+ def transcribe_batch(self, wavs, wav_lens):
337
+ """Transcribes the input audio into a sequence of words
338
+
339
+ The waveforms should already be in the model's desired format.
340
+ You can call:
341
+ ``normalized = EncoderASR.normalizer(signal, sample_rate)``
342
+ to get a correctly converted signal in most cases.
343
+
344
+ Arguments
345
+ ---------
346
+ wavs : torch.Tensor
347
+ Batch of waveforms [batch, time, channels] or [batch, time]
348
+ depending on the model.
349
+ wav_lens : torch.Tensor
350
+ Lengths of the waveforms relative to the longest one in the
351
+ batch, tensor of shape [batch]. The longest one should have
352
+ relative length 1.0 and others len(waveform) / max_length.
353
+ Used for ignoring padding.
354
+
355
+ Returns
356
+ -------
357
+ list
358
+ Each waveform in the batch transcribed.
359
+ tensor
360
+ Each predicted token id.
361
+ """
362
+ with torch.no_grad():
363
+ wav_lens = wav_lens.to(self.device)
364
+ encoder_out = self.encode_batch(wavs, wav_lens)
365
+ predictions = self.decoding_function(encoder_out, wav_lens)
366
+ print(predictions)
367
+ is_ctc_text_encoder_tokenizer = isinstance(
368
+ self.tokenizer, speechbrain.dataio.encoder.CTCTextEncoder
369
+ )
370
+ self.tokenizer.load('sample_data/SLU/labelencoder.txt')
371
+ if isinstance(self.hparams.decoding_function, functools.partial):
372
+ if is_ctc_text_encoder_tokenizer:
373
+ predicted_words = [
374
+ "".join(self.tokenizer.decode_ndim(token_seq))
375
+ for token_seq in predictions
376
+ ]
377
+ else:
378
+ predicted_words = [
379
+ self.tokenizer.decode_ids(token_seq)
380
+ for token_seq in predictions
381
+ ]
382
+ else:
383
+ predicted_words = [hyp[0].text for hyp in predictions]
384
+
385
+ return predicted_words, predictions
386
+
387
+ def forward(self, wavs, wav_lens):
388
+ """Runs the encoder"""
389
+ return self.encode_batch(wavs, wav_lens)
390
+
391
+
392
+ @dataclass
393
+ class ASRWhisperSegment:
394
+ """A single chunk of audio for Whisper ASR streaming.
395
+
396
+ This object is intended to be mutated as streaming progresses and passed across calls
397
+ to the lower-level APIs such as `encode_chunk`, `decode_chunk`, etc.
398
+
399
+ Attributes
400
+ ----------
401
+ start : float
402
+ The start time of the audio chunk.
403
+ end : float
404
+ The end time of the audio chunk.
405
+ chunk : torch.Tensor
406
+ The audio chunk, shape [time, channels].
407
+ lang_id : str
408
+ The language identifier associated with the audio chunk.
409
+ words : str
410
+ The predicted words for the audio chunk.
411
+ tokens : List[int]
412
+ The predicted tokens for the audio chunk.
413
+ prompt : List[str]
414
+ The prompt associated with the audio chunk.
415
+ avg_log_probs : float
416
+ The average log probability associated with the prediction.
417
+ no_speech_prob : float
418
+ The probability of no speech in the audio chunk.
419
+ """
420
+
421
+ start: float
422
+ end: float
423
+ chunk: torch.Tensor
424
+ lang_id: Optional[str] = None
425
+ words: Optional[str] = None
426
+ tokens: Optional[List[str]] = None
427
+ prompt: Optional[List[str]] = None
428
+ avg_log_probs: Optional[float] = None
429
+ no_speech_prob: Optional[float] = None
430
+
431
+
432
+ class WhisperASR(Pretrained):
433
+ """A ready-to-use Whisper ASR model.
434
+
435
+ The class can be used to run the entire encoder-decoder whisper model.
436
+ The set of tasks supported are: ``transcribe``, ``translate``, and ``lang_id``.
437
+ The given YAML must contains the fields specified in the *_NEEDED[] lists.
438
+
439
+ Arguments
440
+ ---------
441
+ *args : tuple
442
+ **kwargs : dict
443
+ Arguments are forwarded to ``Pretrained`` parent class.
444
+
445
+ Example
446
+ -------
447
+ >>> from speechbrain.inference.ASR import WhisperASR
448
+ >>> tmpdir = getfixture("tmpdir")
449
+ >>> asr_model = WhisperASR.from_hparams(source="speechbrain/asr-whisper-medium-commonvoice-it", savedir=tmpdir,) # doctest: +SKIP
450
+ >>> hyp = asr_model.transcribe_file("speechbrain/asr-whisper-medium-commonvoice-it/example-it.wav") # doctest: +SKIP
451
+ >>> hyp # doctest: +SKIP
452
+ buongiorno a tutti e benvenuti a bordo
453
+ >>> _, probs = asr_model.detect_language_file("speechbrain/asr-whisper-medium-commonvoice-it/example-it.wav") # doctest: +SKIP
454
+ >>> print(f"Detected language: {max(probs[0], key=probs[0].get)}") # doctest: +SKIP
455
+ Detected language: it
456
+ """
457
+
458
+ HPARAMS_NEEDED = ["language", "sample_rate"]
459
+ MODULES_NEEDED = ["whisper", "decoder"]
460
+ TASKS = ["transcribe", "translate", "lang_id"]
461
+
462
+ def __init__(self, *args, **kwargs):
463
+ super().__init__(*args, **kwargs)
464
+ self.tokenizer = self.hparams.whisper.tokenizer
465
+
466
+ @torch.no_grad()
467
+ def detect_language_file(self, path: str):
468
+ """Detects the language of the given audiofile.
469
+ This method only works on input_file of 30 seconds or less.
470
+
471
+ Arguments
472
+ ---------
473
+ path : str
474
+ Path to audio file which to transcribe.
475
+
476
+ Returns
477
+ -------
478
+ language_tokens : torch.Tensor
479
+ The detected language tokens.
480
+ language_probs : dict
481
+ The probabilities of the detected language tokens.
482
+
483
+ Raises
484
+ ------
485
+ ValueError
486
+ If the model doesn't have language tokens.
487
+ """
488
+ wavs = self.load_audio(path).float().to(self.device).unsqueeze(0)
489
+ mel = self.mods.whisper._get_mel(wavs)
490
+ language_tokens, language_probs = self.mods.whisper.detect_language(mel)
491
+ return language_tokens, language_probs
492
+
493
+ @torch.no_grad()
494
+ def detect_language_batch(self, wav: torch.Tensor):
495
+ """Detects the language of the given wav Tensor.
496
+ This method only works on wav files of 30 seconds or less.
497
+
498
+ Arguments
499
+ ---------
500
+ wav : torch.tensor
501
+ Batch of waveforms [batch, time, channels].
502
+
503
+ Returns
504
+ -------
505
+ language_tokens : torch.Tensor of shape (batch_size,)
506
+ ids of the most probable language tokens, which appears after the startoftranscript token.
507
+ language_probs : List[Dict[str, float]]
508
+ list of dictionaries containing the probability distribution over all languages.
509
+
510
+ Raises
511
+ ------
512
+ ValueError
513
+ If the model doesn't have language tokens.
514
+
515
+ Example
516
+ -------
517
+ >>> from speechbrain.inference.ASR import WhisperASR
518
+ >>> import torchaudio
519
+ >>> tmpdir = getfixture("tmpdir")
520
+ >>> asr_model = WhisperASR.from_hparams(
521
+ ... source="speechbrain/asr-whisper-medium-commonvoice-it",
522
+ ... savedir=tmpdir,
523
+ ... ) # doctest: +SKIP
524
+ >>> wav, _ = torchaudio.load("your_audio") # doctest: +SKIP
525
+ >>> language_tokens, language_probs = asr_model.detect_language(wav) # doctest: +SKIP
526
+ """
527
+ mel = self.mods.whisper._get_mel(wav)
528
+ language_tokens, language_probs = self.mods.whisper.detect_language(mel)
529
+ return language_tokens, language_probs
530
+
531
+ @torch.no_grad()
532
+ def _detect_language(self, mel: torch.Tensor, task: str):
533
+ """Detects the language of the given mel spectrogram.
534
+
535
+ Arguments
536
+ ---------
537
+ mel : torch.tensor
538
+ Batch of mel spectrograms [batch, time, channels].
539
+ task : str
540
+ The task to perform.
541
+
542
+ Returns
543
+ -------
544
+ language_tokens : Tensor, shape = (n_audio,)
545
+ ids of the most probable language tokens, which appears after the startoftranscript token.
546
+ language_probs : List[Dict[str, float]], length = n_audio
547
+ list of dictionaries containing the probability distribution over all languages.
548
+ """
549
+ languages = [self.mods.whisper.language] * mel.shape[0]
550
+ lang_probs = None
551
+
552
+ if self.mods.whisper.language is None or task == "lang_id":
553
+ lang_tokens, lang_probs = self.mods.whisper.detect_language(mel)
554
+ languages = [max(probs, key=probs.get) for probs in lang_probs]
555
+ self.mods.decoder.set_lang_tokens(lang_tokens)
556
+ return languages, lang_probs
557
+
558
+ def _get_audio_stream(
559
+ self, streamer: "torchaudio.io.StreamReader", frames_per_chunk: int
560
+ ):
561
+ """From a :class:`torchaudio.io.StreamReader`, identifies the audio
562
+ stream and returns an iterable stream of chunks (after resampling and
563
+ downmixing to mono).
564
+
565
+ Arguments
566
+ ---------
567
+ streamer : torchaudio.io.StreamReader
568
+ The stream object. Must hold exactly one source stream of an
569
+ audio type.
570
+ frames_per_chunk : int
571
+ The number of frames per chunk. For a streaming model, this should
572
+ be determined from the DynChunkTrain configuration.
573
+
574
+ Yields
575
+ ------
576
+ chunks from streamer
577
+ """
578
+
579
+ stream_infos = [
580
+ streamer.get_src_stream_info(i)
581
+ for i in range(streamer.num_src_streams)
582
+ ]
583
+
584
+ audio_stream_infos = [
585
+ (i, stream_info)
586
+ for i, stream_info in enumerate(stream_infos)
587
+ if stream_info.media_type == "audio"
588
+ ]
589
+
590
+ if len(audio_stream_infos) != 1:
591
+ raise ValueError(
592
+ f"Expected stream to have only 1 stream (with any number of channels), got {len(audio_stream_infos)} (with streams: {stream_infos})"
593
+ )
594
+
595
+ # find the index of the first (and only) audio stream
596
+ audio_stream_index = audio_stream_infos[0][0]
597
+
598
+ # output stream #0
599
+ streamer.add_basic_audio_stream(
600
+ frames_per_chunk=frames_per_chunk,
601
+ stream_index=audio_stream_index,
602
+ sample_rate=self.audio_normalizer.sample_rate,
603
+ format="fltp", # torch.float32
604
+ num_channels=1,
605
+ )
606
+
607
+ for (chunk,) in streamer.stream():
608
+ chunk = chunk.squeeze(-1) # we deal with mono, remove that dim
609
+ chunk = chunk.unsqueeze(0) # create a fake batch dim
610
+ yield chunk
611
+
612
+ @torch.no_grad()
613
+ def transcribe_file_streaming(
614
+ self,
615
+ path: str,
616
+ task: Optional[str] = None,
617
+ initial_prompt: Optional[str] = None,
618
+ logprob_threshold: Optional[float] = -1.0,
619
+ no_speech_threshold=0.6,
620
+ condition_on_previous_text: bool = False,
621
+ verbose: bool = False,
622
+ use_torchaudio_streaming: bool = False,
623
+ chunk_size: Optional[int] = 30,
624
+ **kwargs,
625
+ ):
626
+ """Transcribes the given audiofile into a sequence of words.
627
+ This method supports the following tasks: ``transcribe``, ``translate``, and ``lang_id``.
628
+ It can process an input audio file longer than 30 seconds by splitting it into chunk_size-second segments.
629
+
630
+ Arguments
631
+ ---------
632
+ path : str
633
+ URI/path to the audio to transcribe. When
634
+ ``use_torchaudio_streaming`` is ``False``, uses SB fetching to allow
635
+ fetching from HF or a local file. When ``True``, resolves the URI
636
+ through ffmpeg, as documented in
637
+ :class:`torchaudio.io.StreamReader`.
638
+ task : Optional[str]
639
+ The task to perform. If None, the default task is the one passed in the Whisper model.
640
+ initial_prompt : Optional[str]
641
+ The initial prompt to condition the model on.
642
+ logprob_threshold : Optional[float]
643
+ The log probability threshold to continue decoding the current segment.
644
+ no_speech_threshold : float
645
+ The threshold to skip decoding segment if the no_speech_prob is higher than this value.
646
+ condition_on_previous_text : bool
647
+ If True, the model will be condition on the last 224 tokens.
648
+ verbose : bool
649
+ If True, print the transcription of each segment.
650
+ use_torchaudio_streaming : bool
651
+ Whether the audio file can be loaded in a streaming fashion. If not,
652
+ transcription is still performed through chunks of audio, but the
653
+ entire audio file is fetched and loaded at once.
654
+ This skips the usual fetching method and instead resolves the URI
655
+ using torchaudio (via ffmpeg).
656
+ chunk_size : Optional[int]
657
+ The size of the chunks to split the audio into. The default
658
+ chunk size is 30 seconds which corresponds to the maximal length
659
+ that the model can process in one go.
660
+ **kwargs : dict
661
+ Arguments forwarded to ``load_audio``
662
+
663
+ Yields
664
+ ------
665
+ ASRWhisperSegment
666
+ A new ASRWhisperSegment instance initialized with the provided parameters.
667
+ """
668
+ if task is not None:
669
+ if task in self.TASKS:
670
+ if task != "lang_id":
671
+ self.mods.decoder.set_task(task)
672
+ else:
673
+ raise ValueError(
674
+ f"Task {task} not supported. Supported tasks are {self.TASKS}"
675
+ )
676
+
677
+ # create chunks of chunk_size seconds
678
+ num_frames_per_chunk = chunk_size * self.hparams.sample_rate
679
+ if use_torchaudio_streaming:
680
+ streamer = torchaudio.io.StreamReader(path)
681
+ segments = self._get_audio_stream(streamer, num_frames_per_chunk)
682
+ else:
683
+ waveform = self.load_audio(path, **kwargs)
684
+ batch = waveform.unsqueeze(0)
685
+ segments = split_fixed_chunks(batch, num_frames_per_chunk)
686
+
687
+ rel_length = torch.tensor([1.0])
688
+
689
+ all_tokens = []
690
+ prompt_reset_since = 0
691
+ if initial_prompt is not None:
692
+ initial_prompt_tokens = self.whisper.tokenizer.encode(
693
+ " " + initial_prompt.strip()
694
+ )
695
+ all_tokens.extend(initial_prompt_tokens)
696
+ else:
697
+ initial_prompt_tokens = []
698
+
699
+ for i, segment in enumerate(tqdm(segments, disable=verbose)):
700
+ # move the segment on the device
701
+ segment = segment.to(self.device)
702
+
703
+ # extract mel spectrogram
704
+ mel_segment = self.mods.whisper._get_mel(segment)
705
+
706
+ start = i * chunk_size
707
+ end = (i + 1) * chunk_size
708
+
709
+ encoder_out = self.mods.whisper.forward_encoder(mel_segment)
710
+ languages, _ = self._detect_language(mel_segment, task)
711
+
712
+ if task == "lang_id":
713
+ yield ASRWhisperSegment(
714
+ start=start,
715
+ end=end,
716
+ chunk=segment,
717
+ lang_id=languages[0],
718
+ )
719
+ continue
720
+
721
+ prompt = all_tokens[prompt_reset_since:]
722
+ self.mods.decoder.set_prompt(prompt)
723
+
724
+ predicted_tokens, _, scores, _ = self.mods.decoder(
725
+ encoder_out, rel_length
726
+ )
727
+ avg_log_probs = scores.sum() / (len(predicted_tokens[0]) + 1)
728
+
729
+ if no_speech_threshold is not None:
730
+ should_skip = (
731
+ self.mods.decoder.no_speech_probs[0] > no_speech_threshold
732
+ )
733
+ if (
734
+ logprob_threshold is not None
735
+ and avg_log_probs > logprob_threshold
736
+ ):
737
+ # don't skip if the logprob is high enough, despite the no_speech_prob
738
+ should_skip = False
739
+
740
+ if should_skip:
741
+ yield ASRWhisperSegment(
742
+ start=start,
743
+ end=end,
744
+ chunk=segment,
745
+ lang_id=languages[0],
746
+ words="",
747
+ tokens=[],
748
+ prompt=prompt,
749
+ avg_log_probs=avg_log_probs.item(),
750
+ no_speech_prob=self.mods.decoder.no_speech_probs[0],
751
+ )
752
+ continue
753
+
754
+ predicted_words = [
755
+ self.tokenizer.decode(t, skip_special_tokens=True).strip()
756
+ for t in predicted_tokens
757
+ ]
758
+
759
+ yield ASRWhisperSegment(
760
+ start=start,
761
+ end=end,
762
+ chunk=segment,
763
+ lang_id=languages[0],
764
+ words=predicted_words[0],
765
+ tokens=predicted_tokens[0],
766
+ prompt=prompt,
767
+ avg_log_probs=avg_log_probs.item(),
768
+ no_speech_prob=self.mods.decoder.no_speech_probs[0],
769
+ )
770
+
771
+ all_tokens.extend(predicted_tokens[0])
772
+
773
+ if (
774
+ not condition_on_previous_text
775
+ or self.mods.decoder.temperature > 0.5
776
+ ):
777
+ prompt_reset_since = len(all_tokens)
778
+
779
+ def transcribe_file(
780
+ self,
781
+ path: str,
782
+ task: Optional[str] = None,
783
+ initial_prompt: Optional[str] = None,
784
+ logprob_threshold: Optional[float] = -1.0,
785
+ no_speech_threshold=0.6,
786
+ condition_on_previous_text: bool = False,
787
+ verbose: bool = False,
788
+ use_torchaudio_streaming: bool = False,
789
+ chunk_size: Optional[int] = 30,
790
+ **kwargs,
791
+ ) -> List[ASRWhisperSegment]:
792
+ """Run the Whisper model using the specified task on the given audio file and return the ``ASRWhisperSegment`` objects
793
+ for each segment.
794
+
795
+ This method supports the following tasks: ``transcribe``, ``translate``, and ``lang_id``.
796
+ It can process an input audio file longer than 30 seconds by splitting it into chunk_size-second segments.
797
+
798
+ Arguments
799
+ ---------
800
+ path : str
801
+ URI/path to the audio to transcribe. When
802
+ ``use_torchaudio_streaming`` is ``False``, uses SB fetching to allow
803
+ fetching from HF or a local file. When ``True``, resolves the URI
804
+ through ffmpeg, as documented in
805
+ :class:`torchaudio.io.StreamReader`.
806
+ task : Optional[str]
807
+ The task to perform. If None, the default task is the one passed in the Whisper model.
808
+ It can be one of the following: ``transcribe``, ``translate``, ``lang_id``.
809
+ initial_prompt : Optional[str]
810
+ The initial prompt to condition the model on.
811
+ logprob_threshold : Optional[float]
812
+ The log probability threshold to continue decoding the current segment.
813
+ no_speech_threshold : float
814
+ The threshold to skip decoding segment if the no_speech_prob is higher than this value.
815
+ condition_on_previous_text : bool
816
+ If True, the model will be condition on the last 224 tokens.
817
+ verbose : bool
818
+ If True, print the details of each segment.
819
+ use_torchaudio_streaming : bool
820
+ Whether the audio file can be loaded in a streaming fashion. If not,
821
+ transcription is still performed through chunks of audio, but the
822
+ entire audio file is fetched and loaded at once.
823
+ This skips the usual fetching method and instead resolves the URI
824
+ using torchaudio (via ffmpeg).
825
+ chunk_size : Optional[int]
826
+ The size of the chunks to split the audio into. The default
827
+ chunk size is 30 seconds which corresponds to the maximal length
828
+ that the model can process in one go.
829
+ **kwargs : dict
830
+ Arguments forwarded to ``load_audio``
831
+
832
+ Returns
833
+ -------
834
+ results : list
835
+ A list of ``WhisperASRChunk`` objects, each containing the task result.
836
+ """
837
+ results = []
838
+ for whisper_segment in self.transcribe_file_streaming(
839
+ path,
840
+ task=task,
841
+ initial_prompt=initial_prompt,
842
+ logprob_threshold=logprob_threshold,
843
+ no_speech_threshold=no_speech_threshold,
844
+ condition_on_previous_text=condition_on_previous_text,
845
+ verbose=verbose,
846
+ use_torchaudio_streaming=use_torchaudio_streaming,
847
+ chunk_size=chunk_size,
848
+ **kwargs,
849
+ ):
850
+ results.append(whisper_segment)
851
+ if verbose:
852
+ pred = (
853
+ whisper_segment.words
854
+ if task != "lang_id"
855
+ else whisper_segment.lang_id
856
+ )
857
+ print(
858
+ f"[{whisper_segment.start}s --> {whisper_segment.end}s] {pred}"
859
+ )
860
+ return results
861
+
862
+ def encode_batch(self, wavs, wav_lens):
863
+ """Encodes the input audio into a sequence of hidden states
864
+
865
+ The waveforms should already be in the model's desired format.
866
+ You can call:
867
+ ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
868
+ to get a correctly converted signal in most cases.
869
+
870
+ Arguments
871
+ ---------
872
+ wavs : torch.tensor
873
+ Batch of waveforms [batch, time, channels].
874
+ wav_lens : torch.tensor
875
+ Lengths of the waveforms relative to the longest one in the
876
+ batch, tensor of shape [batch]. The longest one should have
877
+ relative length 1.0 and others len(waveform) / max_length.
878
+ Used for ignoring padding.
879
+
880
+ Returns
881
+ -------
882
+ torch.tensor
883
+ The encoded batch
884
+ """
885
+ wavs = wavs.to(device=self.device, dtype=torch.float32)
886
+ mel = self.mods.whisper._get_mel(wavs)
887
+ encoder_out = self.mods.whisper.forward_encoder(mel)
888
+ return encoder_out
889
+
890
+ @torch.no_grad()
891
+ def transcribe_batch(self, wavs, wav_lens):
892
+ """Transcribes the input audio into a sequence of words
893
+
894
+ The waveforms should already be in the model's desired format.
895
+ You can call:
896
+ ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
897
+ to get a correctly converted signal in most cases.
898
+
899
+ Arguments
900
+ ---------
901
+ wavs : torch.tensor
902
+ Batch of waveforms [batch, time, channels].
903
+ wav_lens : torch.tensor
904
+ Lengths of the waveforms relative to the longest one in the
905
+ batch, tensor of shape [batch]. The longest one should have
906
+ relative length 1.0 and others len(waveform) / max_length.
907
+ Used for ignoring padding.
908
+
909
+ Returns
910
+ -------
911
+ list
912
+ Each waveform in the batch transcribed.
913
+ tensor
914
+ Each predicted token id.
915
+ """
916
+ wav_lens = wav_lens.float().to(self.device)
917
+ encoder_out = self.encode_batch(wavs, wav_lens)
918
+ predicted_tokens, _, _, _ = self.mods.decoder(encoder_out, wav_lens)
919
+ predicted_words = [
920
+ self.tokenizer.decode(t, skip_special_tokens=True).strip()
921
+ for t in predicted_tokens
922
+ ]
923
+ if self.hparams.normalized_transcripts:
924
+ predicted_words = [
925
+ self.tokenizer.normalize(text).split(" ")
926
+ for text in predicted_words
927
+ ]
928
+
929
+ return predicted_words, predicted_tokens
930
+
931
+ def forward(self, wavs, wav_lens):
932
+ """Runs full transcription - note: no gradients through decoding"""
933
+ return self.transcribe_batch(wavs, wav_lens)
934
+
935
+
936
+ @dataclass
937
+ class ASRStreamingContext:
938
+ """Streaming metadata, initialized by
939
+ :meth:`~StreamingASR.make_streaming_context` (see there for details on
940
+ initialization of fields here).
941
+
942
+ This object is intended to be mutate: the same object should be passed
943
+ across calls as streaming progresses (namely when using the lower-level
944
+ :meth:`~StreamingASR.encode_chunk`, etc. APIs).
945
+
946
+ Holds some references to opaque streaming contexts, so the context is
947
+ model-agnostic to an extent."""
948
+
949
+ config: DynChunkTrainConfig
950
+ """Dynamic chunk training configuration used to initialize the streaming
951
+ context. Cannot be modified on the fly."""
952
+
953
+ fea_extractor_context: Any
954
+ """Opaque feature extractor streaming context."""
955
+
956
+ encoder_context: Any
957
+ """Opaque encoder streaming context."""
958
+
959
+ decoder_context: Any
960
+ """Opaque decoder streaming context."""
961
+
962
+ tokenizer_context: Optional[List[Any]]
963
+ """Opaque streaming context for the tokenizer. Initially `None`. Initialized
964
+ to a list of tokenizer contexts once batch size can be determined."""
965
+
966
+
967
+ class StreamingASR(Pretrained):
968
+ """A ready-to-use, streaming-capable ASR model.
969
+
970
+ Arguments
971
+ ---------
972
+ *args : tuple
973
+ **kwargs : dict
974
+ Arguments are forwarded to ``Pretrained`` parent class.
975
+
976
+ Example
977
+ -------
978
+ >>> from speechbrain.inference.ASR import StreamingASR
979
+ >>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
980
+ >>> tmpdir = getfixture("tmpdir")
981
+ >>> asr_model = StreamingASR.from_hparams(source="speechbrain/asr-conformer-streaming-librispeech", savedir=tmpdir,) # doctest: +SKIP
982
+ >>> asr_model.transcribe_file("speechbrain/asr-conformer-streaming-librispeech/test-en.wav", DynChunkTrainConfig(24, 8)) # doctest: +SKIP
983
+ """
984
+
985
+ HPARAMS_NEEDED = [
986
+ "fea_streaming_extractor",
987
+ "make_decoder_streaming_context",
988
+ "decoding_function",
989
+ "make_tokenizer_streaming_context",
990
+ "tokenizer_decode_streaming",
991
+ ]
992
+ MODULES_NEEDED = ["enc", "proj_enc"]
993
+
994
+ def __init__(self, *args, **kwargs):
995
+ super().__init__(*args, **kwargs)
996
+
997
+ self.filter_props = self.hparams.fea_streaming_extractor.properties
998
+
999
+ def _get_audio_stream(
1000
+ self, streamer: "torchaudio.io.StreamReader", frames_per_chunk: int
1001
+ ):
1002
+ """From a :class:`torchaudio.io.StreamReader`, identifies the audio
1003
+ stream and returns an iterable stream of chunks (after resampling and
1004
+ downmixing to mono).
1005
+
1006
+ Arguments
1007
+ ---------
1008
+ streamer : torchaudio.io.StreamReader
1009
+ The stream object. Must hold exactly one source stream of an
1010
+ audio type.
1011
+ frames_per_chunk : int
1012
+ The number of frames per chunk. For a streaming model, this should
1013
+ be determined from the DynChunkTrain configuration.
1014
+
1015
+ Yields
1016
+ ------
1017
+ chunks from streamer
1018
+ """
1019
+
1020
+ stream_infos = [
1021
+ streamer.get_src_stream_info(i)
1022
+ for i in range(streamer.num_src_streams)
1023
+ ]
1024
+
1025
+ audio_stream_infos = [
1026
+ (i, stream_info)
1027
+ for i, stream_info in enumerate(stream_infos)
1028
+ if stream_info.media_type == "audio"
1029
+ ]
1030
+
1031
+ if len(audio_stream_infos) != 1:
1032
+ raise ValueError(
1033
+ f"Expected stream to have only 1 stream (with any number of channels), got {len(audio_stream_infos)} (with streams: {stream_infos})"
1034
+ )
1035
+
1036
+ # find the index of the first (and only) audio stream
1037
+ audio_stream_index = audio_stream_infos[0][0]
1038
+
1039
+ # output stream #0
1040
+ streamer.add_basic_audio_stream(
1041
+ frames_per_chunk=frames_per_chunk,
1042
+ stream_index=audio_stream_index,
1043
+ sample_rate=self.audio_normalizer.sample_rate,
1044
+ format="fltp", # torch.float32
1045
+ num_channels=1,
1046
+ )
1047
+
1048
+ for (chunk,) in streamer.stream():
1049
+ chunk = chunk.squeeze(-1) # we deal with mono, remove that dim
1050
+ chunk = chunk.unsqueeze(0) # create a fake batch dim
1051
+ yield chunk
1052
+
1053
+ def transcribe_file_streaming(
1054
+ self,
1055
+ path,
1056
+ dynchunktrain_config: DynChunkTrainConfig,
1057
+ use_torchaudio_streaming: bool = True,
1058
+ **kwargs,
1059
+ ):
1060
+ """Transcribes the given audio file into a sequence of words, in a
1061
+ streaming fashion, meaning that text is being yield from this
1062
+ generator, in the form of strings to concatenate.
1063
+
1064
+ Arguments
1065
+ ---------
1066
+ path : str
1067
+ URI/path to the audio to transcribe. When
1068
+ ``use_torchaudio_streaming`` is ``False``, uses SB fetching to allow
1069
+ fetching from HF or a local file. When ``True``, resolves the URI
1070
+ through ffmpeg, as documented in
1071
+ :class:`torchaudio.io.StreamReader`.
1072
+ dynchunktrain_config : DynChunkTrainConfig
1073
+ Streaming configuration. Sane values and how much time chunks
1074
+ actually represent is model-dependent.
1075
+ use_torchaudio_streaming : bool
1076
+ Whether the audio file can be loaded in a streaming fashion. If not,
1077
+ transcription is still performed through chunks of audio, but the
1078
+ entire audio file is fetched and loaded at once.
1079
+ This skips the usual fetching method and instead resolves the URI
1080
+ using torchaudio (via ffmpeg).
1081
+ **kwargs : dict
1082
+ Arguments forwarded to ``load_audio``
1083
+
1084
+ Yields
1085
+ ------
1086
+ generator of str
1087
+ An iterator yielding transcribed chunks (strings). There is a yield
1088
+ for every chunk, even if the transcribed string for that chunk is an
1089
+ empty string.
1090
+ """
1091
+
1092
+ chunk_size = self.get_chunk_size_frames(dynchunktrain_config)
1093
+
1094
+ if use_torchaudio_streaming:
1095
+ streamer = torchaudio.io.StreamReader(path)
1096
+ chunks = self._get_audio_stream(streamer, chunk_size)
1097
+ else:
1098
+ waveform = self.load_audio(path, **kwargs)
1099
+ batch = waveform.unsqueeze(0) # create batch dim
1100
+ chunks = split_fixed_chunks(batch, chunk_size)
1101
+
1102
+ rel_length = torch.tensor([1.0])
1103
+ context = self.make_streaming_context(dynchunktrain_config)
1104
+
1105
+ final_chunks = [
1106
+ torch.zeros((1, chunk_size), device=self.device)
1107
+ ] * self.hparams.fea_streaming_extractor.get_recommended_final_chunk_count(
1108
+ chunk_size
1109
+ )
1110
+
1111
+ for chunk in itertools.chain(chunks, final_chunks):
1112
+ predicted_words = self.transcribe_chunk(context, chunk, rel_length)
1113
+ yield predicted_words[0]
1114
+
1115
+ def transcribe_file(
1116
+ self,
1117
+ path,
1118
+ dynchunktrain_config: DynChunkTrainConfig,
1119
+ use_torchaudio_streaming: bool = True,
1120
+ ):
1121
+ """Transcribes the given audio file into a sequence of words.
1122
+
1123
+ Arguments
1124
+ ---------
1125
+ path : str
1126
+ URI/path to the audio to transcribe. When
1127
+ ``use_torchaudio_streaming`` is ``False``, uses SB fetching to allow
1128
+ fetching from HF or a local file. When ``True``, resolves the URI
1129
+ through ffmpeg, as documented in
1130
+ :class:`torchaudio.io.StreamReader`.
1131
+ dynchunktrain_config : DynChunkTrainConfig
1132
+ Streaming configuration. Sane values and how much time chunks
1133
+ actually represent is model-dependent.
1134
+ use_torchaudio_streaming : bool
1135
+ Whether the audio file can be loaded in a streaming fashion. If not,
1136
+ transcription is still performed through chunks of audio, but the
1137
+ entire audio file is fetched and loaded at once.
1138
+ This skips the usual fetching method and instead resolves the URI
1139
+ using torchaudio (via ffmpeg).
1140
+
1141
+ Returns
1142
+ -------
1143
+ str
1144
+ The audio file transcription produced by this ASR system.
1145
+ """
1146
+
1147
+ pred = ""
1148
+
1149
+ for text_chunk in self.transcribe_file_streaming(
1150
+ path, dynchunktrain_config, use_torchaudio_streaming
1151
+ ):
1152
+ pred += text_chunk
1153
+
1154
+ return pred
1155
+
1156
+ def make_streaming_context(self, dynchunktrain_config: DynChunkTrainConfig):
1157
+ """Create a blank streaming context to be passed around for chunk
1158
+ encoding/transcription.
1159
+
1160
+ Arguments
1161
+ ---------
1162
+ dynchunktrain_config : DynChunkTrainConfig
1163
+ Streaming configuration. Sane values and how much time chunks
1164
+ actually represent is model-dependent.
1165
+
1166
+ Returns
1167
+ -------
1168
+ ASRStreamingContext
1169
+ """
1170
+
1171
+ return ASRStreamingContext(
1172
+ config=dynchunktrain_config,
1173
+ fea_extractor_context=self.hparams.fea_streaming_extractor.make_streaming_context(),
1174
+ encoder_context=self.mods.enc.make_streaming_context(
1175
+ dynchunktrain_config
1176
+ ),
1177
+ decoder_context=self.hparams.make_decoder_streaming_context(),
1178
+ tokenizer_context=None,
1179
+ )
1180
+
1181
+ def get_chunk_size_frames(
1182
+ self, dynchunktrain_config: DynChunkTrainConfig
1183
+ ) -> int:
1184
+ """Returns the chunk size in actual audio samples, i.e. the exact
1185
+ expected length along the time dimension of an input chunk tensor (as
1186
+ passed to :meth:`~StreamingASR.encode_chunk` and similar low-level
1187
+ streaming functions).
1188
+
1189
+ Arguments
1190
+ ---------
1191
+ dynchunktrain_config : DynChunkTrainConfig
1192
+ The streaming configuration to determine the chunk frame count of.
1193
+
1194
+ Returns
1195
+ -------
1196
+ chunk size
1197
+ """
1198
+
1199
+ return (self.filter_props.stride - 1) * dynchunktrain_config.chunk_size
1200
+
1201
+ @torch.no_grad()
1202
+ def encode_chunk(
1203
+ self,
1204
+ context: ASRStreamingContext,
1205
+ chunk: torch.Tensor,
1206
+ chunk_len: Optional[torch.Tensor] = None,
1207
+ ):
1208
+ """Encoding of a batch of audio chunks into a batch of encoded
1209
+ sequences.
1210
+ For full speech-to-text offline transcription, use `transcribe_batch` or
1211
+ `transcribe_file`.
1212
+ Must be called over a given context in the correct order of chunks over
1213
+ time.
1214
+
1215
+ Arguments
1216
+ ---------
1217
+ context : ASRStreamingContext
1218
+ Mutable streaming context object, which must be specified and reused
1219
+ across calls when streaming.
1220
+ You can obtain an initial context by calling
1221
+ `asr.make_streaming_context(config)`.
1222
+
1223
+ chunk : torch.Tensor
1224
+ The tensor for an audio chunk of shape `[batch size, time]`.
1225
+ The time dimension must strictly match
1226
+ `asr.get_chunk_size_frames(config)`.
1227
+ The waveform is expected to be in the model's expected format (i.e.
1228
+ the sampling rate must be correct).
1229
+
1230
+ chunk_len : torch.Tensor, optional
1231
+ The relative chunk length tensor of shape `[batch size]`. This is to
1232
+ be used when the audio in one of the chunks of the batch is ending
1233
+ within this chunk.
1234
+ If unspecified, equivalent to `torch.ones((batch_size,))`.
1235
+
1236
+ Returns
1237
+ -------
1238
+ torch.Tensor
1239
+ Encoded output, of a model-dependent shape."""
1240
+
1241
+ if chunk_len is None:
1242
+ chunk_len = torch.ones((chunk.size(0),))
1243
+
1244
+ chunk = chunk.float()
1245
+ chunk, chunk_len = chunk.to(self.device), chunk_len.to(self.device)
1246
+
1247
+ assert chunk.shape[-1] <= self.get_chunk_size_frames(context.config)
1248
+
1249
+ x = self.hparams.fea_streaming_extractor(
1250
+ chunk, context=context.fea_extractor_context, lengths=chunk_len
1251
+ )
1252
+ x = self.mods.enc.forward_streaming(x, context.encoder_context)
1253
+ x = self.mods.proj_enc(x)
1254
+ return x
1255
+
1256
+ @torch.no_grad()
1257
+ def decode_chunk(
1258
+ self, context: ASRStreamingContext, x: torch.Tensor
1259
+ ) -> Tuple[List[str], List[List[int]]]:
1260
+ """Decodes the output of the encoder into tokens and the associated
1261
+ transcription.
1262
+ Must be called over a given context in the correct order of chunks over
1263
+ time.
1264
+
1265
+ Arguments
1266
+ ---------
1267
+ context : ASRStreamingContext
1268
+ Mutable streaming context object, which should be the same object
1269
+ that was passed to `encode_chunk`.
1270
+
1271
+ x : torch.Tensor
1272
+ The output of `encode_chunk` for a given chunk.
1273
+
1274
+ Returns
1275
+ -------
1276
+ list of str
1277
+ Decoded tokens of length `batch_size`. The decoded strings can be
1278
+ of 0-length.
1279
+ list of list of output token hypotheses
1280
+ List of length `batch_size`, each holding a list of tokens of any
1281
+ length `>=0`.
1282
+ """
1283
+ tokens = self.hparams.decoding_function(x, context.decoder_context)
1284
+
1285
+ # initialize token context for real now that we know the batch size
1286
+ if context.tokenizer_context is None:
1287
+ context.tokenizer_context = [
1288
+ self.hparams.make_tokenizer_streaming_context()
1289
+ for _ in range(len(tokens))
1290
+ ]
1291
+
1292
+ words = [
1293
+ self.hparams.tokenizer_decode_streaming(
1294
+ self.hparams.tokenizer, cur_tokens, context.tokenizer_context[i]
1295
+ )
1296
+ for i, cur_tokens in enumerate(tokens)
1297
+ ]
1298
+
1299
+ return words, tokens
1300
+
1301
+ def transcribe_chunk(
1302
+ self,
1303
+ context: ASRStreamingContext,
1304
+ chunk: torch.Tensor,
1305
+ chunk_len: Optional[torch.Tensor] = None,
1306
+ ):
1307
+ """Transcription of a batch of audio chunks into transcribed text.
1308
+ Must be called over a given context in the correct order of chunks over
1309
+ time.
1310
+
1311
+ Arguments
1312
+ ---------
1313
+ context : ASRStreamingContext
1314
+ Mutable streaming context object, which must be specified and reused
1315
+ across calls when streaming.
1316
+ You can obtain an initial context by calling
1317
+ `asr.make_streaming_context(config)`.
1318
+ chunk : torch.Tensor
1319
+ The tensor for an audio chunk of shape `[batch size, time]`.
1320
+ The time dimension must strictly match
1321
+ `asr.get_chunk_size_frames(config)`.
1322
+ The waveform is expected to be in the model's expected format (i.e.
1323
+ the sampling rate must be correct).
1324
+ chunk_len : torch.Tensor, optional
1325
+ The relative chunk length tensor of shape `[batch size]`. This is to
1326
+ be used when the audio in one of the chunks of the batch is ending
1327
+ within this chunk.
1328
+ If unspecified, equivalent to `torch.ones((batch_size,))`.
1329
+
1330
+ Returns
1331
+ -------
1332
+ str
1333
+ Transcribed string for this chunk, might be of length zero.
1334
+ """
1335
+
1336
+ if chunk_len is None:
1337
+ chunk_len = torch.ones((chunk.size(0),))
1338
+
1339
+ chunk = chunk.float()
1340
+ chunk, chunk_len = chunk.to(self.device), chunk_len.to(self.device)
1341
+
1342
+ x = self.encode_chunk(context, chunk, chunk_len)
1343
+ words, _ = self.decode_chunk(context, x)
1344
+
1345
+ return words
brain.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33809a026a2c1febce7b03c8aafaee4ddfc851b2c70f180f8c06bf1017f4df5c
3
+ size 46
counter.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95aebc97bc646c67fdcd923a5965b001f3c8a5c4d3a77075112e12a3a311d760
3
+ size 3
hyperparams.yaml ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data parameters:
2
+ # With data_parallel batch_size is split into N jobs.
3
+ # With DDP batch_size is multiplied by N jobs.
4
+ batch_size: 6
5
+ test_batch_size: 2
6
+ # We remove utterances longer than 90s in the train/dev/test sets as
7
+ # longer sentences certainly correspond to "open microphones".
8
+ avoid_if_longer_than: 90.0
9
+ avoid_if_smaller_than: 0.0
10
+ dataloader_options:
11
+ batch_size: 6
12
+ num_workers: 6
13
+ shuffle: true
14
+ test_dataloader_options:
15
+ batch_size: 2
16
+ num_workers: 3
17
+
18
+ # Feature parameters:
19
+ sample_rate: 16000
20
+ feats_dim: 1024
21
+
22
+ # Training parameters:
23
+ number_of_epochs: 80
24
+ lr: 1
25
+ lr_wav2vec: 0.0001
26
+ annealing_factor: 0.8
27
+ annealing_factor_wav2vec: 0.9
28
+ improvement_threshold: 0.0025
29
+ improvement_threshold_wav2vec: 0.0025
30
+ patient: 0
31
+ patient_wav2vec: 0
32
+ sorting: random
33
+
34
+ # Model parameters:
35
+ activation: &id001 !name:torch.nn.LeakyReLU
36
+ dropout: 0.15
37
+ cnn_blocks: 0
38
+ rnn_layers: 0
39
+ dnn_blocks: 1
40
+ rnn_neurons: 0
41
+ dnn_neurons: 1024
42
+
43
+ # Wav2Vec parameters:
44
+ freeze: false
45
+
46
+ # Decoding parameters:
47
+ blank_index: 0
48
+
49
+ # Outputs:
50
+ output_neurons: 113
51
+
52
+ # ------ Functions and classes
53
+
54
+ epoch_counter: &id008 !new:speechbrain.utils.epoch_loop.EpochCounter
55
+
56
+ limit: 80
57
+
58
+ wav2vec: &id002 !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
59
+ source: microsoft/wavlm-large
60
+ output_norm: true
61
+ freeze: false
62
+ save_path: results/TARIC_SLU_wav2vec_wavLM_with_intent_criterion_a100_copie/1212/save/wav2vec.pt
63
+
64
+ dec: &id003 !new:speechbrain.lobes.models.VanillaNN.VanillaNN
65
+ input_shape: [null, null, 1024]
66
+ activation: *id001
67
+ dnn_blocks: 1
68
+ dnn_neurons: 1024
69
+
70
+ output_lin: &id004 !new:speechbrain.nnet.linear.Linear
71
+
72
+ input_size: 1024
73
+ n_neurons: 113
74
+ bias: true
75
+
76
+ softmax: !new:speechbrain.nnet.activations.Softmax
77
+ apply_log: true
78
+
79
+ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
80
+ blank_index: 0
81
+
82
+ modules:
83
+ wav2vec: *id002
84
+ dec: *id003
85
+ output_lin: *id004
86
+ model: &id005 !new:torch.nn.ModuleList
87
+ - [*id003, *id004]
88
+ model_wav2vec: !new:torch.nn.ModuleList
89
+ - [*id002]
90
+ opt_class: !name:torch.optim.Adadelta
91
+ lr: 1
92
+ rho: 0.95
93
+ eps: 1.e-8
94
+
95
+ opt_class_wav2vec: !name:torch.optim.Adam
96
+ lr: 0.0001
97
+
98
+ lr_annealing: &id006 !new:speechbrain.nnet.schedulers.NewBobScheduler
99
+ initial_value: 1
100
+ improvement_threshold: 0.0025
101
+ annealing_factor: 0.8
102
+ patient: 0
103
+
104
+ lr_annealing_wav2vec: &id007 !new:speechbrain.nnet.schedulers.NewBobScheduler
105
+ initial_value: 0.0001
106
+ improvement_threshold: 0.0025
107
+ annealing_factor: 0.9
108
+ patient: 0
109
+
110
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
111
+ checkpoints_dir: results/TARIC_SLU_wav2vec_wavLM_with_intent_criterion_a100_copie/1212/save
112
+ recoverables:
113
+ model: *id005
114
+ wav2vec: *id002
115
+ lr_annealing: *id006
116
+ lr_annealing_wav2vec: *id007
117
+ counter: *id008
118
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
119
+ save_file: results/TARIC_SLU_wav2vec_wavLM_with_intent_criterion_a100_copie/1212/train_log.txt
120
+
121
+ ctc_computer: !name:speechbrain.utils.metric_stats.MetricStats
122
+ metric: !name:speechbrain.nnet.losses.ctc_loss
123
+ blank_index: 0
124
+ reduction: batch
125
+
126
+ error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
127
+
128
+ cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
129
+ merge_tokens: true
130
+
131
+ coer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
132
+ extract_concepts_values: true
133
+ keep_values: false
134
+ tag_in: <
135
+ tag_out: >
136
+
137
+ cver_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
138
+ extract_concepts_values: true
139
+ keep_values: true
140
+ tag_in: <
141
+ tag_out: >
142
+
143
+ tokenizer: !new:speechbrain.dataio.encoder.CTCTextEncoder
144
+
145
+ pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
146
+ loadables:
147
+ model: !ref <model>
148
+ wav2vec: !ref <wav2vec>
149
+ tokenizer: !ref <tokenizer>
150
+ paths:
151
+ model: !ref /content/sample_data/SLU/model.cpkt
152
+ wav2vec: !ref /content/sample_data/SLU/wav2vec.cpkt
153
+ tokenizer: !ref /content/sample_data/SLU/label_encoder.txt
154
+
155
+ decoding_function: !name:speechbrain.decoders.ctc_greedy_decode
156
+ blank_id: 0
157
+
158
+ # Tag list:
159
+ tag_list: <politeness>, <directives_query>, <directives_answer>, <age>, <age_req>,
160
+ <age_ticket>, <an>, <answer>, <arrival_time>, <card_price>, <card_type>, <city>,
161
+ <city_name_arrival>, <city_name_before>, <city_name_departure>, <city_name_direction>,
162
+ <class_number>, <class_type>, <command_task>, <comparatif_age>, <comparatif_distance>,
163
+ <comparatif_price>, <comparatif_time>, <coreference_city>, <coreference_departure>,
164
+ <date>, <day>, <departure_time>, <discount_gain>, <discount_pourcent>, <duration>,
165
+ <duration_req>, <existance>, <existance_req>, <hour_req>, <money_exchange>, <month>,
166
+ <negation>, <number>, <number_class>, <number_of_train>, <number_req>, <object>,
167
+ <option>, <other_transport>, <part_price>, <part_time>, <period_day>, <period_year>,
168
+ <person_name>, <price_req>, <rang>, <ref_object>, <ref_person>, <ref_time>, <relative_day>,
169
+ <relative_time>, <state>, <tarif>, <task>, <ticket_number>, <ticket_price>, <ticket_type>,
170
+ <time>, <train_type>
labelencoder.txt ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '<politeness>' => 109
2
+ '_' => 1
3
+ 'A' => 2
4
+ 'y' => 3
5
+ 't' => 4
6
+ 'f' => 5
7
+ 'D' => 6
8
+ 'l' => 7
9
+ 'x' => 8
10
+ 'w' => 9
11
+ '<directives_query>' => 10
12
+ 'm' => 11
13
+ 'E' => 12
14
+ '<hour_req>' => 13
15
+ 'q' => 14
16
+ '3' => 15
17
+ '>' => 16
18
+ 'b' => 17
19
+ 'h' => 18
20
+ '<object>' => 19
21
+ 'r' => 20
22
+ 'n' => 21
23
+ '<directives_answer>' => 22
24
+ '<departure_time>' => 23
25
+ 's' => 24
26
+ '<existance_req>' => 25
27
+ 'v' => 26
28
+ '<ref_object>' => 27
29
+ 'H' => 28
30
+ 'd' => 29
31
+ '<relative_time>' => 30
32
+ '<answer>' => 31
33
+ 'k' => 32
34
+ 'ç' => 33
35
+ '<coreference_departure>' => 34
36
+ '<existance>' => 35
37
+ '<ticket_number>' => 36
38
+ 'z' => 37
39
+ '<city_name_arrival>' => 38
40
+ 'S' => 39
41
+ 'j' => 40
42
+ '<train_type>' => 41
43
+ '9' => 42
44
+ 'g' => 43
45
+ '<arrival_time>' => 44
46
+ '<command_task>' => 45
47
+ 'T' => 46
48
+ '<ticket_price>' => 47
49
+ '<discount_gain>' => 48
50
+ '<discount_pourcent>' => 49
51
+ '<number_of_train>' => 50
52
+ '<person_name>' => 51
53
+ '<comparatif_time>' => 52
54
+ '<card_type>' => 53
55
+ '<relative_day>' => 54
56
+ '<negation>' => 55
57
+ '<price_req>' => 56
58
+ '<class_type>' => 57
59
+ '<money_exchange>' => 58
60
+ '<card_price>' => 59
61
+ '<ticket_type>' => 60
62
+ '<city_name_direction>' => 61
63
+ '<other_transport>' => 62
64
+ 'Z' => 63
65
+ '7' => 64
66
+ '<age_ticket>' => 65
67
+ '<comparatif_age>' => 66
68
+ '<age>' => 67
69
+ '<tarif>' => 68
70
+ '<rang>' => 69
71
+ '<part_time>' => 70
72
+ '<period_day>' => 71
73
+ '<duration_req>' => 72
74
+ '<number>' => 73
75
+ '<part_price>' => 74
76
+ 'ڥ' => 75
77
+ '<day>' => 76
78
+ '<coreference_city>' => 77
79
+ '<ref_time>' => 78
80
+ '<state>' => 79
81
+ '<city_name_departure>' => 80
82
+ '<comparatif_price>' => 81
83
+ '<duration>' => 82
84
+ '.' => 83
85
+ '<city_name_before>' => 84
86
+ '<date>' => 85
87
+ '<ref_person>' => 86
88
+ '<comparatif_distance>' => 87
89
+ '<number_req>' => 88
90
+ '<age_req>' => 89
91
+ '<option>' => 90
92
+ '<time>' => 91
93
+ '<an>' => 92
94
+ '<period_year>' => 93
95
+ '<month>' => 94
96
+ '$' => 95
97
+ 'i' => 96
98
+ 'e' => 97
99
+ 'c' => 98
100
+ 'u' => 99
101
+ 'a' => 100
102
+ 'p' => 101
103
+ 'o' => 102
104
+ '<class_number>' => 103
105
+ '<directives_answer_request>' => 104
106
+ '<task>' => 105
107
+ '<city>' => 106
108
+ '<directives_request>' => 107
109
+ '<number_class>' => 108
110
+ '<blank>' => 0
111
+ ================
112
+ 'starting_index' => 0
113
+ 'blank_label' => '<blank>'
lr_annealing.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c4ea943b3cc3d6c91aa6843cf37362ffcad693e8f4cddfb85159458cc445598
3
+ size 697
lr_annealing_wav2vec.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9043595d8cb86f5dc698ec4c3880a6eba4ba0994c1389703069a1ddac323e905
3
+ size 713
model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94ad8f0789775a5708c8a5c365e1f5d7442270963566248075043d606570884d
3
+ size 4663251
optimizer.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a18feb3922345456cb19d72567f0145816f4e7936d4e07917d35e50103c7bd0
3
+ size 9326243
optimizer_wav2vec.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2acedf6d0996452544892ba315e242de4ef1bb38fef3609e355a1b7d3e51903
3
+ size 2524050533
wav2vec.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e85d339d968c46bb6acb664586d8a11fcfa247f7f77546735a040649a47d8f4
3
+ size 1262004913