Spaces:
Runtime error
Runtime error
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import random | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
from utils.data_utils import * | |
from models.base.base_dataset import ( | |
BaseOfflineCollator, | |
BaseOfflineDataset, | |
BaseTestDataset, | |
BaseTestCollator, | |
) | |
import librosa | |
from transformers import AutoTokenizer | |
class AudioLDMDataset(BaseOfflineDataset): | |
def __init__(self, cfg, dataset, is_valid=False): | |
BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid) | |
self.cfg = cfg | |
# utt2melspec | |
if cfg.preprocess.use_melspec: | |
self.utt2melspec_path = {} | |
for utt_info in self.metadata: | |
dataset = utt_info["Dataset"] | |
uid = utt_info["Uid"] | |
utt = "{}_{}".format(dataset, uid) | |
self.utt2melspec_path[utt] = os.path.join( | |
cfg.preprocess.processed_dir, | |
dataset, | |
cfg.preprocess.melspec_dir, | |
uid + ".npy", | |
) | |
# utt2wav | |
if cfg.preprocess.use_wav: | |
self.utt2wav_path = {} | |
for utt_info in self.metadata: | |
dataset = utt_info["Dataset"] | |
uid = utt_info["Uid"] | |
utt = "{}_{}".format(dataset, uid) | |
self.utt2wav_path[utt] = os.path.join( | |
cfg.preprocess.processed_dir, | |
dataset, | |
cfg.preprocess.wav_dir, | |
uid + ".wav", | |
) | |
# utt2caption | |
if cfg.preprocess.use_caption: | |
self.utt2caption = {} | |
for utt_info in self.metadata: | |
dataset = utt_info["Dataset"] | |
uid = utt_info["Uid"] | |
utt = "{}_{}".format(dataset, uid) | |
self.utt2caption[utt] = utt_info["Caption"] | |
def __getitem__(self, index): | |
# melspec: (n_mels, T) | |
# wav: (T,) | |
single_feature = BaseOfflineDataset.__getitem__(self, index) | |
utt_info = self.metadata[index] | |
dataset = utt_info["Dataset"] | |
uid = utt_info["Uid"] | |
utt = "{}_{}".format(dataset, uid) | |
if self.cfg.preprocess.use_melspec: | |
single_feature["melspec"] = np.load(self.utt2melspec_path[utt]) | |
if self.cfg.preprocess.use_wav: | |
wav, sr = librosa.load( | |
self.utt2wav_path[utt], sr=16000 | |
) # hard coding for 16KHz... | |
single_feature["wav"] = wav | |
if self.cfg.preprocess.use_caption: | |
cond_mask = np.random.choice( | |
[1, 0], | |
p=[ | |
self.cfg.preprocess.cond_mask_prob, | |
1 - self.cfg.preprocess.cond_mask_prob, | |
], | |
) # (0.1, 0.9) | |
if cond_mask: | |
single_feature["caption"] = "" | |
else: | |
single_feature["caption"] = self.utt2caption[utt] | |
return single_feature | |
def __len__(self): | |
return len(self.metadata) | |
class AudioLDMCollator(BaseOfflineCollator): | |
def __init__(self, cfg): | |
BaseOfflineCollator.__init__(self, cfg) | |
self.tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512) | |
def __call__(self, batch): | |
# mel: (B, n_mels, T) | |
# wav (option): (B, T) | |
# text_input_ids: (B, L) | |
# text_attention_mask: (B, L) | |
packed_batch_features = dict() | |
for key in batch[0].keys(): | |
if key == "melspec": | |
packed_batch_features["melspec"] = torch.from_numpy( | |
np.array([b["melspec"][:, :624] for b in batch]) | |
) | |
if key == "wav": | |
values = [torch.from_numpy(b[key]) for b in batch] | |
packed_batch_features[key] = pad_sequence( | |
values, batch_first=True, padding_value=0 | |
) | |
if key == "caption": | |
captions = [b[key] for b in batch] | |
text_input = self.tokenizer( | |
captions, return_tensors="pt", truncation=True, padding="longest" | |
) | |
text_input_ids = text_input["input_ids"] | |
text_attention_mask = text_input["attention_mask"] | |
packed_batch_features["text_input_ids"] = text_input_ids | |
packed_batch_features["text_attention_mask"] = text_attention_mask | |
return packed_batch_features | |
class AudioLDMTestDataset(BaseTestDataset): ... | |
class AudioLDMTestCollator(BaseTestCollator): ... | |