Music_Source_Separation / bytesep /data /batch_data_preprocessors.py
jone's picture
init
75c6e9a
from typing import Dict, List
import torch
class BasicBatchDataPreprocessor:
def __init__(self, target_source_types: List[str]):
r"""Batch data preprocessor. Used for preparing mixtures and targets for
training. If there are multiple target source types, the waveforms of
those sources will be stacked along the channel dimension.
Args:
target_source_types: List[str], e.g., ['vocals', 'bass', ...]
"""
self.target_source_types = target_source_types
def __call__(self, batch_data_dict: Dict) -> List[Dict]:
r"""Format waveforms and targets for training.
Args:
batch_data_dict: dict, e.g., {
'mixture': (batch_size, channels_num, segment_samples),
'vocals': (batch_size, channels_num, segment_samples),
'bass': (batch_size, channels_num, segment_samples),
...,
}
Returns:
input_dict: dict, e.g., {
'waveform': (batch_size, channels_num, segment_samples),
}
output_dict: dict, e.g., {
'target': (batch_size, target_sources_num * channels_num, segment_samples)
}
"""
mixtures = batch_data_dict['mixture']
# mixtures: (batch_size, channels_num, segment_samples)
# Concatenate waveforms of multiple targets along the channel axis.
targets = torch.cat(
[batch_data_dict[source_type] for source_type in self.target_source_types],
dim=1,
)
# targets: (batch_size, target_sources_num * channels_num, segment_samples)
input_dict = {'waveform': mixtures}
target_dict = {'waveform': targets}
return input_dict, target_dict
class ConditionalSisoBatchDataPreprocessor:
def __init__(self, target_source_types: List[str]):
r"""Conditional single input single output (SISO) batch data
preprocessor. Select one target source from several target sources as
training target and prepare the corresponding conditional vector.
Args:
target_source_types: List[str], e.g., ['vocals', 'bass', ...]
"""
self.target_source_types = target_source_types
def __call__(self, batch_data_dict: Dict) -> List[Dict]:
r"""Format waveforms and targets for training.
Args:
batch_data_dict: dict, e.g., {
'mixture': (batch_size, channels_num, segment_samples),
'vocals': (batch_size, channels_num, segment_samples),
'bass': (batch_size, channels_num, segment_samples),
...,
}
Returns:
input_dict: dict, e.g., {
'waveform': (batch_size, channels_num, segment_samples),
'condition': (batch_size, target_sources_num),
}
output_dict: dict, e.g., {
'target': (batch_size, channels_num, segment_samples)
}
"""
batch_size = len(batch_data_dict['mixture'])
target_sources_num = len(self.target_source_types)
assert (
batch_size % target_sources_num == 0
), "Batch size should be \
evenly divided by target sources number."
mixtures = batch_data_dict['mixture']
# mixtures: (batch_size, channels_num, segment_samples)
conditions = torch.zeros(batch_size, target_sources_num).to(mixtures.device)
# conditions: (batch_size, target_sources_num)
targets = []
for n in range(batch_size):
k = n % target_sources_num # source class index
source_type = self.target_source_types[k]
targets.append(batch_data_dict[source_type][n])
conditions[n, k] = 1
# conditions will looks like:
# [[1, 0, 0, 0],
# [0, 1, 0, 0],
# [0, 0, 1, 0],
# [0, 0, 0, 1],
# [1, 0, 0, 0],
# [0, 1, 0, 0],
# ...,
# ]
targets = torch.stack(targets, dim=0)
# targets: (batch_size, channels_num, segment_samples)
input_dict = {
'waveform': mixtures,
'condition': conditions,
}
target_dict = {'waveform': targets}
return input_dict, target_dict
def get_batch_data_preprocessor_class(batch_data_preprocessor_type: str) -> object:
r"""Get batch data preprocessor class."""
if batch_data_preprocessor_type == 'BasicBatchDataPreprocessor':
return BasicBatchDataPreprocessor
elif batch_data_preprocessor_type == 'ConditionalSisoBatchDataPreprocessor':
return ConditionalSisoBatchDataPreprocessor
else:
raise NotImplementedError