from typing import Dict, List import torch class BasicBatchDataPreprocessor: def __init__(self, target_source_types: List[str]): r"""Batch data preprocessor. Used for preparing mixtures and targets for training. If there are multiple target source types, the waveforms of those sources will be stacked along the channel dimension. Args: target_source_types: List[str], e.g., ['vocals', 'bass', ...] """ self.target_source_types = target_source_types def __call__(self, batch_data_dict: Dict) -> List[Dict]: r"""Format waveforms and targets for training. Args: batch_data_dict: dict, e.g., { 'mixture': (batch_size, channels_num, segment_samples), 'vocals': (batch_size, channels_num, segment_samples), 'bass': (batch_size, channels_num, segment_samples), ..., } Returns: input_dict: dict, e.g., { 'waveform': (batch_size, channels_num, segment_samples), } output_dict: dict, e.g., { 'target': (batch_size, target_sources_num * channels_num, segment_samples) } """ mixtures = batch_data_dict['mixture'] # mixtures: (batch_size, channels_num, segment_samples) # Concatenate waveforms of multiple targets along the channel axis. targets = torch.cat( [batch_data_dict[source_type] for source_type in self.target_source_types], dim=1, ) # targets: (batch_size, target_sources_num * channels_num, segment_samples) input_dict = {'waveform': mixtures} target_dict = {'waveform': targets} return input_dict, target_dict class ConditionalSisoBatchDataPreprocessor: def __init__(self, target_source_types: List[str]): r"""Conditional single input single output (SISO) batch data preprocessor. Select one target source from several target sources as training target and prepare the corresponding conditional vector. Args: target_source_types: List[str], e.g., ['vocals', 'bass', ...] """ self.target_source_types = target_source_types def __call__(self, batch_data_dict: Dict) -> List[Dict]: r"""Format waveforms and targets for training. Args: batch_data_dict: dict, e.g., { 'mixture': (batch_size, channels_num, segment_samples), 'vocals': (batch_size, channels_num, segment_samples), 'bass': (batch_size, channels_num, segment_samples), ..., } Returns: input_dict: dict, e.g., { 'waveform': (batch_size, channels_num, segment_samples), 'condition': (batch_size, target_sources_num), } output_dict: dict, e.g., { 'target': (batch_size, channels_num, segment_samples) } """ batch_size = len(batch_data_dict['mixture']) target_sources_num = len(self.target_source_types) assert ( batch_size % target_sources_num == 0 ), "Batch size should be \ evenly divided by target sources number." mixtures = batch_data_dict['mixture'] # mixtures: (batch_size, channels_num, segment_samples) conditions = torch.zeros(batch_size, target_sources_num).to(mixtures.device) # conditions: (batch_size, target_sources_num) targets = [] for n in range(batch_size): k = n % target_sources_num # source class index source_type = self.target_source_types[k] targets.append(batch_data_dict[source_type][n]) conditions[n, k] = 1 # conditions will looks like: # [[1, 0, 0, 0], # [0, 1, 0, 0], # [0, 0, 1, 0], # [0, 0, 0, 1], # [1, 0, 0, 0], # [0, 1, 0, 0], # ..., # ] targets = torch.stack(targets, dim=0) # targets: (batch_size, channels_num, segment_samples) input_dict = { 'waveform': mixtures, 'condition': conditions, } target_dict = {'waveform': targets} return input_dict, target_dict def get_batch_data_preprocessor_class(batch_data_preprocessor_type: str) -> object: r"""Get batch data preprocessor class.""" if batch_data_preprocessor_type == 'BasicBatchDataPreprocessor': return BasicBatchDataPreprocessor elif batch_data_preprocessor_type == 'ConditionalSisoBatchDataPreprocessor': return ConditionalSisoBatchDataPreprocessor else: raise NotImplementedError