File size: 4,446 Bytes
45ee559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os

import numpy as np
from torch.utils.data import DataLoader

from tests import get_tests_output_path, get_tests_path
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs import BaseGANVocoderConfig
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.datasets.preprocess import load_wav_data

file_path = os.path.dirname(os.path.realpath(__file__))
OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/")
os.makedirs(OUTPATH, exist_ok=True)

C = BaseGANVocoderConfig()

test_data_path = os.path.join(get_tests_path(), "data/ljspeech/")
ok_ljspeech = os.path.exists(test_data_path)


def gan_dataset_case(
    batch_size, seq_len, hop_len, conv_pad, return_pairs, return_segments, use_noise_augment, use_cache, num_workers
):
    """Run dataloader with given parameters and check conditions"""
    ap = AudioProcessor(**C.audio)
    _, train_items = load_wav_data(test_data_path, 10)
    dataset = GANDataset(
        ap,
        train_items,
        seq_len=seq_len,
        hop_len=hop_len,
        pad_short=2000,
        conv_pad=conv_pad,
        return_pairs=return_pairs,
        return_segments=return_segments,
        use_noise_augment=use_noise_augment,
        use_cache=use_cache,
    )
    loader = DataLoader(
        dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True
    )

    max_iter = 10
    count_iter = 0

    def check_item(feat, wav):
        """Pass a single pair of features and waveform"""
        feat = feat.numpy()
        wav = wav.numpy()
        expected_feat_shape = (batch_size, ap.num_mels, seq_len // hop_len + conv_pad * 2)

        # check shapes
        assert np.all(feat.shape == expected_feat_shape), f" [!] {feat.shape} vs {expected_feat_shape}"
        assert (feat.shape[2] - conv_pad * 2) * hop_len == wav.shape[2]

        # check feature vs audio match
        if not use_noise_augment:
            for idx in range(batch_size):
                audio = wav[idx].squeeze()
                feat = feat[idx]
                mel = ap.melspectrogram(audio)
                # the first 2 and the last 2 frames are skipped due to the padding
                # differences in stft
                max_diff = abs((feat - mel[:, : feat.shape[-1]])[:, 2:-2]).max()
                assert max_diff <= 1e-6, f" [!] {max_diff}"

    # return random segments or return the whole audio
    if return_segments:
        if return_pairs:
            for item1, item2 in loader:
                feat1, wav1 = item1
                feat2, wav2 = item2
                check_item(feat1, wav1)
                check_item(feat2, wav2)
                count_iter += 1
        else:
            for item1 in loader:
                feat1, wav1 = item1
                check_item(feat1, wav1)
                count_iter += 1
    else:
        for item in loader:
            feat, wav = item
            expected_feat_shape = (batch_size, ap.num_mels, (wav.shape[-1] // hop_len) + (conv_pad * 2))
            assert np.all(feat.shape == expected_feat_shape), f" [!] {feat.shape} vs {expected_feat_shape}"
            assert (feat.shape[2] - conv_pad * 2) * hop_len == wav.shape[2]
            count_iter += 1
            if count_iter == max_iter:
                break


def test_parametrized_gan_dataset():
    """test dataloader with different parameters"""
    params = [
        [32, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, True, False, True, 0],
        [32, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, True, False, True, 4],
        [1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, True, True, True, 0],
        [1, C.audio["hop_length"], C.audio["hop_length"], 0, True, True, True, True, 0],
        [1, C.audio["hop_length"] * 10, C.audio["hop_length"], 2, True, True, True, True, 0],
        [1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, False, True, True, 0],
        [1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, True, False, True, 0],
        [1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, False, True, True, False, 0],
        [1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, False, False, False, 0],
        [1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, False, False, False, 0],
    ]
    for param in params:
        print(param)
        gan_dataset_case(*param)