{ "cells": [ { "cell_type": "markdown", "id": "49b85514-0fb6-49c6-be76-259bfeb638c6", "metadata": {}, "source": [ "# Introduction\n", "N'hésitez pas à nous contacter en cas de questions : antoine.caubriere@orange.com & elodie.gauthier@orange.com\n", "\n", "Pensez à modifier l'ensemble des PATH dans le fichier de configuration ASR_FLEURSswahili_hf.yaml et dans le code python ci-dessous (PATH_TO_YOUR_FOLDER).\n", "\n", "Dans le cas d'un changement de corpus (autre sous partie de FLEURS / vos propres jeux de données), pensez à modifier la taille de la couche de sortie du modèle : ASR_swahili_hf.yaml/output_neurons\n" ] }, { "cell_type": "markdown", "id": "e62faa86-911a-48ce-82bc-8a34e13ffbc4", "metadata": {}, "source": [ "# Préparation des données FLEURS" ] }, { "cell_type": "markdown", "id": "c6ccf4a5-cad1-4632-8954-f4e454ff3540", "metadata": {}, "source": [ "### 1. Installation des dépendances" ] }, { "cell_type": "code", "execution_count": null, "id": "7bb8b44e-826f-4f13-b128-eebbd18dedc5", "metadata": { "jupyter": { "source_hidden": true } }, "outputs": [], "source": [ "pip install datasets librosa soundfile" ] }, { "cell_type": "markdown", "id": "016d7646-bcca-4422-8b28-9d12d4b86c8f", "metadata": {}, "source": [ "### 2. Téléchargement et formatage du dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "da273973-05ee-4de5-830e-34d7f2220353", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "from pathlib import Path\n", "from collections import OrderedDict\n", "from tqdm import tqdm\n", "import shutil\n", "import os\n", "\n", "dataset_write_base = \"PATH_TO_YOUR_FOLDER/data_speechbrain/\"\n", "cache_dir = \"PATH_TO_YOUR_FOLDER/data_huggingface/\"\n", "\n", "if os.path.isdir(cache_dir):\n", " print(\"rm -rf \"+cache_dir)\n", " os.system(\"rm -rf \"+cache_dir)\n", "\n", "if os.path.isdir(dataset_write_base):\n", " print(\"rm -rf \"+dataset_write_base)\n", " os.system(\"rm -rf \"+dataset_write_base)\n", "\n", "# **************************************\n", "# choix des langues à extraire de FLEURS\n", "# **************************************\n", "lang_dict = OrderedDict([\n", " #(\"Afrikaans\",\"af_za\"),\n", " #(\"Amharic\", \"am_et\"),\n", " #(\"Fula\", \"ff_sn\"),\n", " #(\"Ganda\", \"lg_ug\"),\n", " #(\"Hausa\", \"ha_ng\"),\n", " #(\"Igbo\", \"ig_ng\"),\n", " #(\"Kamba\", \"kam_ke\"),\n", " #(\"Lingala\", \"ln_cd\"),\n", " #(\"Luo\", \"luo_ke\"),\n", " #(\"Northern-Sotho\", \"nso_za\"),\n", " #(\"Nyanja\", \"ny_mw\"),\n", " #(\"Oromo\", \"om_et\"),\n", " #(\"Shona\", \"sn_zw\"),\n", " #(\"Somali\", \"so_so\"),\n", " (\"Swahili\", \"sw_ke\"),\n", " #(\"Umbundu\", \"umb_ao\"),\n", " #(\"Wolof\", \"wo_sn\"), \n", " #(\"Xhosa\", \"xh_za\"), \n", " #(\"Yoruba\", \"yo_ng\"), \n", " #(\"Zulu\", \"zu_za\")\n", " ])\n", "\n", "# ********************************\n", "# choix des sous-parties à traiter\n", "# ********************************\n", "datasets = [\"train\",\"test\",\"validation\"]\n", "\n", "for lang in lang_dict:\n", " print(\"Prepare --->\", lang)\n", " \n", " # ********************************\n", " # Download FLEURS from huggingface\n", " # ********************************\n", " fleurs_asr = load_dataset(\"google/fleurs\", lang_dict[lang],cache_dir=cache_dir, trust_remote_code=True)\n", "\n", " for subparts in datasets:\n", " \n", " used_ID = []\n", " Path(dataset_write_base+\"/\"+lang+\"/wavs/\"+subparts).mkdir(parents=True, exist_ok=True)\n", " \n", " # csv header\n", " f = open(dataset_write_base+\"/\"+lang+\"/\"+subparts+\".csv\", \"w\")\n", " f.write(\"ID,duration,wav,spk_id,wrd\\n\")\n", "\n", " for uid in tqdm(range(len(fleurs_asr[subparts]))):\n", "\n", " # ***************\n", " # format CSV line\n", " # ***************\n", " text_id = lang+\"_\"+str(fleurs_asr[subparts][uid][\"id\"])\n", " \n", " # some ID are duplicated (same speaker, same transcription BUT different recording)\n", " while(text_id in used_ID):\n", " text_id += \"_bis\"\n", " used_ID.append(text_id)\n", "\n", " duration = \"{:.3f}\".format(round(float(fleurs_asr[subparts][uid][\"num_samples\"])/float(fleurs_asr[subparts][uid][\"audio\"][\"sampling_rate\"]),3))\n", " wav_path = \"/\".join([dataset_write_base, lang, \"wavs\",subparts, fleurs_asr[subparts][uid][\"audio\"][\"path\"].split('/')[-1]])\n", " spk_id = \"spk_\" + text_id\n", " # AC : \"pseudo-normalisation\" de cas marginaux -- TODO mieux\n", " wrd = fleurs_asr[subparts][uid][\"transcription\"].replace(',','').replace('$',' $ ').replace('\"','').replace('”','').replace(' ',' ')\n", "\n", " # **************\n", " # write CSV line\n", " # **************\n", " f.write(text_id+\",\"+duration+\",\"+wav_path+\",\"+spk_id+\",\"+wrd+\"\\n\") \n", "\n", " # *******************\n", " # Move wav from cache\n", " # *******************\n", " previous_path = \"/\".join(fleurs_asr[subparts][uid][\"path\"].split('/')[:-1]) + \"/\" + fleurs_asr[subparts][uid][\"audio\"][\"path\"]\n", " new_path = \"/\".join([dataset_write_base,lang,\"wavs\",subparts,fleurs_asr[subparts][uid][\"audio\"][\"path\"].split('/')[-1]])\n", " shutil.move(previous_path,new_path)\n", " \n", " f.close()\n", " print(\"--->\", lang, \"done\")" ] }, { "cell_type": "markdown", "id": "4c32e369-f0f9-4695-8c9a-aa3a9de7bf7b", "metadata": {}, "source": [ "# Recette ASR" ] }, { "cell_type": "markdown", "id": "77fb2c55-3f8c-4f34-81f0-ad48a632e010", "metadata": { "jp-MarkdownHeadingCollapsed": true }, "source": [ "## 1. Installation des dépendances" ] }, { "cell_type": "code", "execution_count": null, "id": "fbe25635-e765-480c-8416-c48a31ee6140", "metadata": {}, "outputs": [], "source": [ "pip install torch==2.2.2 torchaudio==2.2.2 torchvision==0.17.2 speechbrain transformers jdc" ] }, { "cell_type": "markdown", "id": "6acf1f8c-2cf3-4c9c-8a45-e2580ecbee27", "metadata": {}, "source": [ "## 2. Mise en place de la recette Speechbrain -- class Brain" ] }, { "cell_type": "markdown", "id": "d5e8884d-3542-40ff-a454-597078fcf97c", "metadata": {}, "source": [ "### 2.1 Imports & logger" ] }, { "cell_type": "code", "execution_count": null, "id": "6c677f9f-6abe-423f-b4dd-fdf5ded357cd", "metadata": {}, "outputs": [], "source": [ "import logging\n", "import os\n", "import sys\n", "from pathlib import Path\n", "\n", "import torch\n", "from hyperpyyaml import load_hyperpyyaml\n", "\n", "import speechbrain as sb\n", "from speechbrain.utils.distributed import if_main_process, run_on_main\n", "\n", "import jdc\n", "\n", "logger = logging.getLogger(__name__)" ] }, { "cell_type": "markdown", "id": "9698bb92-16ad-4b61-8938-c74b62ee93b2", "metadata": {}, "source": [ "### 2.2 Création de notre classe héritant de la classe brain" ] }, { "cell_type": "code", "execution_count": null, "id": "7c7cd624-6249-449b-8ee9-d4a73b7b3301", "metadata": {}, "outputs": [], "source": [ "# Define training procedure\n", "class MY_SSA_ASR(sb.Brain):\n", " print(\"\")\n", " # define here" ] }, { "cell_type": "markdown", "id": "ecf31c9c-15dd-4428-aa10-b3cc5e127f0d", "metadata": {}, "source": [ "### 2.3 Définition de la fonction forward " ] }, { "cell_type": "code", "execution_count": null, "id": "4368b488-b9d8-49ff-8ce3-78a12d46be83", "metadata": {}, "outputs": [], "source": [ "%%add_to MY_SSA_ASR\n", "def compute_forward(self, batch, stage):\n", " \"\"\"Forward computations from the waveform batches to the output probabilities.\"\"\"\n", " batch = batch.to(self.device)\n", " wavs, wav_lens = batch.sig\n", " wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)\n", "\n", " # Downsample the inputs if specified\n", " if hasattr(self.modules, \"downsampler\"):\n", " wavs = self.modules.downsampler(wavs)\n", "\n", " # Add waveform augmentation if specified.\n", " if stage == sb.Stage.TRAIN and hasattr(self.hparams, \"wav_augment\"):\n", " wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)\n", "\n", " # Forward pass\n", " feats = self.modules.hubert(wavs, wav_lens)\n", " x = self.modules.top_lin(feats)\n", "\n", " # Compute outputs\n", " logits = self.modules.ctc_lin(x)\n", " p_ctc = self.hparams.log_softmax(logits)\n", "\n", "\n", " p_tokens = None\n", " if stage == sb.Stage.VALID:\n", " p_tokens = sb.decoders.ctc_greedy_decode(p_ctc, wav_lens, blank_id=self.hparams.blank_index)\n", "\n", " elif stage == sb.Stage.TEST:\n", " p_tokens = test_searcher(p_ctc, wav_lens)\n", "\n", " candidates = []\n", " scores = []\n", "\n", " for batch in p_tokens:\n", " candidates.append([hyp.text for hyp in batch])\n", " scores.append([hyp.score for hyp in batch])\n", "\n", " if hasattr(self.hparams, \"rescorer\"):\n", " p_tokens, _ = self.hparams.rescorer.rescore(candidates, scores)\n", "\n", " return p_ctc, wav_lens, p_tokens\n" ] }, { "cell_type": "markdown", "id": "f0052b79-5a27-4c4c-8601-7ab064e8c951", "metadata": {}, "source": [ "### 2.4 Définition de la fonction objectives" ] }, { "cell_type": "code", "execution_count": null, "id": "3608aee8-c9c3-4e34-98bc-667513fa7f7b", "metadata": {}, "outputs": [], "source": [ "%%add_to MY_SSA_ASR\n", "def compute_objectives(self, predictions, batch, stage):\n", " \"\"\"Computes the loss (CTC+NLL) given predictions and targets.\"\"\"\n", "\n", " p_ctc, wav_lens, predicted_tokens = predictions\n", "\n", " ids = batch.id\n", " tokens, tokens_lens = batch.tokens\n", "\n", " # Labels must be extended if parallel augmentation or concatenated\n", " # augmentation was performed on the input (increasing the time dimension)\n", " if stage == sb.Stage.TRAIN and hasattr(self.hparams, \"wav_augment\"):\n", " (tokens, tokens_lens) = self.hparams.wav_augment.replicate_multiple_labels(tokens, tokens_lens)\n", "\n", "\n", "\n", " # Compute loss\n", " loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)\n", "\n", " if stage == sb.Stage.VALID:\n", " # Decode token terms to words\n", " predicted_words = [\"\".join(self.tokenizer.decode_ndim(utt_seq)).split(\" \") for utt_seq in predicted_tokens]\n", " \n", " elif stage == sb.Stage.TEST:\n", " predicted_words = [hyp[0].text.split(\" \") for hyp in predicted_tokens]\n", "\n", " if stage != sb.Stage.TRAIN:\n", " target_words = [wrd.split(\" \") for wrd in batch.wrd]\n", " self.wer_metric.append(ids, predicted_words, target_words)\n", " self.cer_metric.append(ids, predicted_words, target_words)\n", "\n", " return loss\n" ] }, { "cell_type": "markdown", "id": "9a514c50-89ad-41cb-882a-23daf829a538", "metadata": {}, "source": [ "### 2.5 définition du comportement au début d'un \"stage\"" ] }, { "cell_type": "code", "execution_count": null, "id": "609814ce-3ef0-4818-a70f-cadc293c9dd2", "metadata": {}, "outputs": [], "source": [ "%%add_to MY_SSA_ASR\n", "# stage gestion\n", "def on_stage_start(self, stage, epoch):\n", " \"\"\"Gets called at the beginning of each epoch\"\"\"\n", " if stage != sb.Stage.TRAIN:\n", " self.cer_metric = self.hparams.cer_computer()\n", " self.wer_metric = self.hparams.error_rate_computer()\n", "\n", " if stage == sb.Stage.TEST:\n", " if hasattr(self.hparams, \"rescorer\"):\n", " self.hparams.rescorer.move_rescorers_to_device()\n", "\n" ] }, { "cell_type": "markdown", "id": "55929209-c94a-4f8b-8f2e-9dd5d9de8be9", "metadata": {}, "source": [ "### 2.6 définition du comportement à la fin d'un \"stage\"" ] }, { "cell_type": "code", "execution_count": null, "id": "8f297542-10d5-47bf-9938-c141f5a99ab8", "metadata": {}, "outputs": [], "source": [ "%%add_to MY_SSA_ASR\n", "def on_stage_end(self, stage, stage_loss, epoch):\n", " \"\"\"Gets called at the end of an epoch.\"\"\"\n", " # Compute/store important stats\n", " stage_stats = {\"loss\": stage_loss}\n", " if stage == sb.Stage.TRAIN:\n", " self.train_stats = stage_stats\n", " else:\n", " stage_stats[\"CER\"] = self.cer_metric.summarize(\"error_rate\")\n", " stage_stats[\"WER\"] = self.wer_metric.summarize(\"error_rate\")\n", "\n", " # Perform end-of-iteration things, like annealing, logging, etc.\n", " if stage == sb.Stage.VALID:\n", " # *******************************\n", " # Anneal and update Learning Rate\n", " # *******************************\n", " old_lr_model, new_lr_model = self.hparams.lr_annealing_model(stage_stats[\"loss\"])\n", " old_lr_hubert, new_lr_hubert = self.hparams.lr_annealing_hubert(stage_stats[\"loss\"])\n", " sb.nnet.schedulers.update_learning_rate(self.model_optimizer, new_lr_model)\n", " sb.nnet.schedulers.update_learning_rate(self.hubert_optimizer, new_lr_hubert)\n", "\n", " # *****************\n", " # Logs informations\n", " # *****************\n", " self.hparams.train_logger.log_stats(stats_meta={\"epoch\": epoch, \"lr_model\": old_lr_model, \"lr_hubert\": old_lr_hubert}, train_stats=self.train_stats, valid_stats=stage_stats)\n", "\n", " # ***************\n", " # Save checkpoint\n", " # ***************\n", " self.checkpointer.save_and_keep_only(meta={\"WER\": stage_stats[\"WER\"]},min_keys=[\"WER\"])\n", "\n", " elif stage == sb.Stage.TEST:\n", " self.hparams.train_logger.log_stats(stats_meta={\"Epoch loaded\": self.hparams.epoch_counter.current},test_stats=stage_stats)\n", " if if_main_process():\n", " with open(self.hparams.test_wer_file, \"w\") as w:\n", " self.wer_metric.write_stats(w)\n" ] }, { "cell_type": "markdown", "id": "0c656457-6b61-4316-8199-70021f92babf", "metadata": {}, "source": [ "### 2.7 définition de l'initialisation des optimizers" ] }, { "cell_type": "code", "execution_count": null, "id": "da8d9cb5-c5ad-4e78-83d3-e129e138a741", "metadata": {}, "outputs": [], "source": [ "%%add_to MY_SSA_ASR\n", "def init_optimizers(self):\n", " \"Initializes the hubert optimizer and model optimizer\"\n", " self.hubert_optimizer = self.hparams.hubert_opt_class(self.modules.hubert.parameters())\n", " self.model_optimizer = self.hparams.model_opt_class(self.hparams.model.parameters())\n", "\n", " # save the optimizers in a dictionary\n", " # the key will be used in `freeze_optimizers()`\n", " self.optimizers_dict = {\"model_optimizer\": self.model_optimizer}\n", " if not self.hparams.freeze_hubert:\n", " self.optimizers_dict[\"hubert_optimizer\"] = self.hubert_optimizer\n", "\n", " if self.checkpointer is not None:\n", " self.checkpointer.add_recoverable(\"hubert_opt\", self.hubert_optimizer)\n", " self.checkpointer.add_recoverable(\"model_opt\", self.model_optimizer)\n" ] }, { "cell_type": "markdown", "id": "cf2e730c-2faa-41f2-b98d-e5fbb2305cc2", "metadata": {}, "source": [ "## 3 Définition de la lecture des datasets" ] }, { "cell_type": "code", "execution_count": null, "id": "c5e667f7-6269-4b49-88bb-5e431762c8fe", "metadata": {}, "outputs": [], "source": [ "def dataio_prepare(hparams):\n", " \"\"\"This function prepares the datasets to be used in the brain class.\n", " It also defines the data processing pipeline through user-defined functions.\n", " \"\"\"\n", "\n", " # **************\n", " # Load CSV files\n", " # **************\n", " data_folder = hparams[\"data_folder\"]\n", "\n", " train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(csv_path=hparams[\"train_csv\"],replacements={\"data_root\": data_folder})\n", " # we sort training data to speed up training and get better results.\n", " train_data = train_data.filtered_sorted(sort_key=\"duration\")\n", " hparams[\"train_dataloader_opts\"][\"shuffle\"] = False # when sorting do not shuffle in dataloader ! otherwise is pointless\n", "\n", " valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(csv_path=hparams[\"valid_csv\"],replacements={\"data_root\": data_folder})\n", " valid_data = valid_data.filtered_sorted(sort_key=\"duration\")\n", "\n", " # test is separate\n", " test_datasets = {}\n", " for csv_file in hparams[\"test_csv\"]:\n", " name = Path(csv_file).stem\n", " test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(csv_path=csv_file, replacements={\"data_root\": data_folder})\n", " test_datasets[name] = test_datasets[name].filtered_sorted(sort_key=\"duration\")\n", "\n", " datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]\n", "\n", " # *************************\n", " # 2. Define audio pipeline:\n", " # *************************\n", " @sb.utils.data_pipeline.takes(\"wav\")\n", " @sb.utils.data_pipeline.provides(\"sig\")\n", " def audio_pipeline(wav):\n", " sig = sb.dataio.dataio.read_audio(wav)\n", " return sig\n", "\n", " sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)\n", "\n", " # ************************\n", " # 3. Define text pipeline:\n", " # ************************\n", " label_encoder = sb.dataio.encoder.CTCTextEncoder()\n", " \n", " @sb.utils.data_pipeline.takes(\"wrd\")\n", " @sb.utils.data_pipeline.provides(\"wrd\", \"char_list\", \"tokens_list\", \"tokens\")\n", " def text_pipeline(wrd):\n", " yield wrd\n", " char_list = list(wrd)\n", " yield char_list\n", " tokens_list = label_encoder.encode_sequence(char_list)\n", " yield tokens_list\n", " tokens = torch.LongTensor(tokens_list)\n", " yield tokens\n", "\n", " sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)\n", "\n", "\n", " # *******************************\n", " # 4. Create or load label encoder\n", " # *******************************\n", " lab_enc_file = os.path.join(hparams[\"save_folder\"], \"label_encoder.txt\")\n", " special_labels = {\"blank_label\": hparams[\"blank_index\"]}\n", " label_encoder.add_unk()\n", " label_encoder.load_or_create(path=lab_enc_file, from_didatasets=[train_data], output_key=\"char_list\", special_labels=special_labels, sequence_input=True)\n", "\n", " # **************\n", " # 5. Set output:\n", " # **************\n", " sb.dataio.dataset.set_output_keys(datasets,[\"id\", \"sig\", \"wrd\", \"char_list\", \"tokens\"],)\n", "\n", " return train_data, valid_data, test_datasets, label_encoder\n" ] }, { "cell_type": "markdown", "id": "e97c4f20-6951-4d12-8e17-9eb818a52bb1", "metadata": {}, "source": [ "## 4. Utilisation de la recette Créée" ] }, { "cell_type": "markdown", "id": "76b72148-6bd0-48bd-ad40-cb6f8bfd34c0", "metadata": {}, "source": [ "### 4.1 Préparation au lancement" ] }, { "cell_type": "code", "execution_count": null, "id": "d47ec39a-5562-4a63-8243-656c9235b7a2", "metadata": {}, "outputs": [], "source": [ "hparams_file, run_opts, overrides = sb.parse_arguments([\"PATH_TO_YOUR_FOLDER/ASR_FLEURS-swahili_hf.yaml\"])\n", "# create ddp_group with the right communication protocol\n", "sb.utils.distributed.ddp_init_group(run_opts)\n", "\n", "# ***********************************\n", "# Chargement du fichier de paramètres\n", "# ***********************************\n", "with open(hparams_file) as fin:\n", " hparams = load_hyperpyyaml(fin, overrides)\n", "\n", "# ***************************\n", "# Create experiment directory\n", "# ***************************\n", "sb.create_experiment_directory(experiment_directory=hparams[\"output_folder\"], hyperparams_to_save=hparams_file, overrides=overrides)\n", "\n", "# ***************************\n", "# Create the datasets objects\n", "# ***************************\n", "train_data, valid_data, test_datasets, label_encoder = dataio_prepare(hparams)\n", "\n", "# **********************\n", "# Trainer initialization\n", "# **********************\n", "asr_brain = MY_SSA_ASR(modules=hparams[\"modules\"], hparams=hparams, run_opts=run_opts, checkpointer=hparams[\"checkpointer\"])\n", "asr_brain.tokenizer = label_encoder" ] }, { "cell_type": "markdown", "id": "62ae72eb-416c-4ef0-9348-d02bbc268fbd", "metadata": {}, "source": [ "### 4.2 Apprentissage du modèle" ] }, { "cell_type": "code", "execution_count": null, "id": "d3dd30ee-89c0-40ea-a9d2-0e2b9d8c8686", "metadata": {}, "outputs": [], "source": [ "# ********\n", "# Training\n", "# ********\n", "asr_brain.fit(asr_brain.hparams.epoch_counter, \n", " train_data, valid_data, \n", " train_loader_kwargs=hparams[\"train_dataloader_opts\"], \n", " valid_loader_kwargs=hparams[\"valid_dataloader_opts\"],\n", " )\n", "\n" ] }, { "cell_type": "markdown", "id": "1b55af4c-c544-45ff-8435-58226218328f", "metadata": {}, "source": [ "### 4.3 Test du Modèle" ] }, { "cell_type": "code", "execution_count": null, "id": "9cef9011-1a3e-43a4-ab16-8cfb2b57dbd9", "metadata": {}, "outputs": [], "source": [ "# *******\n", "# Testing\n", "# *******\n", "if not os.path.exists(hparams[\"output_wer_folder\"]):\n", " os.makedirs(hparams[\"output_wer_folder\"])\n", "\n", "from speechbrain.decoders.ctc import CTCBeamSearcher\n", "\n", "ind2lab = label_encoder.ind2lab\n", "vocab_list = [ind2lab[x] for x in range(len(ind2lab))]\n", "test_searcher = CTCBeamSearcher(**hparams[\"test_beam_search\"], vocab_list=vocab_list)\n", "\n", "for k in test_datasets.keys(): # Allow multiple evaluation throught list of test sets\n", " asr_brain.hparams.test_wer_file = os.path.join(hparams[\"output_wer_folder\"], f\"wer_{k}.txt\")\n", " asr_brain.evaluate(test_datasets[k], test_loader_kwargs=hparams[\"test_dataloader_opts\"], min_key=\"WER\")\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }