# Introduction
N'hésitez pas à nous contacter en cas de questions : antoine.caubriere@orange.com & elodie.gauthier@orange.com

Pensez à modifier l'ensemble des PATH dans le fichier de configuration ASR_FLEURSswahili_hf.yaml et dans le code python ci-dessous (PATH_TO_YOUR_FOLDER).

Dans le cas d'un changement de corpus (autre sous partie de FLEURS / vos propres jeux de données), pensez à modifier la taille de la couche de sortie du modèle : ASR_swahili_hf.yaml/output_neurons


# Préparation des données FLEURS

### 1. Installation des dépendances

In [None]:
pip install datasets librosa soundfile

### 2. Téléchargement et formatage du dataset

In [None]:
from datasets import load_dataset
from pathlib import Path
from collections import OrderedDict
from tqdm import tqdm
import shutil
import os

dataset_write_base = "PATH_TO_YOUR_FOLDER/data_speechbrain/"
cache_dir = "PATH_TO_YOUR_FOLDER/data_huggingface/"

if os.path.isdir(cache_dir):
 print("rm -rf "+cache_dir)
 os.system("rm -rf "+cache_dir)

if os.path.isdir(dataset_write_base):
 print("rm -rf "+dataset_write_base)
 os.system("rm -rf "+dataset_write_base)

# **************************************
# choix des langues à extraire de FLEURS
# **************************************
lang_dict = OrderedDict([
 #("Afrikaans","af_za"),
 #("Amharic", "am_et"),
 #("Fula", "ff_sn"),
 #("Ganda", "lg_ug"),
 #("Hausa", "ha_ng"),
 #("Igbo", "ig_ng"),
 #("Kamba", "kam_ke"),
 #("Lingala", "ln_cd"),
 #("Luo", "luo_ke"),
 #("Northern-Sotho", "nso_za"),
 #("Nyanja", "ny_mw"),
 #("Oromo", "om_et"),
 #("Shona", "sn_zw"),
 #("Somali", "so_so"),
 ("Swahili", "sw_ke"),
 #("Umbundu", "umb_ao"),
 #("Wolof", "wo_sn"), 
 #("Xhosa", "xh_za"), 
 #("Yoruba", "yo_ng"), 
 #("Zulu", "zu_za")
 ])

# ********************************
# choix des sous-parties à traiter
# ********************************
datasets = ["train","test","validation"]

for lang in lang_dict:
 print("Prepare --->", lang)
 
 # ********************************
 # Download FLEURS from huggingface
 # ********************************
 fleurs_asr = load_dataset("google/fleurs", lang_dict[lang],cache_dir=cache_dir, trust_remote_code=True)

 for subparts in datasets:
 
 used_ID = []
 Path(dataset_write_base+"/"+lang+"/wavs/"+subparts).mkdir(parents=True, exist_ok=True)
 
 # csv header
 f = open(dataset_write_base+"/"+lang+"/"+subparts+".csv", "w")
 f.write("ID,duration,wav,spk_id,wrd\n")

 for uid in tqdm(range(len(fleurs_asr[subparts]))):

 # ***************
 # format CSV line
 # ***************
 text_id = lang+"_"+str(fleurs_asr[subparts][uid]["id"])
 
 # some ID are duplicated (same speaker, same transcription BUT different recording)
 while(text_id in used_ID):
 text_id += "_bis"
 used_ID.append(text_id)

 duration = "{:.3f}".format(round(float(fleurs_asr[subparts][uid]["num_samples"])/float(fleurs_asr[subparts][uid]["audio"]["sampling_rate"]),3))
 wav_path = "/".join([dataset_write_base, lang, "wavs",subparts, fleurs_asr[subparts][uid]["audio"]["path"].split('/')[-1]])
 spk_id = "spk_" + text_id
 # AC : "pseudo-normalisation" de cas marginaux -- TODO mieux
 wrd = fleurs_asr[subparts][uid]["transcription"].replace(',','').replace('$',' $ ').replace('"','').replace('”','').replace(' ',' ')

 # **************
 # write CSV line
 # **************
 f.write(text_id+","+duration+","+wav_path+","+spk_id+","+wrd+"\n") 

 # *******************
 # Move wav from cache
 # *******************
 previous_path = "/".join(fleurs_asr[subparts][uid]["path"].split('/')[:-1]) + "/" + fleurs_asr[subparts][uid]["audio"]["path"]
 new_path = "/".join([dataset_write_base,lang,"wavs",subparts,fleurs_asr[subparts][uid]["audio"]["path"].split('/')[-1]])
 shutil.move(previous_path,new_path)
 
 f.close()
 print("--->", lang, "done")

