Antoine-caubriere
commited on
Commit
•
fa7be55
1
Parent(s):
212bcf5
Upload SB_ASR_FLEURS_finetuning.ipynb
Browse files- SB_ASR_FLEURS_finetuning.ipynb +689 -0
SB_ASR_FLEURS_finetuning.ipynb
ADDED
@@ -0,0 +1,689 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "49b85514-0fb6-49c6-be76-259bfeb638c6",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Introduction\n",
|
9 |
+
"N'hésitez pas à nous contacter en cas de questions : antoine.caubriere@orange.com & elodie.gauthier@orange.com\n",
|
10 |
+
"\n",
|
11 |
+
"Pensez à modifier l'ensemble des PATH dans le fichier de configuration ASR_FLEURSswahili_hf.yaml et dans le code python ci-dessous (PATH_TO_YOUR_FOLDER).\n",
|
12 |
+
"\n",
|
13 |
+
"Dans le cas d'un changement de corpus (autre sous partie de FLEURS / vos propres jeux de données), pensez à modifier la taille de la couche de sortie du modèle : ASR_swahili_hf.yaml/output_neurons\n"
|
14 |
+
]
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"cell_type": "markdown",
|
18 |
+
"id": "e62faa86-911a-48ce-82bc-8a34e13ffbc4",
|
19 |
+
"metadata": {},
|
20 |
+
"source": [
|
21 |
+
"# Préparation des données FLEURS"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "markdown",
|
26 |
+
"id": "c6ccf4a5-cad1-4632-8954-f4e454ff3540",
|
27 |
+
"metadata": {},
|
28 |
+
"source": [
|
29 |
+
"### 1. Installation des dépendances"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"execution_count": null,
|
35 |
+
"id": "7bb8b44e-826f-4f13-b128-eebbd18dedc5",
|
36 |
+
"metadata": {
|
37 |
+
"jupyter": {
|
38 |
+
"source_hidden": true
|
39 |
+
}
|
40 |
+
},
|
41 |
+
"outputs": [],
|
42 |
+
"source": [
|
43 |
+
"pip install datasets librosa soundfile"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"cell_type": "markdown",
|
48 |
+
"id": "016d7646-bcca-4422-8b28-9d12d4b86c8f",
|
49 |
+
"metadata": {},
|
50 |
+
"source": [
|
51 |
+
"### 2. Téléchargement et formatage du dataset"
|
52 |
+
]
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"cell_type": "code",
|
56 |
+
"execution_count": null,
|
57 |
+
"id": "da273973-05ee-4de5-830e-34d7f2220353",
|
58 |
+
"metadata": {},
|
59 |
+
"outputs": [],
|
60 |
+
"source": [
|
61 |
+
"from datasets import load_dataset\n",
|
62 |
+
"from pathlib import Path\n",
|
63 |
+
"from collections import OrderedDict\n",
|
64 |
+
"from tqdm import tqdm\n",
|
65 |
+
"import shutil\n",
|
66 |
+
"import os\n",
|
67 |
+
"\n",
|
68 |
+
"dataset_write_base = \"/opt/marcel-c3/workdir/zqdb1553/jupyter/data_speechbrain/\"\n",
|
69 |
+
"cache_dir = \"/opt/marcel-c3/workdir/zqdb1553/jupyter/data_huggingface/\"\n",
|
70 |
+
"\n",
|
71 |
+
"if os.path.isdir(cache_dir):\n",
|
72 |
+
" print(\"rm -rf \"+cache_dir)\n",
|
73 |
+
" os.system(\"rm -rf \"+cache_dir)\n",
|
74 |
+
"\n",
|
75 |
+
"if os.path.isdir(dataset_write_base):\n",
|
76 |
+
" print(\"rm -rf \"+dataset_write_base)\n",
|
77 |
+
" os.system(\"rm -rf \"+dataset_write_base)\n",
|
78 |
+
"\n",
|
79 |
+
"# **************************************\n",
|
80 |
+
"# choix des langues à extraire de FLEURS\n",
|
81 |
+
"# **************************************\n",
|
82 |
+
"lang_dict = OrderedDict([\n",
|
83 |
+
" #(\"Afrikaans\",\"af_za\"),\n",
|
84 |
+
" #(\"Amharic\", \"am_et\"),\n",
|
85 |
+
" #(\"Fula\", \"ff_sn\"),\n",
|
86 |
+
" #(\"Ganda\", \"lg_ug\"),\n",
|
87 |
+
" #(\"Hausa\", \"ha_ng\"),\n",
|
88 |
+
" #(\"Igbo\", \"ig_ng\"),\n",
|
89 |
+
" #(\"Kamba\", \"kam_ke\"),\n",
|
90 |
+
" #(\"Lingala\", \"ln_cd\"),\n",
|
91 |
+
" #(\"Luo\", \"luo_ke\"),\n",
|
92 |
+
" #(\"Northern-Sotho\", \"nso_za\"),\n",
|
93 |
+
" #(\"Nyanja\", \"ny_mw\"),\n",
|
94 |
+
" #(\"Oromo\", \"om_et\"),\n",
|
95 |
+
" #(\"Shona\", \"sn_zw\"),\n",
|
96 |
+
" #(\"Somali\", \"so_so\"),\n",
|
97 |
+
" (\"Swahili\", \"sw_ke\"),\n",
|
98 |
+
" #(\"Umbundu\", \"umb_ao\"),\n",
|
99 |
+
" #(\"Wolof\", \"wo_sn\"), \n",
|
100 |
+
" #(\"Xhosa\", \"xh_za\"), \n",
|
101 |
+
" #(\"Yoruba\", \"yo_ng\"), \n",
|
102 |
+
" #(\"Zulu\", \"zu_za\")\n",
|
103 |
+
" ])\n",
|
104 |
+
"\n",
|
105 |
+
"# ********************************\n",
|
106 |
+
"# choix des sous-parties à traiter\n",
|
107 |
+
"# ********************************\n",
|
108 |
+
"datasets = [\"train\",\"test\",\"validation\"]\n",
|
109 |
+
"\n",
|
110 |
+
"for lang in lang_dict:\n",
|
111 |
+
" print(\"Prepare --->\", lang)\n",
|
112 |
+
" \n",
|
113 |
+
" # ********************************\n",
|
114 |
+
" # Download FLEURS from huggingface\n",
|
115 |
+
" # ********************************\n",
|
116 |
+
" fleurs_asr = load_dataset(\"google/fleurs\", lang_dict[lang],cache_dir=cache_dir, trust_remote_code=True)\n",
|
117 |
+
"\n",
|
118 |
+
" for subparts in datasets:\n",
|
119 |
+
" \n",
|
120 |
+
" used_ID = []\n",
|
121 |
+
" Path(dataset_write_base+\"/\"+lang+\"/wavs/\"+subparts).mkdir(parents=True, exist_ok=True)\n",
|
122 |
+
" \n",
|
123 |
+
" # csv header\n",
|
124 |
+
" f = open(dataset_write_base+\"/\"+lang+\"/\"+subparts+\".csv\", \"w\")\n",
|
125 |
+
" f.write(\"ID,duration,wav,spk_id,wrd\\n\")\n",
|
126 |
+
"\n",
|
127 |
+
" for uid in tqdm(range(len(fleurs_asr[subparts]))):\n",
|
128 |
+
"\n",
|
129 |
+
" # ***************\n",
|
130 |
+
" # format CSV line\n",
|
131 |
+
" # ***************\n",
|
132 |
+
" text_id = lang+\"_\"+str(fleurs_asr[subparts][uid][\"id\"])\n",
|
133 |
+
" \n",
|
134 |
+
" # some ID are duplicated (same speaker, same transcription BUT different recording)\n",
|
135 |
+
" while(text_id in used_ID):\n",
|
136 |
+
" text_id += \"_bis\"\n",
|
137 |
+
" used_ID.append(text_id)\n",
|
138 |
+
"\n",
|
139 |
+
" duration = \"{:.3f}\".format(round(float(fleurs_asr[subparts][uid][\"num_samples\"])/float(fleurs_asr[subparts][uid][\"audio\"][\"sampling_rate\"]),3))\n",
|
140 |
+
" wav_path = \"/\".join([dataset_write_base, lang, \"wavs\",subparts, fleurs_asr[subparts][uid][\"audio\"][\"path\"].split('/')[-1]])\n",
|
141 |
+
" spk_id = \"spk_\" + text_id\n",
|
142 |
+
" # AC : \"pseudo-normalisation\" de cas marginaux -- TODO mieux\n",
|
143 |
+
" wrd = fleurs_asr[subparts][uid][\"transcription\"].replace(',','').replace('$',' $ ').replace('\"','').replace('”','').replace(' ',' ')\n",
|
144 |
+
"\n",
|
145 |
+
" # **************\n",
|
146 |
+
" # write CSV line\n",
|
147 |
+
" # **************\n",
|
148 |
+
" f.write(text_id+\",\"+duration+\",\"+wav_path+\",\"+spk_id+\",\"+wrd+\"\\n\") \n",
|
149 |
+
"\n",
|
150 |
+
" # *******************\n",
|
151 |
+
" # Move wav from cache\n",
|
152 |
+
" # *******************\n",
|
153 |
+
" previous_path = \"/\".join(fleurs_asr[subparts][uid][\"path\"].split('/')[:-1]) + \"/\" + fleurs_asr[subparts][uid][\"audio\"][\"path\"]\n",
|
154 |
+
" new_path = \"/\".join([dataset_write_base,lang,\"wavs\",subparts,fleurs_asr[subparts][uid][\"audio\"][\"path\"].split('/')[-1]])\n",
|
155 |
+
" shutil.move(previous_path,new_path)\n",
|
156 |
+
" \n",
|
157 |
+
" f.close()\n",
|
158 |
+
" print(\"--->\", lang, \"done\")"
|
159 |
+
]
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"cell_type": "markdown",
|
163 |
+
"id": "4c32e369-f0f9-4695-8c9a-aa3a9de7bf7b",
|
164 |
+
"metadata": {},
|
165 |
+
"source": [
|
166 |
+
"# Recette ASR"
|
167 |
+
]
|
168 |
+
},
|
169 |
+
{
|
170 |
+
"cell_type": "markdown",
|
171 |
+
"id": "77fb2c55-3f8c-4f34-81f0-ad48a632e010",
|
172 |
+
"metadata": {
|
173 |
+
"jp-MarkdownHeadingCollapsed": true
|
174 |
+
},
|
175 |
+
"source": [
|
176 |
+
"## 1. Installation des dépendances"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "code",
|
181 |
+
"execution_count": null,
|
182 |
+
"id": "fbe25635-e765-480c-8416-c48a31ee6140",
|
183 |
+
"metadata": {},
|
184 |
+
"outputs": [],
|
185 |
+
"source": [
|
186 |
+
"pip install torch==2.2.2 torchaudio==2.2.2 torchvision==0.17.2 speechbrain transformers jdc"
|
187 |
+
]
|
188 |
+
},
|
189 |
+
{
|
190 |
+
"cell_type": "markdown",
|
191 |
+
"id": "6acf1f8c-2cf3-4c9c-8a45-e2580ecbee27",
|
192 |
+
"metadata": {},
|
193 |
+
"source": [
|
194 |
+
"## 2. Mise en place de la recette Speechbrain -- class Brain"
|
195 |
+
]
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"cell_type": "markdown",
|
199 |
+
"id": "d5e8884d-3542-40ff-a454-597078fcf97c",
|
200 |
+
"metadata": {},
|
201 |
+
"source": [
|
202 |
+
"### 2.1 Imports & logger"
|
203 |
+
]
|
204 |
+
},
|
205 |
+
{
|
206 |
+
"cell_type": "code",
|
207 |
+
"execution_count": null,
|
208 |
+
"id": "6c677f9f-6abe-423f-b4dd-fdf5ded357cd",
|
209 |
+
"metadata": {},
|
210 |
+
"outputs": [],
|
211 |
+
"source": [
|
212 |
+
"import logging\n",
|
213 |
+
"import os\n",
|
214 |
+
"import sys\n",
|
215 |
+
"from pathlib import Path\n",
|
216 |
+
"\n",
|
217 |
+
"import torch\n",
|
218 |
+
"from hyperpyyaml import load_hyperpyyaml\n",
|
219 |
+
"\n",
|
220 |
+
"import speechbrain as sb\n",
|
221 |
+
"from speechbrain.utils.distributed import if_main_process, run_on_main\n",
|
222 |
+
"\n",
|
223 |
+
"import jdc\n",
|
224 |
+
"\n",
|
225 |
+
"logger = logging.getLogger(__name__)"
|
226 |
+
]
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"cell_type": "markdown",
|
230 |
+
"id": "9698bb92-16ad-4b61-8938-c74b62ee93b2",
|
231 |
+
"metadata": {},
|
232 |
+
"source": [
|
233 |
+
"### 2.2 Création de notre classe héritant de la classe brain"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"cell_type": "code",
|
238 |
+
"execution_count": null,
|
239 |
+
"id": "7c7cd624-6249-449b-8ee9-d4a73b7b3301",
|
240 |
+
"metadata": {},
|
241 |
+
"outputs": [],
|
242 |
+
"source": [
|
243 |
+
"# Define training procedure\n",
|
244 |
+
"class MY_SSA_ASR(sb.Brain):\n",
|
245 |
+
" print(\"\")\n",
|
246 |
+
" # define here"
|
247 |
+
]
|
248 |
+
},
|
249 |
+
{
|
250 |
+
"cell_type": "markdown",
|
251 |
+
"id": "ecf31c9c-15dd-4428-aa10-b3cc5e127f0d",
|
252 |
+
"metadata": {},
|
253 |
+
"source": [
|
254 |
+
"### 2.3 Définition de la fonction forward "
|
255 |
+
]
|
256 |
+
},
|
257 |
+
{
|
258 |
+
"cell_type": "code",
|
259 |
+
"execution_count": null,
|
260 |
+
"id": "4368b488-b9d8-49ff-8ce3-78a12d46be83",
|
261 |
+
"metadata": {},
|
262 |
+
"outputs": [],
|
263 |
+
"source": [
|
264 |
+
"%%add_to MY_SSA_ASR\n",
|
265 |
+
"def compute_forward(self, batch, stage):\n",
|
266 |
+
" \"\"\"Forward computations from the waveform batches to the output probabilities.\"\"\"\n",
|
267 |
+
" batch = batch.to(self.device)\n",
|
268 |
+
" wavs, wav_lens = batch.sig\n",
|
269 |
+
" wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)\n",
|
270 |
+
"\n",
|
271 |
+
" # Downsample the inputs if specified\n",
|
272 |
+
" if hasattr(self.modules, \"downsampler\"):\n",
|
273 |
+
" wavs = self.modules.downsampler(wavs)\n",
|
274 |
+
"\n",
|
275 |
+
" # Add waveform augmentation if specified.\n",
|
276 |
+
" if stage == sb.Stage.TRAIN and hasattr(self.hparams, \"wav_augment\"):\n",
|
277 |
+
" wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)\n",
|
278 |
+
"\n",
|
279 |
+
" # Forward pass\n",
|
280 |
+
" feats = self.modules.hubert(wavs, wav_lens)\n",
|
281 |
+
" x = self.modules.top_lin(feats)\n",
|
282 |
+
"\n",
|
283 |
+
" # Compute outputs\n",
|
284 |
+
" logits = self.modules.ctc_lin(x)\n",
|
285 |
+
" p_ctc = self.hparams.log_softmax(logits)\n",
|
286 |
+
"\n",
|
287 |
+
"\n",
|
288 |
+
" p_tokens = None\n",
|
289 |
+
" if stage == sb.Stage.VALID:\n",
|
290 |
+
" p_tokens = sb.decoders.ctc_greedy_decode(p_ctc, wav_lens, blank_id=self.hparams.blank_index)\n",
|
291 |
+
"\n",
|
292 |
+
" elif stage == sb.Stage.TEST:\n",
|
293 |
+
" p_tokens = test_searcher(p_ctc, wav_lens)\n",
|
294 |
+
"\n",
|
295 |
+
" candidates = []\n",
|
296 |
+
" scores = []\n",
|
297 |
+
"\n",
|
298 |
+
" for batch in p_tokens:\n",
|
299 |
+
" candidates.append([hyp.text for hyp in batch])\n",
|
300 |
+
" scores.append([hyp.score for hyp in batch])\n",
|
301 |
+
"\n",
|
302 |
+
" if hasattr(self.hparams, \"rescorer\"):\n",
|
303 |
+
" p_tokens, _ = self.hparams.rescorer.rescore(candidates, scores)\n",
|
304 |
+
"\n",
|
305 |
+
" return p_ctc, wav_lens, p_tokens\n"
|
306 |
+
]
|
307 |
+
},
|
308 |
+
{
|
309 |
+
"cell_type": "markdown",
|
310 |
+
"id": "f0052b79-5a27-4c4c-8601-7ab064e8c951",
|
311 |
+
"metadata": {},
|
312 |
+
"source": [
|
313 |
+
"### 2.4 Définition de la fonction objectives"
|
314 |
+
]
|
315 |
+
},
|
316 |
+
{
|
317 |
+
"cell_type": "code",
|
318 |
+
"execution_count": null,
|
319 |
+
"id": "3608aee8-c9c3-4e34-98bc-667513fa7f7b",
|
320 |
+
"metadata": {},
|
321 |
+
"outputs": [],
|
322 |
+
"source": [
|
323 |
+
"%%add_to MY_SSA_ASR\n",
|
324 |
+
"def compute_objectives(self, predictions, batch, stage):\n",
|
325 |
+
" \"\"\"Computes the loss (CTC+NLL) given predictions and targets.\"\"\"\n",
|
326 |
+
"\n",
|
327 |
+
" p_ctc, wav_lens, predicted_tokens = predictions\n",
|
328 |
+
"\n",
|
329 |
+
" ids = batch.id\n",
|
330 |
+
" tokens, tokens_lens = batch.tokens\n",
|
331 |
+
"\n",
|
332 |
+
" # Labels must be extended if parallel augmentation or concatenated\n",
|
333 |
+
" # augmentation was performed on the input (increasing the time dimension)\n",
|
334 |
+
" if stage == sb.Stage.TRAIN and hasattr(self.hparams, \"wav_augment\"):\n",
|
335 |
+
" (tokens, tokens_lens) = self.hparams.wav_augment.replicate_multiple_labels(tokens, tokens_lens)\n",
|
336 |
+
"\n",
|
337 |
+
"\n",
|
338 |
+
"\n",
|
339 |
+
" # Compute loss\n",
|
340 |
+
" loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)\n",
|
341 |
+
"\n",
|
342 |
+
" if stage == sb.Stage.VALID:\n",
|
343 |
+
" # Decode token terms to words\n",
|
344 |
+
" predicted_words = [\"\".join(self.tokenizer.decode_ndim(utt_seq)).split(\" \") for utt_seq in predicted_tokens]\n",
|
345 |
+
" \n",
|
346 |
+
" elif stage == sb.Stage.TEST:\n",
|
347 |
+
" predicted_words = [hyp[0].text.split(\" \") for hyp in predicted_tokens]\n",
|
348 |
+
"\n",
|
349 |
+
" if stage != sb.Stage.TRAIN:\n",
|
350 |
+
" target_words = [wrd.split(\" \") for wrd in batch.wrd]\n",
|
351 |
+
" self.wer_metric.append(ids, predicted_words, target_words)\n",
|
352 |
+
" self.cer_metric.append(ids, predicted_words, target_words)\n",
|
353 |
+
"\n",
|
354 |
+
" return loss\n"
|
355 |
+
]
|
356 |
+
},
|
357 |
+
{
|
358 |
+
"cell_type": "markdown",
|
359 |
+
"id": "9a514c50-89ad-41cb-882a-23daf829a538",
|
360 |
+
"metadata": {},
|
361 |
+
"source": [
|
362 |
+
"### 2.5 définition du comportement au début d'un \"stage\""
|
363 |
+
]
|
364 |
+
},
|
365 |
+
{
|
366 |
+
"cell_type": "code",
|
367 |
+
"execution_count": null,
|
368 |
+
"id": "609814ce-3ef0-4818-a70f-cadc293c9dd2",
|
369 |
+
"metadata": {},
|
370 |
+
"outputs": [],
|
371 |
+
"source": [
|
372 |
+
"%%add_to MY_SSA_ASR\n",
|
373 |
+
"# stage gestion\n",
|
374 |
+
"def on_stage_start(self, stage, epoch):\n",
|
375 |
+
" \"\"\"Gets called at the beginning of each epoch\"\"\"\n",
|
376 |
+
" if stage != sb.Stage.TRAIN:\n",
|
377 |
+
" self.cer_metric = self.hparams.cer_computer()\n",
|
378 |
+
" self.wer_metric = self.hparams.error_rate_computer()\n",
|
379 |
+
"\n",
|
380 |
+
" if stage == sb.Stage.TEST:\n",
|
381 |
+
" if hasattr(self.hparams, \"rescorer\"):\n",
|
382 |
+
" self.hparams.rescorer.move_rescorers_to_device()\n",
|
383 |
+
"\n"
|
384 |
+
]
|
385 |
+
},
|
386 |
+
{
|
387 |
+
"cell_type": "markdown",
|
388 |
+
"id": "55929209-c94a-4f8b-8f2e-9dd5d9de8be9",
|
389 |
+
"metadata": {},
|
390 |
+
"source": [
|
391 |
+
"### 2.6 définition du comportement à la fin d'un \"stage\""
|
392 |
+
]
|
393 |
+
},
|
394 |
+
{
|
395 |
+
"cell_type": "code",
|
396 |
+
"execution_count": null,
|
397 |
+
"id": "8f297542-10d5-47bf-9938-c141f5a99ab8",
|
398 |
+
"metadata": {},
|
399 |
+
"outputs": [],
|
400 |
+
"source": [
|
401 |
+
"%%add_to MY_SSA_ASR\n",
|
402 |
+
"def on_stage_end(self, stage, stage_loss, epoch):\n",
|
403 |
+
" \"\"\"Gets called at the end of an epoch.\"\"\"\n",
|
404 |
+
" # Compute/store important stats\n",
|
405 |
+
" stage_stats = {\"loss\": stage_loss}\n",
|
406 |
+
" if stage == sb.Stage.TRAIN:\n",
|
407 |
+
" self.train_stats = stage_stats\n",
|
408 |
+
" else:\n",
|
409 |
+
" stage_stats[\"CER\"] = self.cer_metric.summarize(\"error_rate\")\n",
|
410 |
+
" stage_stats[\"WER\"] = self.wer_metric.summarize(\"error_rate\")\n",
|
411 |
+
"\n",
|
412 |
+
" # Perform end-of-iteration things, like annealing, logging, etc.\n",
|
413 |
+
" if stage == sb.Stage.VALID:\n",
|
414 |
+
" # *******************************\n",
|
415 |
+
" # Anneal and update Learning Rate\n",
|
416 |
+
" # *******************************\n",
|
417 |
+
" old_lr_model, new_lr_model = self.hparams.lr_annealing_model(stage_stats[\"loss\"])\n",
|
418 |
+
" old_lr_hubert, new_lr_hubert = self.hparams.lr_annealing_hubert(stage_stats[\"loss\"])\n",
|
419 |
+
" sb.nnet.schedulers.update_learning_rate(self.model_optimizer, new_lr_model)\n",
|
420 |
+
" sb.nnet.schedulers.update_learning_rate(self.hubert_optimizer, new_lr_hubert)\n",
|
421 |
+
"\n",
|
422 |
+
" # *****************\n",
|
423 |
+
" # Logs informations\n",
|
424 |
+
" # *****************\n",
|
425 |
+
" self.hparams.train_logger.log_stats(stats_meta={\"epoch\": epoch, \"lr_model\": old_lr_model, \"lr_hubert\": old_lr_hubert}, train_stats=self.train_stats, valid_stats=stage_stats)\n",
|
426 |
+
"\n",
|
427 |
+
" # ***************\n",
|
428 |
+
" # Save checkpoint\n",
|
429 |
+
" # ***************\n",
|
430 |
+
" self.checkpointer.save_and_keep_only(meta={\"WER\": stage_stats[\"WER\"]},min_keys=[\"WER\"])\n",
|
431 |
+
"\n",
|
432 |
+
" elif stage == sb.Stage.TEST:\n",
|
433 |
+
" self.hparams.train_logger.log_stats(stats_meta={\"Epoch loaded\": self.hparams.epoch_counter.current},test_stats=stage_stats)\n",
|
434 |
+
" if if_main_process():\n",
|
435 |
+
" with open(self.hparams.test_wer_file, \"w\") as w:\n",
|
436 |
+
" self.wer_metric.write_stats(w)\n"
|
437 |
+
]
|
438 |
+
},
|
439 |
+
{
|
440 |
+
"cell_type": "markdown",
|
441 |
+
"id": "0c656457-6b61-4316-8199-70021f92babf",
|
442 |
+
"metadata": {},
|
443 |
+
"source": [
|
444 |
+
"### 2.7 définition de l'initialisation des optimizers"
|
445 |
+
]
|
446 |
+
},
|
447 |
+
{
|
448 |
+
"cell_type": "code",
|
449 |
+
"execution_count": null,
|
450 |
+
"id": "da8d9cb5-c5ad-4e78-83d3-e129e138a741",
|
451 |
+
"metadata": {},
|
452 |
+
"outputs": [],
|
453 |
+
"source": [
|
454 |
+
"%%add_to MY_SSA_ASR\n",
|
455 |
+
"def init_optimizers(self):\n",
|
456 |
+
" \"Initializes the hubert optimizer and model optimizer\"\n",
|
457 |
+
" self.hubert_optimizer = self.hparams.hubert_opt_class(self.modules.hubert.parameters())\n",
|
458 |
+
" self.model_optimizer = self.hparams.model_opt_class(self.hparams.model.parameters())\n",
|
459 |
+
"\n",
|
460 |
+
" # save the optimizers in a dictionary\n",
|
461 |
+
" # the key will be used in `freeze_optimizers()`\n",
|
462 |
+
" self.optimizers_dict = {\"model_optimizer\": self.model_optimizer}\n",
|
463 |
+
" if not self.hparams.freeze_hubert:\n",
|
464 |
+
" self.optimizers_dict[\"hubert_optimizer\"] = self.hubert_optimizer\n",
|
465 |
+
"\n",
|
466 |
+
" if self.checkpointer is not None:\n",
|
467 |
+
" self.checkpointer.add_recoverable(\"hubert_opt\", self.hubert_optimizer)\n",
|
468 |
+
" self.checkpointer.add_recoverable(\"model_opt\", self.model_optimizer)\n"
|
469 |
+
]
|
470 |
+
},
|
471 |
+
{
|
472 |
+
"cell_type": "markdown",
|
473 |
+
"id": "cf2e730c-2faa-41f2-b98d-e5fbb2305cc2",
|
474 |
+
"metadata": {},
|
475 |
+
"source": [
|
476 |
+
"## 3 Définition de la lecture des datasets"
|
477 |
+
]
|
478 |
+
},
|
479 |
+
{
|
480 |
+
"cell_type": "code",
|
481 |
+
"execution_count": null,
|
482 |
+
"id": "c5e667f7-6269-4b49-88bb-5e431762c8fe",
|
483 |
+
"metadata": {},
|
484 |
+
"outputs": [],
|
485 |
+
"source": [
|
486 |
+
"def dataio_prepare(hparams):\n",
|
487 |
+
" \"\"\"This function prepares the datasets to be used in the brain class.\n",
|
488 |
+
" It also defines the data processing pipeline through user-defined functions.\n",
|
489 |
+
" \"\"\"\n",
|
490 |
+
"\n",
|
491 |
+
" # **************\n",
|
492 |
+
" # Load CSV files\n",
|
493 |
+
" # **************\n",
|
494 |
+
" data_folder = hparams[\"data_folder\"]\n",
|
495 |
+
"\n",
|
496 |
+
" train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(csv_path=hparams[\"train_csv\"],replacements={\"data_root\": data_folder})\n",
|
497 |
+
" # we sort training data to speed up training and get better results.\n",
|
498 |
+
" train_data = train_data.filtered_sorted(sort_key=\"duration\")\n",
|
499 |
+
" hparams[\"train_dataloader_opts\"][\"shuffle\"] = False # when sorting do not shuffle in dataloader ! otherwise is pointless\n",
|
500 |
+
"\n",
|
501 |
+
" valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(csv_path=hparams[\"valid_csv\"],replacements={\"data_root\": data_folder})\n",
|
502 |
+
" valid_data = valid_data.filtered_sorted(sort_key=\"duration\")\n",
|
503 |
+
"\n",
|
504 |
+
" # test is separate\n",
|
505 |
+
" test_datasets = {}\n",
|
506 |
+
" for csv_file in hparams[\"test_csv\"]:\n",
|
507 |
+
" name = Path(csv_file).stem\n",
|
508 |
+
" test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(csv_path=csv_file, replacements={\"data_root\": data_folder})\n",
|
509 |
+
" test_datasets[name] = test_datasets[name].filtered_sorted(sort_key=\"duration\")\n",
|
510 |
+
"\n",
|
511 |
+
" datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]\n",
|
512 |
+
"\n",
|
513 |
+
" # *************************\n",
|
514 |
+
" # 2. Define audio pipeline:\n",
|
515 |
+
" # *************************\n",
|
516 |
+
" @sb.utils.data_pipeline.takes(\"wav\")\n",
|
517 |
+
" @sb.utils.data_pipeline.provides(\"sig\")\n",
|
518 |
+
" def audio_pipeline(wav):\n",
|
519 |
+
" sig = sb.dataio.dataio.read_audio(wav)\n",
|
520 |
+
" return sig\n",
|
521 |
+
"\n",
|
522 |
+
" sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)\n",
|
523 |
+
"\n",
|
524 |
+
" # ************************\n",
|
525 |
+
" # 3. Define text pipeline:\n",
|
526 |
+
" # ************************\n",
|
527 |
+
" label_encoder = sb.dataio.encoder.CTCTextEncoder()\n",
|
528 |
+
" \n",
|
529 |
+
" @sb.utils.data_pipeline.takes(\"wrd\")\n",
|
530 |
+
" @sb.utils.data_pipeline.provides(\"wrd\", \"char_list\", \"tokens_list\", \"tokens\")\n",
|
531 |
+
" def text_pipeline(wrd):\n",
|
532 |
+
" yield wrd\n",
|
533 |
+
" char_list = list(wrd)\n",
|
534 |
+
" yield char_list\n",
|
535 |
+
" tokens_list = label_encoder.encode_sequence(char_list)\n",
|
536 |
+
" yield tokens_list\n",
|
537 |
+
" tokens = torch.LongTensor(tokens_list)\n",
|
538 |
+
" yield tokens\n",
|
539 |
+
"\n",
|
540 |
+
" sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)\n",
|
541 |
+
"\n",
|
542 |
+
"\n",
|
543 |
+
" # *******************************\n",
|
544 |
+
" # 4. Create or load label encoder\n",
|
545 |
+
" # *******************************\n",
|
546 |
+
" lab_enc_file = os.path.join(hparams[\"save_folder\"], \"label_encoder.txt\")\n",
|
547 |
+
" special_labels = {\"blank_label\": hparams[\"blank_index\"]}\n",
|
548 |
+
" label_encoder.add_unk()\n",
|
549 |
+
" label_encoder.load_or_create(path=lab_enc_file, from_didatasets=[train_data], output_key=\"char_list\", special_labels=special_labels, sequence_input=True)\n",
|
550 |
+
"\n",
|
551 |
+
" # **************\n",
|
552 |
+
" # 5. Set output:\n",
|
553 |
+
" # **************\n",
|
554 |
+
" sb.dataio.dataset.set_output_keys(datasets,[\"id\", \"sig\", \"wrd\", \"char_list\", \"tokens\"],)\n",
|
555 |
+
"\n",
|
556 |
+
" return train_data, valid_data, test_datasets, label_encoder\n"
|
557 |
+
]
|
558 |
+
},
|
559 |
+
{
|
560 |
+
"cell_type": "markdown",
|
561 |
+
"id": "e97c4f20-6951-4d12-8e17-9eb818a52bb1",
|
562 |
+
"metadata": {},
|
563 |
+
"source": [
|
564 |
+
"## 4. Utilisation de la recette Créée"
|
565 |
+
]
|
566 |
+
},
|
567 |
+
{
|
568 |
+
"cell_type": "markdown",
|
569 |
+
"id": "76b72148-6bd0-48bd-ad40-cb6f8bfd34c0",
|
570 |
+
"metadata": {},
|
571 |
+
"source": [
|
572 |
+
"### 4.1 Préparation au lancement"
|
573 |
+
]
|
574 |
+
},
|
575 |
+
{
|
576 |
+
"cell_type": "code",
|
577 |
+
"execution_count": null,
|
578 |
+
"id": "d47ec39a-5562-4a63-8243-656c9235b7a2",
|
579 |
+
"metadata": {},
|
580 |
+
"outputs": [],
|
581 |
+
"source": [
|
582 |
+
"hparams_file, run_opts, overrides = sb.parse_arguments([\"/opt/marcel-c3/workdir/zqdb1553/jupyter/ASR_FLEURS-swahili_hf.yaml\"])\n",
|
583 |
+
"# create ddp_group with the right communication protocol\n",
|
584 |
+
"sb.utils.distributed.ddp_init_group(run_opts)\n",
|
585 |
+
"\n",
|
586 |
+
"# ***********************************\n",
|
587 |
+
"# Chargement du fichier de paramètres\n",
|
588 |
+
"# ***********************************\n",
|
589 |
+
"with open(hparams_file) as fin:\n",
|
590 |
+
" hparams = load_hyperpyyaml(fin, overrides)\n",
|
591 |
+
"\n",
|
592 |
+
"# ***************************\n",
|
593 |
+
"# Create experiment directory\n",
|
594 |
+
"# ***************************\n",
|
595 |
+
"sb.create_experiment_directory(experiment_directory=hparams[\"output_folder\"], hyperparams_to_save=hparams_file, overrides=overrides)\n",
|
596 |
+
"\n",
|
597 |
+
"# ***************************\n",
|
598 |
+
"# Create the datasets objects\n",
|
599 |
+
"# ***************************\n",
|
600 |
+
"train_data, valid_data, test_datasets, label_encoder = dataio_prepare(hparams)\n",
|
601 |
+
"\n",
|
602 |
+
"# **********************\n",
|
603 |
+
"# Trainer initialization\n",
|
604 |
+
"# **********************\n",
|
605 |
+
"asr_brain = MY_SSA_ASR(modules=hparams[\"modules\"], hparams=hparams, run_opts=run_opts, checkpointer=hparams[\"checkpointer\"])\n",
|
606 |
+
"asr_brain.tokenizer = label_encoder"
|
607 |
+
]
|
608 |
+
},
|
609 |
+
{
|
610 |
+
"cell_type": "markdown",
|
611 |
+
"id": "62ae72eb-416c-4ef0-9348-d02bbc268fbd",
|
612 |
+
"metadata": {},
|
613 |
+
"source": [
|
614 |
+
"### 4.2 Apprentissage du modèle"
|
615 |
+
]
|
616 |
+
},
|
617 |
+
{
|
618 |
+
"cell_type": "code",
|
619 |
+
"execution_count": null,
|
620 |
+
"id": "d3dd30ee-89c0-40ea-a9d2-0e2b9d8c8686",
|
621 |
+
"metadata": {},
|
622 |
+
"outputs": [],
|
623 |
+
"source": [
|
624 |
+
"# ********\n",
|
625 |
+
"# Training\n",
|
626 |
+
"# ********\n",
|
627 |
+
"asr_brain.fit(asr_brain.hparams.epoch_counter, \n",
|
628 |
+
" train_data, valid_data, \n",
|
629 |
+
" train_loader_kwargs=hparams[\"train_dataloader_opts\"], \n",
|
630 |
+
" valid_loader_kwargs=hparams[\"valid_dataloader_opts\"],\n",
|
631 |
+
" )\n",
|
632 |
+
"\n"
|
633 |
+
]
|
634 |
+
},
|
635 |
+
{
|
636 |
+
"cell_type": "markdown",
|
637 |
+
"id": "1b55af4c-c544-45ff-8435-58226218328f",
|
638 |
+
"metadata": {},
|
639 |
+
"source": [
|
640 |
+
"### 4.3 Test du Modèle"
|
641 |
+
]
|
642 |
+
},
|
643 |
+
{
|
644 |
+
"cell_type": "code",
|
645 |
+
"execution_count": null,
|
646 |
+
"id": "9cef9011-1a3e-43a4-ab16-8cfb2b57dbd9",
|
647 |
+
"metadata": {},
|
648 |
+
"outputs": [],
|
649 |
+
"source": [
|
650 |
+
"# *******\n",
|
651 |
+
"# Testing\n",
|
652 |
+
"# *******\n",
|
653 |
+
"if not os.path.exists(hparams[\"output_wer_folder\"]):\n",
|
654 |
+
" os.makedirs(hparams[\"output_wer_folder\"])\n",
|
655 |
+
"\n",
|
656 |
+
"from speechbrain.decoders.ctc import CTCBeamSearcher\n",
|
657 |
+
"\n",
|
658 |
+
"ind2lab = label_encoder.ind2lab\n",
|
659 |
+
"vocab_list = [ind2lab[x] for x in range(len(ind2lab))]\n",
|
660 |
+
"test_searcher = CTCBeamSearcher(**hparams[\"test_beam_search\"], vocab_list=vocab_list)\n",
|
661 |
+
"\n",
|
662 |
+
"for k in test_datasets.keys(): # Allow multiple evaluation throught list of test sets\n",
|
663 |
+
" asr_brain.hparams.test_wer_file = os.path.join(hparams[\"output_wer_folder\"], f\"wer_{k}.txt\")\n",
|
664 |
+
" asr_brain.evaluate(test_datasets[k], test_loader_kwargs=hparams[\"test_dataloader_opts\"], min_key=\"WER\")\n"
|
665 |
+
]
|
666 |
+
}
|
667 |
+
],
|
668 |
+
"metadata": {
|
669 |
+
"kernelspec": {
|
670 |
+
"display_name": "Python 3 (ipykernel)",
|
671 |
+
"language": "python",
|
672 |
+
"name": "python3"
|
673 |
+
},
|
674 |
+
"language_info": {
|
675 |
+
"codemirror_mode": {
|
676 |
+
"name": "ipython",
|
677 |
+
"version": 3
|
678 |
+
},
|
679 |
+
"file_extension": ".py",
|
680 |
+
"mimetype": "text/x-python",
|
681 |
+
"name": "python",
|
682 |
+
"nbconvert_exporter": "python",
|
683 |
+
"pygments_lexer": "ipython3",
|
684 |
+
"version": "3.10.14"
|
685 |
+
}
|
686 |
+
},
|
687 |
+
"nbformat": 4,
|
688 |
+
"nbformat_minor": 5
|
689 |
+
}
|