from typing import List from coqpit import Coqpit from torch.utils.data import Dataset from TTS.utils.audio import AudioProcessor from TTS.vocoder.datasets.gan_dataset import GANDataset from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool) -> Dataset: if config.model.lower() in "gan": dataset = GANDataset( ap=ap, items=data_items, seq_len=config.seq_len, hop_len=ap.hop_length, pad_short=config.pad_short, conv_pad=config.conv_pad, return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False, is_training=not is_eval, return_segments=not is_eval, use_noise_augment=config.use_noise_augment, use_cache=config.use_cache, verbose=verbose, ) dataset.shuffle_mapping() elif config.model.lower() == "wavegrad": dataset = WaveGradDataset( ap=ap, items=data_items, seq_len=config.seq_len, hop_len=ap.hop_length, pad_short=config.pad_short, conv_pad=config.conv_pad, is_training=not is_eval, return_segments=True, use_noise_augment=False, use_cache=config.use_cache, verbose=verbose, ) elif config.model.lower() == "wavernn": dataset = WaveRNNDataset( ap=ap, items=data_items, seq_len=config.seq_len, hop_len=ap.hop_length, pad=config.model_params.pad, mode=config.model_params.mode, mulaw=config.model_params.mulaw, is_training=not is_eval, verbose=verbose, ) else: raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.") return dataset