Spaces:
Runtime error
Runtime error
File size: 2,940 Bytes
6bc94ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
from encoder.data_objects.random_cycler import RandomCycler
from encoder.data_objects.speaker_batch import SpeakerBatch
from encoder.data_objects.utterance_batch import UtteranceBatch
from encoder.data_objects.speaker import Speaker
from encoder.params_data import partials_n_frames
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from os import listdir
from os.path import isfile
import numpy as np
# TODO: improve with a pool of speakers for data efficiency
class Train_Dataset(Dataset):
def __init__(self, datasets_root: Path):
self.root = datasets_root
speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
if len(speaker_dirs) == 0:
raise Exception("No speakers found. Make sure you are pointing to the directory "
"containing all preprocessed speaker directories.")
self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
self.speaker_cycler = RandomCycler(self.speakers)
def __len__(self):
return int(1e8)
def __getitem__(self, index):
return next(self.speaker_cycler)
def get_logs(self):
log_string = ""
for log_fpath in self.root.glob("*.txt"):
with log_fpath.open("r") as log_file:
log_string += "".join(log_file.readlines())
return log_string
class Dev_Dataset(Dataset):
def __init__(self, datasets_root: Path):
self.root = datasets_root
speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
if len(speaker_dirs) == 0:
raise Exception("No speakers found. Make sure you are pointing to the directory "
"containing all preprocessed speaker directories.")
self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
self.speaker_cycler = RandomCycler(self.speakers)
def __len__(self):
return len(self.speakers)
def __getitem__(self, index):
return next(self.speaker_cycler)
class DataLoader(DataLoader):
def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, shuffle, sampler=None,
batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
worker_init_fn=None):
self.utterances_per_speaker = utterances_per_speaker
super().__init__(
dataset=dataset,
batch_size=speakers_per_batch,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=self.collate,
pin_memory=pin_memory,
drop_last=False,
timeout=timeout,
worker_init_fn=worker_init_fn
)
def collate(self, speakers):
return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
|