File size: 4,930 Bytes
9b2107c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import glob
import os
import random
from multiprocessing import Manager
from typing import List, Tuple

import numpy as np
import torch
from torch.utils.data import Dataset


class WaveGradDataset(Dataset):
    """
    WaveGrad Dataset searchs for all the wav files under root path
    and converts them to acoustic features on the fly and returns
    random segments of (audio, feature) couples.
    """

    def __init__(
        self,
        ap,
        items,
        seq_len,
        hop_len,
        pad_short,
        conv_pad=2,
        is_training=True,
        return_segments=True,
        use_noise_augment=False,
        use_cache=False,
        verbose=False,
    ):
        super().__init__()
        self.ap = ap
        self.item_list = items
        self.seq_len = seq_len if return_segments else None
        self.hop_len = hop_len
        self.pad_short = pad_short
        self.conv_pad = conv_pad
        self.is_training = is_training
        self.return_segments = return_segments
        self.use_cache = use_cache
        self.use_noise_augment = use_noise_augment
        self.verbose = verbose

        if return_segments:
            assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."
        self.feat_frame_len = seq_len // hop_len + (2 * conv_pad)

        # cache acoustic features
        if use_cache:
            self.create_feature_cache()

    def create_feature_cache(self):
        self.manager = Manager()
        self.cache = self.manager.list()
        self.cache += [None for _ in range(len(self.item_list))]

    @staticmethod
    def find_wav_files(path):
        return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True)

    def __len__(self):
        return len(self.item_list)

    def __getitem__(self, idx):
        item = self.load_item(idx)
        return item

    def load_test_samples(self, num_samples: int) -> List[Tuple]:
        """Return test samples.

        Args:
            num_samples (int): Number of samples to return.

        Returns:
            List[Tuple]: melspectorgram and audio.

        Shapes:
            - melspectrogram (Tensor): :math:`[C, T]`
            - audio (Tensor): :math:`[T_audio]`
        """
        samples = []
        return_segments = self.return_segments
        self.return_segments = False
        for idx in range(num_samples):
            mel, audio = self.load_item(idx)
            samples.append([mel, audio])
        self.return_segments = return_segments
        return samples

    def load_item(self, idx):
        """load (audio, feat) couple"""
        # compute features from wav
        wavpath = self.item_list[idx]

        if self.use_cache and self.cache[idx] is not None:
            audio = self.cache[idx]
        else:
            audio = self.ap.load_wav(wavpath)

            if self.return_segments:
                # correct audio length wrt segment length
                if audio.shape[-1] < self.seq_len + self.pad_short:
                    audio = np.pad(
                        audio, (0, self.seq_len + self.pad_short - len(audio)), mode="constant", constant_values=0.0
                    )
                assert (
                    audio.shape[-1] >= self.seq_len + self.pad_short
                ), f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}"

            # correct the audio length wrt hop length
            p = (audio.shape[-1] // self.hop_len + 1) * self.hop_len - audio.shape[-1]
            audio = np.pad(audio, (0, p), mode="constant", constant_values=0.0)

            if self.use_cache:
                self.cache[idx] = audio

        if self.return_segments:
            max_start = len(audio) - self.seq_len
            start = random.randint(0, max_start)
            end = start + self.seq_len
            audio = audio[start:end]

        if self.use_noise_augment and self.is_training and self.return_segments:
            audio = audio + (1 / 32768) * torch.randn_like(audio)

        mel = self.ap.melspectrogram(audio)
        mel = mel[..., :-1]  # ignore the padding

        audio = torch.from_numpy(audio).float()
        mel = torch.from_numpy(mel).float().squeeze(0)
        return (mel, audio)

    @staticmethod
    def collate_full_clips(batch):
        """This is used in tune_wavegrad.py.
        It pads sequences to the max length."""
        max_mel_length = max([b[0].shape[1] for b in batch]) if len(batch) > 1 else batch[0][0].shape[1]
        max_audio_length = max([b[1].shape[0] for b in batch]) if len(batch) > 1 else batch[0][1].shape[0]

        mels = torch.zeros([len(batch), batch[0][0].shape[0], max_mel_length])
        audios = torch.zeros([len(batch), max_audio_length])

        for idx, b in enumerate(batch):
            mel = b[0]
            audio = b[1]
            mels[idx, :, : mel.shape[1]] = mel
            audios[idx, : audio.shape[0]] = audio

        return mels, audios