# Recette ASR

## 1. Installation des dépendances

In [None]:
pip install torch==2.2.2 torchaudio==2.2.2 torchvision==0.17.2 speechbrain transformers jdc

## 2. Mise en place de la recette Speechbrain -- class Brain

### 2.1 Imports & logger

In [None]:
import logging
import os
import sys
from pathlib import Path

import torch
from hyperpyyaml import load_hyperpyyaml

import speechbrain as sb
from speechbrain.utils.distributed import if_main_process, run_on_main

import jdc

logger = logging.getLogger(__name__)

### 2.2 Création de notre classe héritant de la classe brain

In [None]:
# Define training procedure
class MY_SSA_ASR(sb.Brain):
 print("")
 # define here

### 2.3 Définition de la fonction forward 

In [None]:
%%add_to MY_SSA_ASR
def compute_forward(self, batch, stage):
 """Forward computations from the waveform batches to the output probabilities."""
 batch = batch.to(self.device)
 wavs, wav_lens = batch.sig
 wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)

 # Downsample the inputs if specified
 if hasattr(self.modules, "downsampler"):
 wavs = self.modules.downsampler(wavs)

 # Add waveform augmentation if specified.
 if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
 wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)

 # Forward pass
 feats = self.modules.hubert(wavs, wav_lens)
 x = self.modules.top_lin(feats)

 # Compute outputs
 logits = self.modules.ctc_lin(x)
 p_ctc = self.hparams.log_softmax(logits)


 p_tokens = None
 if stage == sb.Stage.VALID:
 p_tokens = sb.decoders.ctc_greedy_decode(p_ctc, wav_lens, blank_id=self.hparams.blank_index)

 elif stage == sb.Stage.TEST:
 p_tokens = test_searcher(p_ctc, wav_lens)

 candidates = []
 scores = []

 for batch in p_tokens:
 candidates.append([hyp.text for hyp in batch])
 scores.append([hyp.score for hyp in batch])

 if hasattr(self.hparams, "rescorer"):
 p_tokens, _ = self.hparams.rescorer.rescore(candidates, scores)

 return p_ctc, wav_lens, p_tokens


### 2.4 Définition de la fonction objectives

In [None]:
%%add_to MY_SSA_ASR
def compute_objectives(self, predictions, batch, stage):
 """Computes the loss (CTC+NLL) given predictions and targets."""

 p_ctc, wav_lens, predicted_tokens = predictions

 ids = batch.id
 tokens, tokens_lens = batch.tokens

 # Labels must be extended if parallel augmentation or concatenated
 # augmentation was performed on the input (increasing the time dimension)
 if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
 (tokens, tokens_lens) = self.hparams.wav_augment.replicate_multiple_labels(tokens, tokens_lens)



 # Compute loss
 loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)

 if stage == sb.Stage.VALID:
 # Decode token terms to words
 predicted_words = ["".join(self.tokenizer.decode_ndim(utt_seq)).split(" ") for utt_seq in predicted_tokens]
 
 elif stage == sb.Stage.TEST:
 predicted_words = [hyp[0].text.split(" ") for hyp in predicted_tokens]

 if stage != sb.Stage.TRAIN:
 target_words = [wrd.split(" ") for wrd in batch.wrd]
 self.wer_metric.append(ids, predicted_words, target_words)
 self.cer_metric.append(ids, predicted_words, target_words)

 return loss


