File size: 1,987 Bytes
66a6dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import glob
import torch
import torchaudio
from tqdm import tqdm


class StyleDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        audio_dir: str,
        subset: str = "train",
        sample_rate: int = 24000,
        length: int = 131072,
    ) -> None:
        super().__init__()
        self.audio_dir = audio_dir
        self.subset = subset
        self.sample_rate = sample_rate
        self.length = length

        self.style_dirs = glob.glob(os.path.join(audio_dir, subset, "*"))
        self.style_dirs = [sd for sd in self.style_dirs if os.path.isdir(sd)]
        self.num_classes = len(self.style_dirs)
        self.class_labels = {"broadcast" : 0, "telephone": 1, "neutral": 2, "bright": 3, "warm": 4}

        self.examples = []
        for n, style_dir in enumerate(self.style_dirs):

            # get all files in style dir
            style_filepaths = glob.glob(os.path.join(style_dir, "*.wav"))
            style_name = os.path.basename(style_dir)
            for style_filepath in tqdm(style_filepaths, ncols=120):
                # load audio file
                x, sr = torchaudio.load(style_filepath)

                # sum to mono if needed
                if x.shape[0] > 1:
                    x = x.mean(dim=0, keepdim=True)

                # resample
                if sr != self.sample_rate:
                    x = torchaudio.transforms.Resample(sr, self.sample_rate)(x)

                # crop length after resample
                if x.shape[-1] >= self.length:
                    x = x[...,:self.length]

                # store example
                example = (x, self.class_labels[style_name])
                self.examples.append(example)

        print(f"Loaded {len(self.examples)} examples for {subset} subset.")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        x = example[0]
        y = example[1]
        return x, y