import logging import os import time from typing import List, NoReturn import librosa import numpy as np import pysepm import pytorch_lightning as pl import torch.nn as nn from pesq import pesq from pytorch_lightning.utilities import rank_zero_only from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback from bytesep.inference import Separator from bytesep.utils import StatisticsContainer, read_yaml def get_voicebank_demand_callbacks( config_yaml: str, workspace: str, checkpoints_dir: str, statistics_path: str, logger: pl.loggers.TensorBoardLogger, model: nn.Module, evaluate_device: str, ) -> List[pl.Callback]: """Get Voicebank-Demand callbacks of a config yaml. Args: config_yaml: str workspace: str checkpoints_dir: str, directory to save checkpoints statistics_dir: str, directory to save statistics logger: pl.loggers.TensorBoardLogger model: nn.Module evaluate_device: str Return: callbacks: List[pl.Callback] """ configs = read_yaml(config_yaml) task_name = configs['task_name'] target_source_types = configs['train']['target_source_types'] input_channels = configs['train']['channels'] evaluation_audios_dir = os.path.join(workspace, "evaluation_audios", task_name) sample_rate = configs['train']['sample_rate'] evaluate_step_frequency = configs['train']['evaluate_step_frequency'] save_step_frequency = configs['train']['save_step_frequency'] test_batch_size = configs['evaluate']['batch_size'] test_segment_seconds = configs['evaluate']['segment_seconds'] test_segment_samples = int(test_segment_seconds * sample_rate) assert len(target_source_types) == 1 target_source_type = target_source_types[0] assert target_source_type == 'speech' # save checkpoint callback save_checkpoints_callback = SaveCheckpointsCallback( model=model, checkpoints_dir=checkpoints_dir, save_step_frequency=save_step_frequency, ) # statistics container statistics_container = StatisticsContainer(statistics_path) # evaluation callback evaluate_test_callback = EvaluationCallback( model=model, input_channels=input_channels, sample_rate=sample_rate, evaluation_audios_dir=evaluation_audios_dir, segment_samples=test_segment_samples, batch_size=test_batch_size, device=evaluate_device, evaluate_step_frequency=evaluate_step_frequency, logger=logger, statistics_container=statistics_container, ) callbacks = [save_checkpoints_callback, evaluate_test_callback] return callbacks class EvaluationCallback(pl.Callback): def __init__( self, model: nn.Module, input_channels: int, evaluation_audios_dir, sample_rate: int, segment_samples: int, batch_size: int, device: str, evaluate_step_frequency: int, logger: pl.loggers.TensorBoardLogger, statistics_container: StatisticsContainer, ): r"""Callback to evaluate every #save_step_frequency steps. Args: model: nn.Module input_channels: int evaluation_audios_dir: str, directory containing audios for evaluation sample_rate: int segment_samples: int, length of segments to be input to a model, e.g., 44100*30 batch_size, int, e.g., 12 device: str, e.g., 'cuda' evaluate_step_frequency: int, evaluate every #save_step_frequency steps logger: pl.loggers.TensorBoardLogger statistics_container: StatisticsContainer """ self.model = model self.mono = True self.sample_rate = sample_rate self.segment_samples = segment_samples self.evaluate_step_frequency = evaluate_step_frequency self.logger = logger self.statistics_container = statistics_container self.clean_dir = os.path.join(evaluation_audios_dir, "clean_testset_wav") self.noisy_dir = os.path.join(evaluation_audios_dir, "noisy_testset_wav") self.EVALUATION_SAMPLE_RATE = 16000 # Evaluation sample rate of the # Voicebank-Demand task. # separator self.separator = Separator(model, self.segment_samples, batch_size, device) @rank_zero_only def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn: r"""Evaluate losses on a few mini-batches. Losses are only used for observing training, and are not final F1 metrics. """ global_step = trainer.global_step if global_step % self.evaluate_step_frequency == 0: audio_names = sorted( [ audio_name for audio_name in sorted(os.listdir(self.clean_dir)) if audio_name.endswith('.wav') ] ) error_str = "Directory {} does not contain audios for evaluation!".format( self.clean_dir ) assert len(audio_names) > 0, error_str pesqs, csigs, cbaks, covls, ssnrs = [], [], [], [], [] logging.info("--- Step {} ---".format(global_step)) logging.info("Total {} pieces for evaluation:".format(len(audio_names))) eval_time = time.time() for n, audio_name in enumerate(audio_names): # Load audio. clean_path = os.path.join(self.clean_dir, audio_name) mixture_path = os.path.join(self.noisy_dir, audio_name) mixture, _ = librosa.core.load( mixture_path, sr=self.sample_rate, mono=self.mono ) if mixture.ndim == 1: mixture = mixture[None, :] # (channels_num, audio_length) # Separate. input_dict = {'waveform': mixture} sep_wav = self.separator.separate(input_dict) # (channels_num, audio_length) # Target clean, _ = librosa.core.load( clean_path, sr=self.EVALUATION_SAMPLE_RATE, mono=self.mono ) # to mono sep_wav = np.squeeze(sep_wav) # Resample for evaluation. sep_wav = librosa.resample( sep_wav, orig_sr=self.sample_rate, target_sr=self.EVALUATION_SAMPLE_RATE, ) sep_wav = librosa.util.fix_length(sep_wav, size=len(clean), axis=0) # (channels, audio_length) # Evaluate metrics pesq_ = pesq(self.EVALUATION_SAMPLE_RATE, clean, sep_wav, 'wb') (csig, cbak, covl) = pysepm.composite( clean, sep_wav, self.EVALUATION_SAMPLE_RATE ) ssnr = pysepm.SNRseg(clean, sep_wav, self.EVALUATION_SAMPLE_RATE) pesqs.append(pesq_) csigs.append(csig) cbaks.append(cbak) covls.append(covl) ssnrs.append(ssnr) print( '{}, {}, PESQ: {:.3f}, CSIG: {:.3f}, CBAK: {:.3f}, COVL: {:.3f}, SSNR: {:.3f}'.format( n, audio_name, pesq_, csig, cbak, covl, ssnr ) ) logging.info("-----------------------------") logging.info('Avg PESQ: {:.3f}'.format(np.mean(pesqs))) logging.info('Avg CSIG: {:.3f}'.format(np.mean(csigs))) logging.info('Avg CBAK: {:.3f}'.format(np.mean(cbaks))) logging.info('Avg COVL: {:.3f}'.format(np.mean(covls))) logging.info('Avg SSNR: {:.3f}'.format(np.mean(ssnrs))) logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time)) statistics = {"pesq": np.mean(pesqs)} self.statistics_container.append(global_step, statistics, 'test') self.statistics_container.dump()