### 2.5 définition du comportement au début d'un "stage"

In [None]:
%%add_to MY_SSA_ASR
# stage gestion
def on_stage_start(self, stage, epoch):
 """Gets called at the beginning of each epoch"""
 if stage != sb.Stage.TRAIN:
 self.cer_metric = self.hparams.cer_computer()
 self.wer_metric = self.hparams.error_rate_computer()

 if stage == sb.Stage.TEST:
 if hasattr(self.hparams, "rescorer"):
 self.hparams.rescorer.move_rescorers_to_device()



### 2.6 définition du comportement à la fin d'un "stage"

In [None]:
%%add_to MY_SSA_ASR
def on_stage_end(self, stage, stage_loss, epoch):
 """Gets called at the end of an epoch."""
 # Compute/store important stats
 stage_stats = {"loss": stage_loss}
 if stage == sb.Stage.TRAIN:
 self.train_stats = stage_stats
 else:
 stage_stats["CER"] = self.cer_metric.summarize("error_rate")
 stage_stats["WER"] = self.wer_metric.summarize("error_rate")

 # Perform end-of-iteration things, like annealing, logging, etc.
 if stage == sb.Stage.VALID:
 # *******************************
 # Anneal and update Learning Rate
 # *******************************
 old_lr_model, new_lr_model = self.hparams.lr_annealing_model(stage_stats["loss"])
 old_lr_hubert, new_lr_hubert = self.hparams.lr_annealing_hubert(stage_stats["loss"])
 sb.nnet.schedulers.update_learning_rate(self.model_optimizer, new_lr_model)
 sb.nnet.schedulers.update_learning_rate(self.hubert_optimizer, new_lr_hubert)

 # *****************
 # Logs informations
 # *****************
 self.hparams.train_logger.log_stats(stats_meta={"epoch": epoch, "lr_model": old_lr_model, "lr_hubert": old_lr_hubert}, train_stats=self.train_stats, valid_stats=stage_stats)

 # ***************
 # Save checkpoint
 # ***************
 self.checkpointer.save_and_keep_only(meta={"WER": stage_stats["WER"]},min_keys=["WER"])

 elif stage == sb.Stage.TEST:
 self.hparams.train_logger.log_stats(stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},test_stats=stage_stats)
 if if_main_process():
 with open(self.hparams.test_wer_file, "w") as w:
 self.wer_metric.write_stats(w)


### 2.7 définition de l'initialisation des optimizers

In [None]:
%%add_to MY_SSA_ASR
def init_optimizers(self):
 "Initializes the hubert optimizer and model optimizer"
 self.hubert_optimizer = self.hparams.hubert_opt_class(self.modules.hubert.parameters())
 self.model_optimizer = self.hparams.model_opt_class(self.hparams.model.parameters())

 # save the optimizers in a dictionary
 # the key will be used in `freeze_optimizers()`
 self.optimizers_dict = {"model_optimizer": self.model_optimizer}
 if not self.hparams.freeze_hubert:
 self.optimizers_dict["hubert_optimizer"] = self.hubert_optimizer

 if self.checkpointer is not None:
 self.checkpointer.add_recoverable("hubert_opt", self.hubert_optimizer)
 self.checkpointer.add_recoverable("model_opt", self.model_optimizer)


## 3 Définition de la lecture des datasets

