SpeechEmotionDetector / train_with_wav2vec.py
Lingeshg's picture
Upload 8 files
fe65b7d verified
import os
import sys
import logging
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
import json
import random
import torch
from sklearn.preprocessing import LabelEncoder
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
logger = logging.getLogger(__name__)
SAMPLERATE = 16000
def prepare_data(data_original, save_json_train, save_json_valid, save_json_test, split_ratio=[80, 10, 10], seed=12):
# Setting seeds for reproducible code.
random.seed(seed)
# Check if data preparation has already been done (skip if files exist)
if skip(save_json_train, save_json_valid, save_json_test):
logger.info("Preparation completed in previous run, skipping.")
return
# Collect audio files and labels
wav_list = []
labels = os.listdir(data_original)
for label in labels:
label_dir = os.path.join(data_original, label)
if os.path.isdir(label_dir):
for audio_file in os.listdir(label_dir):
if audio_file.endswith('.wav'):
wav_file = os.path.join(label_dir, audio_file)
if os.path.isfile(wav_file):
wav_list.append((wav_file, label))
else:
logger.warning(f"Skipping invalid audio file: {wav_file}")
# Shuffle and split the data
random.shuffle(wav_list)
n_total = len(wav_list)
n_train = n_total * split_ratio[0] // 100
n_valid = n_total * split_ratio[1] // 100
train_set = wav_list[:n_train]
valid_set = wav_list[n_train:n_train + n_valid]
test_set = wav_list[n_train + n_valid:]
# Create JSON files for train, valid, and test sets
create_json(train_set, save_json_train)
create_json(valid_set, save_json_valid)
create_json(test_set, save_json_test)
logger.info(f"Created {save_json_train}, {save_json_valid}, and {save_json_test}")
def create_json(wav_list, json_file):
json_dict = {}
for wav_file, label in wav_list:
signal = sb.dataio.dataio.read_audio(wav_file)
duration = signal.shape[0] / SAMPLERATE
uttid = os.path.splitext(os.path.basename(wav_file))[0]
json_dict[uttid] = {
"wav": wav_file,
"length": duration,
"label": label,
}
with open(json_file, mode="w") as json_f:
json.dump(json_dict, json_f, indent=2)
logger.info(f"Created {json_file}")
def skip(*filenames):
for filename in filenames:
if not os.path.isfile(filename):
return False
return True
class EmoIdBrain(sb.Brain):
def compute_forward(self, batch, stage):
"""Computation pipeline based on an encoder + emotion classifier."""
batch = batch.to(self.device)
wavs, lens = batch.sig
outputs = self.modules.wav2vec2(wavs, lens)
# Apply pooling and MLP layers
outputs = self.hparams.avg_pool(outputs, lens)
outputs = outputs.view(outputs.shape[0], -1)
outputs = self.modules.output_mlp(outputs)
outputs = self.hparams.log_softmax(outputs)
return outputs
def compute_objectives(self, predictions, batch, stage):
emo_encoded_list = []
for sample in batch:
# Check if 'emo_encoded' exists in the sample
if 'emo_encoded' in sample:
emo_encoded_list.append(sample['emo_encoded'])
else:
# Log a warning and skip this sample if 'emo_encoded' is missing
logging.warning(f"'emo_encoded' key not found in sample: {sample}")
if not emo_encoded_list:
# If no valid 'emo_encoded' values were found in the batch, raise an error
raise ValueError("No valid 'emo_encoded' values found in the batch.")
# Convert emo_encoded_list to a torch tensor
emo_encoded = torch.tensor(emo_encoded_list, dtype=torch.long)
# Ensure emo_encoded is a tensor
if not isinstance(emo_encoded, torch.Tensor):
raise TypeError(f"Unsupported label type encountered: {type(emo_encoded)}")
# Perform any necessary operations with emo_encoded here
loss = self.hparams.compute_cost(predictions, emo_encoded)
if stage != sb.Stage.TRAIN:
self.error_metrics.append(batch.id, predictions, emo_encoded)
return loss
def on_stage_start(self, stage, epoch=None):
"""Gets called at the beginning of each epoch."""
self.loss_metric = sb.utils.metric_stats.MetricStats(metric=sb.nnet.losses.nll_loss)
if stage != sb.Stage.TRAIN:
self.error_metrics = self.hparams.error_stats()
def on_stage_end(self, stage, stage_loss, epoch=None):
"""Gets called at the end of an epoch."""
if stage == sb.Stage.TRAIN:
self.train_loss = stage_loss
else:
stats = {
"loss": stage_loss,
}
if self.error_metrics is not None and len(self.error_metrics.scores) > 0:
# Calculate error rate only if there are scores in the error_metrics
stats["error_rate"] = self.error_metrics.summarize("average")
else:
# Handle case where error_metrics are None or empty
stats["error_rate"] = float('nan') # Set error_rate to NaN if no scores available
if stage == sb.Stage.VALID:
old_lr, new_lr = self.hparams.lr_annealing(stats["error_rate"])
sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)
self.hparams.train_logger.log_stats(
{"Epoch": epoch, "lr": old_lr},
train_stats={"loss": self.train_loss},
valid_stats=stats,
)
self.checkpointer.save_and_keep_only(meta=stats, min_keys=["error_rate"])
elif stage == sb.Stage.TEST:
self.hparams.train_logger.log_stats(
{"Epoch loaded": self.hparams.epoch_counter.current},
test_stats=stats,
)
def init_optimizers(self):
"""Initializes the optimizer."""
self.optimizer = self.hparams.opt_class(self.hparams.model.parameters())
if self.checkpointer is not None:
self.checkpointer.add_recoverable("optimizer", self.optimizer)
self.optimizers_dict = {"model_optimizer": self.optimizer}
def dataio_prep(hparams):
"""Prepares the datasets to be used in the brain class."""
# Define the audio processing pipeline
@sb.utils.data_pipeline.takes("wav")
@sb.utils.data_pipeline.provides("sig")
def audio_pipeline(wav):
"""Load the signal from a WAV file."""
sig = sb.dataio.dataio.read_audio(wav)
return sig
# Initialize the label encoder
label_encoder = sb.dataio.encoder.CategoricalEncoder()
label_encoder.add_unk()
label_to_index = {
'angry': 0,
'happy': 1,
'neutral': 2,
'sad': 3,
'surprise': 4,
'disgust': 5,
'fear': 6
}
@sb.utils.data_pipeline.takes("label")
@sb.utils.data_pipeline.provides("label", "emo_encoded")
def label_pipeline(label):
"""Encode the emotion label."""
if label in label_to_index:
emo_encoded = label_to_index[label]
else:
raise ValueError(f"Unknown label encountered: {label}")
yield label, torch.tensor(emo_encoded, dtype=torch.long)
# Define datasets dictionary
datasets = {}
data_info = {
"train": hparams["train_annotation"],
"valid": hparams["valid_annotation"],
"test": hparams["test_annotation"],
}
# Load datasets and apply pipelines
for dataset_name, json_path in data_info.items():
datasets[dataset_name] = sb.dataio.dataset.DynamicItemDataset.from_json(
json_path=json_path,
replacements={"data_root": hparams["data_original"]},
dynamic_items=[audio_pipeline, label_pipeline],
output_keys=["id", "sig", "label", "emo_encoded"],
)
lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
label_encoder.load_or_create(
path=lab_enc_file,
from_didatasets=[datasets["train"]],
output_key="label",
)
return datasets
if __name__ == "__main__":
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
sb.utils.distributed.ddp_init_group(run_opts)
try:
with open(hparams_file) as fin:
hparams = load_hyperpyyaml(fin, overrides)
data_original = hparams.get("data_original")
if data_original is not None:
data_original = os.path.normpath(data_original)
if not os.path.exists(data_original):
raise ValueError(f"data_original path '{data_original}' does not exist.")
else:
raise ValueError("data_original path is not specified in the YAML configuration.")
except Exception as e:
print("Error occurred", e)
sys.exit(1)
sb.create_experiment_directory(
experiment_directory=hparams["output_folder"],
hyperparams_to_save=hparams_file,
overrides=overrides,
)
if not hparams["skip_prep"]:
prepare_kwargs = {
"data_original": hparams["data_original"],
"save_json_train": hparams["train_annotation"],
"save_json_valid": hparams["valid_annotation"],
"save_json_test": hparams["test_annotation"],
"split_ratio": hparams["split_ratio"],
"seed": hparams["seed"],
}
sb.utils.distributed.run_on_main(prepare_data, kwargs=prepare_kwargs)
datasets = dataio_prep(hparams)
hparams["wav2vec2"] = hparams["wav2vec2"].to(device=run_opts["device"])
if not hparams["freeze_wav2vec2"] and hparams["freeze_wav2vec2_conv"]:
hparams["wav2vec2"].model.feature_extractor._freeze_parameters()
emo_id_brain = EmoIdBrain(
modules=hparams["modules"],
opt_class=hparams["opt_class"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)
emo_id_brain.fit(
epoch_counter=emo_id_brain.hparams.epoch_counter,
train_set=datasets["train"],
valid_set=datasets["valid"],
train_loader_kwargs=hparams["dataloader_options"],
valid_loader_kwargs=hparams["dataloader_options"],
)
test_stats = emo_id_brain.evaluate(
test_set=datasets["test"],
min_key="error_rate",
test_loader_kwargs=hparams["dataloader_options"],
)