File size: 4,757 Bytes
5019931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from typing import Dict, List

import torch


class BasicBatchDataPreprocessor:
    def __init__(self, target_source_types: List[str]):
        r"""Batch data preprocessor. Used for preparing mixtures and targets for
        training. If there are multiple target source types, the waveforms of
        those sources will be stacked along the channel dimension.

        Args:
            target_source_types: List[str], e.g., ['vocals', 'bass', ...]
        """
        self.target_source_types = target_source_types

    def __call__(self, batch_data_dict: Dict) -> List[Dict]:
        r"""Format waveforms and targets for training.

        Args:
            batch_data_dict: dict, e.g., {
                'mixture': (batch_size, channels_num, segment_samples),
                'vocals': (batch_size, channels_num, segment_samples),
                'bass': (batch_size, channels_num, segment_samples),
                ...,
            }

        Returns:
            input_dict: dict, e.g., {
                'waveform': (batch_size, channels_num, segment_samples),
            }
            output_dict: dict, e.g., {
                'target': (batch_size, target_sources_num * channels_num, segment_samples)
            }
        """
        mixtures = batch_data_dict['mixture']
        # mixtures: (batch_size, channels_num, segment_samples)

        # Concatenate waveforms of multiple targets along the channel axis.
        targets = torch.cat(
            [batch_data_dict[source_type] for source_type in self.target_source_types],
            dim=1,
        )
        # targets: (batch_size, target_sources_num * channels_num, segment_samples)

        input_dict = {'waveform': mixtures}
        target_dict = {'waveform': targets}

        return input_dict, target_dict


class ConditionalSisoBatchDataPreprocessor:
    def __init__(self, target_source_types: List[str]):
        r"""Conditional single input single output (SISO) batch data
        preprocessor. Select one target source from several target sources as
        training target and prepare the corresponding conditional vector.

        Args:
            target_source_types: List[str], e.g., ['vocals', 'bass', ...]
        """
        self.target_source_types = target_source_types

    def __call__(self, batch_data_dict: Dict) -> List[Dict]:
        r"""Format waveforms and targets for training.

        Args:
            batch_data_dict: dict, e.g., {
                'mixture': (batch_size, channels_num, segment_samples),
                'vocals': (batch_size, channels_num, segment_samples),
                'bass': (batch_size, channels_num, segment_samples),
                ...,
            }

        Returns:
            input_dict: dict, e.g., {
                'waveform': (batch_size, channels_num, segment_samples),
                'condition': (batch_size, target_sources_num),
            }
            output_dict: dict, e.g., {
                'target': (batch_size, channels_num, segment_samples)
            }
        """

        batch_size = len(batch_data_dict['mixture'])
        target_sources_num = len(self.target_source_types)

        assert (
            batch_size % target_sources_num == 0
        ), "Batch size should be \
            evenly divided by target sources number."

        mixtures = batch_data_dict['mixture']
        # mixtures: (batch_size, channels_num, segment_samples)

        conditions = torch.zeros(batch_size, target_sources_num).to(mixtures.device)
        # conditions: (batch_size, target_sources_num)

        targets = []

        for n in range(batch_size):

            k = n % target_sources_num  # source class index
            source_type = self.target_source_types[k]

            targets.append(batch_data_dict[source_type][n])

            conditions[n, k] = 1

        # conditions will looks like:
        # [[1, 0, 0, 0],
        #  [0, 1, 0, 0],
        #  [0, 0, 1, 0],
        #  [0, 0, 0, 1],
        #  [1, 0, 0, 0],
        #  [0, 1, 0, 0],
        #  ...,
        # ]

        targets = torch.stack(targets, dim=0)
        # targets: (batch_size, channels_num, segment_samples)

        input_dict = {
            'waveform': mixtures,
            'condition': conditions,
        }

        target_dict = {'waveform': targets}

        return input_dict, target_dict


def get_batch_data_preprocessor_class(batch_data_preprocessor_type: str) -> object:
    r"""Get batch data preprocessor class."""
    if batch_data_preprocessor_type == 'BasicBatchDataPreprocessor':
        return BasicBatchDataPreprocessor

    elif batch_data_preprocessor_type == 'ConditionalSisoBatchDataPreprocessor':
        return ConditionalSisoBatchDataPreprocessor

    else:
        raise NotImplementedError