Antoine-caubriere commited on
Commit
131ff81
1 Parent(s): fa7be55

Delete SB_ASR_FLEURS_finetuning.ipynb

Browse files
Files changed (1) hide show
  1. SB_ASR_FLEURS_finetuning.ipynb +0 -689
SB_ASR_FLEURS_finetuning.ipynb DELETED
@@ -1,689 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "49b85514-0fb6-49c6-be76-259bfeb638c6",
6
- "metadata": {},
7
- "source": [
8
- "# Introduction\n",
9
- "N'hésitez pas à nous contacter en cas de questions : antoine.caubriere@orange.com & elodie.gauthier@orange.com\n",
10
- "\n",
11
- "Pensez à modifier l'ensemble des PATH dans le fichier de configuration ASR_FLEURSswahili_hf.yaml et dans le code python ci-dessous (PATH_TO_YOUR_FOLDER).\n",
12
- "\n",
13
- "Dans le cas d'un changement de corpus (autre sous partie de FLEURS / vos propres jeux de données), pensez à modifier la taille de la couche de sortie du modèle : ASR_swahili_hf.yaml/output_neurons\n"
14
- ]
15
- },
16
- {
17
- "cell_type": "markdown",
18
- "id": "e62faa86-911a-48ce-82bc-8a34e13ffbc4",
19
- "metadata": {},
20
- "source": [
21
- "# Préparation des données FLEURS"
22
- ]
23
- },
24
- {
25
- "cell_type": "markdown",
26
- "id": "c6ccf4a5-cad1-4632-8954-f4e454ff3540",
27
- "metadata": {},
28
- "source": [
29
- "### 1. Installation des dépendances"
30
- ]
31
- },
32
- {
33
- "cell_type": "code",
34
- "execution_count": null,
35
- "id": "7bb8b44e-826f-4f13-b128-eebbd18dedc5",
36
- "metadata": {
37
- "jupyter": {
38
- "source_hidden": true
39
- }
40
- },
41
- "outputs": [],
42
- "source": [
43
- "pip install datasets librosa soundfile"
44
- ]
45
- },
46
- {
47
- "cell_type": "markdown",
48
- "id": "016d7646-bcca-4422-8b28-9d12d4b86c8f",
49
- "metadata": {},
50
- "source": [
51
- "### 2. Téléchargement et formatage du dataset"
52
- ]
53
- },
54
- {
55
- "cell_type": "code",
56
- "execution_count": null,
57
- "id": "da273973-05ee-4de5-830e-34d7f2220353",
58
- "metadata": {},
59
- "outputs": [],
60
- "source": [
61
- "from datasets import load_dataset\n",
62
- "from pathlib import Path\n",
63
- "from collections import OrderedDict\n",
64
- "from tqdm import tqdm\n",
65
- "import shutil\n",
66
- "import os\n",
67
- "\n",
68
- "dataset_write_base = \"/opt/marcel-c3/workdir/zqdb1553/jupyter/data_speechbrain/\"\n",
69
- "cache_dir = \"/opt/marcel-c3/workdir/zqdb1553/jupyter/data_huggingface/\"\n",
70
- "\n",
71
- "if os.path.isdir(cache_dir):\n",
72
- " print(\"rm -rf \"+cache_dir)\n",
73
- " os.system(\"rm -rf \"+cache_dir)\n",
74
- "\n",
75
- "if os.path.isdir(dataset_write_base):\n",
76
- " print(\"rm -rf \"+dataset_write_base)\n",
77
- " os.system(\"rm -rf \"+dataset_write_base)\n",
78
- "\n",
79
- "# **************************************\n",
80
- "# choix des langues à extraire de FLEURS\n",
81
- "# **************************************\n",
82
- "lang_dict = OrderedDict([\n",
83
- " #(\"Afrikaans\",\"af_za\"),\n",
84
- " #(\"Amharic\", \"am_et\"),\n",
85
- " #(\"Fula\", \"ff_sn\"),\n",
86
- " #(\"Ganda\", \"lg_ug\"),\n",
87
- " #(\"Hausa\", \"ha_ng\"),\n",
88
- " #(\"Igbo\", \"ig_ng\"),\n",
89
- " #(\"Kamba\", \"kam_ke\"),\n",
90
- " #(\"Lingala\", \"ln_cd\"),\n",
91
- " #(\"Luo\", \"luo_ke\"),\n",
92
- " #(\"Northern-Sotho\", \"nso_za\"),\n",
93
- " #(\"Nyanja\", \"ny_mw\"),\n",
94
- " #(\"Oromo\", \"om_et\"),\n",
95
- " #(\"Shona\", \"sn_zw\"),\n",
96
- " #(\"Somali\", \"so_so\"),\n",
97
- " (\"Swahili\", \"sw_ke\"),\n",
98
- " #(\"Umbundu\", \"umb_ao\"),\n",
99
- " #(\"Wolof\", \"wo_sn\"), \n",
100
- " #(\"Xhosa\", \"xh_za\"), \n",
101
- " #(\"Yoruba\", \"yo_ng\"), \n",
102
- " #(\"Zulu\", \"zu_za\")\n",
103
- " ])\n",
104
- "\n",
105
- "# ********************************\n",
106
- "# choix des sous-parties à traiter\n",
107
- "# ********************************\n",
108
- "datasets = [\"train\",\"test\",\"validation\"]\n",
109
- "\n",
110
- "for lang in lang_dict:\n",
111
- " print(\"Prepare --->\", lang)\n",
112
- " \n",
113
- " # ********************************\n",
114
- " # Download FLEURS from huggingface\n",
115
- " # ********************************\n",
116
- " fleurs_asr = load_dataset(\"google/fleurs\", lang_dict[lang],cache_dir=cache_dir, trust_remote_code=True)\n",
117
- "\n",
118
- " for subparts in datasets:\n",
119
- " \n",
120
- " used_ID = []\n",
121
- " Path(dataset_write_base+\"/\"+lang+\"/wavs/\"+subparts).mkdir(parents=True, exist_ok=True)\n",
122
- " \n",
123
- " # csv header\n",
124
- " f = open(dataset_write_base+\"/\"+lang+\"/\"+subparts+\".csv\", \"w\")\n",
125
- " f.write(\"ID,duration,wav,spk_id,wrd\\n\")\n",
126
- "\n",
127
- " for uid in tqdm(range(len(fleurs_asr[subparts]))):\n",
128
- "\n",
129
- " # ***************\n",
130
- " # format CSV line\n",
131
- " # ***************\n",
132
- " text_id = lang+\"_\"+str(fleurs_asr[subparts][uid][\"id\"])\n",
133
- " \n",
134
- " # some ID are duplicated (same speaker, same transcription BUT different recording)\n",
135
- " while(text_id in used_ID):\n",
136
- " text_id += \"_bis\"\n",
137
- " used_ID.append(text_id)\n",
138
- "\n",
139
- " duration = \"{:.3f}\".format(round(float(fleurs_asr[subparts][uid][\"num_samples\"])/float(fleurs_asr[subparts][uid][\"audio\"][\"sampling_rate\"]),3))\n",
140
- " wav_path = \"/\".join([dataset_write_base, lang, \"wavs\",subparts, fleurs_asr[subparts][uid][\"audio\"][\"path\"].split('/')[-1]])\n",
141
- " spk_id = \"spk_\" + text_id\n",
142
- " # AC : \"pseudo-normalisation\" de cas marginaux -- TODO mieux\n",
143
- " wrd = fleurs_asr[subparts][uid][\"transcription\"].replace(',','').replace('$',' $ ').replace('\"','').replace('”','').replace(' ',' ')\n",
144
- "\n",
145
- " # **************\n",
146
- " # write CSV line\n",
147
- " # **************\n",
148
- " f.write(text_id+\",\"+duration+\",\"+wav_path+\",\"+spk_id+\",\"+wrd+\"\\n\") \n",
149
- "\n",
150
- " # *******************\n",
151
- " # Move wav from cache\n",
152
- " # *******************\n",
153
- " previous_path = \"/\".join(fleurs_asr[subparts][uid][\"path\"].split('/')[:-1]) + \"/\" + fleurs_asr[subparts][uid][\"audio\"][\"path\"]\n",
154
- " new_path = \"/\".join([dataset_write_base,lang,\"wavs\",subparts,fleurs_asr[subparts][uid][\"audio\"][\"path\"].split('/')[-1]])\n",
155
- " shutil.move(previous_path,new_path)\n",
156
- " \n",
157
- " f.close()\n",
158
- " print(\"--->\", lang, \"done\")"
159
- ]
160
- },
161
- {
162
- "cell_type": "markdown",
163
- "id": "4c32e369-f0f9-4695-8c9a-aa3a9de7bf7b",
164
- "metadata": {},
165
- "source": [
166
- "# Recette ASR"
167
- ]
168
- },
169
- {
170
- "cell_type": "markdown",
171
- "id": "77fb2c55-3f8c-4f34-81f0-ad48a632e010",
172
- "metadata": {
173
- "jp-MarkdownHeadingCollapsed": true
174
- },
175
- "source": [
176
- "## 1. Installation des dépendances"
177
- ]
178
- },
179
- {
180
- "cell_type": "code",
181
- "execution_count": null,
182
- "id": "fbe25635-e765-480c-8416-c48a31ee6140",
183
- "metadata": {},
184
- "outputs": [],
185
- "source": [
186
- "pip install torch==2.2.2 torchaudio==2.2.2 torchvision==0.17.2 speechbrain transformers jdc"
187
- ]
188
- },
189
- {
190
- "cell_type": "markdown",
191
- "id": "6acf1f8c-2cf3-4c9c-8a45-e2580ecbee27",
192
- "metadata": {},
193
- "source": [
194
- "## 2. Mise en place de la recette Speechbrain -- class Brain"
195
- ]
196
- },
197
- {
198
- "cell_type": "markdown",
199
- "id": "d5e8884d-3542-40ff-a454-597078fcf97c",
200
- "metadata": {},
201
- "source": [
202
- "### 2.1 Imports & logger"
203
- ]
204
- },
205
- {
206
- "cell_type": "code",
207
- "execution_count": null,
208
- "id": "6c677f9f-6abe-423f-b4dd-fdf5ded357cd",
209
- "metadata": {},
210
- "outputs": [],
211
- "source": [
212
- "import logging\n",
213
- "import os\n",
214
- "import sys\n",
215
- "from pathlib import Path\n",
216
- "\n",
217
- "import torch\n",
218
- "from hyperpyyaml import load_hyperpyyaml\n",
219
- "\n",
220
- "import speechbrain as sb\n",
221
- "from speechbrain.utils.distributed import if_main_process, run_on_main\n",
222
- "\n",
223
- "import jdc\n",
224
- "\n",
225
- "logger = logging.getLogger(__name__)"
226
- ]
227
- },
228
- {
229
- "cell_type": "markdown",
230
- "id": "9698bb92-16ad-4b61-8938-c74b62ee93b2",
231
- "metadata": {},
232
- "source": [
233
- "### 2.2 Création de notre classe héritant de la classe brain"
234
- ]
235
- },
236
- {
237
- "cell_type": "code",
238
- "execution_count": null,
239
- "id": "7c7cd624-6249-449b-8ee9-d4a73b7b3301",
240
- "metadata": {},
241
- "outputs": [],
242
- "source": [
243
- "# Define training procedure\n",
244
- "class MY_SSA_ASR(sb.Brain):\n",
245
- " print(\"\")\n",
246
- " # define here"
247
- ]
248
- },
249
- {
250
- "cell_type": "markdown",
251
- "id": "ecf31c9c-15dd-4428-aa10-b3cc5e127f0d",
252
- "metadata": {},
253
- "source": [
254
- "### 2.3 Définition de la fonction forward "
255
- ]
256
- },
257
- {
258
- "cell_type": "code",
259
- "execution_count": null,
260
- "id": "4368b488-b9d8-49ff-8ce3-78a12d46be83",
261
- "metadata": {},
262
- "outputs": [],
263
- "source": [
264
- "%%add_to MY_SSA_ASR\n",
265
- "def compute_forward(self, batch, stage):\n",
266
- " \"\"\"Forward computations from the waveform batches to the output probabilities.\"\"\"\n",
267
- " batch = batch.to(self.device)\n",
268
- " wavs, wav_lens = batch.sig\n",
269
- " wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)\n",
270
- "\n",
271
- " # Downsample the inputs if specified\n",
272
- " if hasattr(self.modules, \"downsampler\"):\n",
273
- " wavs = self.modules.downsampler(wavs)\n",
274
- "\n",
275
- " # Add waveform augmentation if specified.\n",
276
- " if stage == sb.Stage.TRAIN and hasattr(self.hparams, \"wav_augment\"):\n",
277
- " wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)\n",
278
- "\n",
279
- " # Forward pass\n",
280
- " feats = self.modules.hubert(wavs, wav_lens)\n",
281
- " x = self.modules.top_lin(feats)\n",
282
- "\n",
283
- " # Compute outputs\n",
284
- " logits = self.modules.ctc_lin(x)\n",
285
- " p_ctc = self.hparams.log_softmax(logits)\n",
286
- "\n",
287
- "\n",
288
- " p_tokens = None\n",
289
- " if stage == sb.Stage.VALID:\n",
290
- " p_tokens = sb.decoders.ctc_greedy_decode(p_ctc, wav_lens, blank_id=self.hparams.blank_index)\n",
291
- "\n",
292
- " elif stage == sb.Stage.TEST:\n",
293
- " p_tokens = test_searcher(p_ctc, wav_lens)\n",
294
- "\n",
295
- " candidates = []\n",
296
- " scores = []\n",
297
- "\n",
298
- " for batch in p_tokens:\n",
299
- " candidates.append([hyp.text for hyp in batch])\n",
300
- " scores.append([hyp.score for hyp in batch])\n",
301
- "\n",
302
- " if hasattr(self.hparams, \"rescorer\"):\n",
303
- " p_tokens, _ = self.hparams.rescorer.rescore(candidates, scores)\n",
304
- "\n",
305
- " return p_ctc, wav_lens, p_tokens\n"
306
- ]
307
- },
308
- {
309
- "cell_type": "markdown",
310
- "id": "f0052b79-5a27-4c4c-8601-7ab064e8c951",
311
- "metadata": {},
312
- "source": [
313
- "### 2.4 Définition de la fonction objectives"
314
- ]
315
- },
316
- {
317
- "cell_type": "code",
318
- "execution_count": null,
319
- "id": "3608aee8-c9c3-4e34-98bc-667513fa7f7b",
320
- "metadata": {},
321
- "outputs": [],
322
- "source": [
323
- "%%add_to MY_SSA_ASR\n",
324
- "def compute_objectives(self, predictions, batch, stage):\n",
325
- " \"\"\"Computes the loss (CTC+NLL) given predictions and targets.\"\"\"\n",
326
- "\n",
327
- " p_ctc, wav_lens, predicted_tokens = predictions\n",
328
- "\n",
329
- " ids = batch.id\n",
330
- " tokens, tokens_lens = batch.tokens\n",
331
- "\n",
332
- " # Labels must be extended if parallel augmentation or concatenated\n",
333
- " # augmentation was performed on the input (increasing the time dimension)\n",
334
- " if stage == sb.Stage.TRAIN and hasattr(self.hparams, \"wav_augment\"):\n",
335
- " (tokens, tokens_lens) = self.hparams.wav_augment.replicate_multiple_labels(tokens, tokens_lens)\n",
336
- "\n",
337
- "\n",
338
- "\n",
339
- " # Compute loss\n",
340
- " loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)\n",
341
- "\n",
342
- " if stage == sb.Stage.VALID:\n",
343
- " # Decode token terms to words\n",
344
- " predicted_words = [\"\".join(self.tokenizer.decode_ndim(utt_seq)).split(\" \") for utt_seq in predicted_tokens]\n",
345
- " \n",
346
- " elif stage == sb.Stage.TEST:\n",
347
- " predicted_words = [hyp[0].text.split(\" \") for hyp in predicted_tokens]\n",
348
- "\n",
349
- " if stage != sb.Stage.TRAIN:\n",
350
- " target_words = [wrd.split(\" \") for wrd in batch.wrd]\n",
351
- " self.wer_metric.append(ids, predicted_words, target_words)\n",
352
- " self.cer_metric.append(ids, predicted_words, target_words)\n",
353
- "\n",
354
- " return loss\n"
355
- ]
356
- },
357
- {
358
- "cell_type": "markdown",
359
- "id": "9a514c50-89ad-41cb-882a-23daf829a538",
360
- "metadata": {},
361
- "source": [
362
- "### 2.5 définition du comportement au début d'un \"stage\""
363
- ]
364
- },
365
- {
366
- "cell_type": "code",
367
- "execution_count": null,
368
- "id": "609814ce-3ef0-4818-a70f-cadc293c9dd2",
369
- "metadata": {},
370
- "outputs": [],
371
- "source": [
372
- "%%add_to MY_SSA_ASR\n",
373
- "# stage gestion\n",
374
- "def on_stage_start(self, stage, epoch):\n",
375
- " \"\"\"Gets called at the beginning of each epoch\"\"\"\n",
376
- " if stage != sb.Stage.TRAIN:\n",
377
- " self.cer_metric = self.hparams.cer_computer()\n",
378
- " self.wer_metric = self.hparams.error_rate_computer()\n",
379
- "\n",
380
- " if stage == sb.Stage.TEST:\n",
381
- " if hasattr(self.hparams, \"rescorer\"):\n",
382
- " self.hparams.rescorer.move_rescorers_to_device()\n",
383
- "\n"
384
- ]
385
- },
386
- {
387
- "cell_type": "markdown",
388
- "id": "55929209-c94a-4f8b-8f2e-9dd5d9de8be9",
389
- "metadata": {},
390
- "source": [
391
- "### 2.6 définition du comportement à la fin d'un \"stage\""
392
- ]
393
- },
394
- {
395
- "cell_type": "code",
396
- "execution_count": null,
397
- "id": "8f297542-10d5-47bf-9938-c141f5a99ab8",
398
- "metadata": {},
399
- "outputs": [],
400
- "source": [
401
- "%%add_to MY_SSA_ASR\n",
402
- "def on_stage_end(self, stage, stage_loss, epoch):\n",
403
- " \"\"\"Gets called at the end of an epoch.\"\"\"\n",
404
- " # Compute/store important stats\n",
405
- " stage_stats = {\"loss\": stage_loss}\n",
406
- " if stage == sb.Stage.TRAIN:\n",
407
- " self.train_stats = stage_stats\n",
408
- " else:\n",
409
- " stage_stats[\"CER\"] = self.cer_metric.summarize(\"error_rate\")\n",
410
- " stage_stats[\"WER\"] = self.wer_metric.summarize(\"error_rate\")\n",
411
- "\n",
412
- " # Perform end-of-iteration things, like annealing, logging, etc.\n",
413
- " if stage == sb.Stage.VALID:\n",
414
- " # *******************************\n",
415
- " # Anneal and update Learning Rate\n",
416
- " # *******************************\n",
417
- " old_lr_model, new_lr_model = self.hparams.lr_annealing_model(stage_stats[\"loss\"])\n",
418
- " old_lr_hubert, new_lr_hubert = self.hparams.lr_annealing_hubert(stage_stats[\"loss\"])\n",
419
- " sb.nnet.schedulers.update_learning_rate(self.model_optimizer, new_lr_model)\n",
420
- " sb.nnet.schedulers.update_learning_rate(self.hubert_optimizer, new_lr_hubert)\n",
421
- "\n",
422
- " # *****************\n",
423
- " # Logs informations\n",
424
- " # *****************\n",
425
- " self.hparams.train_logger.log_stats(stats_meta={\"epoch\": epoch, \"lr_model\": old_lr_model, \"lr_hubert\": old_lr_hubert}, train_stats=self.train_stats, valid_stats=stage_stats)\n",
426
- "\n",
427
- " # ***************\n",
428
- " # Save checkpoint\n",
429
- " # ***************\n",
430
- " self.checkpointer.save_and_keep_only(meta={\"WER\": stage_stats[\"WER\"]},min_keys=[\"WER\"])\n",
431
- "\n",
432
- " elif stage == sb.Stage.TEST:\n",
433
- " self.hparams.train_logger.log_stats(stats_meta={\"Epoch loaded\": self.hparams.epoch_counter.current},test_stats=stage_stats)\n",
434
- " if if_main_process():\n",
435
- " with open(self.hparams.test_wer_file, \"w\") as w:\n",
436
- " self.wer_metric.write_stats(w)\n"
437
- ]
438
- },
439
- {
440
- "cell_type": "markdown",
441
- "id": "0c656457-6b61-4316-8199-70021f92babf",
442
- "metadata": {},
443
- "source": [
444
- "### 2.7 définition de l'initialisation des optimizers"
445
- ]
446
- },
447
- {
448
- "cell_type": "code",
449
- "execution_count": null,
450
- "id": "da8d9cb5-c5ad-4e78-83d3-e129e138a741",
451
- "metadata": {},
452
- "outputs": [],
453
- "source": [
454
- "%%add_to MY_SSA_ASR\n",
455
- "def init_optimizers(self):\n",
456
- " \"Initializes the hubert optimizer and model optimizer\"\n",
457
- " self.hubert_optimizer = self.hparams.hubert_opt_class(self.modules.hubert.parameters())\n",
458
- " self.model_optimizer = self.hparams.model_opt_class(self.hparams.model.parameters())\n",
459
- "\n",
460
- " # save the optimizers in a dictionary\n",
461
- " # the key will be used in `freeze_optimizers()`\n",
462
- " self.optimizers_dict = {\"model_optimizer\": self.model_optimizer}\n",
463
- " if not self.hparams.freeze_hubert:\n",
464
- " self.optimizers_dict[\"hubert_optimizer\"] = self.hubert_optimizer\n",
465
- "\n",
466
- " if self.checkpointer is not None:\n",
467
- " self.checkpointer.add_recoverable(\"hubert_opt\", self.hubert_optimizer)\n",
468
- " self.checkpointer.add_recoverable(\"model_opt\", self.model_optimizer)\n"
469
- ]
470
- },
471
- {
472
- "cell_type": "markdown",
473
- "id": "cf2e730c-2faa-41f2-b98d-e5fbb2305cc2",
474
- "metadata": {},
475
- "source": [
476
- "## 3 Définition de la lecture des datasets"
477
- ]
478
- },
479
- {
480
- "cell_type": "code",
481
- "execution_count": null,
482
- "id": "c5e667f7-6269-4b49-88bb-5e431762c8fe",
483
- "metadata": {},
484
- "outputs": [],
485
- "source": [
486
- "def dataio_prepare(hparams):\n",
487
- " \"\"\"This function prepares the datasets to be used in the brain class.\n",
488
- " It also defines the data processing pipeline through user-defined functions.\n",
489
- " \"\"\"\n",
490
- "\n",
491
- " # **************\n",
492
- " # Load CSV files\n",
493
- " # **************\n",
494
- " data_folder = hparams[\"data_folder\"]\n",
495
- "\n",
496
- " train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(csv_path=hparams[\"train_csv\"],replacements={\"data_root\": data_folder})\n",
497
- " # we sort training data to speed up training and get better results.\n",
498
- " train_data = train_data.filtered_sorted(sort_key=\"duration\")\n",
499
- " hparams[\"train_dataloader_opts\"][\"shuffle\"] = False # when sorting do not shuffle in dataloader ! otherwise is pointless\n",
500
- "\n",
501
- " valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(csv_path=hparams[\"valid_csv\"],replacements={\"data_root\": data_folder})\n",
502
- " valid_data = valid_data.filtered_sorted(sort_key=\"duration\")\n",
503
- "\n",
504
- " # test is separate\n",
505
- " test_datasets = {}\n",
506
- " for csv_file in hparams[\"test_csv\"]:\n",
507
- " name = Path(csv_file).stem\n",
508
- " test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(csv_path=csv_file, replacements={\"data_root\": data_folder})\n",
509
- " test_datasets[name] = test_datasets[name].filtered_sorted(sort_key=\"duration\")\n",
510
- "\n",
511
- " datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]\n",
512
- "\n",
513
- " # *************************\n",
514
- " # 2. Define audio pipeline:\n",
515
- " # *************************\n",
516
- " @sb.utils.data_pipeline.takes(\"wav\")\n",
517
- " @sb.utils.data_pipeline.provides(\"sig\")\n",
518
- " def audio_pipeline(wav):\n",
519
- " sig = sb.dataio.dataio.read_audio(wav)\n",
520
- " return sig\n",
521
- "\n",
522
- " sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)\n",
523
- "\n",
524
- " # ************************\n",
525
- " # 3. Define text pipeline:\n",
526
- " # ************************\n",
527
- " label_encoder = sb.dataio.encoder.CTCTextEncoder()\n",
528
- " \n",
529
- " @sb.utils.data_pipeline.takes(\"wrd\")\n",
530
- " @sb.utils.data_pipeline.provides(\"wrd\", \"char_list\", \"tokens_list\", \"tokens\")\n",
531
- " def text_pipeline(wrd):\n",
532
- " yield wrd\n",
533
- " char_list = list(wrd)\n",
534
- " yield char_list\n",
535
- " tokens_list = label_encoder.encode_sequence(char_list)\n",
536
- " yield tokens_list\n",
537
- " tokens = torch.LongTensor(tokens_list)\n",
538
- " yield tokens\n",
539
- "\n",
540
- " sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)\n",
541
- "\n",
542
- "\n",
543
- " # *******************************\n",
544
- " # 4. Create or load label encoder\n",
545
- " # *******************************\n",
546
- " lab_enc_file = os.path.join(hparams[\"save_folder\"], \"label_encoder.txt\")\n",
547
- " special_labels = {\"blank_label\": hparams[\"blank_index\"]}\n",
548
- " label_encoder.add_unk()\n",
549
- " label_encoder.load_or_create(path=lab_enc_file, from_didatasets=[train_data], output_key=\"char_list\", special_labels=special_labels, sequence_input=True)\n",
550
- "\n",
551
- " # **************\n",
552
- " # 5. Set output:\n",
553
- " # **************\n",
554
- " sb.dataio.dataset.set_output_keys(datasets,[\"id\", \"sig\", \"wrd\", \"char_list\", \"tokens\"],)\n",
555
- "\n",
556
- " return train_data, valid_data, test_datasets, label_encoder\n"
557
- ]
558
- },
559
- {
560
- "cell_type": "markdown",
561
- "id": "e97c4f20-6951-4d12-8e17-9eb818a52bb1",
562
- "metadata": {},
563
- "source": [
564
- "## 4. Utilisation de la recette Créée"
565
- ]
566
- },
567
- {
568
- "cell_type": "markdown",
569
- "id": "76b72148-6bd0-48bd-ad40-cb6f8bfd34c0",
570
- "metadata": {},
571
- "source": [
572
- "### 4.1 Préparation au lancement"
573
- ]
574
- },
575
- {
576
- "cell_type": "code",
577
- "execution_count": null,
578
- "id": "d47ec39a-5562-4a63-8243-656c9235b7a2",
579
- "metadata": {},
580
- "outputs": [],
581
- "source": [
582
- "hparams_file, run_opts, overrides = sb.parse_arguments([\"/opt/marcel-c3/workdir/zqdb1553/jupyter/ASR_FLEURS-swahili_hf.yaml\"])\n",
583
- "# create ddp_group with the right communication protocol\n",
584
- "sb.utils.distributed.ddp_init_group(run_opts)\n",
585
- "\n",
586
- "# ***********************************\n",
587
- "# Chargement du fichier de paramètres\n",
588
- "# ***********************************\n",
589
- "with open(hparams_file) as fin:\n",
590
- " hparams = load_hyperpyyaml(fin, overrides)\n",
591
- "\n",
592
- "# ***************************\n",
593
- "# Create experiment directory\n",
594
- "# ***************************\n",
595
- "sb.create_experiment_directory(experiment_directory=hparams[\"output_folder\"], hyperparams_to_save=hparams_file, overrides=overrides)\n",
596
- "\n",
597
- "# ***************************\n",
598
- "# Create the datasets objects\n",
599
- "# ***************************\n",
600
- "train_data, valid_data, test_datasets, label_encoder = dataio_prepare(hparams)\n",
601
- "\n",
602
- "# **********************\n",
603
- "# Trainer initialization\n",
604
- "# **********************\n",
605
- "asr_brain = MY_SSA_ASR(modules=hparams[\"modules\"], hparams=hparams, run_opts=run_opts, checkpointer=hparams[\"checkpointer\"])\n",
606
- "asr_brain.tokenizer = label_encoder"
607
- ]
608
- },
609
- {
610
- "cell_type": "markdown",
611
- "id": "62ae72eb-416c-4ef0-9348-d02bbc268fbd",
612
- "metadata": {},
613
- "source": [
614
- "### 4.2 Apprentissage du modèle"
615
- ]
616
- },
617
- {
618
- "cell_type": "code",
619
- "execution_count": null,
620
- "id": "d3dd30ee-89c0-40ea-a9d2-0e2b9d8c8686",
621
- "metadata": {},
622
- "outputs": [],
623
- "source": [
624
- "# ********\n",
625
- "# Training\n",
626
- "# ********\n",
627
- "asr_brain.fit(asr_brain.hparams.epoch_counter, \n",
628
- " train_data, valid_data, \n",
629
- " train_loader_kwargs=hparams[\"train_dataloader_opts\"], \n",
630
- " valid_loader_kwargs=hparams[\"valid_dataloader_opts\"],\n",
631
- " )\n",
632
- "\n"
633
- ]
634
- },
635
- {
636
- "cell_type": "markdown",
637
- "id": "1b55af4c-c544-45ff-8435-58226218328f",
638
- "metadata": {},
639
- "source": [
640
- "### 4.3 Test du Modèle"
641
- ]
642
- },
643
- {
644
- "cell_type": "code",
645
- "execution_count": null,
646
- "id": "9cef9011-1a3e-43a4-ab16-8cfb2b57dbd9",
647
- "metadata": {},
648
- "outputs": [],
649
- "source": [
650
- "# *******\n",
651
- "# Testing\n",
652
- "# *******\n",
653
- "if not os.path.exists(hparams[\"output_wer_folder\"]):\n",
654
- " os.makedirs(hparams[\"output_wer_folder\"])\n",
655
- "\n",
656
- "from speechbrain.decoders.ctc import CTCBeamSearcher\n",
657
- "\n",
658
- "ind2lab = label_encoder.ind2lab\n",
659
- "vocab_list = [ind2lab[x] for x in range(len(ind2lab))]\n",
660
- "test_searcher = CTCBeamSearcher(**hparams[\"test_beam_search\"], vocab_list=vocab_list)\n",
661
- "\n",
662
- "for k in test_datasets.keys(): # Allow multiple evaluation throught list of test sets\n",
663
- " asr_brain.hparams.test_wer_file = os.path.join(hparams[\"output_wer_folder\"], f\"wer_{k}.txt\")\n",
664
- " asr_brain.evaluate(test_datasets[k], test_loader_kwargs=hparams[\"test_dataloader_opts\"], min_key=\"WER\")\n"
665
- ]
666
- }
667
- ],
668
- "metadata": {
669
- "kernelspec": {
670
- "display_name": "Python 3 (ipykernel)",
671
- "language": "python",
672
- "name": "python3"
673
- },
674
- "language_info": {
675
- "codemirror_mode": {
676
- "name": "ipython",
677
- "version": 3
678
- },
679
- "file_extension": ".py",
680
- "mimetype": "text/x-python",
681
- "name": "python",
682
- "nbconvert_exporter": "python",
683
- "pygments_lexer": "ipython3",
684
- "version": "3.10.14"
685
- }
686
- },
687
- "nbformat": 4,
688
- "nbformat_minor": 5
689
- }