anonymoussubmitter222 commited on
Commit
d081411
1 Parent(s): 7cda826

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. EnglishCV/common_voice_prepare.py +410 -0
  3. EnglishCV/results/final_cs/hyperparams.yaml +144 -0
  4. EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/CKPT.yaml +4 -0
  5. EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/brain.ckpt +3 -0
  6. EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/counter.ckpt +3 -0
  7. EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/dataloader-TRAIN.ckpt +3 -0
  8. EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/model.ckpt +3 -0
  9. EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/modelopt.ckpt +3 -0
  10. EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/scheduler_encoder.ckpt +3 -0
  11. EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/scheduler_model.ckpt +3 -0
  12. EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/tokenizer.ckpt +3 -0
  13. EnglishCV/results/final_cs/save/label_encoder.txt +80 -0
  14. EnglishCV/results/final_cs/train_mixer.py +756 -0
  15. EnglishCV/results/wav2vec2_ctc_en/1234/hyperparams.yaml +190 -0
  16. EnglishCV/results/wav2vec2_ctc_en/1234/save/29_char.model +3 -0
  17. EnglishCV/results/wav2vec2_ctc_en/1234/save/29_char.vocab +28 -0
  18. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/CKPT.yaml +4 -0
  19. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/brain.ckpt +3 -0
  20. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/counter.ckpt +3 -0
  21. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/dataloader-TRAIN.ckpt +3 -0
  22. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/model.ckpt +3 -0
  23. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/modelopt.ckpt +3 -0
  24. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/scheduler_model.ckpt +3 -0
  25. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/scheduler_wav2vec.ckpt +3 -0
  26. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/wav2vec2.ckpt +3 -0
  27. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/wav2vec_opt.ckpt +3 -0
  28. EnglishCV/results/wav2vec2_ctc_en/1234/train_with_wav2vec.py +388 -0
  29. EnglishCV/train_en_with_wav2vec.yaml +184 -0
  30. EnglishCV/train_with_wav2vec.py +388 -0
  31. README.md +17 -9
  32. TunisianASR/README.md +21 -0
  33. TunisianASR/results/14epoch_tunisian/1234/env.log +347 -0
  34. TunisianASR/results/14epoch_tunisian/1234/hyperparams.yaml +194 -0
  35. TunisianASR/results/14epoch_tunisian/1234/log.txt +359 -0
  36. TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/CKPT.yaml +4 -0
  37. TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/brain.ckpt +3 -0
  38. TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/counter.ckpt +3 -0
  39. TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/dataloader-TRAIN.ckpt +3 -0
  40. TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/model.ckpt +3 -0
  41. TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/modelopt.ckpt +3 -0
  42. TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/scheduler_model.ckpt +3 -0
  43. TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/scheduler_wav2vec.ckpt +3 -0
  44. TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/wav2vec2.ckpt +3 -0
  45. TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/wav2vec_opt.ckpt +3 -0
  46. TunisianASR/results/14epoch_tunisian/1234/save/label_encoder.txt +44 -0
  47. TunisianASR/results/14epoch_tunisian/1234/train_with_wav2vec.py +399 -0
  48. TunisianASR/results/14epoch_tunisian/<seed>/copy_of_wavlm_tun.py +761 -0
  49. TunisianASR/results/14epoch_tunisian/<seed>/ctc_lin.py +756 -0
  50. TunisianASR/results/14epoch_tunisian/<seed>/env.log +97 -0
.gitattributes CHANGED
@@ -30,6 +30,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.arpa filter=lfs diff=lfs merge=lfs -text
34
  *.zip filter=lfs diff=lfs merge=lfs -text
35
  *.zst filter=lfs diff=lfs merge=lfs -text
36
  *tfevents* filter=lfs diff=lfs merge=lfs -text
EnglishCV/common_voice_prepare.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data preparation.
3
+ Download: https://voice.mozilla.org/en/datasets
4
+ Author
5
+ ------
6
+ Titouan Parcollet
7
+ Luca Della Libera 2022
8
+ Pooneh Mousavi 2022
9
+ """
10
+
11
+ from dataclasses import dataclass
12
+ import os
13
+ import csv
14
+ import re
15
+ import logging
16
+ import torchaudio
17
+ from tqdm import tqdm
18
+ import unicodedata
19
+ import functools
20
+ torchaudio.set_audio_backend("soundfile")
21
+ from speechbrain.utils.parallel import parallel_map
22
+ from speechbrain.dataio.dataio import read_audio_info
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def prepare_common_voice(
28
+ data_folder,
29
+ save_folder,
30
+ train_tsv_file=None,
31
+ dev_tsv_file=None,
32
+ test_tsv_file=None,
33
+ accented_letters=False,
34
+ language="en",
35
+ skip_prep=False,
36
+ ):
37
+ """
38
+ Prepares the csv files for the Mozilla Common Voice dataset.
39
+ Download: https://voice.mozilla.org/en/datasets
40
+ Arguments
41
+ ---------
42
+ data_folder : str
43
+ Path to the folder where the original Common Voice dataset is stored.
44
+ This path should include the lang: /datasets/CommonVoice/<language>/
45
+ save_folder : str
46
+ The directory where to store the csv files.
47
+ train_tsv_file : str, optional
48
+ Path to the Train Common Voice .tsv file (cs)
49
+ dev_tsv_file : str, optional
50
+ Path to the Dev Common Voice .tsv file (cs)
51
+ test_tsv_file : str, optional
52
+ Path to the Test Common Voice .tsv file (cs)
53
+ accented_letters : bool, optional
54
+ Defines if accented letters will be kept as individual letters or
55
+ transformed to the closest non-accented letters.
56
+ language: str
57
+ Specify the language for text normalization.
58
+ skip_prep: bool
59
+ If True, skip data preparation.
60
+ Example
61
+ -------
62
+ >>> from recipes.CommonVoice.common_voice_prepare import prepare_common_voice
63
+ >>> data_folder = '/datasets/CommonVoice/en'
64
+ >>> save_folder = 'exp/CommonVoice_exp'
65
+ >>> train_tsv_file = '/datasets/CommonVoice/en/train.tsv'
66
+ >>> dev_tsv_file = '/datasets/CommonVoice/en/dev.tsv'
67
+ >>> test_tsv_file = '/datasets/CommonVoice/en/test.tsv'
68
+ >>> accented_letters = False
69
+ >>> duration_threshold = 10
70
+ >>> prepare_common_voice( \
71
+ data_folder, \
72
+ save_folder, \
73
+ train_tsv_file, \
74
+ dev_tsv_file, \
75
+ test_tsv_file, \
76
+ accented_letters, \
77
+ language="en" \
78
+ )
79
+ """
80
+
81
+ if skip_prep:
82
+ return
83
+
84
+ # If not specified point toward standard location w.r.t CommonVoice tree
85
+ if train_tsv_file is None:
86
+ train_tsv_file = data_folder + "/train.tsv"
87
+ else:
88
+ train_tsv_file = train_tsv_file
89
+
90
+ if dev_tsv_file is None:
91
+ dev_tsv_file = data_folder + "/dev.tsv"
92
+ else:
93
+ dev_tsv_file = dev_tsv_file
94
+
95
+ if test_tsv_file is None:
96
+ test_tsv_file = data_folder + "/test.tsv"
97
+ else:
98
+ test_tsv_file = test_tsv_file
99
+
100
+ # Setting the save folder
101
+ if not os.path.exists(save_folder):
102
+ os.makedirs(save_folder)
103
+
104
+ # Setting ouput files
105
+ save_csv_train = save_folder + "/train.csv"
106
+ save_csv_dev = save_folder + "/dev.csv"
107
+ save_csv_test = save_folder + "/test.csv"
108
+
109
+ # If csv already exists, we skip the data preparation
110
+ if skip(save_csv_train, save_csv_dev, save_csv_test):
111
+
112
+ msg = "%s already exists, skipping data preparation!" % (save_csv_train)
113
+ logger.info(msg)
114
+
115
+ msg = "%s already exists, skipping data preparation!" % (save_csv_dev)
116
+ logger.info(msg)
117
+
118
+ msg = "%s already exists, skipping data preparation!" % (save_csv_test)
119
+ logger.info(msg)
120
+
121
+ return
122
+
123
+ # Additional checks to make sure the data folder contains Common Voice
124
+ check_commonvoice_folders(data_folder)
125
+ # Creating csv files for {train, dev, test} data
126
+ file_pairs = zip(
127
+ [train_tsv_file, dev_tsv_file, test_tsv_file],
128
+ [save_csv_train, save_csv_dev, save_csv_test],
129
+ )
130
+ for tsv_file, save_csv in file_pairs:
131
+ create_csv(
132
+ tsv_file, save_csv, data_folder, accented_letters, language,
133
+ )
134
+
135
+
136
+ def skip(save_csv_train, save_csv_dev, save_csv_test):
137
+ """
138
+ Detects if the Common Voice data preparation has been already done.
139
+ If the preparation has been done, we can skip it.
140
+ Returns
141
+ -------
142
+ bool
143
+ if True, the preparation phase can be skipped.
144
+ if False, it must be done.
145
+ """
146
+
147
+ # Checking folders and save options
148
+ skip = False
149
+
150
+ if (
151
+ os.path.isfile(save_csv_train)
152
+ and os.path.isfile(save_csv_dev)
153
+ and os.path.isfile(save_csv_test)
154
+ ):
155
+ skip = True
156
+
157
+ return skip
158
+
159
+
160
+ @dataclass
161
+ class CVRow:
162
+ snt_id: str
163
+ duration: float
164
+ mp3_path: str
165
+ spk_id: str
166
+ words: str
167
+
168
+
169
+ def process_line(line, data_folder, language, accented_letters):
170
+ # Path is at indice 1 in Common Voice tsv files. And .mp3 files
171
+ # are located in datasets/lang/clips/
172
+ mp3_path = data_folder + "/clips/" + line.split("\t")[1]
173
+ file_name = mp3_path.split(".")[-2].split("/")[-1]
174
+ spk_id = line.split("\t")[0]
175
+ snt_id = file_name
176
+
177
+ # Setting torchaudio backend to sox-io (needed to read mp3 files)
178
+ """
179
+ if torchaudio.get_audio_backend() != "sox_io":
180
+ logger.warning("This recipe needs the sox-io backend of torchaudio")
181
+ logger.warning("The torchaudio backend is changed to sox_io")
182
+ torchaudio.set_audio_backend("sox_io")
183
+ """
184
+ # Reading the signal (to retrieve duration in seconds)
185
+ if os.path.isfile(mp3_path):
186
+ info = read_audio_info(mp3_path)
187
+ else:
188
+ msg = "\tError loading: %s" % (str(len(file_name)))
189
+ logger.info(msg)
190
+ return None
191
+
192
+ duration = info.num_frames / info.sample_rate
193
+
194
+ # Getting transcript
195
+ words = line.split("\t")[2]
196
+
197
+ # Unicode Normalization
198
+ words = unicode_normalisation(words)
199
+
200
+ # !! Language specific cleaning !!
201
+ words = language_specific_preprocess(language, words)
202
+
203
+ # Remove accents if specified
204
+ if not accented_letters:
205
+ words = strip_accents(words)
206
+ words = words.replace("'", " ")
207
+ words = words.replace("’", " ")
208
+
209
+ # Remove multiple spaces
210
+ words = re.sub(" +", " ", words)
211
+
212
+ # Remove spaces at the beginning and the end of the sentence
213
+ words = words.lstrip().rstrip()
214
+
215
+ # Getting chars
216
+ chars = words.replace(" ", "_")
217
+ chars = " ".join([char for char in chars][:])
218
+
219
+ # Remove too short sentences (or empty):
220
+ if language in ["ja", "ch"]:
221
+ if len(chars) < 3:
222
+ return None
223
+ else:
224
+ if len(words.split(" ")) < 3:
225
+ return None
226
+
227
+ # Composition of the csv_line
228
+ return CVRow(snt_id, duration, mp3_path, spk_id, words)
229
+
230
+
231
+ def create_csv(
232
+ orig_tsv_file, csv_file, data_folder, accented_letters=False, language="en"
233
+ ):
234
+ """
235
+ Creates the csv file given a list of wav files.
236
+ Arguments
237
+ ---------
238
+ orig_tsv_file : str
239
+ Path to the Common Voice tsv file (standard file).
240
+ data_folder : str
241
+ Path of the CommonVoice dataset.
242
+ accented_letters : bool, optional
243
+ Defines if accented letters will be kept as individual letters or
244
+ transformed to the closest non-accented letters.
245
+ Returns
246
+ -------
247
+ None
248
+ """
249
+
250
+ # Check if the given files exists
251
+ if not os.path.isfile(orig_tsv_file):
252
+ msg = "\t%s doesn't exist, verify your dataset!" % (orig_tsv_file)
253
+ logger.info(msg)
254
+ raise FileNotFoundError(msg)
255
+
256
+ # We load and skip the header
257
+ loaded_csv = open(orig_tsv_file, "r").readlines()[1:]
258
+ nb_samples = len(loaded_csv)
259
+
260
+ msg = "Preparing CSV files for %s samples ..." % (str(nb_samples))
261
+ logger.info(msg)
262
+
263
+ # Adding some Prints
264
+ msg = "Creating csv lists in %s ..." % (csv_file)
265
+ logger.info(msg)
266
+
267
+ # Process and write lines
268
+ total_duration = 0.0
269
+
270
+ line_processor = functools.partial(
271
+ process_line,
272
+ data_folder=data_folder,
273
+ language=language,
274
+ accented_letters=accented_letters,
275
+ )
276
+
277
+ # Stream into a .tmp file, and rename it to the real path at the end.
278
+ csv_file_tmp = csv_file + ".tmp"
279
+
280
+ with open(csv_file_tmp, mode="w", encoding="utf-8") as csv_f:
281
+ csv_writer = csv.writer(
282
+ csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL
283
+ )
284
+
285
+ csv_writer.writerow(["ID", "duration", "wav", "spk_id", "wrd"])
286
+ for line in tqdm(loaded_csv) :
287
+
288
+ row = line_processor(line)
289
+ if row is not None :
290
+ total_duration += row.duration
291
+ csv_writer.writerow(
292
+ [
293
+ row.snt_id,
294
+ str(row.duration),
295
+ row.mp3_path,
296
+ row.spk_id,
297
+ row.words,
298
+ ]
299
+ )
300
+
301
+ os.replace(csv_file_tmp, csv_file)
302
+
303
+ # Final prints
304
+ msg = "%s successfully created!" % (csv_file)
305
+ logger.info(msg)
306
+ msg = "Number of samples: %s " % (str(len(loaded_csv)))
307
+ logger.info(msg)
308
+ msg = "Total duration: %s Hours" % (str(round(total_duration / 3600, 2)))
309
+ logger.info(msg)
310
+
311
+
312
+ def language_specific_preprocess(language, words):
313
+ # !! Language specific cleaning !!
314
+ # Important: feel free to specify the text normalization
315
+ # corresponding to your alphabet.
316
+
317
+ if language in ["en", "fr", "it", "rw"]:
318
+ words = re.sub(
319
+ "[^’'A-Za-z0-9À-ÖØ-öø-ÿЀ-ӿéæœâçèàûî]+", " ", words
320
+ ).upper()
321
+
322
+ if language == "de":
323
+ # this replacement helps preserve the case of ß
324
+ # (and helps retain solitary occurrences of SS)
325
+ # since python's upper() converts ß to SS.
326
+ words = words.replace("ß", "0000ß0000")
327
+ words = re.sub("[^’'A-Za-z0-9öÖäÄüÜß]+", " ", words).upper()
328
+ words = words.replace("'", " ")
329
+ words = words.replace("’", " ")
330
+ words = words.replace(
331
+ "0000SS0000", "ß"
332
+ ) # replace 0000SS0000 back to ß as its initial presence in the corpus
333
+
334
+ if language == "fr":
335
+ # Replace J'y D'hui etc by J_ D_hui
336
+ words = words.replace("'", " ")
337
+ words = words.replace("’", " ")
338
+
339
+ elif language == "ar":
340
+ HAMZA = "\u0621"
341
+ ALEF_MADDA = "\u0622"
342
+ ALEF_HAMZA_ABOVE = "\u0623"
343
+ letters = (
344
+ "ابتةثجحخدذرزژشسصضطظعغفقكلمنهويىءآأؤإئ"
345
+ + HAMZA
346
+ + ALEF_MADDA
347
+ + ALEF_HAMZA_ABOVE
348
+ )
349
+ words = re.sub("[^" + letters + " ]+", "", words).upper()
350
+ elif language == "fa":
351
+ HAMZA = "\u0621"
352
+ ALEF_MADDA = "\u0622"
353
+ ALEF_HAMZA_ABOVE = "\u0623"
354
+ letters = (
355
+ "ابپتةثجحخچدذرزژسشصضطظعغفقگکلمنهویىءآأؤإئ"
356
+ + HAMZA
357
+ + ALEF_MADDA
358
+ + ALEF_HAMZA_ABOVE
359
+ )
360
+ words = re.sub("[^" + letters + " ]+", "", words).upper()
361
+ elif language == "ga-IE":
362
+ # Irish lower() is complicated, but upper() is nondeterministic, so use lowercase
363
+ def pfxuc(a):
364
+ return len(a) >= 2 and a[0] in "tn" and a[1] in "AEIOUÁÉÍÓÚ"
365
+
366
+ def galc(w):
367
+ return w.lower() if not pfxuc(w) else w[0] + "-" + w[1:].lower()
368
+
369
+ words = re.sub("[^-A-Za-z'ÁÉÍÓÚáéíóú]+", " ", words)
370
+ words = " ".join(map(galc, words.split(" ")))
371
+ elif language == "es":
372
+ # Fix the following error in dataset large:
373
+ # KeyError: 'The item En noviembre lanzaron Queen Elizabeth , coproducida por Foreign Noi$e . requires replacements which were not supplied.'
374
+ words = words.replace("$", "s")
375
+ return words
376
+
377
+
378
+ def check_commonvoice_folders(data_folder):
379
+ """
380
+ Check if the data folder actually contains the Common Voice dataset.
381
+ If not, raises an error.
382
+ Returns
383
+ -------
384
+ None
385
+ Raises
386
+ ------
387
+ FileNotFoundError
388
+ If data folder doesn't contain Common Voice dataset.
389
+ """
390
+ files_str = "/clips"
391
+ # Checking clips
392
+ if not os.path.exists(data_folder + files_str):
393
+ err_msg = (
394
+ "the folder %s does not exist (it is expected in "
395
+ "the Common Voice dataset)" % (data_folder + files_str)
396
+ )
397
+ raise FileNotFoundError(err_msg)
398
+
399
+
400
+ def unicode_normalisation(text):
401
+ return str(text)
402
+
403
+
404
+ def strip_accents(text):
405
+ text = (
406
+ unicodedata.normalize("NFD", text)
407
+ .encode("ascii", "ignore")
408
+ .decode("utf-8")
409
+ )
410
+ return str(text)
EnglishCV/results/final_cs/hyperparams.yaml ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated 2023-09-08 from:
2
+ # /gpfsssd/scratch/rech/nou/uzn19yk/switched_data/stac.yaml
3
+ # yamllint disable
4
+ # Generated 2023-08-03 from:
5
+ # /home/salah/new_tunisian_model/hparams/train_tunisian_withwavlm.yaml
6
+ # yamllint disable
7
+ # ################################
8
+ # Model: wav2vec2 + DNN + CTC
9
+ # Augmentation: SpecAugment
10
+ # Authors: Titouan Parcollet 2021
11
+ # ################################
12
+
13
+ seed: 1994
14
+ __set_seed: !!python/object/apply:torch.manual_seed [1234]
15
+ output_folder: results/non_semi_final_stac
16
+ wer_file: results/non_semi_final_stac/wer.txt
17
+ save_folder: results/non_semi_final_stac/save
18
+ train_log: results/non_semi_final_stac/train_log.txt
19
+
20
+
21
+
22
+ # Data files
23
+ data_folder: junk # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
24
+ train_tsv_file: junk/train.tsv # Standard CommonVoice .tsv files
25
+ dev_tsv_file: junk/dev.tsv # Standard CommonVoice .tsv files
26
+ test_tsv_file: junk/test.tsv # Standard CommonVoice .tsv files
27
+ accented_letters: true
28
+
29
+ csv_folder: /gpfsscratch/rech/nou/uzn19yk/switched_data/extended_clean/
30
+ train_csv: /gpfsscratch/rech/nou/uzn19yk/switched_data/extended_clean//train.csv
31
+ valid_csv: /gpfsscratch/rech/nou/uzn19yk/switched_data/extended_clean//dev.csv
32
+ test_csv:
33
+ - all_tests/cs_test.csv
34
+ - all_tests/stac_test.csv
35
+
36
+ # We remove utterance slonger than 10s in the train/dev/test sets as
37
+ # longer sentences certainly correspond to "open microphones".
38
+ avoid_if_longer_than: 13.0
39
+ avoid_if_shorter_than: 0.5
40
+
41
+ # Training parameters
42
+ number_of_epochs: 20
43
+ lr: 0.0002
44
+ lr_weights: 0.01
45
+ sorting: ascending
46
+ auto_mix_prec: false
47
+ sample_rate: 16000
48
+ language_modelling: true
49
+ ngram_lm_path:
50
+ /gpfsstore/rech/nou/uzn19yk/switched_code_tunisian/train/tunisian_asr/arpas/pluslanguages_everything.arpa
51
+
52
+ # With data_parallel batch_size is split into N jobs
53
+ # With DDP batch_size is multiplied by N jobs
54
+ # Must be 3 per GPU to fit 32GB of VRAM
55
+ batch_size: 3
56
+ test_batch_size: 4
57
+
58
+ # Dataloader options
59
+ dataloader_options:
60
+ batch_size: 3
61
+ num_workers: 6
62
+
63
+ test_dataloader_options:
64
+ batch_size: 4
65
+ num_workers: 6
66
+
67
+ # Model parameters
68
+ activation: !name:torch.nn.Sigmoid
69
+ dnn_layers: 1
70
+ dnn_neurons: 768
71
+ freeze_encoder: true
72
+
73
+ # Outputs
74
+ output_neurons: 76 # BPE size, index(blank/eos/bos) = 0
75
+
76
+ # Functions and classes
77
+ #
78
+ epoch_counter: &id006 !new:speechbrain.utils.epoch_loop.EpochCounter
79
+ limit: 20
80
+
81
+ encoder_dim: 3217
82
+ enc: &id001 !new:speechbrain.nnet.RNN.LSTM
83
+ input_shape: [null, null, 3217]
84
+ num_layers: 2
85
+ bidirectional: true
86
+ dropout: 0.2
87
+ hidden_size: 1024
88
+
89
+ ctc_lin: &id002 !new:speechbrain.nnet.linear.Linear
90
+
91
+ input_size: 2048
92
+ n_neurons: 76
93
+
94
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
95
+ apply_log: true
96
+
97
+ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
98
+ blank_index: 0
99
+
100
+ modules:
101
+ enc: *id001
102
+ ctc_lin: *id002
103
+ model: &id003 !new:torch.nn.ModuleList
104
+ - [*id001, *id002]
105
+ model_opt_class: !name:torch.optim.Adam
106
+ lr: 0.0002
107
+
108
+ weights_opt_class: !name:torch.optim.Adam
109
+ lr: 0.01
110
+
111
+ lr_annealing_model: &id004 !new:speechbrain.nnet.schedulers.NewBobScheduler
112
+ initial_value: 0.0002
113
+ improvement_threshold: 0.0025
114
+ annealing_factor: 0.8
115
+ patient: 0
116
+
117
+ lr_annealing_weights: &id005 !new:speechbrain.nnet.schedulers.NewBobScheduler
118
+ initial_value: 0.01
119
+ improvement_threshold: 0.0025
120
+ annealing_factor: 0.9
121
+ patient: 0
122
+
123
+ label_encoder: &id007 !new:speechbrain.dataio.encoder.CTCTextEncoder
124
+
125
+
126
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
127
+ checkpoints_dir: results/non_semi_final_stac/save
128
+ recoverables:
129
+ model: *id003
130
+ scheduler_model: *id004
131
+ scheduler_encoder: *id005
132
+ counter: *id006
133
+ tokenizer: *id007
134
+ blank_index: 0
135
+ unk_index: 1
136
+
137
+
138
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
139
+ save_file: results/non_semi_final_stac/train_log.txt
140
+
141
+ error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
142
+
143
+ cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
144
+ split_tokens: true
EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/CKPT.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # yamllint disable
2
+ WER: 51.292116454039906
3
+ end-of-epoch: true
4
+ unixtime: 1694130018.9642384
EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/brain.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5c026fe6fa51700406bd476e131950c797b0b3bacb3daae0854e85689bb4cf9
3
+ size 50
EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/counter.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5ca38f748a1d6eaf726b8a42fb575c3c71f1864a8143301782de13da2d9202b
3
+ size 2
EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/dataloader-TRAIN.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7e1edcac43af8cea1439d222314af06354ae31da6a3d90b8cc6bcebc5c8e397
3
+ size 4
EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da683a8efa5709a06af9b258452c243da841780a0a7942c196c472a3e21e5010
3
+ size 240389017
EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/modelopt.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:416feb314443cf839f4425fc382e555dec90e3dea26fa52b75e4ac1b702c5078
3
+ size 480787579
EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/scheduler_encoder.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e2efd50f0cf28a080e2625fdd8a1852c669841537cdc0a57fce60bc6c1eec11
3
+ size 515
EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/scheduler_model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cec54cc9236fa7aa965b397675d24299b973675cc0c6345de038fc70e51629ab
3
+ size 703
EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/tokenizer.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21080a140faeb4f39fad188aaf081914ec782be9c4320d6415e8822709e18017
3
+ size 39
EnglishCV/results/final_cs/save/label_encoder.txt ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'و' => 74
2
+ 'ي' => 1
3
+ 'ن' => 2
4
+ ' ' => 3
5
+ 'م' => 4
6
+ 'ش' => 5
7
+ 'ل' => 6
8
+ 'س' => 7
9
+ 'ت' => 8
10
+ 'ا' => 9
11
+ 'د' => 10
12
+ 'ر' => 11
13
+ 'ى' => 12
14
+ 'ب' => 13
15
+ 'ح' => 14
16
+ 'ط' => 15
17
+ 'ع' => 16
18
+ 'ك' => 17
19
+ 'ف' => 18
20
+ 'ق' => 19
21
+ 'ذ' => 20
22
+ 'ث' => 21
23
+ 'ج' => 22
24
+ 'ة' => 23
25
+ 'غ' => 24
26
+ 'o' => 25
27
+ 'k' => 26
28
+ 'b' => 27
29
+ 'n' => 28
30
+ 'خ' => 29
31
+ 'ه' => 30
32
+ 'v' => 31
33
+ 'i' => 32
34
+ 'l' => 33
35
+ 'à' => 34
36
+ 'ص' => 35
37
+ 'ض' => 36
38
+ 'a' => 37
39
+ 'u' => 38
40
+ 't' => 39
41
+ 'm' => 40
42
+ 'q' => 41
43
+ 'e' => 42
44
+ 'd' => 43
45
+ 'c' => 44
46
+ 'p' => 45
47
+ 'r' => 46
48
+ 'أ' => 47
49
+ 'إ' => 48
50
+ 's' => 49
51
+ 'j' => 50
52
+ 'ز' => 51
53
+ 'ء' => 52
54
+ 'h' => 53
55
+ 'f' => 54
56
+ 'آ' => 55
57
+ 'ئ' => 56
58
+ 'ؤ' => 57
59
+ 'ظ' => 58
60
+ 'y' => 59
61
+ 'é' => 60
62
+ "'" => 61
63
+ 'z' => 62
64
+ 'x' => 63
65
+ 'w' => 64
66
+ 'g' => 65
67
+ 'è' => 66
68
+ 'û' => 67
69
+ 'ç' => 68
70
+ 'ê' => 69
71
+ 'ô' => 70
72
+ 'ù' => 71
73
+ 'î' => 72
74
+ 'â' => 73
75
+ '<blank>' => 0
76
+ 1 => 75
77
+ ================
78
+ 'starting_index' => 0
79
+ 'unk_label' => 1
80
+ 'blank_label' => '<blank>'
EnglishCV/results/final_cs/train_mixer.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import sys
5
+ import torch
6
+ import logging
7
+ import speechbrain as sb
8
+ from speechbrain.utils.distributed import run_on_main
9
+ from hyperpyyaml import load_hyperpyyaml
10
+ from pathlib import Path
11
+ import torchaudio.transforms as T
12
+ from cv_train import ASRCV
13
+ import torchaudio
14
+ import numpy as np
15
+ import kenlm
16
+ from pyctcdecode import build_ctcdecoder
17
+ import re
18
+
19
+ # Commented out IPython magic to ensure Python compatibility.
20
+ # %cd /content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm
21
+ #hparams_file, run_opts, overrides = sb.parse_arguments(["/gpfsstore/rech/nou/uzn19yk/switched_code_tunisian/train/tunisian_asr/hparams/train_semi.yaml"])
22
+ hparams_file, run_opts, overrides = sb.parse_arguments(["semi_supervised_test_tunisian.yaml"])
23
+
24
+ # If distributed_launch=True then
25
+ # create ddp_group with the right communication protocol
26
+ sb.utils.distributed.ddp_init_group(run_opts)
27
+
28
+ with open(hparams_file) as fin:
29
+ hparams = load_hyperpyyaml(fin, overrides)
30
+
31
+ # Create experiment directory
32
+ sb.create_experiment_directory(
33
+ experiment_directory=hparams["output_folder"],
34
+ hyperparams_to_save=hparams_file,
35
+ overrides=overrides,
36
+ )
37
+ # Dataset prep (parsing Librispeech)
38
+
39
+ def dataio_prepare(hparams):
40
+ """This function prepares the datasets to be used in the brain class.
41
+ It also defines the data processing pipeline through user-defined functions."""
42
+
43
+ # 1. Define datasets
44
+ data_folder = hparams["data_folder"]
45
+
46
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
47
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
48
+ )
49
+
50
+ if hparams["sorting"] == "ascending":
51
+ # we sort training data to speed up training and get better results.
52
+ train_data = train_data.filtered_sorted(
53
+ sort_key="duration",
54
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
55
+ )
56
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
57
+ hparams["dataloader_options"]["shuffle"] = False
58
+
59
+ elif hparams["sorting"] == "descending":
60
+ train_data = train_data.filtered_sorted(
61
+ sort_key="duration",
62
+ reverse=True,
63
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
64
+ )
65
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
66
+ hparams["dataloader_options"]["shuffle"] = False
67
+
68
+ elif hparams["sorting"] == "random":
69
+ pass
70
+
71
+ else:
72
+ raise NotImplementedError(
73
+ "sorting must be random, ascending or descending"
74
+ )
75
+
76
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
77
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
78
+ )
79
+ # We also sort the validation data so it is faster to validate
80
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
81
+ test_datasets = {}
82
+ for csv_file in hparams["test_csv"]:
83
+ name = Path(csv_file).stem
84
+ test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
85
+ csv_path=csv_file, replacements={"data_root": data_folder}
86
+ )
87
+ test_datasets[name] = test_datasets[name].filtered_sorted(
88
+ sort_key="duration"
89
+ )
90
+
91
+ datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
92
+
93
+
94
+ # 2. Define audio pipeline:
95
+ @sb.utils.data_pipeline.takes("wav")
96
+ @sb.utils.data_pipeline.provides("sig")
97
+ def audio_pipeline(wav):
98
+ info = torchaudio.info(wav)
99
+ sig = sb.dataio.dataio.read_audio(wav)
100
+ if len(sig.shape)>1 :
101
+ sig = torch.mean(sig, dim=1)
102
+ resampled = torchaudio.transforms.Resample(
103
+ info.sample_rate, hparams["sample_rate"],
104
+ )(sig)
105
+ return resampled
106
+
107
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
108
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
109
+
110
+ # 3. Define text pipeline:
111
+ @sb.utils.data_pipeline.takes("wrd")
112
+ @sb.utils.data_pipeline.provides(
113
+ "wrd", "char_list", "tokens_list", "tokens"
114
+ )
115
+ def text_pipeline(wrd):
116
+ yield wrd
117
+ char_list = list(wrd)
118
+ yield char_list
119
+ tokens_list = label_encoder.encode_sequence(char_list)
120
+ yield tokens_list
121
+ tokens = torch.LongTensor(tokens_list)
122
+ yield tokens
123
+
124
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
125
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
126
+ special_labels = {
127
+ "blank_label": hparams["blank_index"],
128
+ "unk_label": hparams["unk_index"]
129
+ }
130
+ label_encoder.load_or_create(
131
+ path=lab_enc_file,
132
+ from_didatasets=[train_data],
133
+ output_key="char_list",
134
+ special_labels=special_labels,
135
+ sequence_input=True,
136
+ )
137
+
138
+ # 4. Set output:
139
+ sb.dataio.dataset.set_output_keys(
140
+ datasets, ["id", "sig", "wrd", "char_list", "tokens"],
141
+ )
142
+ return train_data, valid_data,test_datasets, label_encoder
143
+
144
+ class ASR(sb.core.Brain):
145
+ def compute_forward(self, batch, stage):
146
+ """Forward computations from the waveform batches to the output probabilities."""
147
+
148
+ batch = batch.to(self.device)
149
+ wavs, wav_lens = batch.sig
150
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
151
+
152
+ if stage == sb.Stage.TRAIN:
153
+ if hasattr(self.hparams, "augmentation"):
154
+ wavs = self.hparams.augmentation(wavs, wav_lens)
155
+
156
+ # Forward pass
157
+ feats = self.modules.wav2vec2(wavs, wav_lens)
158
+ x = self.modules.enc(feats)
159
+ logits = self.modules.ctc_lin(x)
160
+ p_ctc = self.hparams.log_softmax(logits)
161
+
162
+ return p_ctc, wav_lens
163
+
164
+ def custom_encode(self,wavs,wav_lens) :
165
+ wavs = wavs.to(self.device)
166
+ if(wav_lens is not None): wav_lens.to(self.device)
167
+
168
+ feats = self.modules.wav2vec2(wavs, wav_lens)
169
+ x = self.modules.enc(feats)
170
+ logits = self.modules.ctc_lin(x)
171
+ p_ctc = self.hparams.log_softmax(logits)
172
+
173
+ return feats,p_ctc
174
+
175
+
176
+
177
+ def compute_objectives(self, predictions, batch, stage):
178
+ """Computes the loss (CTC) given predictions and targets."""
179
+
180
+ p_ctc, wav_lens = predictions
181
+
182
+ ids = batch.id
183
+ tokens, tokens_lens = batch.tokens
184
+
185
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
186
+
187
+ if stage != sb.Stage.TRAIN:
188
+ predicted_tokens = sb.decoders.ctc_greedy_decode(
189
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
190
+ )
191
+ # Decode token terms to words
192
+ if self.hparams.use_language_modelling:
193
+ predicted_words = []
194
+ for logs in p_ctc:
195
+ text = decoder.decode(logs.detach().cpu().numpy())
196
+ predicted_words.append(text.split(" "))
197
+ else:
198
+ predicted_words = [
199
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
200
+ for utt_seq in predicted_tokens
201
+ ]
202
+ # Convert indices to words
203
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
204
+
205
+ self.wer_metric.append(ids, predicted_words, target_words)
206
+ self.cer_metric.append(ids, predicted_words, target_words)
207
+
208
+ return loss
209
+
210
+ def fit_batch(self, batch):
211
+ """Train the parameters given a single batch in input"""
212
+ should_step = self.step % self.grad_accumulation_factor == 0
213
+ # Managing automatic mixed precision
214
+ # TOFIX: CTC fine-tuning currently is unstable
215
+ # This is certainly due to CTC being done in fp16 instead of fp32
216
+ if self.auto_mix_prec:
217
+ with torch.cuda.amp.autocast():
218
+ with self.no_sync():
219
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
220
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
221
+ with self.no_sync(not should_step):
222
+ self.scaler.scale(
223
+ loss / self.grad_accumulation_factor
224
+ ).backward()
225
+ if should_step:
226
+
227
+ if not self.hparams.wav2vec2.freeze:
228
+ self.scaler.unscale_(self.wav2vec_optimizer)
229
+ self.scaler.unscale_(self.model_optimizer)
230
+ if self.check_gradients(loss):
231
+ if not self.hparams.wav2vec2.freeze:
232
+ if self.optimizer_step >= self.hparams.warmup_steps:
233
+ self.scaler.step(self.wav2vec_optimizer)
234
+ self.scaler.step(self.model_optimizer)
235
+ self.scaler.update()
236
+ self.zero_grad()
237
+ self.optimizer_step += 1
238
+ else:
239
+ # This is mandatory because HF models have a weird behavior with DDP
240
+ # on the forward pass
241
+ with self.no_sync():
242
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
243
+
244
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
245
+
246
+ with self.no_sync(not should_step):
247
+ (loss / self.grad_accumulation_factor).backward()
248
+ if should_step:
249
+ if self.check_gradients(loss):
250
+ if not self.hparams.wav2vec2.freeze:
251
+ if self.optimizer_step >= self.hparams.warmup_steps:
252
+ self.wav2vec_optimizer.step()
253
+ self.model_optimizer.step()
254
+ self.zero_grad()
255
+ self.optimizer_step += 1
256
+
257
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
258
+ return loss.detach().cpu()
259
+
260
+ def evaluate_batch(self, batch, stage):
261
+ """Computations needed for validation/test batches"""
262
+ predictions = self.compute_forward(batch, stage=stage)
263
+ with torch.no_grad():
264
+ loss = self.compute_objectives(predictions, batch, stage=stage)
265
+ return loss.detach()
266
+
267
+ def on_stage_start(self, stage, epoch):
268
+ """Gets called at the beginning of each epoch"""
269
+ if stage != sb.Stage.TRAIN:
270
+ self.cer_metric = self.hparams.cer_computer()
271
+ self.wer_metric = self.hparams.error_rate_computer()
272
+
273
+ def on_stage_end(self, stage, stage_loss, epoch):
274
+ """Gets called at the end of an epoch."""
275
+ # Compute/store important stats
276
+ stage_stats = {"loss": stage_loss}
277
+ if stage == sb.Stage.TRAIN:
278
+ self.train_stats = stage_stats
279
+ else:
280
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
281
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
282
+
283
+ # Perform end-of-iteration things, like annealing, logging, etc.
284
+ if stage == sb.Stage.VALID:
285
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
286
+ stage_stats["loss"]
287
+ )
288
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
289
+ stage_stats["loss"]
290
+ )
291
+ sb.nnet.schedulers.update_learning_rate(
292
+ self.model_optimizer, new_lr_model
293
+ )
294
+ if not self.hparams.wav2vec2.freeze:
295
+ sb.nnet.schedulers.update_learning_rate(
296
+ self.wav2vec_optimizer, new_lr_wav2vec
297
+ )
298
+ self.hparams.train_logger.log_stats(
299
+ stats_meta={
300
+ "epoch": epoch,
301
+ "lr_model": old_lr_model,
302
+ "lr_wav2vec": old_lr_wav2vec,
303
+ },
304
+ train_stats=self.train_stats,
305
+ valid_stats=stage_stats,
306
+ )
307
+ self.checkpointer.save_and_keep_only(
308
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
309
+ )
310
+ elif stage == sb.Stage.TEST:
311
+ self.hparams.train_logger.log_stats(
312
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
313
+ test_stats=stage_stats,
314
+ )
315
+ with open(self.hparams.wer_file, "w") as w:
316
+ self.wer_metric.write_stats(w)
317
+
318
+ def init_optimizers(self):
319
+ "Initializes the wav2vec2 optimizer and model optimizer"
320
+
321
+ # If the wav2vec encoder is unfrozen, we create the optimizer
322
+ if not self.hparams.wav2vec2.freeze:
323
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
324
+ self.modules.wav2vec2.parameters()
325
+ )
326
+ if self.checkpointer is not None:
327
+ self.checkpointer.add_recoverable(
328
+ "wav2vec_opt", self.wav2vec_optimizer
329
+ )
330
+
331
+ self.model_optimizer = self.hparams.model_opt_class(
332
+ self.hparams.model.parameters()
333
+ )
334
+
335
+ if self.checkpointer is not None:
336
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
337
+
338
+ def zero_grad(self, set_to_none=False):
339
+ if not self.hparams.wav2vec2.freeze:
340
+ self.wav2vec_optimizer.zero_grad(set_to_none)
341
+ self.model_optimizer.zero_grad(set_to_none)
342
+
343
+
344
+ """
345
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
346
+
347
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
348
+ hparams
349
+ )
350
+
351
+
352
+ # We dynamicaly add the tokenizer to our brain class.
353
+ # NB: This tokenizer corresponds to the one used for the LM!!
354
+ """
355
+ from speechbrain.pretrained import EncoderASR,EncoderDecoderASR
356
+ french_asr_model = EncoderASR.from_hparams(source="speechbrain/asr-wav2vec2-commonvoice-fr", savedir="pretrained_models/asr-wav2vec2-commonvoice-fr").cuda()
357
+ #french_asr_model = "r"
358
+
359
+ cvhparams_file, cvrun_opts, cvoverrides = sb.parse_arguments(["en_cv.yaml"])
360
+ with open(cvhparams_file) as cvfin:
361
+ cvhparams = load_hyperpyyaml(cvfin, cvoverrides)
362
+ english_asr_model = ASRCV(
363
+ modules=cvhparams["modules"],
364
+ hparams=cvhparams,
365
+ run_opts=cvrun_opts,
366
+ checkpointer=cvhparams["checkpointer"],
367
+ )
368
+ english_asr_model.checkpointer.recover_if_possible()
369
+ asr_brain = ASR(
370
+ modules=hparams["modules"],
371
+ hparams=hparams,
372
+ run_opts=run_opts,
373
+ checkpointer=hparams["checkpointer"],
374
+ )
375
+ asr_brain.checkpointer.recover_if_possible()
376
+ asr_brain.modules.eval()
377
+ english_asr_model.modules.eval()
378
+ french_asr_model.mods.eval()
379
+ """
380
+ asr_brain.tokenizer = label_encoder
381
+
382
+ # Testing
383
+ real = True
384
+ if real :
385
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
386
+ asr_brain.hparams.wer_file = os.path.join(
387
+ hparams["output_folder"], "wer_{}.txt".format(k)
388
+ )
389
+ asr_brain.evaluate(
390
+ test_datasets[k], test_loader_kwargs=hparams["dataloader_options"]
391
+ )
392
+ """
393
+
394
+ """
395
+ from torch.nn.utils.rnn import pad_sequence
396
+ def load_paths(wavs_path):
397
+ waveforms = []
398
+ for path in wavs_path :
399
+ waveform, _ = torchaudio.load(path)
400
+ waveforms.append(waveform.squeeze(0))
401
+ # normalize array length to the bigger arrays by pading with 0's
402
+ padded_arrays = pad_sequence(waveforms, batch_first=True)
403
+ return torch.tensor(padded_arrays)
404
+
405
+ waveform = load_paths(["/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/samples/Salah10.wav","/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/samples/Salah10.wav"])
406
+ embeddings, posteriogram = asr_brain.custom_encode(waveform,None)
407
+ print(embeddings.shape)
408
+ print(posteriogram.shape)
409
+ """
410
+
411
+ from speechbrain.pretrained import EncoderASR,EncoderDecoderASR
412
+ import torchaudio
413
+ import speechbrain as sb
414
+ import torch
415
+ from torch.nn.utils.rnn import pad_sequence
416
+ import torch
417
+ import speechbrain as sb
418
+ import numpy as np
419
+ import torch.optim as optim
420
+ import torch.nn as nn
421
+
422
+ # Commented out IPython magic to ensure Python compatibility.
423
+ # %ls
424
+
425
+ #UTILS FUNCTIOJNS
426
+ def get_size_dimensions(arr):
427
+ size_dimensions = []
428
+ while isinstance(arr, list):
429
+ size_dimensions.append(len(arr))
430
+ arr = arr[0]
431
+ return size_dimensions
432
+
433
+ def scale_array(batch,n):
434
+ scaled_batch = []
435
+
436
+ for array in batch:
437
+ if(n < len(array)): raise ValueError("Cannot scale Array down")
438
+
439
+ repeat = round(n/len(array))+1
440
+ scaled_length_array= []
441
+
442
+ for i in array:
443
+ for j in range(repeat) :
444
+ if(len(scaled_length_array) == n): break
445
+ scaled_length_array.append(i)
446
+
447
+ scaled_batch.append(scaled_length_array)
448
+
449
+ return torch.tensor(scaled_batch)
450
+
451
+
452
+ def load_paths(wavs_path):
453
+ waveforms = []
454
+ for path in wavs_path :
455
+ waveform, _ = torchaudio.load(path)
456
+ waveforms.append(waveform.squeeze(0))
457
+ # normalize array length to the bigger arrays by pading with 0's
458
+ padded_arrays = pad_sequence(waveforms, batch_first=True)
459
+ return torch.tensor(padded_arrays)
460
+
461
+
462
+
463
+ def word_to_vec(input_string):
464
+ mapping= {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, 'ا': 27, 'ب': 28, 'ت': 29, 'ث': 30, 'ج': 31, 'ح': 32, 'خ': 33, 'د': 34, 'ذ': 35, 'ر': 36, 'ز': 37, 'س': 38, 'ش': 39, 'ص': 40, 'ض': 41, 'ط': 42, 'ظ': 43, 'ع': 44, 'غ': 45, 'ف': 46, 'ق': 47, 'ك': 48, 'ل': 49, 'م': 50, 'ن': 51, 'ه': 52, 'و': 53, 'ي': 54,' ':55}
465
+
466
+ numbers = [mapping[word] for word in input_string if word in mapping]
467
+ return numbers
468
+
469
+ device = 'cuda'
470
+ verbose = 0
471
+ #FLOW LEVEL FUNCTIONS
472
+ def merge_strategy(embeddings1, embeddings2, embeddings3,post1, post2,post3):
473
+
474
+
475
+ post1 = post1.to(device)
476
+ post2 = post2.to(device)
477
+ post3 = post3.to(device)
478
+ embeddings1 = embeddings1.to(device)
479
+ embeddings2 = embeddings2.to(device)
480
+ embeddings3 = embeddings3.to(device)
481
+
482
+ posteriograms_merged = torch.cat((post1,post2,post3),dim=2)
483
+ embeddings_merged = torch.cat((embeddings1,embeddings2,embeddings3),dim=2)
484
+
485
+ if(verbose !=0):
486
+ print('MERGED POST ',posteriograms_merged.shape)
487
+ print('MERGED emb ',embeddings_merged.shape)
488
+
489
+ return torch.cat((posteriograms_merged,embeddings_merged),dim=2).to(device)
490
+
491
+ def decode(model,wavs,wav_lens):
492
+
493
+ with torch.no_grad():
494
+ wav_lens = wav_lens.to(model.device)
495
+ encoder_out = model.encode_batch(wavs, wav_lens)
496
+ predictions = model.decoding_function(encoder_out, wav_lens)
497
+ return predictions
498
+
499
+ def middle_layer(batch, lens):
500
+
501
+ tn_embeddings, tn_posteriogram = asr_brain.custom_encode(batch,None)
502
+
503
+ fr_embeddings = french_asr_model.mods.encoder.wav2vec2(batch)
504
+ fr_posteriogram =french_asr_model.encode_batch(batch,lens)
505
+ en_embeddings = english_asr_model.modules.wav2vec2(batch, lens)
506
+ x = english_asr_model.modules.enc(en_embeddings)
507
+ en_posteriogram = english_asr_model.modules.ctc_lin(x)
508
+ #scores, en_posteriogram = english_asr_model.mods.decoder(en_embeddings ,lens)
509
+ if(verbose !=0):
510
+ print('[EMBEDDINGS] FR:',fr_embeddings.shape, "EN:",en_embeddings.shape, "TN:", tn_embeddings.shape)
511
+ print('[POSTERIOGRAM] FR:',fr_posteriogram.shape, "EN:",en_posteriogram.shape,"TN:",tn_posteriogram.shape)
512
+
513
+
514
+ bilangual_sample = merge_strategy(fr_embeddings,en_embeddings,tn_embeddings,fr_posteriogram,en_posteriogram,tn_posteriogram)
515
+ return bilangual_sample
516
+
517
+ class Mixer(sb.core.Brain):
518
+
519
+ def compute_forward(self, batch, stage):
520
+ """Forward computations from the waveform batches to the output probabilities."""
521
+ wavs, wav_lens = batch.sig
522
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
523
+
524
+ if stage == sb.Stage.TRAIN:
525
+ if hasattr(self.hparams, "augmentation"):
526
+ wavs = self.hparams.augmentation(wavs, wav_lens)
527
+
528
+ multi_langual_feats = middle_layer(wavs, wav_lens)
529
+ multi_langual_feats= multi_langual_feats.to(device)
530
+ feats, _ = self.modules.enc(multi_langual_feats)
531
+ logits = self.modules.ctc_lin(feats)
532
+ p_ctc = self.hparams.log_softmax(logits)
533
+
534
+ if stage!= sb.Stage.TRAIN:
535
+ p_tokens = sb.decoders.ctc_greedy_decode(
536
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
537
+ )
538
+ else :
539
+ p_tokens = None
540
+ return p_ctc, wav_lens, p_tokens
541
+
542
+ def compute_objectives(self, predictions, batch, stage):
543
+ """Computes the loss (CTC) given predictions and targets."""
544
+
545
+ p_ctc, wav_lens , predicted_tokens= predictions
546
+
547
+ ids = batch.id
548
+ tokens, tokens_lens = batch.tokens
549
+
550
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
551
+
552
+
553
+ if stage == sb.Stage.VALID:
554
+ predicted_words = [
555
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
556
+ for utt_seq in predicted_tokens
557
+ ]
558
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
559
+ self.wer_metric.append(ids, predicted_words, target_words)
560
+ self.cer_metric.append(ids, predicted_words, target_words)
561
+ if stage ==sb.Stage.TEST :
562
+ if self.hparams.language_modelling:
563
+ predicted_words = []
564
+ for logs in p_ctc:
565
+ text = decoder.decode(logs.detach().cpu().numpy())
566
+ predicted_words.append(text.split(" "))
567
+ else :
568
+ predicted_words = [
569
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
570
+ for utt_seq in predicted_tokens
571
+ ]
572
+
573
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
574
+ self.wer_metric.append(ids, predicted_words, target_words)
575
+ self.cer_metric.append(ids, predicted_words, target_words)
576
+
577
+ return loss
578
+
579
+ def fit_batch(self, batch):
580
+ """Train the parameters given a single batch in input"""
581
+ should_step = self.step % self.grad_accumulation_factor == 0
582
+ # Managing automatic mixed precision
583
+ # TOFIX: CTC fine-tuning currently is unstable
584
+ # This is certainly due to CTC being done in fp16 instead of fp32
585
+ if self.auto_mix_prec:
586
+ with torch.cuda.amp.autocast():
587
+ with self.no_sync():
588
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
589
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
590
+ with self.no_sync(not should_step):
591
+ self.scaler.scale(
592
+ loss / self.grad_accumulation_factor
593
+ ).backward()
594
+ if should_step:
595
+
596
+
597
+ self.scaler.unscale_(self.model_optimizer)
598
+ if self.check_gradients(loss):
599
+ self.scaler.step(self.model_optimizer)
600
+ self.scaler.update()
601
+ self.zero_grad()
602
+ self.optimizer_step += 1
603
+ else:
604
+ # This is mandatory because HF models have a weird behavior with DDP
605
+ # on the forward pass
606
+ with self.no_sync():
607
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
608
+
609
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
610
+
611
+ with self.no_sync(not should_step):
612
+ (loss / self.grad_accumulation_factor).backward()
613
+ if should_step:
614
+ if self.check_gradients(loss):
615
+ self.model_optimizer.step()
616
+ self.zero_grad()
617
+ self.optimizer_step += 1
618
+
619
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
620
+ return loss.detach().cpu()
621
+
622
+ def evaluate_batch(self, batch, stage):
623
+ """Computations needed for validation/test batches"""
624
+ predictions = self.compute_forward(batch, stage=stage)
625
+ with torch.no_grad():
626
+ loss = self.compute_objectives(predictions, batch, stage=stage)
627
+ return loss.detach()
628
+
629
+ def on_stage_start(self, stage, epoch):
630
+ """Gets called at the beginning of each epoch"""
631
+ if stage != sb.Stage.TRAIN:
632
+ self.cer_metric = self.hparams.cer_computer()
633
+ self.wer_metric = self.hparams.error_rate_computer()
634
+
635
+ def on_stage_end(self, stage, stage_loss, epoch):
636
+ """Gets called at the end of an epoch."""
637
+ # Compute/store important stats
638
+ stage_stats = {"loss": stage_loss}
639
+ if stage == sb.Stage.TRAIN:
640
+ self.train_stats = stage_stats
641
+ else:
642
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
643
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
644
+
645
+ # Perform end-of-iteration things, like annealing, logging, etc.
646
+ if stage == sb.Stage.VALID:
647
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
648
+ stage_stats["loss"]
649
+ )
650
+ sb.nnet.schedulers.update_learning_rate(
651
+ self.model_optimizer, new_lr_model
652
+ )
653
+ self.hparams.train_logger.log_stats(
654
+ stats_meta={
655
+ "epoch": epoch,
656
+ "lr_model": old_lr_model,
657
+ },
658
+ train_stats=self.train_stats,
659
+ valid_stats=stage_stats,
660
+ )
661
+ self.checkpointer.save_and_keep_only(
662
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
663
+ )
664
+ elif stage == sb.Stage.TEST:
665
+ self.hparams.train_logger.log_stats(
666
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
667
+ test_stats=stage_stats,
668
+ )
669
+ with open(self.hparams.wer_file, "w") as w:
670
+ self.wer_metric.write_stats(w)
671
+
672
+ def init_optimizers(self):
673
+
674
+ self.model_optimizer = self.hparams.model_opt_class(
675
+ self.hparams.model.parameters()
676
+ )
677
+
678
+ if self.checkpointer is not None:
679
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
680
+
681
+ def zero_grad(self, set_to_none=False):
682
+
683
+ self.model_optimizer.zero_grad(set_to_none)
684
+
685
+
686
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
687
+
688
+ # If distributed_launch=True then
689
+ # create ddp_group with the right communication protocol
690
+ sb.utils.distributed.ddp_init_group(run_opts)
691
+
692
+ with open(hparams_file) as fin:
693
+ hparams = load_hyperpyyaml(fin, overrides)
694
+
695
+ # Create experiment directory
696
+ sb.create_experiment_directory(
697
+ experiment_directory=hparams["output_folder"],
698
+ hyperparams_to_save=hparams_file,
699
+ overrides=overrides,
700
+ )
701
+ def read_labels_file(labels_file):
702
+ with open(labels_file, "r",encoding="utf-8") as lf:
703
+ lines = lf.read().splitlines()
704
+ division = "==="
705
+ numbers = {}
706
+ for line in lines :
707
+ if division in line :
708
+ break
709
+ string, number = line.split("=>")
710
+ number = int(number)
711
+ string = string[1:-2]
712
+ numbers[number] = string
713
+ return [numbers[x] for x in range(len(numbers))]
714
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
715
+ hparams
716
+ )
717
+
718
+
719
+ labels = read_labels_file(os.path.join(hparams["save_folder"], "label_encoder.txt"))
720
+ labels = [""] + labels[1:-1] + ["1"]
721
+ if hparams["language_modelling"]:
722
+ decoder = build_ctcdecoder(
723
+ labels,
724
+ kenlm_model_path=hparams["ngram_lm_path"], # either .arpa or .bin file
725
+ alpha=0.5, # tuned on a val set
726
+ beta=1, # tuned on a val set
727
+ )
728
+
729
+
730
+
731
+
732
+ mixer = Mixer(
733
+ modules=hparams["modules"],
734
+ hparams=hparams,
735
+ run_opts=run_opts,
736
+ checkpointer=hparams["checkpointer"],
737
+ )
738
+ mixer.tokenizer = label_encoder
739
+
740
+
741
+ mixer.fit(
742
+ mixer.hparams.epoch_counter,
743
+ train_data,
744
+ valid_data,
745
+ train_loader_kwargs=hparams["dataloader_options"],
746
+ valid_loader_kwargs=hparams["test_dataloader_options"],
747
+ )
748
+ print(test_datasets.keys())
749
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
750
+ mixer.hparams.wer_file = os.path.join(
751
+ hparams["output_folder"], "wer_{}.txt".format(k)
752
+ )
753
+ mixer.evaluate(
754
+ test_datasets[k], test_loader_kwargs=hparams["test_dataloader_options"]
755
+ )
756
+
EnglishCV/results/wav2vec2_ctc_en/1234/hyperparams.yaml ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated 2023-09-06 from:
2
+ # /gpfsdswork/projects/rech/nou/uzn19yk/final_forke/speechbrain-3/recipes/CommonVoice/ASR/CTC/hparams/train_en_with_wav2vec.yaml
3
+ # yamllint disable
4
+ # ################################
5
+ # Model: wav2vec2 + DNN + CTC
6
+ # Augmentation: SpecAugment
7
+ # Authors: Titouan Parcollet 2021
8
+ # ################################
9
+
10
+ # Seed needs to be set at top of yaml, before objects with parameters are made
11
+ seed: 1234
12
+ __set_seed: !!python/object/apply:torch.manual_seed [1234]
13
+ output_folder: results/wav2vec2_ctc_en/1234
14
+ wer_file: results/wav2vec2_ctc_en/1234/wer.txt
15
+ save_folder: results/wav2vec2_ctc_en/1234/save
16
+ train_log: results/wav2vec2_ctc_en/1234/train_log.txt
17
+
18
+ # URL for the biggest Fairseq english wav2vec2 model.
19
+ wav2vec2_hub: facebook/wav2vec2-large-lv60
20
+ wav2vec2_folder: results/wav2vec2_ctc_en/1234/save/wav2vec2_checkpoint
21
+
22
+ # Data files
23
+ data_folder:
24
+ /gpfsscratch/rech/nou/uzn19yk/download/cv-corpus-12.0-2022-12-07/en/cv-corpus-12.0-2022-12-07/en # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
25
+ train_tsv_file:
26
+ /gpfsscratch/rech/nou/uzn19yk/download/cv-corpus-12.0-2022-12-07/en/cv-corpus-12.0-2022-12-07/en/train.tsv # Standard CommonVoice .tsv files
27
+ dev_tsv_file:
28
+ /gpfsscratch/rech/nou/uzn19yk/download/cv-corpus-12.0-2022-12-07/en/cv-corpus-12.0-2022-12-07/en/dev.tsv # Standard CommonVoice .tsv files
29
+ test_tsv_file:
30
+ /gpfsscratch/rech/nou/uzn19yk/download/cv-corpus-12.0-2022-12-07/en/cv-corpus-12.0-2022-12-07/en/test.tsv # Standard CommonVoice .tsv files
31
+ accented_letters: false
32
+ language: en # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
33
+ train_csv: results/wav2vec2_ctc_en/1234/save/train.csv
34
+ valid_csv: results/wav2vec2_ctc_en/1234/save/dev.csv
35
+ test_csv: results/wav2vec2_ctc_en/1234/save/test.csv
36
+ skip_prep: false # Skip data preparation
37
+
38
+ # We remove utterance slonger than 10s in the train/dev/test sets as
39
+ # longer sentences certainly correspond to "open microphones".
40
+ avoid_if_longer_than: 10.0
41
+
42
+ # Training parameters
43
+ number_of_epochs: 10
44
+ lr: 1.0
45
+ lr_wav2vec: 0.0001
46
+ sorting: ascending
47
+ auto_mix_prec: false
48
+ sample_rate: 16000
49
+ ckpt_interval_minutes: 30 # save checkpoint every N min
50
+
51
+ # With data_parallel batch_size is split into N jobs
52
+ # With DDP batch_size is multiplied by N jobs
53
+ # Must be 8 per GPU to fit 32GB of VRAM
54
+ batch_size: 8
55
+ test_batch_size: 4
56
+
57
+ dataloader_options:
58
+ batch_size: 8
59
+ num_workers: 6
60
+ test_dataloader_options:
61
+ batch_size: 4
62
+ num_workers: 6
63
+
64
+ # BPE parameters
65
+ token_type: char # ["unigram", "bpe", "char"]
66
+ character_coverage: 1.0
67
+
68
+ # Model parameters
69
+ # activation: !name:torch.nn.LeakyReLU
70
+ wav2vec_output_dim: 1024
71
+ dnn_neurons: 1024
72
+ freeze_wav2vec: false
73
+ freeze_feature_extractor: true
74
+ dropout: 0.15
75
+ warmup_steps: 500
76
+
77
+ # Outputs
78
+ output_neurons: 29 # BPE size, index(blank/eos/bos) = 0
79
+
80
+ # Decoding parameters
81
+ # Be sure that the bos and eos index match with the BPEs ones
82
+ blank_index: 0
83
+ bos_index: 1
84
+ eos_index: 2
85
+
86
+ #
87
+ # Functions and classes
88
+ #
89
+ epoch_counter: &id007 !new:speechbrain.utils.epoch_loop.EpochCounter
90
+
91
+ limit: 10
92
+
93
+ augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
94
+ sample_rate: 16000
95
+ speeds: [95, 100, 105]
96
+
97
+ enc: &id002 !new:speechbrain.nnet.containers.Sequential
98
+ input_shape: [null, null, 1024]
99
+ linear1: !name:speechbrain.nnet.linear.Linear
100
+ n_neurons: 1024
101
+ bias: true
102
+ bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
103
+ activation: !new:torch.nn.LeakyReLU
104
+ drop: !new:torch.nn.Dropout
105
+ p: 0.15
106
+ linear2: !name:speechbrain.nnet.linear.Linear
107
+ n_neurons: 1024
108
+ bias: true
109
+ bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
110
+ activation2: !new:torch.nn.LeakyReLU
111
+ drop2: !new:torch.nn.Dropout
112
+ p: 0.15
113
+ linear3: !name:speechbrain.nnet.linear.Linear
114
+ n_neurons: 1024
115
+ bias: true
116
+ bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
117
+ activation3: !new:torch.nn.LeakyReLU
118
+
119
+ wav2vec2: &id001 !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
120
+ source: /gpfsscratch/rech/nou/uzn19yk/wav2vec2-large-lv60/
121
+ output_norm: true
122
+ freeze: false
123
+ freeze_feature_extractor: true
124
+ save_path: results/wav2vec2_ctc_en/1234/save/wav2vec2_checkpoint
125
+
126
+ #####
127
+ # Uncomment this block if you prefer to use a Fairseq pretrained model instead
128
+ # of a HuggingFace one. Here, we provide an URL that is obtained from the
129
+ # Fairseq github for the multilingual XLSR.
130
+ #
131
+ #wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt
132
+ #wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
133
+ # pretrained_path: !ref <wav2vec2_url>
134
+ # output_norm: True
135
+ # freeze: False
136
+ # save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
137
+ #####
138
+
139
+ ctc_lin: &id003 !new:speechbrain.nnet.linear.Linear
140
+
141
+ input_size: 1024
142
+ n_neurons: 29
143
+
144
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
145
+ apply_log: true
146
+
147
+ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
148
+ blank_index: 0
149
+
150
+ modules:
151
+ wav2vec2: *id001
152
+ enc: *id002
153
+ ctc_lin: *id003
154
+ model: &id004 !new:torch.nn.ModuleList
155
+ - [*id002, *id003]
156
+ model_opt_class: !name:torch.optim.Adadelta
157
+ lr: 1.0
158
+ rho: 0.95
159
+ eps: 1.e-8
160
+
161
+ wav2vec_opt_class: !name:torch.optim.Adam
162
+ lr: 0.0001
163
+
164
+ lr_annealing_model: &id005 !new:speechbrain.nnet.schedulers.NewBobScheduler
165
+ initial_value: 1.0
166
+ improvement_threshold: 0.0025
167
+ annealing_factor: 0.8
168
+ patient: 0
169
+
170
+ lr_annealing_wav2vec: &id006 !new:speechbrain.nnet.schedulers.NewBobScheduler
171
+ initial_value: 0.0001
172
+ improvement_threshold: 0.0025
173
+ annealing_factor: 0.9
174
+ patient: 0
175
+
176
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
177
+ checkpoints_dir: results/wav2vec2_ctc_en/1234/save
178
+ recoverables:
179
+ wav2vec2: *id001
180
+ model: *id004
181
+ scheduler_model: *id005
182
+ scheduler_wav2vec: *id006
183
+ counter: *id007
184
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
185
+ save_file: results/wav2vec2_ctc_en/1234/train_log.txt
186
+
187
+ error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
188
+
189
+ cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
190
+ split_tokens: true
EnglishCV/results/wav2vec2_ctc_en/1234/save/29_char.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee4214a3ebba9461ca02ca61220a2338412bbf9ef5a5982f2bc40740c4ab91a8
3
+ size 238011
EnglishCV/results/wav2vec2_ctc_en/1234/save/29_char.vocab ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <unk> 0
2
+ ▁ -1.786
3
+ E -2.27261
4
+ A -2.6326
5
+ T -2.64317
6
+ I -2.76341
7
+ S -2.81519
8
+ O -2.8189
9
+ N -2.83568
10
+ R -2.87568
11
+ H -3.22802
12
+ L -3.30075
13
+ D -3.43047
14
+ C -3.58554
15
+ U -3.84445
16
+ M -3.84732
17
+ F -4.07023
18
+ P -4.09107
19
+ G -4.16259
20
+ W -4.25412
21
+ Y -4.30147
22
+ B -4.36224
23
+ V -4.71267
24
+ K -5.1744
25
+ X -6.46672
26
+ J -6.5246
27
+ Z -6.95828
28
+ Q -7.12388
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/CKPT.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # yamllint disable
2
+ WER: 18.234978071545488
3
+ end-of-epoch: true
4
+ unixtime: 1694033791.9455216
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/brain.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06617abf655f8550362b963062fc2a57bd819826ab70e63701676ea09d23618d
3
+ size 51
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/counter.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4735e3a265e16eee03f59718b9b5d03019c07d8b6c51f90da3a666eec13ab35
3
+ size 1
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/dataloader-TRAIN.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f21c20a479fcc07663ec4255ad1c85466afb791f514f8f3baa174bd56edca2d4
3
+ size 6
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:422a1d7a30720e846d2cb79ff510832fe96c1495f559f08fb37bdd118269ea7b
3
+ size 12769326
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/modelopt.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65cda77e4403deb7c8cee3052ac687bfc3bf6e68264dcb0e297e8f88bccf0d66
3
+ size 25485359
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/scheduler_model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9c36e38dd81971c68387a9f921cf0d61adad21f5b3f6420b6f3015b0f9d20df
3
+ size 511
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/scheduler_wav2vec.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0293788921aad16c6e904d7ec0b7dba2dd4778fa3b7f1bfa04276b3965599999
3
+ size 515
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/wav2vec2.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f7073aa70c88927f11cff4f2ba63a026c8ff6c119837391d84013feb229ad3e
3
+ size 1261924189
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/wav2vec_opt.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42691a96ebaba3dd3baf7e2521763db7f79b37a6bde9b0ea9d1adc2cac5bdf5e
3
+ size 2490156402
EnglishCV/results/wav2vec2_ctc_en/1234/train_with_wav2vec.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import torch
4
+ import logging
5
+ import speechbrain as sb
6
+ import torchaudio
7
+ from hyperpyyaml import load_hyperpyyaml
8
+ from speechbrain.tokenizers.SentencePiece import SentencePiece
9
+ from speechbrain.utils.data_utils import undo_padding
10
+ from speechbrain.utils.distributed import run_on_main
11
+
12
+ """Recipe for training a sequence-to-sequence ASR system with CommonVoice.
13
+ The system employs a wav2vec2 encoder and a CTC decoder.
14
+ Decoding is performed with greedy decoding (will be extended to beam search).
15
+
16
+ To run this recipe, do the following:
17
+ > python train_with_wav2vec2.py hparams/train_with_wav2vec2.yaml
18
+
19
+ With the default hyperparameters, the system employs a pretrained wav2vec2 encoder.
20
+ The wav2vec2 model is pretrained following the model given in the hprams file.
21
+ It may be dependent on the language.
22
+
23
+ The neural network is trained with CTC on sub-word units estimated with
24
+ Byte Pairwise Encoding (BPE).
25
+
26
+ The experiment file is flexible enough to support a large variety of
27
+ different systems. By properly changing the parameter files, you can try
28
+ different encoders, decoders, tokens (e.g, characters instead of BPE),
29
+ training languages (all CommonVoice languages), and many
30
+ other possible variations.
31
+
32
+ Authors
33
+ * Titouan Parcollet 2021
34
+ """
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ # Define training procedure
40
+ class ASR(sb.core.Brain):
41
+ def compute_forward(self, batch, stage):
42
+ """Forward computations from the waveform batches to the output probabilities."""
43
+
44
+ batch = batch.to(self.device)
45
+ wavs, wav_lens = batch.sig
46
+ tokens_bos, _ = batch.tokens_bos
47
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
48
+
49
+ if stage == sb.Stage.TRAIN:
50
+ if hasattr(self.hparams, "augmentation"):
51
+ wavs = self.hparams.augmentation(wavs, wav_lens)
52
+
53
+ # Forward pass
54
+ feats = self.modules.wav2vec2(wavs, wav_lens)
55
+ x = self.modules.enc(feats)
56
+ logits = self.modules.ctc_lin(x)
57
+ p_ctc = self.hparams.log_softmax(logits)
58
+
59
+ return p_ctc, wav_lens
60
+
61
+ def compute_objectives(self, predictions, batch, stage):
62
+ """Computes the loss (CTC) given predictions and targets."""
63
+
64
+ p_ctc, wav_lens = predictions
65
+
66
+ ids = batch.id
67
+ tokens_eos, tokens_eos_lens = batch.tokens_eos
68
+ tokens, tokens_lens = batch.tokens
69
+
70
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
71
+
72
+ if stage != sb.Stage.TRAIN:
73
+ # Decode token terms to words
74
+ sequence = sb.decoders.ctc_greedy_decode(
75
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
76
+ )
77
+
78
+ predicted_words = self.tokenizer(sequence, task="decode_from_list")
79
+
80
+ # Convert indices to words
81
+ target_words = undo_padding(tokens, tokens_lens)
82
+ target_words = self.tokenizer(target_words, task="decode_from_list")
83
+
84
+ self.wer_metric.append(ids, predicted_words, target_words)
85
+ self.cer_metric.append(ids, predicted_words, target_words)
86
+
87
+ return loss
88
+
89
+ def fit_batch(self, batch):
90
+ """Train the parameters given a single batch in input"""
91
+ should_step = self.step % self.grad_accumulation_factor == 0
92
+ # Managing automatic mixed precision
93
+ # TOFIX: CTC fine-tuning currently is unstable
94
+ # This is certainly due to CTC being done in fp16 instead of fp32
95
+ if self.auto_mix_prec:
96
+ with torch.cuda.amp.autocast():
97
+ with self.no_sync():
98
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
99
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
100
+ with self.no_sync(not should_step):
101
+ self.scaler.scale(
102
+ loss / self.grad_accumulation_factor
103
+ ).backward()
104
+ if should_step:
105
+
106
+ if not self.hparams.wav2vec2.freeze:
107
+ self.scaler.unscale_(self.wav2vec_optimizer)
108
+ self.scaler.unscale_(self.model_optimizer)
109
+ if self.check_gradients(loss):
110
+ if not self.hparams.wav2vec2.freeze:
111
+ if self.optimizer_step >= self.hparams.warmup_steps:
112
+ self.scaler.step(self.wav2vec_optimizer)
113
+ self.scaler.step(self.model_optimizer)
114
+ self.scaler.update()
115
+ self.zero_grad()
116
+ self.optimizer_step += 1
117
+ else:
118
+ # This is mandatory because HF models have a weird behavior with DDP
119
+ # on the forward pass
120
+ with self.no_sync():
121
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
122
+
123
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
124
+
125
+ with self.no_sync(not should_step):
126
+ (loss / self.grad_accumulation_factor).backward()
127
+ if should_step:
128
+ if self.check_gradients(loss):
129
+ if not self.hparams.wav2vec2.freeze:
130
+ if self.optimizer_step >= self.hparams.warmup_steps:
131
+ self.wav2vec_optimizer.step()
132
+ self.model_optimizer.step()
133
+ self.zero_grad()
134
+ self.optimizer_step += 1
135
+
136
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
137
+ return loss.detach().cpu()
138
+
139
+ def evaluate_batch(self, batch, stage):
140
+ """Computations needed for validation/test batches"""
141
+ predictions = self.compute_forward(batch, stage=stage)
142
+ with torch.no_grad():
143
+ loss = self.compute_objectives(predictions, batch, stage=stage)
144
+ return loss.detach()
145
+
146
+ def on_stage_start(self, stage, epoch):
147
+ """Gets called at the beginning of each epoch"""
148
+ if stage != sb.Stage.TRAIN:
149
+ self.cer_metric = self.hparams.cer_computer()
150
+ self.wer_metric = self.hparams.error_rate_computer()
151
+
152
+ def on_stage_end(self, stage, stage_loss, epoch):
153
+ """Gets called at the end of an epoch."""
154
+ # Compute/store important stats
155
+ stage_stats = {"loss": stage_loss}
156
+ if stage == sb.Stage.TRAIN:
157
+ self.train_stats = stage_stats
158
+ else:
159
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
160
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
161
+
162
+ # Perform end-of-iteration things, like annealing, logging, etc.
163
+ if stage == sb.Stage.VALID:
164
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
165
+ stage_stats["loss"]
166
+ )
167
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
168
+ stage_stats["loss"]
169
+ )
170
+ sb.nnet.schedulers.update_learning_rate(
171
+ self.model_optimizer, new_lr_model
172
+ )
173
+ if not self.hparams.wav2vec2.freeze:
174
+ sb.nnet.schedulers.update_learning_rate(
175
+ self.wav2vec_optimizer, new_lr_wav2vec
176
+ )
177
+ self.hparams.train_logger.log_stats(
178
+ stats_meta={
179
+ "epoch": epoch,
180
+ "lr_model": old_lr_model,
181
+ "lr_wav2vec": old_lr_wav2vec,
182
+ },
183
+ train_stats=self.train_stats,
184
+ valid_stats=stage_stats,
185
+ )
186
+ self.checkpointer.save_and_keep_only(
187
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
188
+ )
189
+ elif stage == sb.Stage.TEST:
190
+ self.hparams.train_logger.log_stats(
191
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
192
+ test_stats=stage_stats,
193
+ )
194
+ with open(self.hparams.wer_file, "w") as w:
195
+ self.wer_metric.write_stats(w)
196
+
197
+ def init_optimizers(self):
198
+ "Initializes the wav2vec2 optimizer and model optimizer"
199
+
200
+ # If the wav2vec encoder is unfrozen, we create the optimizer
201
+ if not self.hparams.wav2vec2.freeze:
202
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
203
+ self.modules.wav2vec2.parameters()
204
+ )
205
+ if self.checkpointer is not None:
206
+ self.checkpointer.add_recoverable(
207
+ "wav2vec_opt", self.wav2vec_optimizer
208
+ )
209
+
210
+ self.model_optimizer = self.hparams.model_opt_class(
211
+ self.hparams.model.parameters()
212
+ )
213
+
214
+ if self.checkpointer is not None:
215
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
216
+
217
+ def zero_grad(self, set_to_none=False):
218
+ if not self.hparams.wav2vec2.freeze:
219
+ self.wav2vec_optimizer.zero_grad(set_to_none)
220
+ self.model_optimizer.zero_grad(set_to_none)
221
+
222
+
223
+ # Define custom data procedure
224
+ def dataio_prepare(hparams, tokenizer):
225
+ """This function prepares the datasets to be used in the brain class.
226
+ It also defines the data processing pipeline through user-defined functions."""
227
+
228
+ # 1. Define datasets
229
+ data_folder = hparams["data_folder"]
230
+
231
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
232
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
233
+ )
234
+
235
+ if hparams["sorting"] == "ascending":
236
+ # we sort training data to speed up training and get better results.
237
+ train_data = train_data.filtered_sorted(
238
+ sort_key="duration",
239
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
240
+ )
241
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
242
+ hparams["dataloader_options"]["shuffle"] = False
243
+
244
+ elif hparams["sorting"] == "descending":
245
+ train_data = train_data.filtered_sorted(
246
+ sort_key="duration",
247
+ reverse=True,
248
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
249
+ )
250
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
251
+ hparams["dataloader_options"]["shuffle"] = False
252
+
253
+ elif hparams["sorting"] == "random":
254
+ pass
255
+
256
+ else:
257
+ raise NotImplementedError(
258
+ "sorting must be random, ascending or descending"
259
+ )
260
+
261
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
262
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
263
+ )
264
+ # We also sort the validation data so it is faster to validate
265
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
266
+
267
+ test_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
268
+ csv_path=hparams["test_csv"], replacements={"data_root": data_folder},
269
+ )
270
+
271
+ # We also sort the validation data so it is faster to validate
272
+ test_data = test_data.filtered_sorted(sort_key="duration")
273
+
274
+ datasets = [train_data, valid_data, test_data]
275
+
276
+ # 2. Define audio pipeline:
277
+ @sb.utils.data_pipeline.takes("wav")
278
+ @sb.utils.data_pipeline.provides("sig")
279
+ def audio_pipeline(wav):
280
+ info = torchaudio.info(wav)
281
+ sig = sb.dataio.dataio.read_audio(wav)
282
+ resampled = torchaudio.transforms.Resample(
283
+ info.sample_rate, hparams["sample_rate"],
284
+ )(sig)
285
+ return resampled
286
+
287
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
288
+
289
+ # 3. Define text pipeline:
290
+ @sb.utils.data_pipeline.takes("wrd")
291
+ @sb.utils.data_pipeline.provides(
292
+ "tokens_list", "tokens_bos", "tokens_eos", "tokens"
293
+ )
294
+ def text_pipeline(wrd):
295
+ tokens_list = tokenizer.sp.encode_as_ids(wrd)
296
+ yield tokens_list
297
+ tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
298
+ yield tokens_bos
299
+ tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
300
+ yield tokens_eos
301
+ tokens = torch.LongTensor(tokens_list)
302
+ yield tokens
303
+
304
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
305
+
306
+ # 4. Set output:
307
+ sb.dataio.dataset.set_output_keys(
308
+ datasets, ["id", "sig", "tokens_bos", "tokens_eos", "tokens"],
309
+ )
310
+ return train_data, valid_data, test_data
311
+
312
+
313
+ if __name__ == "__main__":
314
+
315
+ # Load hyperparameters file with command-line overrides
316
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
317
+ with open(hparams_file) as fin:
318
+ hparams = load_hyperpyyaml(fin, overrides)
319
+
320
+ # If --distributed_launch then
321
+ # create ddp_group with the right communication protocol
322
+ sb.utils.distributed.ddp_init_group(run_opts)
323
+
324
+ # Dataset preparation (parsing CommonVoice)
325
+ from common_voice_prepare import prepare_common_voice # noqa
326
+
327
+ # Create experiment directory
328
+ sb.create_experiment_directory(
329
+ experiment_directory=hparams["output_folder"],
330
+ hyperparams_to_save=hparams_file,
331
+ overrides=overrides,
332
+ )
333
+
334
+ # Due to DDP, we do the preparation ONLY on the main python process
335
+ run_on_main(
336
+ prepare_common_voice,
337
+ kwargs={
338
+ "data_folder": hparams["data_folder"],
339
+ "save_folder": hparams["save_folder"],
340
+ "train_tsv_file": hparams["train_tsv_file"],
341
+ "dev_tsv_file": hparams["dev_tsv_file"],
342
+ "test_tsv_file": hparams["test_tsv_file"],
343
+ "accented_letters": hparams["accented_letters"],
344
+ "language": hparams["language"],
345
+ "skip_prep": hparams["skip_prep"],
346
+ },
347
+ )
348
+
349
+ # Defining tokenizer and loading it
350
+ tokenizer = SentencePiece(
351
+ model_dir=hparams["save_folder"],
352
+ vocab_size=hparams["output_neurons"],
353
+ annotation_train=hparams["train_csv"],
354
+ annotation_read="wrd",
355
+ model_type=hparams["token_type"],
356
+ character_coverage=hparams["character_coverage"],
357
+ )
358
+
359
+ # Create the datasets objects as well as tokenization and encoding :-D
360
+ train_data, valid_data, test_data = dataio_prepare(hparams, tokenizer)
361
+
362
+ # Trainer initialization
363
+ asr_brain = ASR(
364
+ modules=hparams["modules"],
365
+ hparams=hparams,
366
+ run_opts=run_opts,
367
+ checkpointer=hparams["checkpointer"],
368
+ )
369
+
370
+ # Adding objects to trainer.
371
+ asr_brain.tokenizer = tokenizer
372
+
373
+ # Training
374
+ asr_brain.fit(
375
+ asr_brain.hparams.epoch_counter,
376
+ train_data,
377
+ valid_data,
378
+ train_loader_kwargs=hparams["dataloader_options"],
379
+ valid_loader_kwargs=hparams["test_dataloader_options"],
380
+ )
381
+
382
+ # Test
383
+ asr_brain.hparams.wer_file = hparams["output_folder"] + "/wer_test.txt"
384
+ asr_brain.evaluate(
385
+ test_data,
386
+ min_key="WER",
387
+ test_loader_kwargs=hparams["test_dataloader_options"],
388
+ )
EnglishCV/train_en_with_wav2vec.yaml ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ################################
2
+ # Model: wav2vec2 + DNN + CTC
3
+ # Augmentation: SpecAugment
4
+ # Authors: Titouan Parcollet 2021
5
+ # ################################
6
+
7
+ # Seed needs to be set at top of yaml, before objects with parameters are made
8
+ seed: 1234
9
+ __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
10
+ output_folder: !ref EnglishCV/results/wav2vec2_ctc_en/<seed>
11
+ wer_file: !ref <output_folder>/wer.txt
12
+ save_folder: !ref <output_folder>/save
13
+ train_log: !ref <output_folder>/train_log.txt
14
+
15
+ # URL for the biggest Fairseq english wav2vec2 model.
16
+ wav2vec2_hub: wav2vec2-large-lv60/
17
+ wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
18
+
19
+ # Data files
20
+ data_folder: /gpfsscratch/rech/nou/uzn19yk/download/cv-corpus-12.0-2022-12-07/en/cv-corpus-12.0-2022-12-07/en # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
21
+ train_tsv_file: !ref <data_folder>/train.tsv # Standard CommonVoice .tsv files
22
+ dev_tsv_file: !ref <data_folder>/dev.tsv # Standard CommonVoice .tsv files
23
+ test_tsv_file: !ref <data_folder>/test.tsv # Standard CommonVoice .tsv files
24
+ accented_letters: False
25
+ language: en # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
26
+ train_csv: !ref <save_folder>/train.csv
27
+ valid_csv: !ref <save_folder>/dev.csv
28
+ test_csv: !ref <save_folder>/test.csv
29
+ skip_prep: False # Skip data preparation
30
+
31
+ # We remove utterance slonger than 10s in the train/dev/test sets as
32
+ # longer sentences certainly correspond to "open microphones".
33
+ avoid_if_longer_than: 10.0
34
+
35
+ # Training parameters
36
+ number_of_epochs: 10
37
+ lr: 1.0
38
+ lr_wav2vec: 0.0001
39
+ sorting: ascending
40
+ auto_mix_prec: False
41
+ sample_rate: 16000
42
+ ckpt_interval_minutes: 30 # save checkpoint every N min
43
+
44
+ # With data_parallel batch_size is split into N jobs
45
+ # With DDP batch_size is multiplied by N jobs
46
+ # Must be 8 per GPU to fit 32GB of VRAM
47
+ batch_size: 8
48
+ test_batch_size: 4
49
+
50
+ dataloader_options:
51
+ batch_size: !ref <batch_size>
52
+ num_workers: 6
53
+ test_dataloader_options:
54
+ batch_size: !ref <test_batch_size>
55
+ num_workers: 6
56
+
57
+ # BPE parameters
58
+ token_type: char # ["unigram", "bpe", "char"]
59
+ character_coverage: 1.0
60
+
61
+ # Model parameters
62
+ # activation: !name:torch.nn.LeakyReLU
63
+ wav2vec_output_dim: 1024
64
+ dnn_neurons: 1024
65
+ freeze_wav2vec: False
66
+ freeze_feature_extractor: True
67
+ dropout: 0.15
68
+ warmup_steps: 500
69
+
70
+ # Outputs
71
+ output_neurons: 29 # BPE size, index(blank/eos/bos) = 0
72
+
73
+ # Decoding parameters
74
+ # Be sure that the bos and eos index match with the BPEs ones
75
+ blank_index: 0
76
+ bos_index: 1
77
+ eos_index: 2
78
+
79
+ #
80
+ # Functions and classes
81
+ #
82
+ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
83
+ limit: !ref <number_of_epochs>
84
+
85
+ augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
86
+ sample_rate: !ref <sample_rate>
87
+ speeds: [95, 100, 105]
88
+
89
+ enc: !new:speechbrain.nnet.containers.Sequential
90
+ input_shape: [null, null, !ref <wav2vec_output_dim>]
91
+ linear1: !name:speechbrain.nnet.linear.Linear
92
+ n_neurons: !ref <dnn_neurons>
93
+ bias: True
94
+ bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
95
+ activation: !new:torch.nn.LeakyReLU
96
+ drop: !new:torch.nn.Dropout
97
+ p: !ref <dropout>
98
+ linear2: !name:speechbrain.nnet.linear.Linear
99
+ n_neurons: !ref <dnn_neurons>
100
+ bias: True
101
+ bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
102
+ activation2: !new:torch.nn.LeakyReLU
103
+ drop2: !new:torch.nn.Dropout
104
+ p: !ref <dropout>
105
+ linear3: !name:speechbrain.nnet.linear.Linear
106
+ n_neurons: !ref <dnn_neurons>
107
+ bias: True
108
+ bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
109
+ activation3: !new:torch.nn.LeakyReLU
110
+
111
+ wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
112
+ source: wav2vec2-large-lv60/
113
+ output_norm: True
114
+ freeze: !ref <freeze_wav2vec>
115
+ freeze_feature_extractor: !ref <freeze_feature_extractor>
116
+ save_path: !ref <wav2vec2_folder>
117
+
118
+ #####
119
+ # Uncomment this block if you prefer to use a Fairseq pretrained model instead
120
+ # of a HuggingFace one. Here, we provide an URL that is obtained from the
121
+ # Fairseq github for the multilingual XLSR.
122
+ #
123
+ #wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt
124
+ #wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
125
+ # pretrained_path: !ref <wav2vec2_url>
126
+ # output_norm: True
127
+ # freeze: False
128
+ # save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
129
+ #####
130
+
131
+ ctc_lin: !new:speechbrain.nnet.linear.Linear
132
+ input_size: !ref <dnn_neurons>
133
+ n_neurons: !ref <output_neurons>
134
+
135
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
136
+ apply_log: True
137
+
138
+ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
139
+ blank_index: !ref <blank_index>
140
+
141
+ modules:
142
+ wav2vec2: !ref <wav2vec2>
143
+ enc: !ref <enc>
144
+ ctc_lin: !ref <ctc_lin>
145
+
146
+ model: !new:torch.nn.ModuleList
147
+ - [!ref <enc>, !ref <ctc_lin>]
148
+
149
+ model_opt_class: !name:torch.optim.Adadelta
150
+ lr: !ref <lr>
151
+ rho: 0.95
152
+ eps: 1.e-8
153
+
154
+ wav2vec_opt_class: !name:torch.optim.Adam
155
+ lr: !ref <lr_wav2vec>
156
+
157
+ lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
158
+ initial_value: !ref <lr>
159
+ improvement_threshold: 0.0025
160
+ annealing_factor: 0.8
161
+ patient: 0
162
+
163
+ lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
164
+ initial_value: !ref <lr_wav2vec>
165
+ improvement_threshold: 0.0025
166
+ annealing_factor: 0.9
167
+ patient: 0
168
+
169
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
170
+ checkpoints_dir: !ref <save_folder>
171
+ recoverables:
172
+ wav2vec2: !ref <wav2vec2>
173
+ model: !ref <model>
174
+ scheduler_model: !ref <lr_annealing_model>
175
+ scheduler_wav2vec: !ref <lr_annealing_wav2vec>
176
+ counter: !ref <epoch_counter>
177
+
178
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
179
+ save_file: !ref <train_log>
180
+
181
+ error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
182
+
183
+ cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
184
+ split_tokens: True
EnglishCV/train_with_wav2vec.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import torch
4
+ import logging
5
+ import speechbrain as sb
6
+ import torchaudio
7
+ from hyperpyyaml import load_hyperpyyaml
8
+ from speechbrain.tokenizers.SentencePiece import SentencePiece
9
+ from speechbrain.utils.data_utils import undo_padding
10
+ from speechbrain.utils.distributed import run_on_main
11
+
12
+ """Recipe for training a sequence-to-sequence ASR system with CommonVoice.
13
+ The system employs a wav2vec2 encoder and a CTC decoder.
14
+ Decoding is performed with greedy decoding (will be extended to beam search).
15
+
16
+ To run this recipe, do the following:
17
+ > python train_with_wav2vec2.py hparams/train_with_wav2vec2.yaml
18
+
19
+ With the default hyperparameters, the system employs a pretrained wav2vec2 encoder.
20
+ The wav2vec2 model is pretrained following the model given in the hprams file.
21
+ It may be dependent on the language.
22
+
23
+ The neural network is trained with CTC on sub-word units estimated with
24
+ Byte Pairwise Encoding (BPE).
25
+
26
+ The experiment file is flexible enough to support a large variety of
27
+ different systems. By properly changing the parameter files, you can try
28
+ different encoders, decoders, tokens (e.g, characters instead of BPE),
29
+ training languages (all CommonVoice languages), and many
30
+ other possible variations.
31
+
32
+ Authors
33
+ * Titouan Parcollet 2021
34
+ """
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ # Define training procedure
40
+ class ASRCV(sb.core.Brain):
41
+ def compute_forward(self, batch, stage):
42
+ """Forward computations from the waveform batches to the output probabilities."""
43
+
44
+ batch = batch.to(self.device)
45
+ wavs, wav_lens = batch.sig
46
+ tokens_bos, _ = batch.tokens_bos
47
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
48
+
49
+ if stage == sb.Stage.TRAIN:
50
+ if hasattr(self.hparams, "augmentation"):
51
+ wavs = self.hparams.augmentation(wavs, wav_lens)
52
+
53
+ # Forward pass
54
+ feats = self.modules.wav2vec2(wavs, wav_lens)
55
+ x = self.modules.enc(feats)
56
+ logits = self.modules.ctc_lin(x)
57
+ p_ctc = self.hparams.log_softmax(logits)
58
+
59
+ return p_ctc, wav_lens
60
+
61
+ def compute_objectives(self, predictions, batch, stage):
62
+ """Computes the loss (CTC) given predictions and targets."""
63
+
64
+ p_ctc, wav_lens = predictions
65
+
66
+ ids = batch.id
67
+ tokens_eos, tokens_eos_lens = batch.tokens_eos
68
+ tokens, tokens_lens = batch.tokens
69
+
70
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
71
+
72
+ if stage != sb.Stage.TRAIN:
73
+ # Decode token terms to words
74
+ sequence = sb.decoders.ctc_greedy_decode(
75
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
76
+ )
77
+
78
+ predicted_words = self.tokenizer(sequence, task="decode_from_list")
79
+
80
+ # Convert indices to words
81
+ target_words = undo_padding(tokens, tokens_lens)
82
+ target_words = self.tokenizer(target_words, task="decode_from_list")
83
+
84
+ self.wer_metric.append(ids, predicted_words, target_words)
85
+ self.cer_metric.append(ids, predicted_words, target_words)
86
+
87
+ return loss
88
+
89
+ def fit_batch(self, batch):
90
+ """Train the parameters given a single batch in input"""
91
+ should_step = self.step % self.grad_accumulation_factor == 0
92
+ # Managing automatic mixed precision
93
+ # TOFIX: CTC fine-tuning currently is unstable
94
+ # This is certainly due to CTC being done in fp16 instead of fp32
95
+ if self.auto_mix_prec:
96
+ with torch.cuda.amp.autocast():
97
+ with self.no_sync():
98
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
99
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
100
+ with self.no_sync(not should_step):
101
+ self.scaler.scale(
102
+ loss / self.grad_accumulation_factor
103
+ ).backward()
104
+ if should_step:
105
+
106
+ if not self.hparams.wav2vec2.freeze:
107
+ self.scaler.unscale_(self.wav2vec_optimizer)
108
+ self.scaler.unscale_(self.model_optimizer)
109
+ if self.check_gradients(loss):
110
+ if not self.hparams.wav2vec2.freeze:
111
+ if self.optimizer_step >= self.hparams.warmup_steps:
112
+ self.scaler.step(self.wav2vec_optimizer)
113
+ self.scaler.step(self.model_optimizer)
114
+ self.scaler.update()
115
+ self.zero_grad()
116
+ self.optimizer_step += 1
117
+ else:
118
+ # This is mandatory because HF models have a weird behavior with DDP
119
+ # on the forward pass
120
+ with self.no_sync():
121
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
122
+
123
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
124
+
125
+ with self.no_sync(not should_step):
126
+ (loss / self.grad_accumulation_factor).backward()
127
+ if should_step:
128
+ if self.check_gradients(loss):
129
+ if not self.hparams.wav2vec2.freeze:
130
+ if self.optimizer_step >= self.hparams.warmup_steps:
131
+ self.wav2vec_optimizer.step()
132
+ self.model_optimizer.step()
133
+ self.zero_grad()
134
+ self.optimizer_step += 1
135
+
136
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
137
+ return loss.detach().cpu()
138
+
139
+ def evaluate_batch(self, batch, stage):
140
+ """Computations needed for validation/test batches"""
141
+ predictions = self.compute_forward(batch, stage=stage)
142
+ with torch.no_grad():
143
+ loss = self.compute_objectives(predictions, batch, stage=stage)
144
+ return loss.detach()
145
+
146
+ def on_stage_start(self, stage, epoch):
147
+ """Gets called at the beginning of each epoch"""
148
+ if stage != sb.Stage.TRAIN:
149
+ self.cer_metric = self.hparams.cer_computer()
150
+ self.wer_metric = self.hparams.error_rate_computer()
151
+
152
+ def on_stage_end(self, stage, stage_loss, epoch):
153
+ """Gets called at the end of an epoch."""
154
+ # Compute/store important stats
155
+ stage_stats = {"loss": stage_loss}
156
+ if stage == sb.Stage.TRAIN:
157
+ self.train_stats = stage_stats
158
+ else:
159
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
160
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
161
+
162
+ # Perform end-of-iteration things, like annealing, logging, etc.
163
+ if stage == sb.Stage.VALID:
164
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
165
+ stage_stats["loss"]
166
+ )
167
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
168
+ stage_stats["loss"]
169
+ )
170
+ sb.nnet.schedulers.update_learning_rate(
171
+ self.model_optimizer, new_lr_model
172
+ )
173
+ if not self.hparams.wav2vec2.freeze:
174
+ sb.nnet.schedulers.update_learning_rate(
175
+ self.wav2vec_optimizer, new_lr_wav2vec
176
+ )
177
+ self.hparams.train_logger.log_stats(
178
+ stats_meta={
179
+ "epoch": epoch,
180
+ "lr_model": old_lr_model,
181
+ "lr_wav2vec": old_lr_wav2vec,
182
+ },
183
+ train_stats=self.train_stats,
184
+ valid_stats=stage_stats,
185
+ )
186
+ self.checkpointer.save_and_keep_only(
187
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
188
+ )
189
+ elif stage == sb.Stage.TEST:
190
+ self.hparams.train_logger.log_stats(
191
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
192
+ test_stats=stage_stats,
193
+ )
194
+ with open(self.hparams.wer_file, "w") as w:
195
+ self.wer_metric.write_stats(w)
196
+
197
+ def init_optimizers(self):
198
+ "Initializes the wav2vec2 optimizer and model optimizer"
199
+
200
+ # If the wav2vec encoder is unfrozen, we create the optimizer
201
+ if not self.hparams.wav2vec2.freeze:
202
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
203
+ self.modules.wav2vec2.parameters()
204
+ )
205
+ if self.checkpointer is not None:
206
+ self.checkpointer.add_recoverable(
207
+ "wav2vec_opt", self.wav2vec_optimizer
208
+ )
209
+
210
+ self.model_optimizer = self.hparams.model_opt_class(
211
+ self.hparams.model.parameters()
212
+ )
213
+
214
+ if self.checkpointer is not None:
215
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
216
+
217
+ def zero_grad(self, set_to_none=False):
218
+ if not self.hparams.wav2vec2.freeze:
219
+ self.wav2vec_optimizer.zero_grad(set_to_none)
220
+ self.model_optimizer.zero_grad(set_to_none)
221
+
222
+
223
+ # Define custom data procedure
224
+ def dataio_prepare(hparams, tokenizer):
225
+ """This function prepares the datasets to be used in the brain class.
226
+ It also defines the data processing pipeline through user-defined functions."""
227
+
228
+ # 1. Define datasets
229
+ data_folder = hparams["data_folder"]
230
+
231
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
232
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
233
+ )
234
+
235
+ if hparams["sorting"] == "ascending":
236
+ # we sort training data to speed up training and get better results.
237
+ train_data = train_data.filtered_sorted(
238
+ sort_key="duration",
239
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
240
+ )
241
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
242
+ hparams["dataloader_options"]["shuffle"] = False
243
+
244
+ elif hparams["sorting"] == "descending":
245
+ train_data = train_data.filtered_sorted(
246
+ sort_key="duration",
247
+ reverse=True,
248
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
249
+ )
250
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
251
+ hparams["dataloader_options"]["shuffle"] = False
252
+
253
+ elif hparams["sorting"] == "random":
254
+ pass
255
+
256
+ else:
257
+ raise NotImplementedError(
258
+ "sorting must be random, ascending or descending"
259
+ )
260
+
261
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
262
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
263
+ )
264
+ # We also sort the validation data so it is faster to validate
265
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
266
+
267
+ test_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
268
+ csv_path=hparams["test_csv"], replacements={"data_root": data_folder},
269
+ )
270
+
271
+ # We also sort the validation data so it is faster to validate
272
+ test_data = test_data.filtered_sorted(sort_key="duration")
273
+
274
+ datasets = [train_data, valid_data, test_data]
275
+
276
+ # 2. Define audio pipeline:
277
+ @sb.utils.data_pipeline.takes("wav")
278
+ @sb.utils.data_pipeline.provides("sig")
279
+ def audio_pipeline(wav):
280
+ info = torchaudio.info(wav)
281
+ sig = sb.dataio.dataio.read_audio(wav)
282
+ resampled = torchaudio.transforms.Resample(
283
+ info.sample_rate, hparams["sample_rate"],
284
+ )(sig)
285
+ return resampled
286
+
287
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
288
+
289
+ # 3. Define text pipeline:
290
+ @sb.utils.data_pipeline.takes("wrd")
291
+ @sb.utils.data_pipeline.provides(
292
+ "tokens_list", "tokens_bos", "tokens_eos", "tokens"
293
+ )
294
+ def text_pipeline(wrd):
295
+ tokens_list = tokenizer.sp.encode_as_ids(wrd)
296
+ yield tokens_list
297
+ tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
298
+ yield tokens_bos
299
+ tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
300
+ yield tokens_eos
301
+ tokens = torch.LongTensor(tokens_list)
302
+ yield tokens
303
+
304
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
305
+
306
+ # 4. Set output:
307
+ sb.dataio.dataset.set_output_keys(
308
+ datasets, ["id", "sig", "tokens_bos", "tokens_eos", "tokens"],
309
+ )
310
+ return train_data, valid_data, test_data
311
+
312
+
313
+ if __name__ == "__main__":
314
+
315
+ # Load hyperparameters file with command-line overrides
316
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
317
+ with open(hparams_file) as fin:
318
+ hparams = load_hyperpyyaml(fin, overrides)
319
+
320
+ # If --distributed_launch then
321
+ # create ddp_group with the right communication protocol
322
+ sb.utils.distributed.ddp_init_group(run_opts)
323
+
324
+ # Dataset preparation (parsing CommonVoice)
325
+ from common_voice_prepare import prepare_common_voice # noqa
326
+
327
+ # Create experiment directory
328
+ sb.create_experiment_directory(
329
+ experiment_directory=hparams["output_folder"],
330
+ hyperparams_to_save=hparams_file,
331
+ overrides=overrides,
332
+ )
333
+
334
+ # Due to DDP, we do the preparation ONLY on the main python process
335
+ run_on_main(
336
+ prepare_common_voice,
337
+ kwargs={
338
+ "data_folder": hparams["data_folder"],
339
+ "save_folder": hparams["save_folder"],
340
+ "train_tsv_file": hparams["train_tsv_file"],
341
+ "dev_tsv_file": hparams["dev_tsv_file"],
342
+ "test_tsv_file": hparams["test_tsv_file"],
343
+ "accented_letters": hparams["accented_letters"],
344
+ "language": hparams["language"],
345
+ "skip_prep": hparams["skip_prep"],
346
+ },
347
+ )
348
+
349
+ # Defining tokenizer and loading it
350
+ tokenizer = SentencePiece(
351
+ model_dir=hparams["save_folder"],
352
+ vocab_size=hparams["output_neurons"],
353
+ annotation_train=hparams["train_csv"],
354
+ annotation_read="wrd",
355
+ model_type=hparams["token_type"],
356
+ character_coverage=hparams["character_coverage"],
357
+ )
358
+
359
+ # Create the datasets objects as well as tokenization and encoding :-D
360
+ train_data, valid_data, test_data = dataio_prepare(hparams, tokenizer)
361
+
362
+ # Trainer initialization
363
+ asr_brain = ASRCV(
364
+ modules=hparams["modules"],
365
+ hparams=hparams,
366
+ run_opts=run_opts,
367
+ checkpointer=hparams["checkpointer"],
368
+ )
369
+
370
+ # Adding objects to trainer.
371
+ asr_brain.tokenizer = tokenizer
372
+
373
+ # Training
374
+ asr_brain.fit(
375
+ asr_brain.hparams.epoch_counter,
376
+ train_data,
377
+ valid_data,
378
+ train_loader_kwargs=hparams["dataloader_options"],
379
+ valid_loader_kwargs=hparams["test_dataloader_options"],
380
+ )
381
+
382
+ # Test
383
+ asr_brain.hparams.wer_file = hparams["output_folder"] + "/wer_test.txt"
384
+ asr_brain.evaluate(
385
+ test_data,
386
+ min_key="WER",
387
+ test_loader_kwargs=hparams["test_dataloader_options"],
388
+ )
README.md CHANGED
@@ -1,13 +1,21 @@
1
  ---
2
- title: Code Switched Tunisian SpeechToText
3
- emoji: 👀
4
- colorFrom: yellow
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 3.44.4
8
- app_file: app.py
9
- pinned: false
10
  license: apache-2.0
11
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
 
 
 
 
 
 
 
 
2
  license: apache-2.0
3
  ---
4
+ # Tunisian Arabic ASR Model with wav2vec2 and code switching
5
+ This repository provides all the necessary tools to perform automatic speech recognition from an end-to-end system pretrained on Tunisian arabic dialect. This model utilizes a code_switching approach and can process english , french and tunisian arabic
6
+ ## Performance
7
+ the performance of the mode is :
8
+ | Release Version |WER (%) | CER (%) |
9
+ |-----------------|---------|---------|
10
+ | v1.0 |29.47 | 12.44 |
11
+ ## Pipeline
12
+ The architecture comprises three components:
13
+ * French ASR pretrained with wav2vec2 on french corporas
14
+ * English ASR pretrained with wav2vec2 on english corporas
15
+ * Custom Tunisian ASR pretrained using wav2vec on a tunisian arabic corpora
16
+ All three models will process the audio data. Subsequently, the resulting posteriorgrams will be combined and utilized as input for the Mixer, which will produce the final posteriorgrams.
17
+ ## Install
18
+ ```python
19
+ pip install speechbrain transformers
20
+ ```
21
 
 
TunisianASR/README.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tunisian Arabic ASR Model with wav2vec2
2
+
3
+ This repository provides all the necessary tools to perform automatic speech recognition from an end-to-end system pretrained on Tunisian arabic dialect
4
+
5
+ ## Performance
6
+ the performance of the mode is :
7
+ | Release Version | |WER (%) | CER (%) |
8
+ |-----------------|----|---------|---------|
9
+ | v1.0 | Without LM |11.82 | 6.33 |
10
+ ## Dataset
11
+ This ASR model was trained on :
12
+ * TARIC : The corpus, named TARIC (Tunisian Arabic Railway Interaction Corpus) has a collection of audio recordings and transcriptions from dialogues in the Tunisian Railway Transport Network. - [Taric Corpus](https://aclanthology.org/L14-1385/) -
13
+ * STAC :A corpus of spoken Tunisian Arabic - [STAC Corpus](https://www.researchgate.net/publication/307583782_Spoken_Tunisian_Arabic_Corpus_STAC_Transcription_and_Annotation)
14
+ * IWSLT : A Tunisian conversational speech - [IWSLT Corpus](https://iwslt.org/2022/dialect)-
15
+ * Tunspeech : Our custom dataset
16
+
17
+ ## Install
18
+ ```python
19
+ pip install speechbrain transformers
20
+ ```
21
+
TunisianASR/results/14epoch_tunisian/1234/env.log ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SpeechBrain system description
2
+ ==============================
3
+ Python version:
4
+ 3.8.5 (default, Sep 4 2020, 07:30:14)
5
+ [GCC 7.3.0]
6
+ ==============================
7
+ Installed Python packages:
8
+ absl-py==1.2.0
9
+ aiohttp==3.8.1
10
+ aiosignal==1.2.0
11
+ alabaster==0.7.12
12
+ anaconda-client==1.7.2
13
+ anaconda-navigator==1.10.0
14
+ anaconda-project==0.8.3
15
+ antlr4-python3-runtime==4.9.3
16
+ appdirs==1.4.4
17
+ argh==0.26.2
18
+ argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1596828493937/work
19
+ asn1crypto @ file:///tmp/build/80754af9/asn1crypto_1596577642040/work
20
+ astroid @ file:///tmp/build/80754af9/astroid_1592495912941/work
21
+ astropy==4.0.2
22
+ async-generator==1.10
23
+ async-timeout==4.0.2
24
+ atomicwrites==1.4.0
25
+ attrs @ file:///tmp/build/80754af9/attrs_1604765588209/work
26
+ audioread==2.1.9
27
+ autopep8 @ file:///tmp/build/80754af9/autopep8_1596578164842/work
28
+ Babel @ file:///tmp/build/80754af9/babel_1605108370292/work
29
+ backcall==0.2.0
30
+ backports.functools-lru-cache==1.6.1
31
+ backports.shutil-get-terminal-size==1.0.0
32
+ backports.tempfile==1.0
33
+ backports.weakref==1.0.post1
34
+ beautifulsoup4 @ file:///tmp/build/80754af9/beautifulsoup4_1601924105527/work
35
+ bitarray @ file:///tmp/build/80754af9/bitarray_1605065113847/work
36
+ bkcharts==0.2
37
+ black==22.12.0
38
+ bleach @ file:///tmp/build/80754af9/bleach_1600439572647/work
39
+ bokeh @ file:///tmp/build/80754af9/bokeh_1603297833684/work
40
+ boto==2.49.0
41
+ boto3==1.28.43
42
+ botocore==1.31.43
43
+ Bottleneck==1.3.2
44
+ bpemb==0.3.4
45
+ brotlipy==0.7.0
46
+ cachetools==5.2.0
47
+ certifi==2020.6.20
48
+ cffi @ file:///tmp/build/80754af9/cffi_1600699146221/work
49
+ chardet==3.0.4
50
+ charset-normalizer==2.0.12
51
+ click==8.1.3
52
+ cloudpickle @ file:///tmp/build/80754af9/cloudpickle_1598884132938/work
53
+ clyent==1.2.2
54
+ colorama @ file:///tmp/build/80754af9/colorama_1603211150991/work
55
+ coloredlogs==15.0.1
56
+ conda==4.9.2
57
+ conda-build==3.20.5
58
+ conda-package-handling @ file:///tmp/build/80754af9/conda-package-handling_1603018141399/work
59
+ conda-verify==3.4.2
60
+ conllu==4.5.3
61
+ contextlib2==0.6.0.post1
62
+ cryptography @ file:///tmp/build/80754af9/cryptography_1601046815590/work
63
+ cycler==0.10.0
64
+ Cython @ file:///tmp/build/80754af9/cython_1594831566883/work
65
+ cytoolz==0.11.0
66
+ dask @ file:///tmp/build/80754af9/dask-core_1602083700509/work
67
+ datasets==1.18.3
68
+ decorator==4.4.2
69
+ defusedxml==0.6.0
70
+ Deprecated==1.2.14
71
+ diff-match-patch @ file:///tmp/build/80754af9/diff-match-patch_1594828741838/work
72
+ dill==0.3.4
73
+ distributed @ file:///tmp/build/80754af9/distributed_1605066520644/work
74
+ docutils==0.16
75
+ easyocr==1.2.1
76
+ einops==0.3.0
77
+ entrypoints==0.3
78
+ et-xmlfile==1.0.1
79
+ farasapy==0.0.14
80
+ fastcache==1.1.0
81
+ ffmpeg-python==0.2.0
82
+ filelock==3.0.12
83
+ flair==0.12.2
84
+ flake8 @ file:///tmp/build/80754af9/flake8_1601911421857/work
85
+ Flask==1.1.2
86
+ flatbuffers==22.9.24
87
+ frozenlist==1.3.0
88
+ fsspec==2022.3.0
89
+ ftfy==6.1.1
90
+ future==0.18.2
91
+ gdown==4.4.0
92
+ gensim==4.1.2
93
+ gevent @ file:///tmp/build/80754af9/gevent_1601397537062/work
94
+ glob2==0.7
95
+ gmpy2==2.0.8
96
+ google-auth==2.12.0
97
+ google-auth-oauthlib==0.4.6
98
+ greenlet @ file:///tmp/build/80754af9/greenlet_1600874013538/work
99
+ grpcio==1.49.1
100
+ h5py==2.10.0
101
+ HeapDict==1.0.1
102
+ html5lib @ file:///tmp/build/80754af9/html5lib_1593446221756/work
103
+ huggingface-hub==0.16.4
104
+ humanfriendly==10.0
105
+ hyperopt==0.2.7
106
+ idna @ file:///tmp/build/80754af9/idna_1593446292537/work
107
+ imageio @ file:///tmp/build/80754af9/imageio_1594161405741/work
108
+ imagesize==1.2.0
109
+ imhist==0.0.4
110
+ importlib-metadata==5.0.0
111
+ imWatermark==0.0.2
112
+ iniconfig @ file:///tmp/build/80754af9/iniconfig_1602780191262/work
113
+ install==1.3.5
114
+ intervaltree @ file:///tmp/build/80754af9/intervaltree_1598376443606/work
115
+ invisible-watermark==0.1.5
116
+ ipykernel @ file:///tmp/build/80754af9/ipykernel_1596207638929/work/dist/ipykernel-5.3.4-py3-none-any.whl
117
+ ipython @ file:///tmp/build/80754af9/ipython_1604101197014/work
118
+ ipython-genutils==0.2.0
119
+ ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1601490159889/work
120
+ isort @ file:///tmp/build/80754af9/isort_1602603989581/work
121
+ itsdangerous==1.1.0
122
+ Janome==0.5.0
123
+ jdcal==1.4.1
124
+ jedi @ file:///tmp/build/80754af9/jedi_1592841866100/work
125
+ jeepney @ file:///tmp/build/80754af9/jeepney_1605069705079/work
126
+ Jinja2==2.11.2
127
+ jiwer==2.3.0
128
+ jmespath==1.0.1
129
+ joblib @ file:///tmp/build/80754af9/joblib_1601912903842/work
130
+ json5==0.9.5
131
+ jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work
132
+ jupyter==1.0.0
133
+ jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1601311786391/work
134
+ jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1598884538475/work
135
+ jupyter-core==4.6.3
136
+ jupyterlab==2.2.6
137
+ jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
138
+ jupyterlab-server @ file:///tmp/build/80754af9/jupyterlab_server_1594164409481/work
139
+ keyring @ file:///tmp/build/80754af9/keyring_1601490835422/work
140
+ kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1604014535162/work
141
+ langdetect==1.0.9
142
+ lazy-object-proxy==1.4.3
143
+ libarchive-c==2.9
144
+ librosa==0.9.1
145
+ llvmlite==0.34.0
146
+ locket==0.2.0
147
+ lxml @ file:///tmp/build/80754af9/lxml_1603216285000/work
148
+ Markdown==3.4.1
149
+ MarkupSafe==1.1.1
150
+ matplotlib @ file:///tmp/build/80754af9/matplotlib-base_1603378225747/work
151
+ mccabe==0.6.1
152
+ mido==1.2.10
153
+ mistune==0.8.4
154
+ mkl-fft==1.2.0
155
+ mkl-random==1.1.1
156
+ mkl-service==2.3.0
157
+ mock==4.0.2
158
+ more-itertools @ file:///tmp/build/80754af9/more-itertools_1605111547926/work
159
+ mpld3==0.3
160
+ mpmath==1.1.0
161
+ msgpack==1.0.0
162
+ multidict==6.0.2
163
+ multipledispatch==0.6.0
164
+ multiprocess==0.70.12.2
165
+ mypy-extensions==0.4.3
166
+ navigator-updater==0.2.1
167
+ nbclient @ file:///tmp/build/80754af9/nbclient_1602783176460/work
168
+ nbconvert @ file:///tmp/build/80754af9/nbconvert_1601914830498/work
169
+ nbformat @ file:///tmp/build/80754af9/nbformat_1602783287752/work
170
+ nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1605115881283/work
171
+ networkx @ file:///tmp/build/80754af9/networkx_1598376031484/work
172
+ nltk @ file:///tmp/build/80754af9/nltk_1592496090529/work
173
+ nose==1.3.7
174
+ notebook @ file:///tmp/build/80754af9/notebook_1601501575118/work
175
+ numba @ file:///tmp/build/80754af9/numba_1600100669015/work
176
+ numexpr==2.7.1
177
+ numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1603570489231/work
178
+ numpydoc @ file:///tmp/build/80754af9/numpydoc_1605117425582/work
179
+ oauthlib==3.2.1
180
+ olefile==0.46
181
+ omegaconf==2.2.3
182
+ onnx==1.12.0
183
+ onnxruntime==1.12.1
184
+ opencv-python==4.4.0.46
185
+ openpyxl @ file:///tmp/build/80754af9/openpyxl_1598113097404/work
186
+ packaging==20.9
187
+ pandas @ file:///tmp/build/80754af9/pandas_1602088120436/work
188
+ pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120460739/work
189
+ parso==0.7.0
190
+ partd==1.1.0
191
+ path @ file:///tmp/build/80754af9/path_1598376507494/work
192
+ pathlib2==2.3.5
193
+ pathspec==0.10.3
194
+ pathtools==0.1.2
195
+ patsy==0.5.1
196
+ pep8==1.7.1
197
+ pexpect==4.8.0
198
+ pickleshare==0.7.5
199
+ Pillow @ file:///tmp/build/80754af9/pillow_1603822255246/work
200
+ pkginfo==1.6.1
201
+ platformdirs==2.6.0
202
+ pluggy==0.13.1
203
+ ply==3.11
204
+ pooch==1.6.0
205
+ pptree==3.1
206
+ pretty-midi==0.2.9
207
+ prometheus-client==0.8.0
208
+ prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1602688806899/work
209
+ protobuf==3.19.6
210
+ psutil @ file:///tmp/build/80754af9/psutil_1598370257551/work
211
+ ptyprocess==0.6.0
212
+ py @ file:///tmp/build/80754af9/py_1593446248552/work
213
+ py-espeak-ng==0.1.8
214
+ py4j==0.10.9.7
215
+ PyArabic==0.6.15
216
+ pyarrow==7.0.0
217
+ pyasn1==0.4.8
218
+ pyasn1-modules==0.2.8
219
+ pycodestyle==2.6.0
220
+ pycosat==0.6.3
221
+ pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work
222
+ pycurl==7.43.0.6
223
+ pyDeprecate==0.3.1
224
+ pydocstyle @ file:///tmp/build/80754af9/pydocstyle_1598885001695/work
225
+ pyflakes==2.2.0
226
+ Pygments @ file:///tmp/build/80754af9/pygments_1604103097372/work
227
+ pylint @ file:///tmp/build/80754af9/pylint_1598623985952/work
228
+ pyodbc===4.0.0-unsupported
229
+ pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1594392929924/work
230
+ pyparsing==2.4.7
231
+ pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141720057/work
232
+ PySocks==1.7.1
233
+ pytest==0.0.0
234
+ python-bidi==0.4.2
235
+ python-crfsuite==0.9.7
236
+ python-dateutil==2.8.1
237
+ python-jsonrpc-server @ file:///tmp/build/80754af9/python-jsonrpc-server_1600278539111/work
238
+ python-language-server @ file:///tmp/build/80754af9/python-language-server_1600454544709/work
239
+ python-Levenshtein==0.12.2
240
+ pytorch-lightning==1.4.2
241
+ pytorch-revgrad==0.2.0
242
+ pytz==2020.1
243
+ PyWavelets @ file:///tmp/build/80754af9/pywavelets_1601658317819/work
244
+ pyxdg @ file:///tmp/build/80754af9/pyxdg_1603822279816/work
245
+ PyYAML==5.3.1
246
+ pyzmq==19.0.2
247
+ QDarkStyle==2.8.1
248
+ QtAwesome @ file:///tmp/build/80754af9/qtawesome_1602272867890/work
249
+ qtconsole @ file:///tmp/build/80754af9/qtconsole_1600870028330/work
250
+ QtPy==1.9.0
251
+ regex @ file:///tmp/build/80754af9/regex_1602786672676/work
252
+ requests @ file:///tmp/build/80754af9/requests_1592841827918/work
253
+ requests-oauthlib==1.3.1
254
+ resampy==0.2.2
255
+ rope @ file:///tmp/build/80754af9/rope_1602264064449/work
256
+ rsa==4.9
257
+ Rtree==0.9.4
258
+ ruamel-yaml==0.15.87
259
+ s3transfer==0.6.2
260
+ sacremoses==0.0.49
261
+ safetensors==0.3.3
262
+ scikit-image==0.17.2
263
+ scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1598376899566/work
264
+ scipy @ file:///tmp/build/80754af9/scipy_1597686649129/work
265
+ seaborn @ file:///tmp/build/80754af9/seaborn_1600553570093/work
266
+ SecretStorage==3.1.2
267
+ segtok==1.5.11
268
+ Send2Trash==1.5.0
269
+ sentencepiece==0.1.97
270
+ simplegeneric==0.8.1
271
+ singledispatch @ file:///tmp/build/80754af9/singledispatch_1602523705405/work
272
+ sip==4.19.13
273
+ six @ file:///tmp/build/80754af9/six_1605205327372/work
274
+ smart-open==5.2.1
275
+ snowballstemmer==2.0.0
276
+ sortedcollections==1.2.1
277
+ sortedcontainers==2.2.2
278
+ SoundFile==0.10.3.post1
279
+ soupsieve==2.0.1
280
+ sphfile==1.0.3
281
+ Sphinx @ file:///tmp/build/80754af9/sphinx_1597428793432/work
282
+ sphinxcontrib-applehelp==1.0.2
283
+ sphinxcontrib-devhelp==1.0.2
284
+ sphinxcontrib-htmlhelp==1.0.3
285
+ sphinxcontrib-jsmath==1.0.1
286
+ sphinxcontrib-qthelp==1.0.3
287
+ sphinxcontrib-serializinghtml==1.1.4
288
+ sphinxcontrib-websupport @ file:///tmp/build/80754af9/sphinxcontrib-websupport_1597081412696/work
289
+ spyder @ file:///tmp/build/80754af9/spyder_1599056981321/work
290
+ spyder-kernels @ file:///tmp/build/80754af9/spyder-kernels_1599056754858/work
291
+ SQLAlchemy @ file:///tmp/build/80754af9/sqlalchemy_1603397987316/work
292
+ sqlitedict==2.1.0
293
+ statsmodels @ file:///tmp/build/80754af9/statsmodels_1602280205159/work
294
+ sympy @ file:///tmp/build/80754af9/sympy_1605119542615/work
295
+ tables==3.6.1
296
+ tabulate==0.9.0
297
+ tblib @ file:///tmp/build/80754af9/tblib_1597928476713/work
298
+ tensorboard==2.10.1
299
+ tensorboard-data-server==0.6.1
300
+ tensorboard-plugin-wit==1.8.1
301
+ terminado==0.9.1
302
+ testpath==0.4.4
303
+ threadpoolctl @ file:///tmp/tmp9twdgx9k/threadpoolctl-2.1.0-py3-none-any.whl
304
+ tifffile==2020.10.1
305
+ tkseem==0.0.3
306
+ tokenizers==0.13.3
307
+ toml @ file:///tmp/build/80754af9/toml_1592853716807/work
308
+ tomli==2.0.1
309
+ toolz @ file:///tmp/build/80754af9/toolz_1601054250827/work
310
+ torch==1.11.0
311
+ torchaudio==0.11.0
312
+ torchmetrics==0.6.0
313
+ torchvision==0.8.2
314
+ tornado==6.0.4
315
+ tqdm==4.64.0
316
+ traitlets @ file:///tmp/build/80754af9/traitlets_1602787416690/work
317
+ transformer-smaller-training-vocab==0.3.1
318
+ transformers==4.33.1
319
+ typing-extensions==4.4.0
320
+ ujson @ file:///tmp/build/80754af9/ujson_1602523317881/work
321
+ unicodecsv==0.14.1
322
+ urllib3 @ file:///tmp/build/80754af9/urllib3_1603305693037/work
323
+ watchdog @ file:///tmp/build/80754af9/watchdog_1593447344699/work
324
+ wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
325
+ webencodings==0.5.1
326
+ Werkzeug==1.0.1
327
+ widgetsnbextension==3.5.1
328
+ Wikipedia-API==0.6.0
329
+ wrapt==1.11.2
330
+ wurlitzer @ file:///tmp/build/80754af9/wurlitzer_1594753850195/work
331
+ xlrd==1.2.0
332
+ XlsxWriter @ file:///tmp/build/80754af9/xlsxwriter_1602692860603/work
333
+ xlwt==1.3.0
334
+ xmltodict==0.12.0
335
+ xxhash==3.0.0
336
+ yapf @ file:///tmp/build/80754af9/yapf_1593528177422/work
337
+ yarl==1.7.2
338
+ zict==2.0.0
339
+ zipp @ file:///tmp/build/80754af9/zipp_1604001098328/work
340
+ zope.event==4.5.0
341
+ zope.interface @ file:///tmp/build/80754af9/zope.interface_1602002420968/work
342
+ ==============================
343
+ Git revision:
344
+ 8a51838
345
+ ==============================
346
+ CUDA version:
347
+ 11.7
TunisianASR/results/14epoch_tunisian/1234/hyperparams.yaml ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated 2023-09-20 from:
2
+ # /home/salah/Code_Switched_Tunisian_Speech_Recognition/TunisianASR/semi_trained.yaml
3
+ # yamllint disable
4
+ # ################################
5
+ # Model: wav2vec2 + DNN + CTC
6
+ # Augmentation: SpecAugment
7
+ # Authors: Titouan Parcollet 2021
8
+ # ################################
9
+
10
+ # Seed needs to be set at top of yaml, before objects with parameters are made
11
+ seed: 1234
12
+ __set_seed: !!python/object/apply:torch.manual_seed [1234]
13
+ output_folder: TunisianASR/results/14epoch_tunisian/1234/
14
+ wer_file: TunisianASR/results/14epoch_tunisian/1234//wer.txt
15
+ save_folder: TunisianASR/results/14epoch_tunisian/1234//save
16
+ train_log: TunisianASR/results/14epoch_tunisian/1234//train_log.txt
17
+
18
+ # URL for the biggest LeBenchmark wav2vec french.
19
+ wav2vec2_folder: TunisianASR/results/14epoch_tunisian/1234//save/wav2vec2_checkpoint
20
+
21
+ # Data files
22
+ data_folder: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
23
+ train_tsv_file: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk/train.tsv # Standard CommonVoice .tsv files
24
+ dev_tsv_file: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk/dev.tsv # Standard CommonVoice .tsv files
25
+ test_tsv_file: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk/test.tsv # Standard CommonVoice .tsv files
26
+ accented_letters: true
27
+ language: fr # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
28
+ train_csv: /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/train.csv
29
+ valid_csv: /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/dev.csv
30
+ test_csv:
31
+ - /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/full_annotation_test.csv
32
+ - /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/iwslt_test.csv
33
+ - /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/taric_test.csv
34
+
35
+ skip_prep: true # Skip data preparation
36
+
37
+ use_language_modelling: true
38
+ ngram_lm_path: arpas/outdomain.arpa
39
+
40
+ # We remove utterance slonger than 10s in the train/dev/test sets as
41
+ # longer sentences certainly correspond to "open microphones".
42
+ avoid_if_longer_than: 10.0
43
+ avoid_if_shorter_than: 1.2
44
+
45
+
46
+ # Training parameters
47
+ number_of_epochs: 14
48
+ lr: 1.0
49
+ lr_wav2vec: 0.0001
50
+ sorting: ascending
51
+ auto_mix_prec: false
52
+ sample_rate: 16000
53
+ ckpt_interval_minutes: 30 # save checkpoint every N min
54
+
55
+ # With data_parallel batch_size is split into N jobs
56
+ # With DDP batch_size is multiplied by N jobs
57
+ # Must be 6 per GPU to fit 16GB of VRAM
58
+ batch_size: 10
59
+ test_batch_size: 4
60
+
61
+ dataloader_options:
62
+ batch_size: 10
63
+ num_workers: 6
64
+ test_dataloader_options:
65
+ batch_size: 4
66
+ num_workers: 6
67
+
68
+ # BPE parameters
69
+ token_type: char # ["unigram", "bpe", "char"]
70
+ character_coverage: 1.0
71
+
72
+ # Model parameters
73
+ # activation: !name:torch.nn.LeakyReLU
74
+ wav2vec_output_dim: 1024
75
+ dnn_neurons: 1024
76
+ freeze_wav2vec: false
77
+ freeze_feature_extractor: true
78
+ dropout: 0.15
79
+ warmup_steps: 500 # The wav2vec 2 model isn't updated for this amount of steps
80
+
81
+ # Outputs
82
+ output_neurons: 40 # BPE size, index(blank/eos/bos) = 0
83
+
84
+ # Decoding parameters
85
+ # Be sure that the bos and eos index match with the BPEs ones
86
+ blank_index: 0
87
+ unk_index: 1
88
+
89
+ #
90
+ # Functions and classes
91
+ #
92
+ epoch_counter: &id007 !new:speechbrain.utils.epoch_loop.EpochCounter
93
+
94
+ limit: 14
95
+
96
+ augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
97
+ sample_rate: 16000
98
+ speeds: [95, 100, 105]
99
+
100
+ enc: &id002 !new:speechbrain.nnet.containers.Sequential
101
+ input_shape: [null, null, 1024]
102
+ linear1: !name:speechbrain.nnet.linear.Linear
103
+ n_neurons: 1024
104
+ bias: true
105
+ bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
106
+ activation: !new:torch.nn.LeakyReLU
107
+ drop: !new:torch.nn.Dropout
108
+ p: 0.15
109
+ linear2: !name:speechbrain.nnet.linear.Linear
110
+ n_neurons: 1024
111
+ bias: true
112
+ bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
113
+ activation2: !new:torch.nn.LeakyReLU
114
+ drop2: !new:torch.nn.Dropout
115
+ p: 0.15
116
+ linear3: !name:speechbrain.nnet.linear.Linear
117
+ n_neurons: 1024
118
+ bias: true
119
+ bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
120
+ activation3: !new:torch.nn.LeakyReLU
121
+
122
+ wav2vec2: &id001 !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
123
+ source: wavlm-large/
124
+ output_norm: false
125
+ freeze: false
126
+ freeze_feature_extractor: true
127
+ save_path: TunisianASR/results/14epoch_tunisian/1234//save/wav2vec2_checkpoint
128
+
129
+ #####
130
+ # Uncomment this block if you prefer to use a Fairseq pretrained model instead
131
+ # of a HuggingFace one. Here, we provide an URL that is obtained from the
132
+ # Fairseq github for the multilingual XLSR.
133
+ #
134
+ #wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt
135
+ #wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
136
+ # pretrained_path: !ref <wav2vec2_url>
137
+ # output_norm: True
138
+ # freeze: False
139
+ # save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
140
+ #####
141
+
142
+
143
+ ctc_lin: &id003 !new:speechbrain.nnet.linear.Linear
144
+
145
+ input_size: 1024
146
+ n_neurons: 40
147
+
148
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
149
+ apply_log: true
150
+
151
+ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
152
+ blank_index: 0
153
+
154
+ modules:
155
+ wav2vec2: *id001
156
+ enc: *id002
157
+ ctc_lin: *id003
158
+ model: &id004 !new:torch.nn.ModuleList
159
+ - [*id002, *id003]
160
+ model_opt_class: !name:torch.optim.Adadelta
161
+ lr: 1.0
162
+ rho: 0.95
163
+ eps: 1.e-8
164
+
165
+ wav2vec_opt_class: !name:torch.optim.Adam
166
+ lr: 0.0001
167
+
168
+ lr_annealing_model: &id005 !new:speechbrain.nnet.schedulers.NewBobScheduler
169
+ initial_value: 1.0
170
+ improvement_threshold: 0.0025
171
+ annealing_factor: 0.8
172
+ patient: 0
173
+
174
+ lr_annealing_wav2vec: &id006 !new:speechbrain.nnet.schedulers.NewBobScheduler
175
+ initial_value: 0.0001
176
+ improvement_threshold: 0.0025
177
+ annealing_factor: 0.9
178
+ patient: 0
179
+
180
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
181
+ checkpoints_dir: TunisianASR/results/14epoch_tunisian/1234//save
182
+ recoverables:
183
+ wav2vec2: *id001
184
+ model: *id004
185
+ scheduler_model: *id005
186
+ scheduler_wav2vec: *id006
187
+ counter: *id007
188
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
189
+ save_file: TunisianASR/results/14epoch_tunisian/1234//train_log.txt
190
+
191
+ error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
192
+
193
+ cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
194
+ split_tokens: true
TunisianASR/results/14epoch_tunisian/1234/log.txt ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2023-09-20 16:23:38,106 - speechbrain.core - INFO - Beginning experiment!
2
+ 2023-09-20 16:23:38,106 - speechbrain.core - INFO - Experiment folder: TunisianASR/results/14epoch_tunisian/1234/
3
+ 2023-09-20 16:23:39,287 - speechbrain.utils.superpowers - DEBUG - absl-py==1.2.0
4
+ aiohttp==3.8.1
5
+ aiosignal==1.2.0
6
+ alabaster==0.7.12
7
+ anaconda-client==1.7.2
8
+ anaconda-navigator==1.10.0
9
+ anaconda-project==0.8.3
10
+ antlr4-python3-runtime==4.9.3
11
+ appdirs==1.4.4
12
+ argh==0.26.2
13
+ argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1596828493937/work
14
+ asn1crypto @ file:///tmp/build/80754af9/asn1crypto_1596577642040/work
15
+ astroid @ file:///tmp/build/80754af9/astroid_1592495912941/work
16
+ astropy==4.0.2
17
+ async-generator==1.10
18
+ async-timeout==4.0.2
19
+ atomicwrites==1.4.0
20
+ attrs @ file:///tmp/build/80754af9/attrs_1604765588209/work
21
+ audioread==2.1.9
22
+ autopep8 @ file:///tmp/build/80754af9/autopep8_1596578164842/work
23
+ Babel @ file:///tmp/build/80754af9/babel_1605108370292/work
24
+ backcall==0.2.0
25
+ backports.functools-lru-cache==1.6.1
26
+ backports.shutil-get-terminal-size==1.0.0
27
+ backports.tempfile==1.0
28
+ backports.weakref==1.0.post1
29
+ beautifulsoup4 @ file:///tmp/build/80754af9/beautifulsoup4_1601924105527/work
30
+ bitarray @ file:///tmp/build/80754af9/bitarray_1605065113847/work
31
+ bkcharts==0.2
32
+ black==22.12.0
33
+ bleach @ file:///tmp/build/80754af9/bleach_1600439572647/work
34
+ bokeh @ file:///tmp/build/80754af9/bokeh_1603297833684/work
35
+ boto==2.49.0
36
+ boto3==1.28.43
37
+ botocore==1.31.43
38
+ Bottleneck==1.3.2
39
+ bpemb==0.3.4
40
+ brotlipy==0.7.0
41
+ cachetools==5.2.0
42
+ certifi==2020.6.20
43
+ cffi @ file:///tmp/build/80754af9/cffi_1600699146221/work
44
+ chardet==3.0.4
45
+ charset-normalizer==2.0.12
46
+ click==8.1.3
47
+ cloudpickle @ file:///tmp/build/80754af9/cloudpickle_1598884132938/work
48
+ clyent==1.2.2
49
+ colorama @ file:///tmp/build/80754af9/colorama_1603211150991/work
50
+ coloredlogs==15.0.1
51
+ conda==4.9.2
52
+ conda-build==3.20.5
53
+ conda-package-handling @ file:///tmp/build/80754af9/conda-package-handling_1603018141399/work
54
+ conda-verify==3.4.2
55
+ conllu==4.5.3
56
+ contextlib2==0.6.0.post1
57
+ cryptography @ file:///tmp/build/80754af9/cryptography_1601046815590/work
58
+ cycler==0.10.0
59
+ Cython @ file:///tmp/build/80754af9/cython_1594831566883/work
60
+ cytoolz==0.11.0
61
+ dask @ file:///tmp/build/80754af9/dask-core_1602083700509/work
62
+ datasets==1.18.3
63
+ decorator==4.4.2
64
+ defusedxml==0.6.0
65
+ Deprecated==1.2.14
66
+ diff-match-patch @ file:///tmp/build/80754af9/diff-match-patch_1594828741838/work
67
+ dill==0.3.4
68
+ distributed @ file:///tmp/build/80754af9/distributed_1605066520644/work
69
+ docutils==0.16
70
+ easyocr==1.2.1
71
+ einops==0.3.0
72
+ entrypoints==0.3
73
+ et-xmlfile==1.0.1
74
+ farasapy==0.0.14
75
+ fastcache==1.1.0
76
+ ffmpeg-python==0.2.0
77
+ filelock==3.0.12
78
+ flair==0.12.2
79
+ flake8 @ file:///tmp/build/80754af9/flake8_1601911421857/work
80
+ Flask==1.1.2
81
+ flatbuffers==22.9.24
82
+ frozenlist==1.3.0
83
+ fsspec==2022.3.0
84
+ ftfy==6.1.1
85
+ future==0.18.2
86
+ gdown==4.4.0
87
+ gensim==4.1.2
88
+ gevent @ file:///tmp/build/80754af9/gevent_1601397537062/work
89
+ glob2==0.7
90
+ gmpy2==2.0.8
91
+ google-auth==2.12.0
92
+ google-auth-oauthlib==0.4.6
93
+ greenlet @ file:///tmp/build/80754af9/greenlet_1600874013538/work
94
+ grpcio==1.49.1
95
+ h5py==2.10.0
96
+ HeapDict==1.0.1
97
+ html5lib @ file:///tmp/build/80754af9/html5lib_1593446221756/work
98
+ huggingface-hub==0.16.4
99
+ humanfriendly==10.0
100
+ hyperopt==0.2.7
101
+ idna @ file:///tmp/build/80754af9/idna_1593446292537/work
102
+ imageio @ file:///tmp/build/80754af9/imageio_1594161405741/work
103
+ imagesize==1.2.0
104
+ imhist==0.0.4
105
+ importlib-metadata==5.0.0
106
+ imWatermark==0.0.2
107
+ iniconfig @ file:///tmp/build/80754af9/iniconfig_1602780191262/work
108
+ install==1.3.5
109
+ intervaltree @ file:///tmp/build/80754af9/intervaltree_1598376443606/work
110
+ invisible-watermark==0.1.5
111
+ ipykernel @ file:///tmp/build/80754af9/ipykernel_1596207638929/work/dist/ipykernel-5.3.4-py3-none-any.whl
112
+ ipython @ file:///tmp/build/80754af9/ipython_1604101197014/work
113
+ ipython-genutils==0.2.0
114
+ ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1601490159889/work
115
+ isort @ file:///tmp/build/80754af9/isort_1602603989581/work
116
+ itsdangerous==1.1.0
117
+ Janome==0.5.0
118
+ jdcal==1.4.1
119
+ jedi @ file:///tmp/build/80754af9/jedi_1592841866100/work
120
+ jeepney @ file:///tmp/build/80754af9/jeepney_1605069705079/work
121
+ Jinja2==2.11.2
122
+ jiwer==2.3.0
123
+ jmespath==1.0.1
124
+ joblib @ file:///tmp/build/80754af9/joblib_1601912903842/work
125
+ json5==0.9.5
126
+ jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work
127
+ jupyter==1.0.0
128
+ jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1601311786391/work
129
+ jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1598884538475/work
130
+ jupyter-core==4.6.3
131
+ jupyterlab==2.2.6
132
+ jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
133
+ jupyterlab-server @ file:///tmp/build/80754af9/jupyterlab_server_1594164409481/work
134
+ keyring @ file:///tmp/build/80754af9/keyring_1601490835422/work
135
+ kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1604014535162/work
136
+ langdetect==1.0.9
137
+ lazy-object-proxy==1.4.3
138
+ libarchive-c==2.9
139
+ librosa==0.9.1
140
+ llvmlite==0.34.0
141
+ locket==0.2.0
142
+ lxml @ file:///tmp/build/80754af9/lxml_1603216285000/work
143
+ Markdown==3.4.1
144
+ MarkupSafe==1.1.1
145
+ matplotlib @ file:///tmp/build/80754af9/matplotlib-base_1603378225747/work
146
+ mccabe==0.6.1
147
+ mido==1.2.10
148
+ mistune==0.8.4
149
+ mkl-fft==1.2.0
150
+ mkl-random==1.1.1
151
+ mkl-service==2.3.0
152
+ mock==4.0.2
153
+ more-itertools @ file:///tmp/build/80754af9/more-itertools_1605111547926/work
154
+ mpld3==0.3
155
+ mpmath==1.1.0
156
+ msgpack==1.0.0
157
+ multidict==6.0.2
158
+ multipledispatch==0.6.0
159
+ multiprocess==0.70.12.2
160
+ mypy-extensions==0.4.3
161
+ navigator-updater==0.2.1
162
+ nbclient @ file:///tmp/build/80754af9/nbclient_1602783176460/work
163
+ nbconvert @ file:///tmp/build/80754af9/nbconvert_1601914830498/work
164
+ nbformat @ file:///tmp/build/80754af9/nbformat_1602783287752/work
165
+ nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1605115881283/work
166
+ networkx @ file:///tmp/build/80754af9/networkx_1598376031484/work
167
+ nltk @ file:///tmp/build/80754af9/nltk_1592496090529/work
168
+ nose==1.3.7
169
+ notebook @ file:///tmp/build/80754af9/notebook_1601501575118/work
170
+ numba @ file:///tmp/build/80754af9/numba_1600100669015/work
171
+ numexpr==2.7.1
172
+ numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1603570489231/work
173
+ numpydoc @ file:///tmp/build/80754af9/numpydoc_1605117425582/work
174
+ oauthlib==3.2.1
175
+ olefile==0.46
176
+ omegaconf==2.2.3
177
+ onnx==1.12.0
178
+ onnxruntime==1.12.1
179
+ opencv-python==4.4.0.46
180
+ openpyxl @ file:///tmp/build/80754af9/openpyxl_1598113097404/work
181
+ packaging==20.9
182
+ pandas @ file:///tmp/build/80754af9/pandas_1602088120436/work
183
+ pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120460739/work
184
+ parso==0.7.0
185
+ partd==1.1.0
186
+ path @ file:///tmp/build/80754af9/path_1598376507494/work
187
+ pathlib2==2.3.5
188
+ pathspec==0.10.3
189
+ pathtools==0.1.2
190
+ patsy==0.5.1
191
+ pep8==1.7.1
192
+ pexpect==4.8.0
193
+ pickleshare==0.7.5
194
+ Pillow @ file:///tmp/build/80754af9/pillow_1603822255246/work
195
+ pkginfo==1.6.1
196
+ platformdirs==2.6.0
197
+ pluggy==0.13.1
198
+ ply==3.11
199
+ pooch==1.6.0
200
+ pptree==3.1
201
+ pretty-midi==0.2.9
202
+ prometheus-client==0.8.0
203
+ prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1602688806899/work
204
+ protobuf==3.19.6
205
+ psutil @ file:///tmp/build/80754af9/psutil_1598370257551/work
206
+ ptyprocess==0.6.0
207
+ py @ file:///tmp/build/80754af9/py_1593446248552/work
208
+ py-espeak-ng==0.1.8
209
+ py4j==0.10.9.7
210
+ PyArabic==0.6.15
211
+ pyarrow==7.0.0
212
+ pyasn1==0.4.8
213
+ pyasn1-modules==0.2.8
214
+ pycodestyle==2.6.0
215
+ pycosat==0.6.3
216
+ pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work
217
+ pycurl==7.43.0.6
218
+ pyDeprecate==0.3.1
219
+ pydocstyle @ file:///tmp/build/80754af9/pydocstyle_1598885001695/work
220
+ pyflakes==2.2.0
221
+ Pygments @ file:///tmp/build/80754af9/pygments_1604103097372/work
222
+ pylint @ file:///tmp/build/80754af9/pylint_1598623985952/work
223
+ pyodbc===4.0.0-unsupported
224
+ pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1594392929924/work
225
+ pyparsing==2.4.7
226
+ pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141720057/work
227
+ PySocks==1.7.1
228
+ pytest==0.0.0
229
+ python-bidi==0.4.2
230
+ python-crfsuite==0.9.7
231
+ python-dateutil==2.8.1
232
+ python-jsonrpc-server @ file:///tmp/build/80754af9/python-jsonrpc-server_1600278539111/work
233
+ python-language-server @ file:///tmp/build/80754af9/python-language-server_1600454544709/work
234
+ python-Levenshtein==0.12.2
235
+ pytorch-lightning==1.4.2
236
+ pytorch-revgrad==0.2.0
237
+ pytz==2020.1
238
+ PyWavelets @ file:///tmp/build/80754af9/pywavelets_1601658317819/work
239
+ pyxdg @ file:///tmp/build/80754af9/pyxdg_1603822279816/work
240
+ PyYAML==5.3.1
241
+ pyzmq==19.0.2
242
+ QDarkStyle==2.8.1
243
+ QtAwesome @ file:///tmp/build/80754af9/qtawesome_1602272867890/work
244
+ qtconsole @ file:///tmp/build/80754af9/qtconsole_1600870028330/work
245
+ QtPy==1.9.0
246
+ regex @ file:///tmp/build/80754af9/regex_1602786672676/work
247
+ requests @ file:///tmp/build/80754af9/requests_1592841827918/work
248
+ requests-oauthlib==1.3.1
249
+ resampy==0.2.2
250
+ rope @ file:///tmp/build/80754af9/rope_1602264064449/work
251
+ rsa==4.9
252
+ Rtree==0.9.4
253
+ ruamel-yaml==0.15.87
254
+ s3transfer==0.6.2
255
+ sacremoses==0.0.49
256
+ safetensors==0.3.3
257
+ scikit-image==0.17.2
258
+ scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1598376899566/work
259
+ scipy @ file:///tmp/build/80754af9/scipy_1597686649129/work
260
+ seaborn @ file:///tmp/build/80754af9/seaborn_1600553570093/work
261
+ SecretStorage==3.1.2
262
+ segtok==1.5.11
263
+ Send2Trash==1.5.0
264
+ sentencepiece==0.1.97
265
+ simplegeneric==0.8.1
266
+ singledispatch @ file:///tmp/build/80754af9/singledispatch_1602523705405/work
267
+ sip==4.19.13
268
+ six @ file:///tmp/build/80754af9/six_1605205327372/work
269
+ smart-open==5.2.1
270
+ snowballstemmer==2.0.0
271
+ sortedcollections==1.2.1
272
+ sortedcontainers==2.2.2
273
+ SoundFile==0.10.3.post1
274
+ soupsieve==2.0.1
275
+ sphfile==1.0.3
276
+ Sphinx @ file:///tmp/build/80754af9/sphinx_1597428793432/work
277
+ sphinxcontrib-applehelp==1.0.2
278
+ sphinxcontrib-devhelp==1.0.2
279
+ sphinxcontrib-htmlhelp==1.0.3
280
+ sphinxcontrib-jsmath==1.0.1
281
+ sphinxcontrib-qthelp==1.0.3
282
+ sphinxcontrib-serializinghtml==1.1.4
283
+ sphinxcontrib-websupport @ file:///tmp/build/80754af9/sphinxcontrib-websupport_1597081412696/work
284
+ spyder @ file:///tmp/build/80754af9/spyder_1599056981321/work
285
+ spyder-kernels @ file:///tmp/build/80754af9/spyder-kernels_1599056754858/work
286
+ SQLAlchemy @ file:///tmp/build/80754af9/sqlalchemy_1603397987316/work
287
+ sqlitedict==2.1.0
288
+ statsmodels @ file:///tmp/build/80754af9/statsmodels_1602280205159/work
289
+ sympy @ file:///tmp/build/80754af9/sympy_1605119542615/work
290
+ tables==3.6.1
291
+ tabulate==0.9.0
292
+ tblib @ file:///tmp/build/80754af9/tblib_1597928476713/work
293
+ tensorboard==2.10.1
294
+ tensorboard-data-server==0.6.1
295
+ tensorboard-plugin-wit==1.8.1
296
+ terminado==0.9.1
297
+ testpath==0.4.4
298
+ threadpoolctl @ file:///tmp/tmp9twdgx9k/threadpoolctl-2.1.0-py3-none-any.whl
299
+ tifffile==2020.10.1
300
+ tkseem==0.0.3
301
+ tokenizers==0.13.3
302
+ toml @ file:///tmp/build/80754af9/toml_1592853716807/work
303
+ tomli==2.0.1
304
+ toolz @ file:///tmp/build/80754af9/toolz_1601054250827/work
305
+ torch==1.11.0
306
+ torchaudio==0.11.0
307
+ torchmetrics==0.6.0
308
+ torchvision==0.8.2
309
+ tornado==6.0.4
310
+ tqdm==4.64.0
311
+ traitlets @ file:///tmp/build/80754af9/traitlets_1602787416690/work
312
+ transformer-smaller-training-vocab==0.3.1
313
+ transformers==4.33.1
314
+ typing-extensions==4.4.0
315
+ ujson @ file:///tmp/build/80754af9/ujson_1602523317881/work
316
+ unicodecsv==0.14.1
317
+ urllib3 @ file:///tmp/build/80754af9/urllib3_1603305693037/work
318
+ watchdog @ file:///tmp/build/80754af9/watchdog_1593447344699/work
319
+ wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
320
+ webencodings==0.5.1
321
+ Werkzeug==1.0.1
322
+ widgetsnbextension==3.5.1
323
+ Wikipedia-API==0.6.0
324
+ wrapt==1.11.2
325
+ wurlitzer @ file:///tmp/build/80754af9/wurlitzer_1594753850195/work
326
+ xlrd==1.2.0
327
+ XlsxWriter @ file:///tmp/build/80754af9/xlsxwriter_1602692860603/work
328
+ xlwt==1.3.0
329
+ xmltodict==0.12.0
330
+ xxhash==3.0.0
331
+ yapf @ file:///tmp/build/80754af9/yapf_1593528177422/work
332
+ yarl==1.7.2
333
+ zict==2.0.0
334
+ zipp @ file:///tmp/build/80754af9/zipp_1604001098328/work
335
+ zope.event==4.5.0
336
+ zope.interface @ file:///tmp/build/80754af9/zope.interface_1602002420968/work
337
+
338
+
339
+ 2023-09-20 16:23:39,866 - speechbrain.utils.superpowers - DEBUG - 8a51838
340
+
341
+
342
+ 2023-09-20 16:23:39,869 - speechbrain.pretrained.fetching - INFO - Fetch hyperparams.yaml: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/hyperparams.yaml.
343
+ 2023-09-20 16:23:39,871 - speechbrain.pretrained.fetching - INFO - Fetch custom.py: Linking to local file in /home/salah/Code_Switched_Tunisian_Speech_Recognition/asr-wav2vec2-commonvoice-fr/custom.py.
344
+ 2023-09-20 16:23:47,958 - speechbrain.lobes.models.huggingface_wav2vec - WARNING - speechbrain.lobes.models.huggingface_wav2vec - wav2vec 2.0 is frozen.
345
+ 2023-09-20 16:23:48,018 - speechbrain.utils.parameter_transfer - DEBUG - Collecting files (or symlinks) for pretraining in pretrained_models/asr-wav2vec2-commonvoice-fr.
346
+ 2023-09-20 16:23:48,023 - speechbrain.pretrained.fetching - INFO - Fetch wav2vec2.ckpt: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/wav2vec2.ckpt.
347
+ 2023-09-20 16:23:48,025 - speechbrain.pretrained.fetching - INFO - Fetch asr.ckpt: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/asr.ckpt.
348
+ 2023-09-20 16:23:48,028 - speechbrain.pretrained.fetching - INFO - Fetch tokenizer.ckpt: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/tokenizer.ckpt.
349
+ 2023-09-20 16:23:48,029 - speechbrain.utils.parameter_transfer - INFO - Loading pretrained files for: wav2vec2, asr, tokenizer
350
+ 2023-09-20 16:23:56,361 - speechbrain.lobes.models.huggingface_wav2vec - WARNING - speechbrain.lobes.models.huggingface_wav2vec - wav2vec 2.0 feature extractor is frozen.
351
+ 2023-09-20 16:23:56,366 - speechbrain.core - INFO - Info: auto_mix_prec arg from hparam file is used
352
+ 2023-09-20 16:23:56,366 - speechbrain.core - INFO - Info: ckpt_interval_minutes arg from hparam file is used
353
+ 2023-09-20 16:23:56,529 - speechbrain.core - INFO - 314.4M trainable parameters in ASRCV
354
+ 2023-09-20 16:23:57,316 - speechbrain.utils.checkpoints - INFO - Loading a checkpoint from EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00
355
+ 2023-09-20 16:23:59,928 - speechbrain.core - INFO - Info: auto_mix_prec arg from hparam file is used
356
+ 2023-09-20 16:23:59,940 - speechbrain.core - INFO - Info: ckpt_interval_minutes arg from hparam file is used
357
+ 2023-09-20 16:24:00,139 - speechbrain.core - INFO - 314.4M trainable parameters in ASR
358
+ 2023-09-20 16:24:00,967 - speechbrain.utils.checkpoints - INFO - Loading a checkpoint from TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00
359
+ 2023-09-20 16:24:49,007 - speechbrain.utils.distributed - INFO - distributed_launch flag is disabled, this experiment will be executed without DDP.
TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/CKPT.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # yamllint disable
2
+ WER: 26.88369650826989
3
+ end-of-epoch: true
4
+ unixtime: 1691019518.289327
TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/brain.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c991c52635ebf5f1d342ff11f149ab3000260e4b08bf1b4356e5134002a60feb
3
+ size 51
TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/counter.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8527a891e224136950ff32ca212b45bc93f69fbb801c3b1ebedac52775f99e61
3
+ size 2
TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/dataloader-TRAIN.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:646ecafa8b16fbb513bf9ddc56ba5e34c8818c0c8a7858871698ef9d15ddea68
3
+ size 5
TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7cb0cabac5780ffeb4b9850d30e8cd10f748896ddb04aa963d29512463f9b65c
3
+ size 12814446
TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/modelopt.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20f6aecbdbc179aeac4a305431e9f4d17a3436a4aa8426d20f066d6c99c7b449
3
+ size 25575599
TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/scheduler_model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a25325a19660b044c1edc3405e8a298702a93f2f569b0548f8066bb50e8e3c8
3
+ size 639
TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/scheduler_wav2vec.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b238edab7c400d8ad289eb44c8912bc0d1d2144f2ac59e48b9ab736dc4ef5f79
3
+ size 643
TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/wav2vec2.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9f5cd05dd7941f51ce7f19acd406b3eb562de4bc4c6ed818f709563a8308e8d
3
+ size 1262005979
TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00/wav2vec_opt.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d191b5b044b5dd0414bafb8c762083c4d344e3c81de807ef6a092dba4a383dd
3
+ size 2490361859
TunisianASR/results/14epoch_tunisian/1234/save/label_encoder.txt ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'ب' => 38
2
+ 'ا' => 1
3
+ 'ه' => 2
4
+ 'ي' => 3
5
+ 'و' => 4
6
+ 'ن' => 5
7
+ 'أ' => 6
8
+ ' ' => 7
9
+ 'م' => 8
10
+ 'ش' => 9
11
+ 'ل' => 10
12
+ 'س' => 11
13
+ 'ت' => 12
14
+ 'د' => 13
15
+ 'ر' => 14
16
+ 'ى' => 15
17
+ 'ح' => 16
18
+ 'ط' => 17
19
+ 'ع' => 18
20
+ 'ك' => 19
21
+ 'ف' => 20
22
+ 'ق' => 21
23
+ 'آ' => 22
24
+ 'ة' => 23
25
+ 'ج' => 24
26
+ 'ض' => 25
27
+ 'ز' => 26
28
+ 'ص' => 27
29
+ 'إ' => 28
30
+ 'ث' => 29
31
+ 'خ' => 30
32
+ 'ڨ' => 31
33
+ 'ذ' => 32
34
+ 'ظ' => 33
35
+ 'ء' => 34
36
+ 'غ' => 35
37
+ 'ئ' => 36
38
+ 'ؤ' => 37
39
+ '<blank>' => 0
40
+ 1 => 39
41
+ ================
42
+ 'starting_index' => 0
43
+ 'unk_label' => 1
44
+ 'blank_label' => '<blank>'
TunisianASR/results/14epoch_tunisian/1234/train_with_wav2vec.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import torch
4
+ import logging
5
+ import speechbrain as sb
6
+ from pathlib import Path
7
+ import os
8
+ import torchaudio
9
+ from hyperpyyaml import load_hyperpyyaml
10
+ from speechbrain.tokenizers.SentencePiece import SentencePiece
11
+ from speechbrain.utils.data_utils import undo_padding
12
+ from speechbrain.utils.distributed import run_on_main
13
+
14
+ """Recipe for training a sequence-to-sequence ASR system with CommonVoice.
15
+ The system employs a wav2vec2 encoder and a CTC decoder.
16
+ Decoding is performed with greedy decoding (will be extended to beam search).
17
+
18
+ To run this recipe, do the following:
19
+ > python train_with_wav2vec2.py hparams/train_with_wav2vec2.yaml
20
+
21
+ With the default hyperparameters, the system employs a pretrained wav2vec2 encoder.
22
+ The wav2vec2 model is pretrained following the model given in the hprams file.
23
+ It may be dependent on the language.
24
+
25
+ The neural network is trained with CTC on sub-word units estimated with
26
+ Byte Pairwise Encoding (BPE).
27
+
28
+ The experiment file is flexible enough to support a large variety of
29
+ different systems. By properly changing the parameter files, you can try
30
+ different encoders, decoders, tokens (e.g, characters instead of BPE),
31
+ training languages (all CommonVoice languages), and many
32
+ other possible variations.
33
+
34
+ Authors
35
+ * Titouan Parcollet 2021
36
+ """
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ # Define training procedure
42
+ class ASR(sb.core.Brain):
43
+ def compute_forward(self, batch, stage):
44
+ """Forward computations from the waveform batches to the output probabilities."""
45
+
46
+ batch = batch.to(self.device)
47
+ wavs, wav_lens = batch.sig
48
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
49
+ if stage == sb.Stage.TRAIN:
50
+ if hasattr(self.hparams, "augmentation"):
51
+ wavs = self.hparams.augmentation(wavs, wav_lens)
52
+
53
+ # Forward pass
54
+ feats = self.modules.wav2vec2(wavs, wav_lens)
55
+ x = self.modules.enc(feats)
56
+ logits = self.modules.ctc_lin(x)
57
+ p_ctc = self.hparams.log_softmax(logits)
58
+
59
+ return p_ctc, wav_lens
60
+
61
+ def compute_objectives(self, predictions, batch, stage):
62
+ """Computes the loss (CTC) given predictions and targets."""
63
+
64
+ p_ctc, wav_lens = predictions
65
+
66
+ ids = batch.id
67
+ tokens, tokens_lens = batch.tokens
68
+
69
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
70
+
71
+ if stage != sb.Stage.TRAIN:
72
+ predicted_tokens = sb.decoders.ctc_greedy_decode(
73
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
74
+ )
75
+ # Decode token terms to words
76
+ if self.hparams.use_language_modelling:
77
+ predicted_words = []
78
+ for logs in p_ctc:
79
+ text = decoder.decode(logs.detach().cpu().numpy())
80
+ predicted_words.append(text.split(" "))
81
+ else:
82
+ predicted_words = [
83
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
84
+ for utt_seq in predicted_tokens
85
+ ]
86
+ # Convert indices to words
87
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
88
+
89
+ self.wer_metric.append(ids, predicted_words, target_words)
90
+ self.cer_metric.append(ids, predicted_words, target_words)
91
+
92
+ return loss
93
+
94
+ def fit_batch(self, batch):
95
+ """Train the parameters given a single batch in input"""
96
+ should_step = self.step % self.grad_accumulation_factor == 0
97
+ # Managing automatic mixed precision
98
+ # TOFIX: CTC fine-tuning currently is unstable
99
+ # This is certainly due to CTC being done in fp16 instead of fp32
100
+ if self.auto_mix_prec:
101
+ with torch.cuda.amp.autocast():
102
+ with self.no_sync():
103
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
104
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
105
+ with self.no_sync(not should_step):
106
+ self.scaler.scale(
107
+ loss / self.grad_accumulation_factor
108
+ ).backward()
109
+ if should_step:
110
+
111
+ if not self.hparams.wav2vec2.freeze:
112
+ self.scaler.unscale_(self.wav2vec_optimizer)
113
+ self.scaler.unscale_(self.model_optimizer)
114
+ if self.check_gradients(loss):
115
+ if not self.hparams.wav2vec2.freeze:
116
+ if self.optimizer_step >= self.hparams.warmup_steps:
117
+ self.scaler.step(self.wav2vec_optimizer)
118
+ self.scaler.step(self.model_optimizer)
119
+ self.scaler.update()
120
+ self.zero_grad()
121
+ self.optimizer_step += 1
122
+ else:
123
+ # This is mandatory because HF models have a weird behavior with DDP
124
+ # on the forward pass
125
+ with self.no_sync():
126
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
127
+
128
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
129
+
130
+ with self.no_sync(not should_step):
131
+ (loss / self.grad_accumulation_factor).backward()
132
+ if should_step:
133
+ if self.check_gradients(loss):
134
+ if not self.hparams.wav2vec2.freeze:
135
+ if self.optimizer_step >= self.hparams.warmup_steps:
136
+ self.wav2vec_optimizer.step()
137
+ self.model_optimizer.step()
138
+ self.zero_grad()
139
+ self.optimizer_step += 1
140
+
141
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
142
+ return loss.detach().cpu()
143
+
144
+ def evaluate_batch(self, batch, stage):
145
+ """Computations needed for validation/test batches"""
146
+ predictions = self.compute_forward(batch, stage=stage)
147
+ with torch.no_grad():
148
+ loss = self.compute_objectives(predictions, batch, stage=stage)
149
+ return loss.detach()
150
+
151
+ def on_stage_start(self, stage, epoch):
152
+ """Gets called at the beginning of each epoch"""
153
+ if stage != sb.Stage.TRAIN:
154
+ self.cer_metric = self.hparams.cer_computer()
155
+ self.wer_metric = self.hparams.error_rate_computer()
156
+
157
+ def on_stage_end(self, stage, stage_loss, epoch):
158
+ """Gets called at the end of an epoch."""
159
+ # Compute/store important stats
160
+ stage_stats = {"loss": stage_loss}
161
+ if stage == sb.Stage.TRAIN:
162
+ self.train_stats = stage_stats
163
+ else:
164
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
165
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
166
+
167
+ # Perform end-of-iteration things, like annealing, logging, etc.
168
+ if stage == sb.Stage.VALID:
169
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
170
+ stage_stats["loss"]
171
+ )
172
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
173
+ stage_stats["loss"]
174
+ )
175
+ sb.nnet.schedulers.update_learning_rate(
176
+ self.model_optimizer, new_lr_model
177
+ )
178
+ if not self.hparams.wav2vec2.freeze:
179
+ sb.nnet.schedulers.update_learning_rate(
180
+ self.wav2vec_optimizer, new_lr_wav2vec
181
+ )
182
+ self.hparams.train_logger.log_stats(
183
+ stats_meta={
184
+ "epoch": epoch,
185
+ "lr_model": old_lr_model,
186
+ "lr_wav2vec": old_lr_wav2vec,
187
+ },
188
+ train_stats=self.train_stats,
189
+ valid_stats=stage_stats,
190
+ )
191
+ self.checkpointer.save_and_keep_only(
192
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
193
+ )
194
+ elif stage == sb.Stage.TEST:
195
+ self.hparams.train_logger.log_stats(
196
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
197
+ test_stats=stage_stats,
198
+ )
199
+ with open(self.hparams.wer_file, "w") as w:
200
+ self.wer_metric.write_stats(w)
201
+
202
+ def init_optimizers(self):
203
+ "Initializes the wav2vec2 optimizer and model optimizer"
204
+
205
+ # If the wav2vec encoder is unfrozen, we create the optimizer
206
+ if not self.hparams.wav2vec2.freeze:
207
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
208
+ self.modules.wav2vec2.parameters()
209
+ )
210
+ if self.checkpointer is not None:
211
+ self.checkpointer.add_recoverable(
212
+ "wav2vec_opt", self.wav2vec_optimizer
213
+ )
214
+
215
+ self.model_optimizer = self.hparams.model_opt_class(
216
+ self.hparams.model.parameters()
217
+ )
218
+
219
+ if self.checkpointer is not None:
220
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
221
+
222
+ def zero_grad(self, set_to_none=False):
223
+ if not self.hparams.wav2vec2.freeze:
224
+ self.wav2vec_optimizer.zero_grad(set_to_none)
225
+ self.model_optimizer.zero_grad(set_to_none)
226
+
227
+
228
+ # Define custom data procedure
229
+ def dataio_prepare(hparams):
230
+ """This function prepares the datasets to be used in the brain class.
231
+ It also defines the data processing pipeline through user-defined functions."""
232
+
233
+ # 1. Define datasets
234
+ data_folder = hparams["data_folder"]
235
+
236
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
237
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
238
+ )
239
+
240
+ if hparams["sorting"] == "ascending":
241
+ # we sort training data to speed up training and get better results.
242
+ train_data = train_data.filtered_sorted(
243
+ sort_key="duration",
244
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
245
+ )
246
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
247
+ hparams["dataloader_options"]["shuffle"] = False
248
+
249
+ elif hparams["sorting"] == "descending":
250
+ train_data = train_data.filtered_sorted(
251
+ sort_key="duration",
252
+ reverse=True,
253
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
254
+ )
255
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
256
+ hparams["dataloader_options"]["shuffle"] = False
257
+
258
+ elif hparams["sorting"] == "random":
259
+ pass
260
+
261
+ else:
262
+ raise NotImplementedError(
263
+ "sorting must be random, ascending or descending"
264
+ )
265
+
266
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
267
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
268
+ )
269
+ # We also sort the validation data so it is faster to validate
270
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
271
+ test_datasets = {}
272
+ for csv_file in hparams["test_csv"]:
273
+ name = Path(csv_file).stem
274
+ test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
275
+ csv_path=csv_file, replacements={"data_root": data_folder}
276
+ )
277
+ test_datasets[name] = test_datasets[name].filtered_sorted(
278
+ sort_key="duration"
279
+ )
280
+
281
+ datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
282
+
283
+
284
+ # 2. Define audio pipeline:
285
+ @sb.utils.data_pipeline.takes("wav")
286
+ @sb.utils.data_pipeline.provides("sig")
287
+ def audio_pipeline(wav):
288
+ info = torchaudio.info(wav)
289
+ sig = sb.dataio.dataio.read_audio(wav)
290
+ resampled = torchaudio.transforms.Resample(
291
+ info.sample_rate, hparams["sample_rate"],
292
+ )(sig)
293
+ return resampled
294
+
295
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
296
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
297
+
298
+ # 3. Define text pipeline:
299
+ @sb.utils.data_pipeline.takes("wrd")
300
+ @sb.utils.data_pipeline.provides(
301
+ "wrd", "char_list", "tokens_list", "tokens"
302
+ )
303
+ def text_pipeline(wrd):
304
+ yield wrd
305
+ char_list = list(wrd)
306
+ yield char_list
307
+ tokens_list = label_encoder.encode_sequence(char_list)
308
+ yield tokens_list
309
+ tokens = torch.LongTensor(tokens_list)
310
+ yield tokens
311
+
312
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
313
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
314
+ special_labels = {
315
+ "blank_label": hparams["blank_index"],
316
+ "unk_label": hparams["unk_index"]
317
+ }
318
+ label_encoder.load_or_create(
319
+ path=lab_enc_file,
320
+ from_didatasets=[train_data],
321
+ output_key="char_list",
322
+ special_labels=special_labels,
323
+ sequence_input=True,
324
+ )
325
+
326
+ # 4. Set output:
327
+ sb.dataio.dataset.set_output_keys(
328
+ datasets, ["id", "sig", "wrd", "char_list", "tokens"],
329
+ )
330
+ return train_data, valid_data,test_datasets, label_encoder
331
+
332
+
333
+ if __name__ == "__main__":
334
+
335
+ # Load hyperparameters file with command-line overrides
336
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
337
+ with open(hparams_file) as fin:
338
+ hparams = load_hyperpyyaml(fin, overrides)
339
+
340
+ # If --distributed_launch then
341
+ # create ddp_group with the right communication protocol
342
+ sb.utils.distributed.ddp_init_group(run_opts)
343
+
344
+
345
+ # Create experiment directory
346
+ sb.create_experiment_directory(
347
+ experiment_directory=hparams["output_folder"],
348
+ hyperparams_to_save=hparams_file,
349
+ overrides=overrides,
350
+ )
351
+
352
+ # Due to DDP, we do the preparation ONLY on the main python process
353
+ # Defining tokenizer and loading it
354
+ # Create the datasets objects as well as tokenization and encoding :-D
355
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(hparams)
356
+ if hparams["use_language_modelling"]:
357
+ print("using langauge_modeeling")
358
+ from pyctcdecode import build_ctcdecoder
359
+ ind2lab = label_encoder.ind2lab
360
+ print(ind2lab)
361
+ labels = [ind2lab[x] for x in range(len(ind2lab))]
362
+ labels = [""] + labels[1:-1] + ["1"]
363
+ # Replace the <blank> token with a blank character, needed for PyCTCdecode
364
+ print(labels)
365
+ decoder = build_ctcdecoder(
366
+ labels,
367
+ kenlm_model_path=hparams["ngram_lm_path"], # .arpa or .bin
368
+ alpha=0.5, # Default by KenLM
369
+ beta=1.0, # Default by KenLM
370
+ )
371
+ # Trainer initialization
372
+ asr_brain = ASR(
373
+ modules=hparams["modules"],
374
+ hparams=hparams,
375
+ run_opts=run_opts,
376
+ checkpointer=hparams["checkpointer"],
377
+ )
378
+
379
+ # Adding objects to trainer.
380
+ asr_brain.tokenizer = label_encoder
381
+
382
+ # Training
383
+ asr_brain.fit(
384
+ asr_brain.hparams.epoch_counter,
385
+ train_data,
386
+ valid_data,
387
+ train_loader_kwargs=hparams["dataloader_options"],
388
+ valid_loader_kwargs=hparams["test_dataloader_options"],
389
+ )
390
+
391
+ # Test
392
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
393
+ asr_brain.hparams.wer_file = os.path.join(
394
+ hparams["output_folder"], "wer_{}.txt".format(k)
395
+ )
396
+ asr_brain.evaluate(
397
+ test_datasets[k], test_loader_kwargs=hparams["test_dataloader_options"]
398
+ )
399
+
TunisianASR/results/14epoch_tunisian/<seed>/copy_of_wavlm_tun.py ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import sys
5
+ import torch
6
+ import logging
7
+ import speechbrain as sb
8
+ from speechbrain.utils.distributed import run_on_main
9
+ from hyperpyyaml import load_hyperpyyaml
10
+ from pathlib import Path
11
+ import torchaudio.transforms as T
12
+ import torchaudio
13
+ import numpy as np
14
+ import kenlm
15
+ from pyctcdecode import build_ctcdecoder
16
+ import re
17
+
18
+ # Commented out IPython magic to ensure Python compatibility.
19
+ # %cd /content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm
20
+
21
+ hparams_file, run_opts, overrides = sb.parse_arguments(["semi_supervised_test_tunisian.yaml"])
22
+
23
+ # If distributed_launch=True then
24
+ # create ddp_group with the right communication protocol
25
+ sb.utils.distributed.ddp_init_group(run_opts)
26
+
27
+ with open(hparams_file) as fin:
28
+ hparams = load_hyperpyyaml(fin, overrides)
29
+
30
+ # Create experiment directory
31
+ sb.create_experiment_directory(
32
+ experiment_directory=hparams["output_folder"],
33
+ hyperparams_to_save=hparams_file,
34
+ overrides=overrides,
35
+ )
36
+ """
37
+ def read_labels_file(labels_file):
38
+ with open(labels_file, "r",encoding="utf-8") as lf:
39
+ lines = lf.read().splitlines()
40
+ division = "==="
41
+ numbers = {}
42
+ for line in lines :
43
+ if division in line :
44
+ break
45
+ string, number = line.split("=>")
46
+ number = int(number)
47
+ string = string[1:-2]
48
+ numbers[number] = string
49
+ return [numbers[x] for x in range(len(numbers))]
50
+
51
+ labels = read_labels_file(os.path.join(hparams["save_folder"], "label_encoder.txt"))
52
+ labels = [""] + labels[1:-1] + ["1"]
53
+
54
+ # Dataset prep (parsing Librispeech)
55
+ """
56
+
57
+ def dataio_prepare(hparams):
58
+ """This function prepares the datasets to be used in the brain class.
59
+ It also defines the data processing pipeline through user-defined functions."""
60
+
61
+ # 1. Define datasets
62
+ data_folder = hparams["data_folder"]
63
+
64
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
65
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
66
+ )
67
+
68
+ if hparams["sorting"] == "ascending":
69
+ # we sort training data to speed up training and get better results.
70
+ train_data = train_data.filtered_sorted(
71
+ sort_key="duration",
72
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
73
+ )
74
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
75
+ hparams["dataloader_options"]["shuffle"] = False
76
+
77
+ elif hparams["sorting"] == "descending":
78
+ train_data = train_data.filtered_sorted(
79
+ sort_key="duration",
80
+ reverse=True,
81
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
82
+ )
83
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
84
+ hparams["dataloader_options"]["shuffle"] = False
85
+
86
+ elif hparams["sorting"] == "random":
87
+ pass
88
+
89
+ else:
90
+ raise NotImplementedError(
91
+ "sorting must be random, ascending or descending"
92
+ )
93
+
94
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
95
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
96
+ )
97
+ # We also sort the validation data so it is faster to validate
98
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
99
+ test_datasets = {}
100
+ for csv_file in hparams["test_csv"]:
101
+ name = Path(csv_file).stem
102
+ test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
103
+ csv_path=csv_file, replacements={"data_root": data_folder}
104
+ )
105
+ test_datasets[name] = test_datasets[name].filtered_sorted(
106
+ sort_key="duration"
107
+ )
108
+
109
+ datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
110
+
111
+
112
+ # 2. Define audio pipeline:
113
+ @sb.utils.data_pipeline.takes("wav")
114
+ @sb.utils.data_pipeline.provides("sig")
115
+ def audio_pipeline(wav):
116
+ info = torchaudio.info(wav)
117
+ sig = sb.dataio.dataio.read_audio(wav)
118
+ if len(sig.shape)>1 :
119
+ sig = torch.mean(sig, dim=1)
120
+ resampled = torchaudio.transforms.Resample(
121
+ info.sample_rate, hparams["sample_rate"],
122
+ )(sig)
123
+ return resampled
124
+
125
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
126
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
127
+
128
+ # 3. Define text pipeline:
129
+ @sb.utils.data_pipeline.takes("wrd")
130
+ @sb.utils.data_pipeline.provides(
131
+ "wrd", "char_list", "tokens_list", "tokens"
132
+ )
133
+ def text_pipeline(wrd):
134
+ yield wrd
135
+ char_list = list(wrd)
136
+ yield char_list
137
+ tokens_list = label_encoder.encode_sequence(char_list)
138
+ yield tokens_list
139
+ tokens = torch.LongTensor(tokens_list)
140
+ yield tokens
141
+
142
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
143
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
144
+ special_labels = {
145
+ "blank_label": hparams["blank_index"],
146
+ "unk_label": hparams["unk_index"]
147
+ }
148
+ label_encoder.load_or_create(
149
+ path=lab_enc_file,
150
+ from_didatasets=[train_data],
151
+ output_key="char_list",
152
+ special_labels=special_labels,
153
+ sequence_input=True,
154
+ )
155
+
156
+ # 4. Set output:
157
+ sb.dataio.dataset.set_output_keys(
158
+ datasets, ["id", "sig", "wrd", "char_list", "tokens"],
159
+ )
160
+ return train_data, valid_data,test_datasets, label_encoder
161
+
162
+ class ASR(sb.core.Brain):
163
+ def compute_forward(self, batch, stage):
164
+ """Forward computations from the waveform batches to the output probabilities."""
165
+
166
+ batch = batch.to(self.device)
167
+ wavs, wav_lens = batch.sig
168
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
169
+
170
+ if stage == sb.Stage.TRAIN:
171
+ if hasattr(self.hparams, "augmentation"):
172
+ wavs = self.hparams.augmentation(wavs, wav_lens)
173
+
174
+ # Forward pass
175
+ feats = self.modules.wav2vec2(wavs, wav_lens)
176
+ x = self.modules.enc(feats)
177
+ logits = self.modules.ctc_lin(x)
178
+ p_ctc = self.hparams.log_softmax(logits)
179
+
180
+ return p_ctc, wav_lens
181
+
182
+ def custom_encode(self,wavs,wav_lens) :
183
+ wavs = wavs.to(self.device)
184
+ if(wav_lens is not None): wav_lens.to(self.device)
185
+
186
+ feats = self.modules.wav2vec2(wavs, wav_lens)
187
+ x = self.modules.enc(feats)
188
+ logits = self.modules.ctc_lin(x)
189
+ p_ctc = self.hparams.log_softmax(logits)
190
+
191
+ return feats,p_ctc
192
+
193
+
194
+
195
+ def compute_objectives(self, predictions, batch, stage):
196
+ """Computes the loss (CTC) given predictions and targets."""
197
+
198
+ p_ctc, wav_lens = predictions
199
+
200
+ ids = batch.id
201
+ tokens, tokens_lens = batch.tokens
202
+
203
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
204
+
205
+ if stage != sb.Stage.TRAIN:
206
+ predicted_tokens = sb.decoders.ctc_greedy_decode(
207
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
208
+ )
209
+ # Decode token terms to words
210
+ if self.hparams.use_language_modelling:
211
+ predicted_words = []
212
+ for logs in p_ctc:
213
+ text = decoder.decode(logs.detach().cpu().numpy())
214
+ predicted_words.append(text.split(" "))
215
+ else:
216
+ predicted_words = [
217
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
218
+ for utt_seq in predicted_tokens
219
+ ]
220
+ # Convert indices to words
221
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
222
+
223
+ self.wer_metric.append(ids, predicted_words, target_words)
224
+ self.cer_metric.append(ids, predicted_words, target_words)
225
+
226
+ return loss
227
+
228
+ def fit_batch(self, batch):
229
+ """Train the parameters given a single batch in input"""
230
+ should_step = self.step % self.grad_accumulation_factor == 0
231
+ # Managing automatic mixed precision
232
+ # TOFIX: CTC fine-tuning currently is unstable
233
+ # This is certainly due to CTC being done in fp16 instead of fp32
234
+ if self.auto_mix_prec:
235
+ with torch.cuda.amp.autocast():
236
+ with self.no_sync():
237
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
238
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
239
+ with self.no_sync(not should_step):
240
+ self.scaler.scale(
241
+ loss / self.grad_accumulation_factor
242
+ ).backward()
243
+ if should_step:
244
+
245
+ if not self.hparams.wav2vec2.freeze:
246
+ self.scaler.unscale_(self.wav2vec_optimizer)
247
+ self.scaler.unscale_(self.model_optimizer)
248
+ if self.check_gradients(loss):
249
+ if not self.hparams.wav2vec2.freeze:
250
+ if self.optimizer_step >= self.hparams.warmup_steps:
251
+ self.scaler.step(self.wav2vec_optimizer)
252
+ self.scaler.step(self.model_optimizer)
253
+ self.scaler.update()
254
+ self.zero_grad()
255
+ self.optimizer_step += 1
256
+ else:
257
+ # This is mandatory because HF models have a weird behavior with DDP
258
+ # on the forward pass
259
+ with self.no_sync():
260
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
261
+
262
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
263
+
264
+ with self.no_sync(not should_step):
265
+ (loss / self.grad_accumulation_factor).backward()
266
+ if should_step:
267
+ if self.check_gradients(loss):
268
+ if not self.hparams.wav2vec2.freeze:
269
+ if self.optimizer_step >= self.hparams.warmup_steps:
270
+ self.wav2vec_optimizer.step()
271
+ self.model_optimizer.step()
272
+ self.zero_grad()
273
+ self.optimizer_step += 1
274
+
275
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
276
+ return loss.detach().cpu()
277
+
278
+ def evaluate_batch(self, batch, stage):
279
+ """Computations needed for validation/test batches"""
280
+ predictions = self.compute_forward(batch, stage=stage)
281
+ with torch.no_grad():
282
+ loss = self.compute_objectives(predictions, batch, stage=stage)
283
+ return loss.detach()
284
+
285
+ def on_stage_start(self, stage, epoch):
286
+ """Gets called at the beginning of each epoch"""
287
+ if stage != sb.Stage.TRAIN:
288
+ self.cer_metric = self.hparams.cer_computer()
289
+ self.wer_metric = self.hparams.error_rate_computer()
290
+
291
+ def on_stage_end(self, stage, stage_loss, epoch):
292
+ """Gets called at the end of an epoch."""
293
+ # Compute/store important stats
294
+ stage_stats = {"loss": stage_loss}
295
+ if stage == sb.Stage.TRAIN:
296
+ self.train_stats = stage_stats
297
+ else:
298
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
299
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
300
+
301
+ # Perform end-of-iteration things, like annealing, logging, etc.
302
+ if stage == sb.Stage.VALID:
303
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
304
+ stage_stats["loss"]
305
+ )
306
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
307
+ stage_stats["loss"]
308
+ )
309
+ sb.nnet.schedulers.update_learning_rate(
310
+ self.model_optimizer, new_lr_model
311
+ )
312
+ if not self.hparams.wav2vec2.freeze:
313
+ sb.nnet.schedulers.update_learning_rate(
314
+ self.wav2vec_optimizer, new_lr_wav2vec
315
+ )
316
+ self.hparams.train_logger.log_stats(
317
+ stats_meta={
318
+ "epoch": epoch,
319
+ "lr_model": old_lr_model,
320
+ "lr_wav2vec": old_lr_wav2vec,
321
+ },
322
+ train_stats=self.train_stats,
323
+ valid_stats=stage_stats,
324
+ )
325
+ self.checkpointer.save_and_keep_only(
326
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
327
+ )
328
+ elif stage == sb.Stage.TEST:
329
+ self.hparams.train_logger.log_stats(
330
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
331
+ test_stats=stage_stats,
332
+ )
333
+ with open(self.hparams.wer_file, "w") as w:
334
+ self.wer_metric.write_stats(w)
335
+
336
+ def init_optimizers(self):
337
+ "Initializes the wav2vec2 optimizer and model optimizer"
338
+
339
+ # If the wav2vec encoder is unfrozen, we create the optimizer
340
+ if not self.hparams.wav2vec2.freeze:
341
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
342
+ self.modules.wav2vec2.parameters()
343
+ )
344
+ if self.checkpointer is not None:
345
+ self.checkpointer.add_recoverable(
346
+ "wav2vec_opt", self.wav2vec_optimizer
347
+ )
348
+
349
+ self.model_optimizer = self.hparams.model_opt_class(
350
+ self.hparams.model.parameters()
351
+ )
352
+
353
+ if self.checkpointer is not None:
354
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
355
+
356
+ def zero_grad(self, set_to_none=False):
357
+ if not self.hparams.wav2vec2.freeze:
358
+ self.wav2vec_optimizer.zero_grad(set_to_none)
359
+ self.model_optimizer.zero_grad(set_to_none)
360
+
361
+
362
+ """
363
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
364
+
365
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
366
+ hparams
367
+ )
368
+
369
+
370
+ # We dynamicaly add the tokenizer to our brain class.
371
+ # NB: This tokenizer corresponds to the one used for the LM!!
372
+ decoder = build_ctcdecoder(
373
+ labels,
374
+ kenlm_model_path="/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/lm_data/arpas/indomain.arpa", # either .arpa or .bin file
375
+ alpha=0.5, # tuned on a val set
376
+ beta=1, # tuned on a val set
377
+ )
378
+ """
379
+ from speechbrain.pretrained import EncoderASR,EncoderDecoderASR
380
+ french_asr_model = EncoderASR.from_hparams(source="speechbrain/asr-wav2vec2-commonvoice-fr", savedir="pretrained_models/asr-wav2vec2-commonvoice-fr").cuda()
381
+ french_asr_model.mods.eval()
382
+ #french_asr_model = "r"
383
+
384
+ english_asr_model = EncoderDecoderASR.from_hparams(source="speechbrain/asr-wav2vec2-commonvoice-en", savedir="pretrained_models/asr-wav2vec2-commonvoice-en/").cuda()
385
+ english_asr_model.mods.eval()
386
+
387
+ asr_brain = ASR(
388
+ modules=hparams["modules"],
389
+ hparams=hparams,
390
+ run_opts=run_opts,
391
+ checkpointer=hparams["checkpointer"],
392
+ )
393
+ asr_brain.checkpointer.recover_if_possible()
394
+ asr_brain.modules.eval()
395
+ """
396
+ asr_brain.tokenizer = label_encoder
397
+
398
+ # Testing
399
+ real = True
400
+ if real :
401
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
402
+ asr_brain.hparams.wer_file = os.path.join(
403
+ hparams["output_folder"], "wer_{}.txt".format(k)
404
+ )
405
+ asr_brain.evaluate(
406
+ test_datasets[k], test_loader_kwargs=hparams["dataloader_options"]
407
+ )
408
+ """
409
+
410
+ """
411
+ from torch.nn.utils.rnn import pad_sequence
412
+ def load_paths(wavs_path):
413
+ waveforms = []
414
+ for path in wavs_path :
415
+ waveform, _ = torchaudio.load(path)
416
+ waveforms.append(waveform.squeeze(0))
417
+ # normalize array length to the bigger arrays by pading with 0's
418
+ padded_arrays = pad_sequence(waveforms, batch_first=True)
419
+ return torch.tensor(padded_arrays)
420
+
421
+ waveform = load_paths(["/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/samples/Salah10.wav","/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/samples/Salah10.wav"])
422
+ embeddings, posteriogram = asr_brain.custom_encode(waveform,None)
423
+ print(embeddings.shape)
424
+ print(posteriogram.shape)
425
+ """
426
+
427
+ from speechbrain.pretrained import EncoderASR,EncoderDecoderASR
428
+ import torchaudio
429
+ import speechbrain as sb
430
+ import torch
431
+ from torch.nn.utils.rnn import pad_sequence
432
+ import torch
433
+ import speechbrain as sb
434
+ import numpy as np
435
+ import torch.optim as optim
436
+ import torch.nn as nn
437
+
438
+ # Commented out IPython magic to ensure Python compatibility.
439
+ # %ls
440
+
441
+ #UTILS FUNCTIOJNS
442
+ def get_size_dimensions(arr):
443
+ size_dimensions = []
444
+ while isinstance(arr, list):
445
+ size_dimensions.append(len(arr))
446
+ arr = arr[0]
447
+ return size_dimensions
448
+
449
+ def scale_array(batch,n):
450
+ scaled_batch = []
451
+
452
+ for array in batch:
453
+ if(n < len(array)): raise ValueError("Cannot scale Array down")
454
+
455
+ repeat = round(n/len(array))+1
456
+ scaled_length_array= []
457
+
458
+ for i in array:
459
+ for j in range(repeat) :
460
+ if(len(scaled_length_array) == n): break
461
+ scaled_length_array.append(i)
462
+
463
+ scaled_batch.append(scaled_length_array)
464
+
465
+ return torch.tensor(scaled_batch)
466
+
467
+
468
+ def load_paths(wavs_path):
469
+ waveforms = []
470
+ for path in wavs_path :
471
+ waveform, _ = torchaudio.load(path)
472
+ waveforms.append(waveform.squeeze(0))
473
+ # normalize array length to the bigger arrays by pading with 0's
474
+ padded_arrays = pad_sequence(waveforms, batch_first=True)
475
+ return torch.tensor(padded_arrays)
476
+
477
+
478
+
479
+ def word_to_vec(input_string):
480
+ mapping= {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, 'ا': 27, 'ب': 28, 'ت': 29, 'ث': 30, 'ج': 31, 'ح': 32, 'خ': 33, 'د': 34, 'ذ': 35, 'ر': 36, 'ز': 37, 'س': 38, 'ش': 39, 'ص': 40, 'ض': 41, 'ط': 42, 'ظ': 43, 'ع': 44, 'غ': 45, 'ف': 46, 'ق': 47, 'ك': 48, 'ل': 49, 'م': 50, 'ن': 51, 'ه': 52, 'و': 53, 'ي': 54,' ':55}
481
+
482
+ numbers = [mapping[word] for word in input_string if word in mapping]
483
+ return numbers
484
+
485
+ device = 'cuda'
486
+ verbose = 0
487
+ #FLOW LEVEL FUNCTIONS
488
+ def merge_strategy(embeddings1, embeddings2, embeddings3,post1, post2,post3):
489
+
490
+
491
+ post1 = post1.to(device)
492
+ post2 = post2.to(device)
493
+ post3 = post3.to(device)
494
+ embeddings1 = embeddings1.to(device)
495
+ embeddings2 = embeddings2.to(device)
496
+ embeddings3 = embeddings3.to(device)
497
+
498
+ posteriograms_merged = torch.cat((post1,post2,post3),dim=2)
499
+ embeddings_merged = torch.cat((embeddings1,embeddings2,embeddings3),dim=2)
500
+
501
+ if(verbose !=0):
502
+ print('MERGED POST ',posteriograms_merged.shape)
503
+ print('MERGED emb ',embeddings_merged.shape)
504
+
505
+ return torch.cat((posteriograms_merged,embeddings_merged),dim=2).to(device)
506
+
507
+ def decode(model,wavs,wav_lens):
508
+
509
+ with torch.no_grad():
510
+ wav_lens = wav_lens.to(model.device)
511
+ encoder_out = model.encode_batch(wavs, wav_lens)
512
+ predictions = model.decoding_function(encoder_out, wav_lens)
513
+ return predictions
514
+
515
+ def middle_layer(batch, lens):
516
+
517
+ tn_embeddings, tn_posteriogram = asr_brain.custom_encode(batch,None)
518
+
519
+ fr_embeddings = french_asr_model.mods.encoder.wav2vec2(batch)
520
+ fr_posteriogram =french_asr_model.encode_batch(batch,lens)
521
+
522
+ en_embeddings = english_asr_model.encode_batch(batch, lens)
523
+ #scores, en_posteriogram = english_asr_model.mods.decoder(en_embeddings ,lens)
524
+ en_posteriogram = en_embeddings
525
+
526
+ if(verbose !=0):
527
+ print('[EMBEDDINGS] FR:',fr_embeddings.shape, "EN:",en_embeddings.shape, "TN:", tn_embeddings.shape)
528
+ print('[POSTERIOGRAM] FR:',fr_posteriogram.shape, "EN:",en_posteriogram.shape,"TN:",tn_posteriogram.shape)
529
+
530
+
531
+ bilangual_sample = merge_strategy(fr_embeddings,en_embeddings,tn_embeddings,fr_posteriogram,en_posteriogram,tn_posteriogram)
532
+ return bilangual_sample
533
+
534
+ class Mixer(sb.core.Brain):
535
+
536
+ def compute_forward(self, batch, stage):
537
+ """Forward computations from the waveform batches to the output probabilities."""
538
+ wavs, wav_lens = batch.sig
539
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
540
+
541
+ if stage == sb.Stage.TRAIN:
542
+ if hasattr(self.hparams, "augmentation"):
543
+ wavs = self.hparams.augmentation(wavs, wav_lens)
544
+
545
+ multi_langual_feats = middle_layer(wavs, wav_lens)
546
+ multi_langual_feats= multi_langual_feats.to(device)
547
+ feats, _ = self.modules.enc(multi_langual_feats)
548
+ logits = self.modules.ctc_lin(feats)
549
+ p_ctc = self.hparams.log_softmax(logits)
550
+
551
+ if stage!= sb.Stage.TRAIN:
552
+ p_tokens = sb.decoders.ctc_greedy_decode(
553
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
554
+ )
555
+ else :
556
+ p_tokens = None
557
+ return p_ctc, wav_lens, p_tokens
558
+
559
+ def compute_objectives(self, predictions, batch, stage):
560
+ """Computes the loss (CTC) given predictions and targets."""
561
+
562
+ p_ctc, wav_lens , predicted_tokens= predictions
563
+
564
+ ids = batch.id
565
+ tokens, tokens_lens = batch.tokens
566
+
567
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
568
+
569
+
570
+ if stage != sb.Stage.TRAIN:
571
+ predicted_words = [
572
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
573
+ for utt_seq in predicted_tokens
574
+ ]
575
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
576
+ self.wer_metric.append(ids, predicted_words, target_words)
577
+ self.cer_metric.append(ids, predicted_words, target_words)
578
+
579
+ return loss
580
+
581
+ def fit_batch(self, batch):
582
+ """Train the parameters given a single batch in input"""
583
+ should_step = self.step % self.grad_accumulation_factor == 0
584
+ # Managing automatic mixed precision
585
+ # TOFIX: CTC fine-tuning currently is unstable
586
+ # This is certainly due to CTC being done in fp16 instead of fp32
587
+ if self.auto_mix_prec:
588
+ with torch.cuda.amp.autocast():
589
+ with self.no_sync():
590
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
591
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
592
+ with self.no_sync(not should_step):
593
+ self.scaler.scale(
594
+ loss / self.grad_accumulation_factor
595
+ ).backward()
596
+ if should_step:
597
+
598
+
599
+ self.scaler.unscale_(self.model_optimizer)
600
+ if self.check_gradients(loss):
601
+ self.scaler.step(self.model_optimizer)
602
+ self.scaler.update()
603
+ self.zero_grad()
604
+ self.optimizer_step += 1
605
+ else:
606
+ # This is mandatory because HF models have a weird behavior with DDP
607
+ # on the forward pass
608
+ with self.no_sync():
609
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
610
+
611
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
612
+
613
+ with self.no_sync(not should_step):
614
+ (loss / self.grad_accumulation_factor).backward()
615
+ if should_step:
616
+ if self.check_gradients(loss):
617
+ self.model_optimizer.step()
618
+ self.zero_grad()
619
+ self.optimizer_step += 1
620
+
621
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
622
+ return loss.detach().cpu()
623
+
624
+ def evaluate_batch(self, batch, stage):
625
+ """Computations needed for validation/test batches"""
626
+ predictions = self.compute_forward(batch, stage=stage)
627
+ with torch.no_grad():
628
+ loss = self.compute_objectives(predictions, batch, stage=stage)
629
+ return loss.detach()
630
+
631
+ def on_stage_start(self, stage, epoch):
632
+ """Gets called at the beginning of each epoch"""
633
+ if stage != sb.Stage.TRAIN:
634
+ self.cer_metric = self.hparams.cer_computer()
635
+ self.wer_metric = self.hparams.error_rate_computer()
636
+
637
+ def on_stage_end(self, stage, stage_loss, epoch):
638
+ """Gets called at the end of an epoch."""
639
+ # Compute/store important stats
640
+ stage_stats = {"loss": stage_loss}
641
+ if stage == sb.Stage.TRAIN:
642
+ self.train_stats = stage_stats
643
+ else:
644
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
645
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
646
+
647
+ # Perform end-of-iteration things, like annealing, logging, etc.
648
+ if stage == sb.Stage.VALID:
649
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
650
+ stage_stats["loss"]
651
+ )
652
+ sb.nnet.schedulers.update_learning_rate(
653
+ self.model_optimizer, new_lr_model
654
+ )
655
+ self.hparams.train_logger.log_stats(
656
+ stats_meta={
657
+ "epoch": epoch,
658
+ "lr_model": old_lr_model,
659
+ },
660
+ train_stats=self.train_stats,
661
+ valid_stats=stage_stats,
662
+ )
663
+ self.checkpointer.save_and_keep_only(
664
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
665
+ )
666
+ elif stage == sb.Stage.TEST:
667
+ self.hparams.train_logger.log_stats(
668
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
669
+ test_stats=stage_stats,
670
+ )
671
+ with open(self.hparams.wer_file, "w") as w:
672
+ self.wer_metric.write_stats(w)
673
+
674
+ def init_optimizers(self):
675
+
676
+ self.model_optimizer = self.hparams.model_opt_class(
677
+ self.hparams.model.parameters()
678
+ )
679
+
680
+ if self.checkpointer is not None:
681
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
682
+
683
+ def zero_grad(self, set_to_none=False):
684
+
685
+ self.model_optimizer.zero_grad(set_to_none)
686
+
687
+
688
+ hparams_file, run_opts, overrides = sb.parse_arguments([sys.argv[1]])
689
+
690
+ # If distributed_launch=True then
691
+ # create ddp_group with the right communication protocol
692
+ sb.utils.distributed.ddp_init_group(run_opts)
693
+
694
+ with open(hparams_file) as fin:
695
+ hparams = load_hyperpyyaml(fin, overrides)
696
+
697
+ # Create experiment directory
698
+ sb.create_experiment_directory(
699
+ experiment_directory=hparams["output_folder"],
700
+ hyperparams_to_save=hparams_file,
701
+ overrides=overrides,
702
+ )
703
+ """
704
+ def read_labels_file(labels_file):
705
+ with open(labels_file, "r",encoding="utf-8") as lf:
706
+ lines = lf.read().splitlines()
707
+ division = "==="
708
+ numbers = {}
709
+ for line in lines :
710
+ if division in line :
711
+ break
712
+ string, number = line.split("=>")
713
+ number = int(number)
714
+ string = string[1:-2]
715
+ numbers[number] = string
716
+ return [numbers[x] for x in range(len(numbers))]
717
+ labels = read_labels_file(os.path.join(hparams["save_folder"], "label_encoder.txt"))
718
+ labels = [""] + labels[1:-1] + ["1"]
719
+
720
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
721
+ """
722
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
723
+ hparams
724
+ )
725
+
726
+
727
+
728
+
729
+ """
730
+ decoder = build_ctcdecoder(
731
+ labels,
732
+ kenlm_model_path="/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/lm_data/arpas/indomain.arpa", # either .arpa or .bin file
733
+ alpha=0.5, # tuned on a val set
734
+ beta=1, # tuned on a val set
735
+ )
736
+ """
737
+ mixer = Mixer(
738
+ modules=hparams["modules"],
739
+ hparams=hparams,
740
+ run_opts=run_opts,
741
+ checkpointer=hparams["checkpointer"],
742
+ )
743
+ mixer.tokenizer = label_encoder
744
+
745
+
746
+ mixer.fit(
747
+ mixer.hparams.epoch_counter,
748
+ train_data,
749
+ valid_data,
750
+ train_loader_kwargs=hparams["dataloader_options"],
751
+ valid_loader_kwargs=hparams["test_dataloader_options"],
752
+ )
753
+
754
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
755
+ mixer.hparams.wer_file = os.path.join(
756
+ hparams["output_folder"], "wer_{}.txt".format(k)
757
+ )
758
+ mixer.evaluate(
759
+ test_datasets[k], test_loader_kwargs=hparams["test_dataloader_options"]
760
+ )
761
+
TunisianASR/results/14epoch_tunisian/<seed>/ctc_lin.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import sys
5
+ import torch
6
+ import logging
7
+ import speechbrain as sb
8
+ from speechbrain.utils.distributed import run_on_main
9
+ from hyperpyyaml import load_hyperpyyaml
10
+ from pathlib import Path
11
+ import torchaudio.transforms as T
12
+ from cv_train import ASRCV
13
+ import torchaudio
14
+ import numpy as np
15
+ import kenlm
16
+ from pyctcdecode import build_ctcdecoder
17
+ import re
18
+
19
+ # Commented out IPython magic to ensure Python compatibility.
20
+ # %cd /content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm
21
+ #hparams_file, run_opts, overrides = sb.parse_arguments(["/gpfsstore/rech/nou/uzn19yk/switched_code_tunisian/train/tunisian_asr/hparams/train_semi.yaml"])
22
+ hparams_file, run_opts, overrides = sb.parse_arguments(["semi_supervised_test_tunisian.yaml"])
23
+
24
+ # If distributed_launch=True then
25
+ # create ddp_group with the right communication protocol
26
+ sb.utils.distributed.ddp_init_group(run_opts)
27
+
28
+ with open(hparams_file) as fin:
29
+ hparams = load_hyperpyyaml(fin, overrides)
30
+
31
+ # Create experiment directory
32
+ sb.create_experiment_directory(
33
+ experiment_directory=hparams["output_folder"],
34
+ hyperparams_to_save=hparams_file,
35
+ overrides=overrides,
36
+ )
37
+ # Dataset prep (parsing Librispeech)
38
+
39
+ def dataio_prepare(hparams):
40
+ """This function prepares the datasets to be used in the brain class.
41
+ It also defines the data processing pipeline through user-defined functions."""
42
+
43
+ # 1. Define datasets
44
+ data_folder = hparams["data_folder"]
45
+
46
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
47
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
48
+ )
49
+
50
+ if hparams["sorting"] == "ascending":
51
+ # we sort training data to speed up training and get better results.
52
+ train_data = train_data.filtered_sorted(
53
+ sort_key="duration",
54
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
55
+ )
56
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
57
+ hparams["dataloader_options"]["shuffle"] = False
58
+
59
+ elif hparams["sorting"] == "descending":
60
+ train_data = train_data.filtered_sorted(
61
+ sort_key="duration",
62
+ reverse=True,
63
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
64
+ )
65
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
66
+ hparams["dataloader_options"]["shuffle"] = False
67
+
68
+ elif hparams["sorting"] == "random":
69
+ pass
70
+
71
+ else:
72
+ raise NotImplementedError(
73
+ "sorting must be random, ascending or descending"
74
+ )
75
+
76
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
77
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
78
+ )
79
+ # We also sort the validation data so it is faster to validate
80
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
81
+ test_datasets = {}
82
+ for csv_file in hparams["test_csv"]:
83
+ name = Path(csv_file).stem
84
+ test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
85
+ csv_path=csv_file, replacements={"data_root": data_folder}
86
+ )
87
+ test_datasets[name] = test_datasets[name].filtered_sorted(
88
+ sort_key="duration"
89
+ )
90
+
91
+ datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
92
+
93
+
94
+ # 2. Define audio pipeline:
95
+ @sb.utils.data_pipeline.takes("wav")
96
+ @sb.utils.data_pipeline.provides("sig")
97
+ def audio_pipeline(wav):
98
+ info = torchaudio.info(wav)
99
+ sig = sb.dataio.dataio.read_audio(wav)
100
+ if len(sig.shape)>1 :
101
+ sig = torch.mean(sig, dim=1)
102
+ resampled = torchaudio.transforms.Resample(
103
+ info.sample_rate, hparams["sample_rate"],
104
+ )(sig)
105
+ return resampled
106
+
107
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
108
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
109
+
110
+ # 3. Define text pipeline:
111
+ @sb.utils.data_pipeline.takes("wrd")
112
+ @sb.utils.data_pipeline.provides(
113
+ "wrd", "char_list", "tokens_list", "tokens"
114
+ )
115
+ def text_pipeline(wrd):
116
+ yield wrd
117
+ char_list = list(wrd)
118
+ yield char_list
119
+ tokens_list = label_encoder.encode_sequence(char_list)
120
+ yield tokens_list
121
+ tokens = torch.LongTensor(tokens_list)
122
+ yield tokens
123
+
124
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
125
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
126
+ special_labels = {
127
+ "blank_label": hparams["blank_index"],
128
+ "unk_label": hparams["unk_index"]
129
+ }
130
+ label_encoder.load_or_create(
131
+ path=lab_enc_file,
132
+ from_didatasets=[train_data],
133
+ output_key="char_list",
134
+ special_labels=special_labels,
135
+ sequence_input=True,
136
+ )
137
+
138
+ # 4. Set output:
139
+ sb.dataio.dataset.set_output_keys(
140
+ datasets, ["id", "sig", "wrd", "char_list", "tokens"],
141
+ )
142
+ return train_data, valid_data,test_datasets, label_encoder
143
+
144
+ class ASR(sb.core.Brain):
145
+ def compute_forward(self, batch, stage):
146
+ """Forward computations from the waveform batches to the output probabilities."""
147
+
148
+ batch = batch.to(self.device)
149
+ wavs, wav_lens = batch.sig
150
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
151
+
152
+ if stage == sb.Stage.TRAIN:
153
+ if hasattr(self.hparams, "augmentation"):
154
+ wavs = self.hparams.augmentation(wavs, wav_lens)
155
+
156
+ # Forward pass
157
+ feats = self.modules.wav2vec2(wavs, wav_lens)
158
+ x = self.modules.enc(feats)
159
+ logits = self.modules.ctc_lin(x)
160
+ p_ctc = self.hparams.log_softmax(logits)
161
+
162
+ return p_ctc, wav_lens
163
+
164
+ def custom_encode(self,wavs,wav_lens) :
165
+ wavs = wavs.to(self.device)
166
+ if(wav_lens is not None): wav_lens.to(self.device)
167
+
168
+ feats = self.modules.wav2vec2(wavs, wav_lens)
169
+ x = self.modules.enc(feats)
170
+ logits = self.modules.ctc_lin(x)
171
+ p_ctc = self.hparams.log_softmax(logits)
172
+
173
+ return feats,p_ctc
174
+
175
+
176
+
177
+ def compute_objectives(self, predictions, batch, stage):
178
+ """Computes the loss (CTC) given predictions and targets."""
179
+
180
+ p_ctc, wav_lens = predictions
181
+
182
+ ids = batch.id
183
+ tokens, tokens_lens = batch.tokens
184
+
185
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
186
+
187
+ if stage != sb.Stage.TRAIN:
188
+ predicted_tokens = sb.decoders.ctc_greedy_decode(
189
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
190
+ )
191
+ # Decode token terms to words
192
+ if self.hparams.use_language_modelling:
193
+ predicted_words = []
194
+ for logs in p_ctc:
195
+ text = decoder.decode(logs.detach().cpu().numpy())
196
+ predicted_words.append(text.split(" "))
197
+ else:
198
+ predicted_words = [
199
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
200
+ for utt_seq in predicted_tokens
201
+ ]
202
+ # Convert indices to words
203
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
204
+
205
+ self.wer_metric.append(ids, predicted_words, target_words)
206
+ self.cer_metric.append(ids, predicted_words, target_words)
207
+
208
+ return loss
209
+
210
+ def fit_batch(self, batch):
211
+ """Train the parameters given a single batch in input"""
212
+ should_step = self.step % self.grad_accumulation_factor == 0
213
+ # Managing automatic mixed precision
214
+ # TOFIX: CTC fine-tuning currently is unstable
215
+ # This is certainly due to CTC being done in fp16 instead of fp32
216
+ if self.auto_mix_prec:
217
+ with torch.cuda.amp.autocast():
218
+ with self.no_sync():
219
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
220
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
221
+ with self.no_sync(not should_step):
222
+ self.scaler.scale(
223
+ loss / self.grad_accumulation_factor
224
+ ).backward()
225
+ if should_step:
226
+
227
+ if not self.hparams.wav2vec2.freeze:
228
+ self.scaler.unscale_(self.wav2vec_optimizer)
229
+ self.scaler.unscale_(self.model_optimizer)
230
+ if self.check_gradients(loss):
231
+ if not self.hparams.wav2vec2.freeze:
232
+ if self.optimizer_step >= self.hparams.warmup_steps:
233
+ self.scaler.step(self.wav2vec_optimizer)
234
+ self.scaler.step(self.model_optimizer)
235
+ self.scaler.update()
236
+ self.zero_grad()
237
+ self.optimizer_step += 1
238
+ else:
239
+ # This is mandatory because HF models have a weird behavior with DDP
240
+ # on the forward pass
241
+ with self.no_sync():
242
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
243
+
244
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
245
+
246
+ with self.no_sync(not should_step):
247
+ (loss / self.grad_accumulation_factor).backward()
248
+ if should_step:
249
+ if self.check_gradients(loss):
250
+ if not self.hparams.wav2vec2.freeze:
251
+ if self.optimizer_step >= self.hparams.warmup_steps:
252
+ self.wav2vec_optimizer.step()
253
+ self.model_optimizer.step()
254
+ self.zero_grad()
255
+ self.optimizer_step += 1
256
+
257
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
258
+ return loss.detach().cpu()
259
+
260
+ def evaluate_batch(self, batch, stage):
261
+ """Computations needed for validation/test batches"""
262
+ predictions = self.compute_forward(batch, stage=stage)
263
+ with torch.no_grad():
264
+ loss = self.compute_objectives(predictions, batch, stage=stage)
265
+ return loss.detach()
266
+
267
+ def on_stage_start(self, stage, epoch):
268
+ """Gets called at the beginning of each epoch"""
269
+ if stage != sb.Stage.TRAIN:
270
+ self.cer_metric = self.hparams.cer_computer()
271
+ self.wer_metric = self.hparams.error_rate_computer()
272
+
273
+ def on_stage_end(self, stage, stage_loss, epoch):
274
+ """Gets called at the end of an epoch."""
275
+ # Compute/store important stats
276
+ stage_stats = {"loss": stage_loss}
277
+ if stage == sb.Stage.TRAIN:
278
+ self.train_stats = stage_stats
279
+ else:
280
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
281
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
282
+
283
+ # Perform end-of-iteration things, like annealing, logging, etc.
284
+ if stage == sb.Stage.VALID:
285
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
286
+ stage_stats["loss"]
287
+ )
288
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
289
+ stage_stats["loss"]
290
+ )
291
+ sb.nnet.schedulers.update_learning_rate(
292
+ self.model_optimizer, new_lr_model
293
+ )
294
+ if not self.hparams.wav2vec2.freeze:
295
+ sb.nnet.schedulers.update_learning_rate(
296
+ self.wav2vec_optimizer, new_lr_wav2vec
297
+ )
298
+ self.hparams.train_logger.log_stats(
299
+ stats_meta={
300
+ "epoch": epoch,
301
+ "lr_model": old_lr_model,
302
+ "lr_wav2vec": old_lr_wav2vec,
303
+ },
304
+ train_stats=self.train_stats,
305
+ valid_stats=stage_stats,
306
+ )
307
+ self.checkpointer.save_and_keep_only(
308
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
309
+ )
310
+ elif stage == sb.Stage.TEST:
311
+ self.hparams.train_logger.log_stats(
312
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
313
+ test_stats=stage_stats,
314
+ )
315
+ with open(self.hparams.wer_file, "w") as w:
316
+ self.wer_metric.write_stats(w)
317
+
318
+ def init_optimizers(self):
319
+ "Initializes the wav2vec2 optimizer and model optimizer"
320
+
321
+ # If the wav2vec encoder is unfrozen, we create the optimizer
322
+ if not self.hparams.wav2vec2.freeze:
323
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
324
+ self.modules.wav2vec2.parameters()
325
+ )
326
+ if self.checkpointer is not None:
327
+ self.checkpointer.add_recoverable(
328
+ "wav2vec_opt", self.wav2vec_optimizer
329
+ )
330
+
331
+ self.model_optimizer = self.hparams.model_opt_class(
332
+ self.hparams.model.parameters()
333
+ )
334
+
335
+ if self.checkpointer is not None:
336
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
337
+
338
+ def zero_grad(self, set_to_none=False):
339
+ if not self.hparams.wav2vec2.freeze:
340
+ self.wav2vec_optimizer.zero_grad(set_to_none)
341
+ self.model_optimizer.zero_grad(set_to_none)
342
+
343
+
344
+ """
345
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
346
+
347
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
348
+ hparams
349
+ )
350
+
351
+
352
+ # We dynamicaly add the tokenizer to our brain class.
353
+ # NB: This tokenizer corresponds to the one used for the LM!!
354
+ """
355
+ from speechbrain.pretrained import EncoderASR,EncoderDecoderASR
356
+ french_asr_model = EncoderASR.from_hparams(source="speechbrain/asr-wav2vec2-commonvoice-fr", savedir="pretrained_models/asr-wav2vec2-commonvoice-fr").cuda()
357
+ #french_asr_model = "r"
358
+
359
+ cvhparams_file, cvrun_opts, cvoverrides = sb.parse_arguments(["en_cv.yaml"])
360
+ with open(cvhparams_file) as cvfin:
361
+ cvhparams = load_hyperpyyaml(cvfin, cvoverrides)
362
+ english_asr_model = ASRCV(
363
+ modules=cvhparams["modules"],
364
+ hparams=cvhparams,
365
+ run_opts=cvrun_opts,
366
+ checkpointer=cvhparams["checkpointer"],
367
+ )
368
+ english_asr_model.checkpointer.recover_if_possible()
369
+ asr_brain = ASR(
370
+ modules=hparams["modules"],
371
+ hparams=hparams,
372
+ run_opts=run_opts,
373
+ checkpointer=hparams["checkpointer"],
374
+ )
375
+ asr_brain.checkpointer.recover_if_possible()
376
+ asr_brain.modules.eval()
377
+ english_asr_model.modules.eval()
378
+ french_asr_model.mods.eval()
379
+ """
380
+ asr_brain.tokenizer = label_encoder
381
+
382
+ # Testing
383
+ real = True
384
+ if real :
385
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
386
+ asr_brain.hparams.wer_file = os.path.join(
387
+ hparams["output_folder"], "wer_{}.txt".format(k)
388
+ )
389
+ asr_brain.evaluate(
390
+ test_datasets[k], test_loader_kwargs=hparams["dataloader_options"]
391
+ )
392
+ """
393
+
394
+ """
395
+ from torch.nn.utils.rnn import pad_sequence
396
+ def load_paths(wavs_path):
397
+ waveforms = []
398
+ for path in wavs_path :
399
+ waveform, _ = torchaudio.load(path)
400
+ waveforms.append(waveform.squeeze(0))
401
+ # normalize array length to the bigger arrays by pading with 0's
402
+ padded_arrays = pad_sequence(waveforms, batch_first=True)
403
+ return torch.tensor(padded_arrays)
404
+
405
+ waveform = load_paths(["/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/samples/Salah10.wav","/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/samples/Salah10.wav"])
406
+ embeddings, posteriogram = asr_brain.custom_encode(waveform,None)
407
+ print(embeddings.shape)
408
+ print(posteriogram.shape)
409
+ """
410
+
411
+ from speechbrain.pretrained import EncoderASR,EncoderDecoderASR
412
+ import torchaudio
413
+ import speechbrain as sb
414
+ import torch
415
+ from torch.nn.utils.rnn import pad_sequence
416
+ import torch
417
+ import speechbrain as sb
418
+ import numpy as np
419
+ import torch.optim as optim
420
+ import torch.nn as nn
421
+
422
+ # Commented out IPython magic to ensure Python compatibility.
423
+ # %ls
424
+
425
+ #UTILS FUNCTIOJNS
426
+ def get_size_dimensions(arr):
427
+ size_dimensions = []
428
+ while isinstance(arr, list):
429
+ size_dimensions.append(len(arr))
430
+ arr = arr[0]
431
+ return size_dimensions
432
+
433
+ def scale_array(batch,n):
434
+ scaled_batch = []
435
+
436
+ for array in batch:
437
+ if(n < len(array)): raise ValueError("Cannot scale Array down")
438
+
439
+ repeat = round(n/len(array))+1
440
+ scaled_length_array= []
441
+
442
+ for i in array:
443
+ for j in range(repeat) :
444
+ if(len(scaled_length_array) == n): break
445
+ scaled_length_array.append(i)
446
+
447
+ scaled_batch.append(scaled_length_array)
448
+
449
+ return torch.tensor(scaled_batch)
450
+
451
+
452
+ def load_paths(wavs_path):
453
+ waveforms = []
454
+ for path in wavs_path :
455
+ waveform, _ = torchaudio.load(path)
456
+ waveforms.append(waveform.squeeze(0))
457
+ # normalize array length to the bigger arrays by pading with 0's
458
+ padded_arrays = pad_sequence(waveforms, batch_first=True)
459
+ return torch.tensor(padded_arrays)
460
+
461
+
462
+
463
+ def word_to_vec(input_string):
464
+ mapping= {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, 'ا': 27, 'ب': 28, 'ت': 29, 'ث': 30, 'ج': 31, 'ح': 32, 'خ': 33, 'د': 34, 'ذ': 35, 'ر': 36, 'ز': 37, 'س': 38, 'ش': 39, 'ص': 40, 'ض': 41, 'ط': 42, 'ظ': 43, 'ع': 44, 'غ': 45, 'ف': 46, 'ق': 47, 'ك': 48, 'ل': 49, 'م': 50, 'ن': 51, 'ه': 52, 'و': 53, 'ي': 54,' ':55}
465
+
466
+ numbers = [mapping[word] for word in input_string if word in mapping]
467
+ return numbers
468
+
469
+ device = 'cuda'
470
+ verbose = 0
471
+ #FLOW LEVEL FUNCTIONS
472
+ def merge_strategy(embeddings1, embeddings2, embeddings3,post1, post2,post3):
473
+
474
+
475
+ post1 = post1.to(device)
476
+ post2 = post2.to(device)
477
+ post3 = post3.to(device)
478
+ embeddings1 = embeddings1.to(device)
479
+ embeddings2 = embeddings2.to(device)
480
+ embeddings3 = embeddings3.to(device)
481
+
482
+ posteriograms_merged = torch.cat((post1,post2,post3),dim=2)
483
+ embeddings_merged = torch.cat((embeddings1,embeddings2,embeddings3),dim=2)
484
+
485
+ if(verbose !=0):
486
+ print('MERGED POST ',posteriograms_merged.shape)
487
+ print('MERGED emb ',embeddings_merged.shape)
488
+
489
+ return torch.cat((posteriograms_merged,embeddings_merged),dim=2).to(device)
490
+
491
+ def decode(model,wavs,wav_lens):
492
+
493
+ with torch.no_grad():
494
+ wav_lens = wav_lens.to(model.device)
495
+ encoder_out = model.encode_batch(wavs, wav_lens)
496
+ predictions = model.decoding_function(encoder_out, wav_lens)
497
+ return predictions
498
+
499
+ def middle_layer(batch, lens):
500
+
501
+ tn_embeddings, tn_posteriogram = asr_brain.custom_encode(batch,None)
502
+
503
+ fr_embeddings = french_asr_model.mods.encoder.wav2vec2(batch)
504
+ fr_posteriogram =french_asr_model.encode_batch(batch,lens)
505
+ en_embeddings = english_asr_model.modules.wav2vec2(batch, lens)
506
+ x = english_asr_model.modules.enc(en_embeddings)
507
+ en_posteriogram = english_asr_model.modules.ctc_lin(x)
508
+ #scores, en_posteriogram = english_asr_model.mods.decoder(en_embeddings ,lens)
509
+ if(verbose !=0):
510
+ print('[EMBEDDINGS] FR:',fr_embeddings.shape, "EN:",en_embeddings.shape, "TN:", tn_embeddings.shape)
511
+ print('[POSTERIOGRAM] FR:',fr_posteriogram.shape, "EN:",en_posteriogram.shape,"TN:",tn_posteriogram.shape)
512
+
513
+
514
+ bilangual_sample = merge_strategy(fr_embeddings,en_embeddings,tn_embeddings,fr_posteriogram,en_posteriogram,tn_posteriogram)
515
+ return bilangual_sample
516
+
517
+ class Mixer(sb.core.Brain):
518
+
519
+ def compute_forward(self, batch, stage):
520
+ """Forward computations from the waveform batches to the output probabilities."""
521
+ wavs, wav_lens = batch.sig
522
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
523
+
524
+ if stage == sb.Stage.TRAIN:
525
+ if hasattr(self.hparams, "augmentation"):
526
+ wavs = self.hparams.augmentation(wavs, wav_lens)
527
+
528
+ multi_langual_feats = middle_layer(wavs, wav_lens)
529
+ multi_langual_feats= multi_langual_feats.to(device)
530
+ feats, _ = self.modules.enc(multi_langual_feats)
531
+ logits = self.modules.ctc_lin(feats)
532
+ p_ctc = self.hparams.log_softmax(logits)
533
+
534
+ if stage!= sb.Stage.TRAIN:
535
+ p_tokens = sb.decoders.ctc_greedy_decode(
536
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
537
+ )
538
+ else :
539
+ p_tokens = None
540
+ return p_ctc, wav_lens, p_tokens
541
+
542
+ def compute_objectives(self, predictions, batch, stage):
543
+ """Computes the loss (CTC) given predictions and targets."""
544
+
545
+ p_ctc, wav_lens , predicted_tokens= predictions
546
+
547
+ ids = batch.id
548
+ tokens, tokens_lens = batch.tokens
549
+
550
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
551
+
552
+
553
+ if stage == sb.Stage.VALID:
554
+ predicted_words = [
555
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
556
+ for utt_seq in predicted_tokens
557
+ ]
558
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
559
+ self.wer_metric.append(ids, predicted_words, target_words)
560
+ self.cer_metric.append(ids, predicted_words, target_words)
561
+ if stage ==sb.Stage.TEST :
562
+ if self.hparams.language_modelling:
563
+ predicted_words = []
564
+ for logs in p_ctc:
565
+ text = decoder.decode(logs.detach().cpu().numpy())
566
+ predicted_words.append(text.split(" "))
567
+ else :
568
+ predicted_words = [
569
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
570
+ for utt_seq in predicted_tokens
571
+ ]
572
+
573
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
574
+ self.wer_metric.append(ids, predicted_words, target_words)
575
+ self.cer_metric.append(ids, predicted_words, target_words)
576
+
577
+ return loss
578
+
579
+ def fit_batch(self, batch):
580
+ """Train the parameters given a single batch in input"""
581
+ should_step = self.step % self.grad_accumulation_factor == 0
582
+ # Managing automatic mixed precision
583
+ # TOFIX: CTC fine-tuning currently is unstable
584
+ # This is certainly due to CTC being done in fp16 instead of fp32
585
+ if self.auto_mix_prec:
586
+ with torch.cuda.amp.autocast():
587
+ with self.no_sync():
588
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
589
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
590
+ with self.no_sync(not should_step):
591
+ self.scaler.scale(
592
+ loss / self.grad_accumulation_factor
593
+ ).backward()
594
+ if should_step:
595
+
596
+
597
+ self.scaler.unscale_(self.model_optimizer)
598
+ if self.check_gradients(loss):
599
+ self.scaler.step(self.model_optimizer)
600
+ self.scaler.update()
601
+ self.zero_grad()
602
+ self.optimizer_step += 1
603
+ else:
604
+ # This is mandatory because HF models have a weird behavior with DDP
605
+ # on the forward pass
606
+ with self.no_sync():
607
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
608
+
609
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
610
+
611
+ with self.no_sync(not should_step):
612
+ (loss / self.grad_accumulation_factor).backward()
613
+ if should_step:
614
+ if self.check_gradients(loss):
615
+ self.model_optimizer.step()
616
+ self.zero_grad()
617
+ self.optimizer_step += 1
618
+
619
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
620
+ return loss.detach().cpu()
621
+
622
+ def evaluate_batch(self, batch, stage):
623
+ """Computations needed for validation/test batches"""
624
+ predictions = self.compute_forward(batch, stage=stage)
625
+ with torch.no_grad():
626
+ loss = self.compute_objectives(predictions, batch, stage=stage)
627
+ return loss.detach()
628
+
629
+ def on_stage_start(self, stage, epoch):
630
+ """Gets called at the beginning of each epoch"""
631
+ if stage != sb.Stage.TRAIN:
632
+ self.cer_metric = self.hparams.cer_computer()
633
+ self.wer_metric = self.hparams.error_rate_computer()
634
+
635
+ def on_stage_end(self, stage, stage_loss, epoch):
636
+ """Gets called at the end of an epoch."""
637
+ # Compute/store important stats
638
+ stage_stats = {"loss": stage_loss}
639
+ if stage == sb.Stage.TRAIN:
640
+ self.train_stats = stage_stats
641
+ else:
642
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
643
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
644
+
645
+ # Perform end-of-iteration things, like annealing, logging, etc.
646
+ if stage == sb.Stage.VALID:
647
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
648
+ stage_stats["loss"]
649
+ )
650
+ sb.nnet.schedulers.update_learning_rate(
651
+ self.model_optimizer, new_lr_model
652
+ )
653
+ self.hparams.train_logger.log_stats(
654
+ stats_meta={
655
+ "epoch": epoch,
656
+ "lr_model": old_lr_model,
657
+ },
658
+ train_stats=self.train_stats,
659
+ valid_stats=stage_stats,
660
+ )
661
+ self.checkpointer.save_and_keep_only(
662
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
663
+ )
664
+ elif stage == sb.Stage.TEST:
665
+ self.hparams.train_logger.log_stats(
666
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
667
+ test_stats=stage_stats,
668
+ )
669
+ with open(self.hparams.wer_file, "w") as w:
670
+ self.wer_metric.write_stats(w)
671
+
672
+ def init_optimizers(self):
673
+
674
+ self.model_optimizer = self.hparams.model_opt_class(
675
+ self.hparams.model.parameters()
676
+ )
677
+
678
+ if self.checkpointer is not None:
679
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
680
+
681
+ def zero_grad(self, set_to_none=False):
682
+
683
+ self.model_optimizer.zero_grad(set_to_none)
684
+
685
+
686
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
687
+
688
+ # If distributed_launch=True then
689
+ # create ddp_group with the right communication protocol
690
+ sb.utils.distributed.ddp_init_group(run_opts)
691
+
692
+ with open(hparams_file) as fin:
693
+ hparams = load_hyperpyyaml(fin, overrides)
694
+
695
+ # Create experiment directory
696
+ sb.create_experiment_directory(
697
+ experiment_directory=hparams["output_folder"],
698
+ hyperparams_to_save=hparams_file,
699
+ overrides=overrides,
700
+ )
701
+ def read_labels_file(labels_file):
702
+ with open(labels_file, "r",encoding="utf-8") as lf:
703
+ lines = lf.read().splitlines()
704
+ division = "==="
705
+ numbers = {}
706
+ for line in lines :
707
+ if division in line :
708
+ break
709
+ string, number = line.split("=>")
710
+ number = int(number)
711
+ string = string[1:-2]
712
+ numbers[number] = string
713
+ return [numbers[x] for x in range(len(numbers))]
714
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
715
+ hparams
716
+ )
717
+
718
+
719
+ labels = read_labels_file(os.path.join(hparams["save_folder"], "label_encoder.txt"))
720
+ labels = [""] + labels[1:-1] + ["1"]
721
+ if hparams["language_modelling"]:
722
+ decoder = build_ctcdecoder(
723
+ labels,
724
+ kenlm_model_path=hparams["ngram_lm_path"], # either .arpa or .bin file
725
+ alpha=0.5, # tuned on a val set
726
+ beta=1, # tuned on a val set
727
+ )
728
+
729
+
730
+
731
+
732
+ mixer = Mixer(
733
+ modules=hparams["modules"],
734
+ hparams=hparams,
735
+ run_opts=run_opts,
736
+ checkpointer=hparams["checkpointer"],
737
+ )
738
+ mixer.tokenizer = label_encoder
739
+
740
+
741
+ mixer.fit(
742
+ mixer.hparams.epoch_counter,
743
+ train_data,
744
+ valid_data,
745
+ train_loader_kwargs=hparams["dataloader_options"],
746
+ valid_loader_kwargs=hparams["test_dataloader_options"],
747
+ )
748
+ print(test_datasets.keys())
749
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
750
+ mixer.hparams.wer_file = os.path.join(
751
+ hparams["output_folder"], "wer_{}.txt".format(k)
752
+ )
753
+ mixer.evaluate(
754
+ test_datasets[k], test_loader_kwargs=hparams["test_dataloader_options"]
755
+ )
756
+
TunisianASR/results/14epoch_tunisian/<seed>/env.log ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SpeechBrain system description
2
+ ==============================
3
+ Python version:
4
+ 3.9.12 | packaged by conda-forge | (main, Mar 24 2022, 23:22:55)
5
+ [GCC 10.3.0]
6
+ ==============================
7
+ Installed Python packages:
8
+ aiohttp==3.8.5
9
+ aiosignal==1.3.1
10
+ async-timeout==4.0.3
11
+ attrs==23.1.0
12
+ audioread==3.0.0
13
+ certifi==2023.7.22
14
+ cffi==1.15.1
15
+ charset-normalizer==3.2.0
16
+ click==8.1.7
17
+ cmake==3.27.2
18
+ datasets==2.14.4
19
+ decorator==5.1.1
20
+ dill==0.3.7
21
+ exceptiongroup==1.1.3
22
+ filelock==3.12.3
23
+ frozenlist==1.4.0
24
+ fsspec==2023.6.0
25
+ huggingface-hub==0.16.4
26
+ HyperPyYAML==1.2.1
27
+ hypothesis==6.82.7
28
+ idna==3.4
29
+ Jinja2==3.1.2
30
+ jiwer==3.0.3
31
+ joblib==1.3.2
32
+ kenlm @ https://github.com/kpu/kenlm/archive/master.zip#sha256=4d002dcde70b52d519cafff4dc0008696c40cff1c9184a531b40c7b45905be6b
33
+ lazy_loader==0.3
34
+ librosa==0.10.1
35
+ lit==16.0.6
36
+ llvmlite==0.40.1
37
+ MarkupSafe==2.1.3
38
+ mpmath==1.3.0
39
+ msgpack==1.0.5
40
+ multidict==6.0.4
41
+ multiprocess==0.70.15
42
+ networkx==3.1
43
+ numba==0.57.1
44
+ numpy==1.24.4
45
+ nvidia-cublas-cu11==11.10.3.66
46
+ nvidia-cuda-cupti-cu11==11.7.101
47
+ nvidia-cuda-nvrtc-cu11==11.7.99
48
+ nvidia-cuda-runtime-cu11==11.7.99
49
+ nvidia-cudnn-cu11==8.5.0.96
50
+ nvidia-cufft-cu11==10.9.0.58
51
+ nvidia-curand-cu11==10.2.10.91
52
+ nvidia-cusolver-cu11==11.4.0.1
53
+ nvidia-cusparse-cu11==11.7.4.91
54
+ nvidia-nccl-cu11==2.14.3
55
+ nvidia-nvtx-cu11==11.7.91
56
+ packaging==23.1
57
+ pandas==2.0.3
58
+ platformdirs==3.10.0
59
+ pooch==1.7.0
60
+ pyarrow==13.0.0
61
+ pycparser==2.21
62
+ pyctcdecode==0.5.0
63
+ pygtrie==2.5.0
64
+ python-dateutil==2.8.2
65
+ pytz==2023.3
66
+ PyYAML==6.0.1
67
+ rapidfuzz==3.2.0
68
+ regex==2023.8.8
69
+ requests==2.31.0
70
+ ruamel.yaml==0.17.28
71
+ ruamel.yaml.clib==0.2.7
72
+ safetensors==0.3.3
73
+ scikit-learn==1.3.0
74
+ scipy==1.11.2
75
+ sentencepiece==0.1.99
76
+ six==1.16.0
77
+ sortedcontainers==2.4.0
78
+ soundfile==0.12.1
79
+ soxr==0.3.6
80
+ speechbrain==0.5.15
81
+ sympy==1.12
82
+ threadpoolctl==3.2.0
83
+ tokenizers==0.13.3
84
+ torch==2.0.1
85
+ torchaudio==2.0.2
86
+ tqdm==4.66.1
87
+ transformers==4.32.1
88
+ triton==2.0.0
89
+ typing_extensions==4.7.1
90
+ tzdata==2023.3
91
+ urllib3==2.0.4
92
+ xxhash==3.3.0
93
+ yarl==1.9.2
94
+ ==============================
95
+ Could not get git revision==============================
96
+ CUDA version:
97
+ 11.7