In [None]:
def dataio_prepare(hparams):
 """This function prepares the datasets to be used in the brain class.
 It also defines the data processing pipeline through user-defined functions.
 """

 # **************
 # Load CSV files
 # **************
 data_folder = hparams["data_folder"]

 train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(csv_path=hparams["train_csv"],replacements={"data_root": data_folder})
 # we sort training data to speed up training and get better results.
 train_data = train_data.filtered_sorted(sort_key="duration")
 hparams["train_dataloader_opts"]["shuffle"] = False # when sorting do not shuffle in dataloader ! otherwise is pointless

 valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(csv_path=hparams["valid_csv"],replacements={"data_root": data_folder})
 valid_data = valid_data.filtered_sorted(sort_key="duration")

 # test is separate
 test_datasets = {}
 for csv_file in hparams["test_csv"]:
 name = Path(csv_file).stem
 test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(csv_path=csv_file, replacements={"data_root": data_folder})
 test_datasets[name] = test_datasets[name].filtered_sorted(sort_key="duration")

 datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]

 # *************************
 # 2. Define audio pipeline:
 # *************************
 @sb.utils.data_pipeline.takes("wav")
 @sb.utils.data_pipeline.provides("sig")
 def audio_pipeline(wav):
 sig = sb.dataio.dataio.read_audio(wav)
 return sig

 sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)

 # ************************
 # 3. Define text pipeline:
 # ************************
 label_encoder = sb.dataio.encoder.CTCTextEncoder()
 
 @sb.utils.data_pipeline.takes("wrd")
 @sb.utils.data_pipeline.provides("wrd", "char_list", "tokens_list", "tokens")
 def text_pipeline(wrd):
 yield wrd
 char_list = list(wrd)
 yield char_list
 tokens_list = label_encoder.encode_sequence(char_list)
 yield tokens_list
 tokens = torch.LongTensor(tokens_list)
 yield tokens

 sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)


 # *******************************
 # 4. Create or load label encoder
 # *******************************
 lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
 special_labels = {"blank_label": hparams["blank_index"]}
 label_encoder.add_unk()
 label_encoder.load_or_create(path=lab_enc_file, from_didatasets=[train_data], output_key="char_list", special_labels=special_labels, sequence_input=True)

 # **************
 # 5. Set output:
 # **************
 sb.dataio.dataset.set_output_keys(datasets,["id", "sig", "wrd", "char_list", "tokens"],)

 return train_data, valid_data, test_datasets, label_encoder


## 4. Utilisation de la recette Créée

### 4.1 Préparation au lancement

In [None]:
hparams_file, run_opts, overrides = sb.parse_arguments(["PATH_TO_YOUR_FOLDER/ASR_FLEURS-swahili_hf.yaml"])
# create ddp_group with the right communication protocol
sb.utils.distributed.ddp_init_group(run_opts)

# ***********************************
# Chargement du fichier de paramètres
# ***********************************
with open(hparams_file) as fin:
 hparams = load_hyperpyyaml(fin, overrides)

# ***************************
# Create experiment directory
# ***************************
sb.create_experiment_directory(experiment_directory=hparams["output_folder"], hyperparams_to_save=hparams_file, overrides=overrides)

# ***************************
# Create the datasets objects
# ***************************
train_data, valid_data, test_datasets, label_encoder = dataio_prepare(hparams)

# **********************
# Trainer initialization
# **********************
asr_brain = MY_SSA_ASR(modules=hparams["modules"], hparams=hparams, run_opts=run_opts, checkpointer=hparams["checkpointer"])
asr_brain.tokenizer = label_encoder

### 4.2 Apprentissage du modèle

In [None]:
# ********
# Training
# ********
asr_brain.fit(asr_brain.hparams.epoch_counter, 
 train_data, valid_data, 
 train_loader_kwargs=hparams["train_dataloader_opts"], 
 valid_loader_kwargs=hparams["valid_dataloader_opts"],
 )



### 4.3 Test du Modèle

In [None]:
# *******
# Testing
# *******
if not os.path.exists(hparams["output_wer_folder"]):
 os.makedirs(hparams["output_wer_folder"])

from speechbrain.decoders.ctc import CTCBeamSearcher

ind2lab = label_encoder.ind2lab
vocab_list = [ind2lab[x] for x in range(len(ind2lab))]
test_searcher = CTCBeamSearcher(**hparams["test_beam_search"], vocab_list=vocab_list)

for k in test_datasets.keys(): # Allow multiple evaluation throught list of test sets
 asr_brain.hparams.test_wer_file = os.path.join(hparams["output_wer_folder"], f"wer_{k}.txt")
 asr_brain.evaluate(test_datasets[k], test_loader_kwargs=hparams["test_dataloader_opts"], min_key="WER")
