# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch import torchaudio import json import os import numpy as np import librosa from torch.nn.utils.rnn import pad_sequence from modules import whisper_extractor as whisper class TorchaudioDataset(torch.utils.data.Dataset): def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): """ Args: cfg: config dataset: dataset name """ assert isinstance(dataset, str) self.sr = sr self.cfg = cfg if metadata is None: self.train_metadata_path = os.path.join( cfg.preprocess.processed_dir, dataset, cfg.preprocess.train_file ) self.valid_metadata_path = os.path.join( cfg.preprocess.processed_dir, dataset, cfg.preprocess.valid_file ) self.metadata = self.get_metadata() else: self.metadata = metadata if accelerator is not None: self.device = accelerator.device elif torch.cuda.is_available(): self.device = torch.device("cuda") else: self.device = torch.device("cpu") def get_metadata(self): metadata = [] with open(self.train_metadata_path, "r", encoding="utf-8") as t: metadata.extend(json.load(t)) with open(self.valid_metadata_path, "r", encoding="utf-8") as v: metadata.extend(json.load(v)) return metadata def __len__(self): return len(self.metadata) def __getitem__(self, index): utt_info = self.metadata[index] wav_path = utt_info["Path"] wav, sr = torchaudio.load(wav_path) # resample if sr != self.sr: wav = torchaudio.functional.resample(wav, sr, self.sr) # downmixing if wav.shape[0] > 1: wav = torch.mean(wav, dim=0, keepdim=True) assert wav.shape[0] == 1 wav = wav.squeeze(0) # record the length of wav without padding length = wav.shape[0] # wav: (T) return utt_info, wav, length class LibrosaDataset(TorchaudioDataset): def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): super().__init__(cfg, dataset, sr, accelerator, metadata) def __getitem__(self, index): utt_info = self.metadata[index] wav_path = utt_info["Path"] wav, _ = librosa.load(wav_path, sr=self.sr) # wav: (T) wav = torch.from_numpy(wav) # record the length of wav without padding length = wav.shape[0] return utt_info, wav, length class FFmpegDataset(TorchaudioDataset): def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): super().__init__(cfg, dataset, sr, accelerator, metadata) def __getitem__(self, index): utt_info = self.metadata[index] wav_path = utt_info["Path"] # wav: (T,) wav = whisper.load_audio(wav_path) # sr = 16000 # convert to torch tensor wav = torch.from_numpy(wav) # record the length of wav without padding length = wav.shape[0] return utt_info, wav, length def collate_batch(batch_list): """ Args: batch_list: list of (metadata, wav, length) """ metadata = [item[0] for item in batch_list] # wavs: (B, T) wavs = pad_sequence([item[1] for item in batch_list], batch_first=True) lens = [item[2] for item in batch_list] return metadata, wavs, lens