Spaces:
Running
on
A10G
Running
on
A10G
File size: 3,651 Bytes
df2accb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torchaudio
import json
import os
import numpy as np
import librosa
from torch.nn.utils.rnn import pad_sequence
from modules import whisper_extractor as whisper
class TorchaudioDataset(torch.utils.data.Dataset):
def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None):
"""
Args:
cfg: config
dataset: dataset name
"""
assert isinstance(dataset, str)
self.sr = sr
self.cfg = cfg
if metadata is None:
self.train_metadata_path = os.path.join(
cfg.preprocess.processed_dir, dataset, cfg.preprocess.train_file
)
self.valid_metadata_path = os.path.join(
cfg.preprocess.processed_dir, dataset, cfg.preprocess.valid_file
)
self.metadata = self.get_metadata()
else:
self.metadata = metadata
if accelerator is not None:
self.device = accelerator.device
elif torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
def get_metadata(self):
metadata = []
with open(self.train_metadata_path, "r", encoding="utf-8") as t:
metadata.extend(json.load(t))
with open(self.valid_metadata_path, "r", encoding="utf-8") as v:
metadata.extend(json.load(v))
return metadata
def __len__(self):
return len(self.metadata)
def __getitem__(self, index):
utt_info = self.metadata[index]
wav_path = utt_info["Path"]
wav, sr = torchaudio.load(wav_path)
# resample
if sr != self.sr:
wav = torchaudio.functional.resample(wav, sr, self.sr)
# downmixing
if wav.shape[0] > 1:
wav = torch.mean(wav, dim=0, keepdim=True)
assert wav.shape[0] == 1
wav = wav.squeeze(0)
# record the length of wav without padding
length = wav.shape[0]
# wav: (T)
return utt_info, wav, length
class LibrosaDataset(TorchaudioDataset):
def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None):
super().__init__(cfg, dataset, sr, accelerator, metadata)
def __getitem__(self, index):
utt_info = self.metadata[index]
wav_path = utt_info["Path"]
wav, _ = librosa.load(wav_path, sr=self.sr)
# wav: (T)
wav = torch.from_numpy(wav)
# record the length of wav without padding
length = wav.shape[0]
return utt_info, wav, length
class FFmpegDataset(TorchaudioDataset):
def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None):
super().__init__(cfg, dataset, sr, accelerator, metadata)
def __getitem__(self, index):
utt_info = self.metadata[index]
wav_path = utt_info["Path"]
# wav: (T,)
wav = whisper.load_audio(wav_path) # sr = 16000
# convert to torch tensor
wav = torch.from_numpy(wav)
# record the length of wav without padding
length = wav.shape[0]
return utt_info, wav, length
def collate_batch(batch_list):
"""
Args:
batch_list: list of (metadata, wav, length)
"""
metadata = [item[0] for item in batch_list]
# wavs: (B, T)
wavs = pad_sequence([item[1] for item in batch_list], batch_first=True)
lens = [item[2] for item in batch_list]
return metadata, wavs, lens
|