Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2024 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import random | |
import torch | |
import torchaudio | |
from torch.nn.utils.rnn import pad_sequence | |
from utils.data_utils import * | |
from models.base.base_dataset import ( | |
BaseOfflineCollator, | |
BaseOfflineDataset, | |
BaseTestDataset, | |
BaseTestCollator, | |
) | |
from text import text_to_sequence | |
class JetsDataset(BaseOfflineDataset): | |
def __init__(self, cfg, dataset, is_valid=False): | |
BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid) | |
self.batch_size = cfg.train.batch_size | |
cfg = cfg.preprocess | |
# utt2duration | |
self.utt2duration_path = {} | |
for utt_info in self.metadata: | |
dataset = utt_info["Dataset"] | |
uid = utt_info["Uid"] | |
utt = "{}_{}".format(dataset, uid) | |
self.utt2duration_path[utt] = os.path.join( | |
cfg.processed_dir, | |
dataset, | |
cfg.duration_dir, | |
uid + ".npy", | |
) | |
self.utt2dur = self.read_duration() | |
if cfg.use_frame_energy: | |
self.frame_utt2energy, self.energy_statistic = load_energy( | |
self.metadata, | |
cfg.processed_dir, | |
cfg.energy_dir, | |
use_log_scale=cfg.use_log_scale_energy, | |
utt2spk=self.preprocess.utt2spk if cfg.use_spkid else None, | |
return_norm=True, | |
) | |
elif cfg.use_phone_energy: | |
self.phone_utt2energy, self.energy_statistic = load_energy( | |
self.metadata, | |
cfg.processed_dir, | |
cfg.phone_energy_dir, | |
use_log_scale=cfg.use_log_scale_energy, | |
utt2spk=self.utt2spk if cfg.use_spkid else None, | |
return_norm=True, | |
) | |
if cfg.use_frame_pitch: | |
self.frame_utt2pitch, self.pitch_statistic = load_energy( | |
self.metadata, | |
cfg.processed_dir, | |
cfg.pitch_dir, | |
use_log_scale=cfg.energy_extract_mode, | |
utt2spk=self.utt2spk if cfg.use_spkid else None, | |
return_norm=True, | |
) | |
elif cfg.use_phone_pitch: | |
self.phone_utt2pitch, self.pitch_statistic = load_energy( | |
self.metadata, | |
cfg.processed_dir, | |
cfg.phone_pitch_dir, | |
use_log_scale=cfg.use_log_scale_pitch, | |
utt2spk=self.utt2spk if cfg.use_spkid else None, | |
return_norm=True, | |
) | |
# utt2lab | |
self.utt2lab_path = {} | |
for utt_info in self.metadata: | |
dataset = utt_info["Dataset"] | |
uid = utt_info["Uid"] | |
utt = "{}_{}".format(dataset, uid) | |
self.utt2lab_path[utt] = os.path.join( | |
cfg.processed_dir, | |
dataset, | |
cfg.lab_dir, | |
uid + ".txt", | |
) | |
self.speaker_map = {} | |
if os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")): | |
with open( | |
os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")) | |
) as f: | |
self.speaker_map = json.load(f) | |
self.metadata = self.check_metadata() | |
if cfg.use_audios: | |
self.utt2audio_path = {} | |
for utt_info in self.metadata: | |
dataset = utt_info["Dataset"] | |
uid = utt_info["Uid"] | |
utt = "{}_{}".format(dataset, uid) | |
if cfg.extract_audio: | |
self.utt2audio_path[utt] = os.path.join( | |
cfg.processed_dir, | |
dataset, | |
cfg.audio_dir, | |
uid + ".wav", | |
) | |
else: | |
self.utt2audio_path[utt] = utt_info["Path"] | |
def __getitem__(self, index): | |
single_feature = BaseOfflineDataset.__getitem__(self, index) | |
utt_info = self.metadata[index] | |
dataset = utt_info["Dataset"] | |
uid = utt_info["Uid"] | |
utt = "{}_{}".format(dataset, uid) | |
duration = self.utt2dur[utt] | |
# text | |
f = open(self.utt2lab_path[utt], "r") | |
phones = f.readlines()[0].strip() | |
f.close() | |
# todo: add cleaner(chenxi) | |
phones_ids = np.array(text_to_sequence(phones, ["english_cleaners"])) | |
text_len = len(phones_ids) | |
if self.cfg.preprocess.use_frame_pitch: | |
pitch = self.frame_utt2pitch[utt] | |
elif self.cfg.preprocess.use_phone_pitch: | |
pitch = self.phone_utt2pitch[utt] | |
if self.cfg.preprocess.use_frame_energy: | |
energy = self.frame_utt2energy[utt] | |
elif self.cfg.preprocess.use_phone_energy: | |
energy = self.phone_utt2energy[utt] | |
# speaker | |
if len(self.speaker_map) > 0: | |
speaker_id = self.speaker_map[utt_info["Singer"]] | |
else: | |
speaker_id = 0 | |
single_feature.update( | |
{ | |
"durations": duration, | |
"texts": phones_ids, | |
"spk_id": speaker_id, | |
"text_len": text_len, | |
"pitch": pitch, | |
"energy": energy, | |
"uid": uid, | |
} | |
) | |
if self.cfg.preprocess.use_audios: | |
audio, sr = torchaudio.load(self.utt2audio_path[utt]) | |
audio = audio.cpu().numpy().squeeze() | |
single_feature["audio"] = audio | |
single_feature["audio_len"] = audio.shape[0] | |
return self.clip_if_too_long(single_feature) | |
def read_duration(self): | |
# read duration | |
utt2dur = {} | |
for index in range(len(self.metadata)): | |
utt_info = self.metadata[index] | |
dataset = utt_info["Dataset"] | |
uid = utt_info["Uid"] | |
utt = "{}_{}".format(dataset, uid) | |
if not os.path.exists(self.utt2mel_path[utt]) or not os.path.exists( | |
self.utt2duration_path[utt] | |
): | |
continue | |
mel = np.load(self.utt2mel_path[utt]).transpose(1, 0) | |
duration = np.load(self.utt2duration_path[utt]) | |
assert mel.shape[0] == sum( | |
duration | |
), f"{utt}: mismatch length between mel {mel.shape[0]} and sum(duration) {sum(duration)}" | |
utt2dur[utt] = duration | |
return utt2dur | |
def __len__(self): | |
return len(self.metadata) | |
def random_select(self, feature_seq_len, max_seq_len, ending_ts=2812): | |
""" | |
ending_ts: to avoid invalid whisper features for over 30s audios | |
2812 = 30 * 24000 // 256 | |
""" | |
ts = max(feature_seq_len - max_seq_len, 0) | |
ts = min(ts, ending_ts - max_seq_len) | |
start = random.randint(0, ts) | |
end = start + max_seq_len | |
return start, end | |
def clip_if_too_long(self, sample, max_seq_len=1000): | |
""" | |
sample : | |
{ | |
'spk_id': (1,), | |
'target_len': int | |
'mel': (seq_len, dim), | |
'frame_pitch': (seq_len,) | |
'frame_energy': (seq_len,) | |
'content_vector_feat': (seq_len, dim) | |
} | |
""" | |
if sample["target_len"] <= max_seq_len: | |
return sample | |
start, end = self.random_select(sample["target_len"], max_seq_len) | |
sample["target_len"] = end - start | |
for k in sample.keys(): | |
if k not in ["spk_id", "target_len"]: | |
sample[k] = sample[k][start:end] | |
return sample | |
def check_metadata(self): | |
new_metadata = [] | |
for utt_info in self.metadata: | |
dataset = utt_info["Dataset"] | |
uid = utt_info["Uid"] | |
utt = "{}_{}".format(dataset, uid) | |
if not os.path.exists(self.utt2duration_path[utt]) or not os.path.exists( | |
self.utt2mel_path[utt] | |
): | |
continue | |
else: | |
new_metadata.append(utt_info) | |
return new_metadata | |
class JetsCollator(BaseOfflineCollator): | |
"""Zero-pads model inputs and targets based on number of frames per step""" | |
def __init__(self, cfg): | |
BaseOfflineCollator.__init__(self, cfg) | |
self.sort = cfg.train.sort_sample | |
self.batch_size = cfg.train.batch_size | |
self.drop_last = cfg.train.drop_last | |
def __call__(self, batch): | |
# mel: [b, T, n_mels] | |
# frame_pitch, frame_energy: [1, T] | |
# target_len: [1] | |
# spk_id: [b, 1] | |
# mask: [b, T, 1] | |
packed_batch_features = dict() | |
for key in batch[0].keys(): | |
if key == "target_len": | |
packed_batch_features["target_len"] = torch.LongTensor( | |
[b["target_len"] for b in batch] | |
) | |
masks = [ | |
torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch | |
] | |
packed_batch_features["mask"] = pad_sequence( | |
masks, batch_first=True, padding_value=0 | |
) | |
elif key == "text_len": | |
packed_batch_features["text_len"] = torch.LongTensor( | |
[b["text_len"] for b in batch] | |
) | |
masks = [ | |
torch.ones((b["text_len"], 1), dtype=torch.long) for b in batch | |
] | |
packed_batch_features["text_mask"] = pad_sequence( | |
masks, batch_first=True, padding_value=0 | |
) | |
elif key == "spk_id": | |
packed_batch_features["spk_id"] = torch.LongTensor( | |
[b["spk_id"] for b in batch] | |
) | |
elif key == "uid": | |
packed_batch_features[key] = [b["uid"] for b in batch] | |
elif key == "audio_len": | |
packed_batch_features["audio_len"] = torch.LongTensor( | |
[b["audio_len"] for b in batch] | |
) | |
else: | |
values = [torch.from_numpy(b[key]) for b in batch] | |
packed_batch_features[key] = pad_sequence( | |
values, batch_first=True, padding_value=0 | |
) | |
return packed_batch_features | |
class JetsTestDataset(BaseTestDataset): | |
def __init__(self, args, cfg, infer_type=None): | |
datasets = cfg.dataset | |
cfg = cfg.preprocess | |
is_bigdata = False | |
assert len(datasets) >= 1 | |
if len(datasets) > 1: | |
datasets.sort() | |
bigdata_version = "_".join(datasets) | |
processed_data_dir = os.path.join(cfg.processed_dir, bigdata_version) | |
is_bigdata = True | |
else: | |
processed_data_dir = os.path.join(cfg.processed_dir, args.dataset) | |
if args.test_list_file: | |
self.metafile_path = args.test_list_file | |
self.metadata = self.get_metadata() | |
else: | |
assert args.testing_set | |
source_metafile_path = os.path.join( | |
cfg.processed_dir, | |
args.dataset, | |
"{}.json".format(args.testing_set), | |
) | |
with open(source_metafile_path, "r") as f: | |
self.metadata = json.load(f) | |
self.cfg = cfg | |
self.datasets = datasets | |
self.data_root = processed_data_dir | |
self.is_bigdata = is_bigdata | |
self.source_dataset = args.dataset | |
######### Load source acoustic features ######### | |
if cfg.use_spkid: | |
spk2id_path = os.path.join(self.data_root, cfg.spk2id) | |
utt2sp_path = os.path.join(self.data_root, cfg.utt2spk) | |
self.spk2id, self.utt2spk = get_spk_map(spk2id_path, utt2sp_path, datasets) | |
# utt2lab | |
self.utt2lab_path = {} | |
for utt_info in self.metadata: | |
dataset = utt_info["Dataset"] | |
uid = utt_info["Uid"] | |
utt = "{}_{}".format(dataset, uid) | |
self.utt2lab_path[utt] = os.path.join( | |
cfg.processed_dir, | |
dataset, | |
cfg.lab_dir, | |
uid + ".txt", | |
) | |
self.speaker_map = {} | |
if os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")): | |
with open( | |
os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")) | |
) as f: | |
self.speaker_map = json.load(f) | |
def __getitem__(self, index): | |
single_feature = {} | |
utt_info = self.metadata[index] | |
dataset = utt_info["Dataset"] | |
uid = utt_info["Uid"] | |
utt = "{}_{}".format(dataset, uid) | |
# text | |
f = open(self.utt2lab_path[utt], "r") | |
phones = f.readlines()[0].strip() | |
f.close() | |
phones_ids = np.array(text_to_sequence(phones, self.cfg.text_cleaners)) | |
text_len = len(phones_ids) | |
# speaker | |
if len(self.speaker_map) > 0: | |
speaker_id = self.speaker_map[utt_info["Singer"]] | |
else: | |
speaker_id = 0 | |
single_feature.update( | |
{ | |
"texts": phones_ids, | |
"spk_id": speaker_id, | |
"text_len": text_len, | |
} | |
) | |
return single_feature | |
def __len__(self): | |
return len(self.metadata) | |
def get_metadata(self): | |
with open(self.metafile_path, "r", encoding="utf-8") as f: | |
metadata = json.load(f) | |
return metadata | |
class JetsTestCollator(BaseTestCollator): | |
"""Zero-pads model inputs and targets based on number of frames per step""" | |
def __init__(self, cfg): | |
self.cfg = cfg | |
def __call__(self, batch): | |
packed_batch_features = dict() | |
# mel: [b, T, n_mels] | |
# frame_pitch, frame_energy: [1, T] | |
# target_len: [1] | |
# spk_id: [b, 1] | |
# mask: [b, T, 1] | |
for key in batch[0].keys(): | |
if key == "target_len": | |
packed_batch_features["target_len"] = torch.LongTensor( | |
[b["target_len"] for b in batch] | |
) | |
masks = [ | |
torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch | |
] | |
packed_batch_features["mask"] = pad_sequence( | |
masks, batch_first=True, padding_value=0 | |
) | |
elif key == "text_len": | |
packed_batch_features["text_len"] = torch.LongTensor( | |
[b["text_len"] for b in batch] | |
) | |
masks = [ | |
torch.ones((b["text_len"], 1), dtype=torch.long) for b in batch | |
] | |
packed_batch_features["text_mask"] = pad_sequence( | |
masks, batch_first=True, padding_value=0 | |
) | |
elif key == "spk_id": | |
packed_batch_features["spk_id"] = torch.LongTensor( | |
[b["spk_id"] for b in batch] | |
) | |
else: | |
values = [torch.from_numpy(b[key]) for b in batch] | |
packed_batch_features[key] = pad_sequence( | |
values, batch_first=True, padding_value=0 | |
) | |
return packed_batch_features | |