diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..3a1b305b6c348347531535f1bc695882201c6c3c --- /dev/null +++ b/app.py @@ -0,0 +1,18 @@ +import os +import time +import numpy as np +import torch +from tqdm import tqdm +import torch.nn as nn +from collections import OrderedDict +import json + +from models.tta.autoencoder.autoencoder import AutoencoderKL +from models.tta.ldm.inference_utils.vocoder import Generator +from models.tta.ldm.audioldm import AudioLDM +from transformers import T5EncoderModel, AutoTokenizer +from diffusers import PNDMScheduler + +import matplotlib.pyplot as plt +from scipy.io.wavfile import write + diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/base/__init__.py b/models/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe0221047a62e0b9b3ddd112c79a700c48834fd1 --- /dev/null +++ b/models/base/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .new_trainer import BaseTrainer +from .new_inference import BaseInference diff --git a/models/base/base_dataset.py b/models/base/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3486edfd22c0562be8460179fa8d13848da4192b --- /dev/null +++ b/models/base/base_dataset.py @@ -0,0 +1,344 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np +import torch.utils.data +from torch.nn.utils.rnn import pad_sequence +from utils.data_utils import * +from processors.acoustic_extractor import cal_normalized_mel +from text import text_to_sequence +from text.text_token_collation import phoneIDCollation + + +class BaseDataset(torch.utils.data.Dataset): + def __init__(self, cfg, dataset, is_valid=False): + """ + Args: + cfg: config + dataset: dataset name + is_valid: whether to use train or valid dataset + """ + + assert isinstance(dataset, str) + + # self.data_root = processed_data_dir + self.cfg = cfg + + processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset) + meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file + self.metafile_path = os.path.join(processed_data_dir, meta_file) + self.metadata = self.get_metadata() + + """ + load spk2id and utt2spk from json file + spk2id: {spk1: 0, spk2: 1, ...} + utt2spk: {dataset_uid: spk1, ...} + """ + if cfg.preprocess.use_spkid: + spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id) + with open(spk2id_path, "r") as f: + self.spk2id = json.load(f) + + utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk) + self.utt2spk = dict() + with open(utt2spk_path, "r") as f: + for line in f.readlines(): + utt, spk = line.strip().split("\t") + self.utt2spk[utt] = spk + + if cfg.preprocess.use_uv: + self.utt2uv_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + self.utt2uv_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.uv_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_frame_pitch: + self.utt2frame_pitch_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2frame_pitch_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.pitch_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_frame_energy: + self.utt2frame_energy_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2frame_energy_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.energy_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_mel: + self.utt2mel_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2mel_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.mel_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_linear: + self.utt2linear_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2linear_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.linear_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_audio: + self.utt2audio_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2audio_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.audio_dir, + uid + ".npy", + ) + elif cfg.preprocess.use_label: + self.utt2label_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2label_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.label_dir, + uid + ".npy", + ) + elif cfg.preprocess.use_one_hot: + self.utt2one_hot_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2one_hot_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.one_hot_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_text or cfg.preprocess.use_phone: + self.utt2seq = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + if cfg.preprocess.use_text: + text = utt_info["Text"] + sequence = text_to_sequence(text, cfg.preprocess.text_cleaners) + elif cfg.preprocess.use_phone: + # load phoneme squence from phone file + phone_path = os.path.join( + processed_data_dir, cfg.preprocess.phone_dir, uid + ".phone" + ) + with open(phone_path, "r") as fin: + phones = fin.readlines() + assert len(phones) == 1 + phones = phones[0].strip() + phones_seq = phones.split(" ") + + phon_id_collator = phoneIDCollation(cfg, dataset=dataset) + sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq) + + self.utt2seq[utt] = sequence + + def get_metadata(self): + with open(self.metafile_path, "r", encoding="utf-8") as f: + metadata = json.load(f) + + return metadata + + def get_dataset_name(self): + return self.metadata[0]["Dataset"] + + def __getitem__(self, index): + utt_info = self.metadata[index] + + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + single_feature = dict() + + if self.cfg.preprocess.use_spkid: + single_feature["spk_id"] = np.array( + [self.spk2id[self.utt2spk[utt]]], dtype=np.int32 + ) + + if self.cfg.preprocess.use_mel: + mel = np.load(self.utt2mel_path[utt]) + assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T] + if self.cfg.preprocess.use_min_max_norm_mel: + # do mel norm + mel = cal_normalized_mel(mel, utt_info["Dataset"], self.cfg.preprocess) + + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = mel.shape[1] + single_feature["mel"] = mel.T # [T, n_mels] + + if self.cfg.preprocess.use_linear: + linear = np.load(self.utt2linear_path[utt]) + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = linear.shape[1] + single_feature["linear"] = linear.T # [T, n_linear] + + if self.cfg.preprocess.use_frame_pitch: + frame_pitch_path = self.utt2frame_pitch_path[utt] + frame_pitch = np.load(frame_pitch_path) + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = len(frame_pitch) + aligned_frame_pitch = align_length( + frame_pitch, single_feature["target_len"] + ) + single_feature["frame_pitch"] = aligned_frame_pitch + + if self.cfg.preprocess.use_uv: + frame_uv_path = self.utt2uv_path[utt] + frame_uv = np.load(frame_uv_path) + aligned_frame_uv = align_length(frame_uv, single_feature["target_len"]) + aligned_frame_uv = [ + 0 if frame_uv else 1 for frame_uv in aligned_frame_uv + ] + aligned_frame_uv = np.array(aligned_frame_uv) + single_feature["frame_uv"] = aligned_frame_uv + + if self.cfg.preprocess.use_frame_energy: + frame_energy_path = self.utt2frame_energy_path[utt] + frame_energy = np.load(frame_energy_path) + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = len(frame_energy) + aligned_frame_energy = align_length( + frame_energy, single_feature["target_len"] + ) + single_feature["frame_energy"] = aligned_frame_energy + + if self.cfg.preprocess.use_audio: + audio = np.load(self.utt2audio_path[utt]) + single_feature["audio"] = audio + single_feature["audio_len"] = audio.shape[0] + + if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text: + single_feature["phone_seq"] = np.array(self.utt2seq[utt]) + single_feature["phone_len"] = len(self.utt2seq[utt]) + + return single_feature + + def __len__(self): + return len(self.metadata) + + +class BaseCollator(object): + """Zero-pads model inputs and targets based on number of frames per step""" + + def __init__(self, cfg): + self.cfg = cfg + + def __call__(self, batch): + packed_batch_features = dict() + + # mel: [b, T, n_mels] + # frame_pitch, frame_energy: [1, T] + # target_len: [1] + # spk_id: [b, 1] + # mask: [b, T, 1] + + for key in batch[0].keys(): + if key == "target_len": + packed_batch_features["target_len"] = torch.LongTensor( + [b["target_len"] for b in batch] + ) + masks = [ + torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch + ] + packed_batch_features["mask"] = pad_sequence( + masks, batch_first=True, padding_value=0 + ) + elif key == "phone_len": + packed_batch_features["phone_len"] = torch.LongTensor( + [b["phone_len"] for b in batch] + ) + masks = [ + torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch + ] + packed_batch_features["phn_mask"] = pad_sequence( + masks, batch_first=True, padding_value=0 + ) + elif key == "audio_len": + packed_batch_features["audio_len"] = torch.LongTensor( + [b["audio_len"] for b in batch] + ) + masks = [ + torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch + ] + else: + values = [torch.from_numpy(b[key]) for b in batch] + packed_batch_features[key] = pad_sequence( + values, batch_first=True, padding_value=0 + ) + return packed_batch_features + + +class BaseTestDataset(torch.utils.data.Dataset): + def __init__(self, cfg, args): + raise NotImplementedError + + def get_metadata(self): + raise NotImplementedError + + def __getitem__(self, index): + raise NotImplementedError + + def __len__(self): + return len(self.metadata) + + +class BaseTestCollator(object): + """Zero-pads model inputs and targets based on number of frames per step""" + + def __init__(self, cfg): + raise NotImplementedError + + def __call__(self, batch): + raise NotImplementedError diff --git a/models/base/base_inference.py b/models/base/base_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..2713f19a0d61f06bca1f01de5ccd8a3b4d2cc02f --- /dev/null +++ b/models/base/base_inference.py @@ -0,0 +1,220 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import re +import time +from pathlib import Path + +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from models.vocoders.vocoder_inference import synthesis +from torch.utils.data import DataLoader +from utils.util import set_all_random_seed +from utils.util import load_config + + +def parse_vocoder(vocoder_dir): + r"""Parse vocoder config""" + vocoder_dir = os.path.abspath(vocoder_dir) + ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")] + ckpt_list.sort(key=lambda x: int(x.stem), reverse=True) + ckpt_path = str(ckpt_list[0]) + vocoder_cfg = load_config(os.path.join(vocoder_dir, "args.json"), lowercase=True) + vocoder_cfg.model.bigvgan = vocoder_cfg.vocoder + return vocoder_cfg, ckpt_path + + +class BaseInference(object): + def __init__(self, cfg, args): + self.cfg = cfg + self.args = args + self.model_type = cfg.model_type + self.avg_rtf = list() + set_all_random_seed(10086) + os.makedirs(args.output_dir, exist_ok=True) + + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + torch.set_num_threads(10) # inference on 1 core cpu. + + # Load acoustic model + self.model = self.create_model().to(self.device) + state_dict = self.load_state_dict() + self.load_model(state_dict) + self.model.eval() + + # Load vocoder model if necessary + if self.args.checkpoint_dir_vocoder is not None: + self.get_vocoder_info() + + def create_model(self): + raise NotImplementedError + + def load_state_dict(self): + self.checkpoint_file = self.args.checkpoint_file + if self.checkpoint_file is None: + assert self.args.checkpoint_dir is not None + checkpoint_path = os.path.join(self.args.checkpoint_dir, "checkpoint") + checkpoint_filename = open(checkpoint_path).readlines()[-1].strip() + self.checkpoint_file = os.path.join( + self.args.checkpoint_dir, checkpoint_filename + ) + + self.checkpoint_dir = os.path.split(self.checkpoint_file)[0] + + print("Restore acoustic model from {}".format(self.checkpoint_file)) + raw_state_dict = torch.load(self.checkpoint_file, map_location=self.device) + self.am_restore_step = re.findall(r"step-(.+?)_loss", self.checkpoint_file)[0] + + return raw_state_dict + + def load_model(self, model): + raise NotImplementedError + + def get_vocoder_info(self): + self.checkpoint_dir_vocoder = self.args.checkpoint_dir_vocoder + self.vocoder_cfg = os.path.join( + os.path.dirname(self.checkpoint_dir_vocoder), "args.json" + ) + self.cfg.vocoder = load_config(self.vocoder_cfg, lowercase=True) + self.vocoder_tag = self.checkpoint_dir_vocoder.split("/")[-2].split(":")[-1] + self.vocoder_steps = self.checkpoint_dir_vocoder.split("/")[-1].split(".")[0] + + def build_test_utt_data(self): + raise NotImplementedError + + def build_testdata_loader(self, args, target_speaker=None): + datasets, collate = self.build_test_dataset() + self.test_dataset = datasets(self.cfg, args, target_speaker) + self.test_collate = collate(self.cfg) + self.test_batch_size = min( + self.cfg.train.batch_size, len(self.test_dataset.metadata) + ) + test_loader = DataLoader( + self.test_dataset, + collate_fn=self.test_collate, + num_workers=self.args.num_workers, + batch_size=self.test_batch_size, + shuffle=False, + ) + return test_loader + + def inference_each_batch(self, batch_data): + raise NotImplementedError + + def inference_for_batches(self, args, target_speaker=None): + ###### Construct test_batch ###### + loader = self.build_testdata_loader(args, target_speaker) + + n_batch = len(loader) + now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) + print( + "Model eval time: {}, batch_size = {}, n_batch = {}".format( + now, self.test_batch_size, n_batch + ) + ) + self.model.eval() + + ###### Inference for each batch ###### + pred_res = [] + with torch.no_grad(): + for i, batch_data in enumerate(loader if n_batch == 1 else tqdm(loader)): + # Put the data to device + for k, v in batch_data.items(): + batch_data[k] = batch_data[k].to(self.device) + + y_pred, stats = self.inference_each_batch(batch_data) + + pred_res += y_pred + + return pred_res + + def inference(self, feature): + raise NotImplementedError + + def synthesis_by_vocoder(self, pred): + audios_pred = synthesis( + self.vocoder_cfg, + self.checkpoint_dir_vocoder, + len(pred), + pred, + ) + return audios_pred + + def __call__(self, utt): + feature = self.build_test_utt_data(utt) + start_time = time.time() + with torch.no_grad(): + outputs = self.inference(feature)[0] + time_used = time.time() - start_time + rtf = time_used / ( + outputs.shape[1] + * self.cfg.preprocess.hop_size + / self.cfg.preprocess.sample_rate + ) + print("Time used: {:.3f}, RTF: {:.4f}".format(time_used, rtf)) + self.avg_rtf.append(rtf) + audios = outputs.cpu().squeeze().numpy().reshape(-1, 1) + return audios + + +def base_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", default="config.json", help="json files for configurations." + ) + parser.add_argument("--use_ddp_inference", default=False) + parser.add_argument("--n_workers", default=1, type=int) + parser.add_argument("--local_rank", default=-1, type=int) + parser.add_argument( + "--batch_size", default=1, type=int, help="Batch size for inference" + ) + parser.add_argument( + "--num_workers", + default=1, + type=int, + help="Worker number for inference dataloader", + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + default=None, + help="Checkpoint dir including model file and configuration", + ) + parser.add_argument( + "--checkpoint_file", help="checkpoint file", type=str, default=None + ) + parser.add_argument( + "--test_list", help="test utterance list for testing", type=str, default=None + ) + parser.add_argument( + "--checkpoint_dir_vocoder", + help="Vocoder's checkpoint dir including model file and configuration", + type=str, + default=None, + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="Output dir for saving generated results", + ) + return parser + + +if __name__ == "__main__": + parser = base_parser() + args = parser.parse_args() + cfg = load_config(args.config) + + # Build inference + inference = BaseInference(cfg, args) + inference() diff --git a/models/base/base_sampler.py b/models/base/base_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..149d1437eb1d3a00ca8c9895b150b39b2a3635fa --- /dev/null +++ b/models/base/base_sampler.py @@ -0,0 +1,136 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import random + +from torch.utils.data import ConcatDataset, Dataset +from torch.utils.data.sampler import ( + BatchSampler, + RandomSampler, + Sampler, + SequentialSampler, +) + + +class ScheduledSampler(Sampler): + """A sampler that samples data from a given concat-dataset. + + Args: + concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets + batch_size (int): batch size + holistic_shuffle (bool): whether to shuffle the whole dataset or not + logger (logging.Logger): logger to print warning message + + Usage: + For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True: + >>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]]))) + [3, 4, 5, 0, 1, 2, 6, 7, 8] + """ + + def __init__( + self, + concat_dataset, + batch_size, + holistic_shuffle, + logger=None, + loader_type="train", + ): + if not isinstance(concat_dataset, ConcatDataset): + raise ValueError( + "concat_dataset must be an instance of ConcatDataset, but got {}".format( + type(concat_dataset) + ) + ) + if not isinstance(batch_size, int): + raise ValueError( + "batch_size must be an integer, but got {}".format(type(batch_size)) + ) + if not isinstance(holistic_shuffle, bool): + raise ValueError( + "holistic_shuffle must be a boolean, but got {}".format( + type(holistic_shuffle) + ) + ) + + self.concat_dataset = concat_dataset + self.batch_size = batch_size + self.holistic_shuffle = holistic_shuffle + + affected_dataset_name = [] + affected_dataset_len = [] + for dataset in concat_dataset.datasets: + dataset_len = len(dataset) + dataset_name = dataset.get_dataset_name() + if dataset_len < batch_size: + affected_dataset_name.append(dataset_name) + affected_dataset_len.append(dataset_len) + + self.type = loader_type + for dataset_name, dataset_len in zip( + affected_dataset_name, affected_dataset_len + ): + if not loader_type == "valid": + logger.warning( + "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format( + loader_type, dataset_name, dataset_len, batch_size + ) + ) + + def __len__(self): + # the number of batches with drop last + num_of_batches = sum( + [ + math.floor(len(dataset) / self.batch_size) + for dataset in self.concat_dataset.datasets + ] + ) + # if samples are not enough for one batch, we don't drop last + if self.type == "valid" and num_of_batches < 1: + return len(self.concat_dataset) + return num_of_batches * self.batch_size + + def __iter__(self): + iters = [] + for dataset in self.concat_dataset.datasets: + iters.append( + SequentialSampler(dataset).__iter__() + if not self.holistic_shuffle + else RandomSampler(dataset).__iter__() + ) + # e.g. [0, 200, 400] + init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1] + output_batches = [] + for dataset_idx in range(len(self.concat_dataset.datasets)): + cur_batch = [] + for idx in iters[dataset_idx]: + cur_batch.append(idx + init_indices[dataset_idx]) + if len(cur_batch) == self.batch_size: + output_batches.append(cur_batch) + cur_batch = [] + # if loader_type is valid, we don't need to drop last + if self.type == "valid" and len(cur_batch) > 0: + output_batches.append(cur_batch) + + # force drop last in training + random.shuffle(output_batches) + output_indices = [item for sublist in output_batches for item in sublist] + return iter(output_indices) + + +def build_samplers(concat_dataset: Dataset, cfg, logger, loader_type): + sampler = ScheduledSampler( + concat_dataset, + cfg.train.batch_size, + cfg.train.sampler.holistic_shuffle, + logger, + loader_type, + ) + batch_sampler = BatchSampler( + sampler, + cfg.train.batch_size, + cfg.train.sampler.drop_last if not loader_type == "valid" else False, + ) + return sampler, batch_sampler diff --git a/models/base/base_trainer.py b/models/base/base_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..8782216dc13ce5d9de05ae790faeb82cf7cfd501 --- /dev/null +++ b/models/base/base_trainer.py @@ -0,0 +1,348 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import collections +import json +import os +import sys +import time + +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import ConcatDataset, DataLoader +from torch.utils.tensorboard import SummaryWriter + +from models.base.base_sampler import BatchSampler +from utils.util import ( + Logger, + remove_older_ckpt, + save_config, + set_all_random_seed, + ValueWindow, +) + + +class BaseTrainer(object): + def __init__(self, args, cfg): + self.args = args + self.log_dir = args.log_dir + self.cfg = cfg + + self.checkpoint_dir = os.path.join(args.log_dir, "checkpoints") + os.makedirs(self.checkpoint_dir, exist_ok=True) + if not cfg.train.ddp or args.local_rank == 0: + self.sw = SummaryWriter(os.path.join(args.log_dir, "events")) + self.logger = self.build_logger() + self.time_window = ValueWindow(50) + + self.step = 0 + self.epoch = -1 + self.max_epochs = self.cfg.train.epochs + self.max_steps = self.cfg.train.max_steps + + # set random seed & init distributed training + set_all_random_seed(self.cfg.train.random_seed) + if cfg.train.ddp: + dist.init_process_group(backend="nccl") + + if cfg.model_type not in ["AutoencoderKL", "AudioLDM"]: + self.singers = self.build_singers_lut() + + # setup data_loader + self.data_loader = self.build_data_loader() + + # setup model & enable distributed training + self.model = self.build_model() + print(self.model) + + if isinstance(self.model, dict): + for key, value in self.model.items(): + value.cuda(self.args.local_rank) + if key == "PQMF": + continue + if cfg.train.ddp: + self.model[key] = DistributedDataParallel( + value, device_ids=[self.args.local_rank] + ) + else: + self.model.cuda(self.args.local_rank) + if cfg.train.ddp: + self.model = DistributedDataParallel( + self.model, device_ids=[self.args.local_rank] + ) + + # create criterion + self.criterion = self.build_criterion() + if isinstance(self.criterion, dict): + for key, value in self.criterion.items(): + self.criterion[key].cuda(args.local_rank) + else: + self.criterion.cuda(self.args.local_rank) + + # optimizer + self.optimizer = self.build_optimizer() + self.scheduler = self.build_scheduler() + + # save config file + self.config_save_path = os.path.join(self.checkpoint_dir, "args.json") + + def build_logger(self): + log_file = os.path.join(self.checkpoint_dir, "train.log") + logger = Logger(log_file, level=self.args.log_level).logger + + return logger + + def build_dataset(self): + raise NotImplementedError + + def build_data_loader(self): + Dataset, Collator = self.build_dataset() + # build dataset instance for each dataset and combine them by ConcatDataset + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=False) + datasets_list.append(subdataset) + train_dataset = ConcatDataset(datasets_list) + + train_collate = Collator(self.cfg) + # TODO: multi-GPU training + if self.cfg.train.ddp: + raise NotImplementedError("DDP is not supported yet.") + + # sampler will provide indices to batch_sampler, which will perform batching and yield batch indices + batch_sampler = BatchSampler( + cfg=self.cfg, concat_dataset=train_dataset, dataset_list=datasets_list + ) + + # use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size) + train_loader = DataLoader( + train_dataset, + collate_fn=train_collate, + num_workers=self.args.num_workers, + batch_sampler=batch_sampler, + pin_memory=False, + ) + if not self.cfg.train.ddp or self.args.local_rank == 0: + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=True) + datasets_list.append(subdataset) + valid_dataset = ConcatDataset(datasets_list) + valid_collate = Collator(self.cfg) + batch_sampler = BatchSampler( + cfg=self.cfg, concat_dataset=valid_dataset, dataset_list=datasets_list + ) + valid_loader = DataLoader( + valid_dataset, + collate_fn=valid_collate, + num_workers=1, + batch_sampler=batch_sampler, + ) + else: + raise NotImplementedError("DDP is not supported yet.") + # valid_loader = None + data_loader = {"train": train_loader, "valid": valid_loader} + return data_loader + + def build_singers_lut(self): + # combine singers + if not os.path.exists(os.path.join(self.log_dir, self.cfg.preprocess.spk2id)): + singers = collections.OrderedDict() + else: + with open( + os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "r" + ) as singer_file: + singers = json.load(singer_file) + singer_count = len(singers) + for dataset in self.cfg.dataset: + singer_lut_path = os.path.join( + self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id + ) + with open(singer_lut_path, "r") as singer_lut_path: + singer_lut = json.load(singer_lut_path) + for singer in singer_lut.keys(): + if singer not in singers: + singers[singer] = singer_count + singer_count += 1 + with open( + os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "w" + ) as singer_file: + json.dump(singers, singer_file, indent=4, ensure_ascii=False) + print( + "singers have been dumped to {}".format( + os.path.join(self.log_dir, self.cfg.preprocess.spk2id) + ) + ) + return singers + + def build_model(self): + raise NotImplementedError() + + def build_optimizer(self): + raise NotImplementedError + + def build_scheduler(self): + raise NotImplementedError() + + def build_criterion(self): + raise NotImplementedError + + def get_state_dict(self): + raise NotImplementedError + + def save_config_file(self): + save_config(self.config_save_path, self.cfg) + + # TODO, save without module. + def save_checkpoint(self, state_dict, saved_model_path): + torch.save(state_dict, saved_model_path) + + def load_checkpoint(self): + checkpoint_path = os.path.join(self.checkpoint_dir, "checkpoint") + assert os.path.exists(checkpoint_path) + checkpoint_filename = open(checkpoint_path).readlines()[-1].strip() + model_path = os.path.join(self.checkpoint_dir, checkpoint_filename) + assert os.path.exists(model_path) + if not self.cfg.train.ddp or self.args.local_rank == 0: + self.logger.info(f"Re(store) from {model_path}") + checkpoint = torch.load(model_path, map_location="cpu") + return checkpoint + + def load_model(self, checkpoint): + raise NotImplementedError + + def restore(self): + checkpoint = self.load_checkpoint() + self.load_model(checkpoint) + + def train_step(self, data): + raise NotImplementedError( + f"Need to implement function {sys._getframe().f_code.co_name} in " + f"your sub-class of {self.__class__.__name__}. " + ) + + @torch.no_grad() + def eval_step(self): + raise NotImplementedError( + f"Need to implement function {sys._getframe().f_code.co_name} in " + f"your sub-class of {self.__class__.__name__}. " + ) + + def write_summary(self, losses, stats): + raise NotImplementedError( + f"Need to implement function {sys._getframe().f_code.co_name} in " + f"your sub-class of {self.__class__.__name__}. " + ) + + def write_valid_summary(self, losses, stats): + raise NotImplementedError( + f"Need to implement function {sys._getframe().f_code.co_name} in " + f"your sub-class of {self.__class__.__name__}. " + ) + + def echo_log(self, losses, mode="Training"): + message = [ + "{} - Epoch {} Step {}: [{:.3f} s/step]".format( + mode, self.epoch + 1, self.step, self.time_window.average + ) + ] + + for key in sorted(losses.keys()): + if isinstance(losses[key], dict): + for k, v in losses[key].items(): + message.append( + str(k).split("/")[-1] + "=" + str(round(float(v), 5)) + ) + else: + message.append( + str(key).split("/")[-1] + "=" + str(round(float(losses[key]), 5)) + ) + self.logger.info(", ".join(message)) + + def eval_epoch(self): + self.logger.info("Validation...") + valid_losses = {} + for i, batch_data in enumerate(self.data_loader["valid"]): + for k, v in batch_data.items(): + if isinstance(v, torch.Tensor): + batch_data[k] = v.cuda() + valid_loss, valid_stats, total_valid_loss = self.eval_step(batch_data, i) + for key in valid_loss: + if key not in valid_losses: + valid_losses[key] = 0 + valid_losses[key] += valid_loss[key] + + # Add mel and audio to the Tensorboard + # Average loss + for key in valid_losses: + valid_losses[key] /= i + 1 + self.echo_log(valid_losses, "Valid") + return valid_losses, valid_stats + + def train_epoch(self): + for i, batch_data in enumerate(self.data_loader["train"]): + start_time = time.time() + # Put the data to cuda device + for k, v in batch_data.items(): + if isinstance(v, torch.Tensor): + batch_data[k] = v.cuda(self.args.local_rank) + + # Training step + train_losses, train_stats, total_loss = self.train_step(batch_data) + self.time_window.append(time.time() - start_time) + + if self.args.local_rank == 0 or not self.cfg.train.ddp: + if self.step % self.args.stdout_interval == 0: + self.echo_log(train_losses, "Training") + + if self.step % self.cfg.train.save_summary_steps == 0: + self.logger.info(f"Save summary as step {self.step}") + self.write_summary(train_losses, train_stats) + + if ( + self.step % self.cfg.train.save_checkpoints_steps == 0 + and self.step != 0 + ): + saved_model_name = "step-{:07d}_loss-{:.4f}.pt".format( + self.step, total_loss + ) + saved_model_path = os.path.join( + self.checkpoint_dir, saved_model_name + ) + saved_state_dict = self.get_state_dict() + self.save_checkpoint(saved_state_dict, saved_model_path) + self.save_config_file() + # keep max n models + remove_older_ckpt( + saved_model_name, + self.checkpoint_dir, + max_to_keep=self.cfg.train.keep_checkpoint_max, + ) + + if self.step != 0 and self.step % self.cfg.train.valid_interval == 0: + if isinstance(self.model, dict): + for key in self.model.keys(): + self.model[key].eval() + else: + self.model.eval() + # Evaluate one epoch and get average loss + valid_losses, valid_stats = self.eval_epoch() + if isinstance(self.model, dict): + for key in self.model.keys(): + self.model[key].train() + else: + self.model.train() + # Write validation losses to summary. + self.write_valid_summary(valid_losses, valid_stats) + self.step += 1 + + def train(self): + for epoch in range(max(0, self.epoch), self.max_epochs): + self.train_epoch() + self.epoch += 1 + if self.step > self.max_steps: + self.logger.info("Training finished!") + break diff --git a/models/base/new_dataset.py b/models/base/new_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2201bb4132ab86d1110092d7ab9e509296367a22 --- /dev/null +++ b/models/base/new_dataset.py @@ -0,0 +1,50 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +from abc import abstractmethod +from pathlib import Path + +import json5 +import torch +import yaml + + +# TODO: for training and validating +class BaseDataset(torch.utils.data.Dataset): + r"""Base dataset for training and validating.""" + + def __init__(self, args, cfg, is_valid=False): + pass + + +class BaseTestDataset(torch.utils.data.Dataset): + r"""Test dataset for inference.""" + + def __init__(self, args=None, cfg=None, infer_type="from_dataset"): + assert infer_type in ["from_dataset", "from_file"] + + self.args = args + self.cfg = cfg + self.infer_type = infer_type + + @abstractmethod + def __getitem__(self, index): + pass + + def __len__(self): + return len(self.metadata) + + def get_metadata(self): + path = Path(self.args.source) + if path.suffix == ".json" or path.suffix == ".jsonc": + metadata = json5.load(open(self.args.source, "r")) + elif path.suffix == ".yaml" or path.suffix == ".yml": + metadata = yaml.full_load(open(self.args.source, "r")) + else: + raise ValueError(f"Unsupported file type: {path.suffix}") + + return metadata diff --git a/models/base/new_inference.py b/models/base/new_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..4813fca4aba192fb8737dd74f37f6d430e1909a4 --- /dev/null +++ b/models/base/new_inference.py @@ -0,0 +1,249 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import random +import re +import time +from abc import abstractmethod +from pathlib import Path + +import accelerate +import json5 +import numpy as np +import torch +from accelerate.logging import get_logger +from torch.utils.data import DataLoader + +from models.vocoders.vocoder_inference import synthesis +from utils.io import save_audio +from utils.util import load_config +from utils.audio_slicer import is_silence + +EPS = 1.0e-12 + + +class BaseInference(object): + def __init__(self, args=None, cfg=None, infer_type="from_dataset"): + super().__init__() + + start = time.monotonic_ns() + self.args = args + self.cfg = cfg + + assert infer_type in ["from_dataset", "from_file"] + self.infer_type = infer_type + + # init with accelerate + self.accelerator = accelerate.Accelerator() + self.accelerator.wait_for_everyone() + + # Use accelerate logger for distributed inference + with self.accelerator.main_process_first(): + self.logger = get_logger("inference", log_level=args.log_level) + + # Log some info + self.logger.info("=" * 56) + self.logger.info("||\t\t" + "New inference process started." + "\t\t||") + self.logger.info("=" * 56) + self.logger.info("\n") + self.logger.debug(f"Using {args.log_level.upper()} logging level.") + + self.acoustics_dir = args.acoustics_dir + self.logger.debug(f"Acoustic dir: {args.acoustics_dir}") + self.vocoder_dir = args.vocoder_dir + self.logger.debug(f"Vocoder dir: {args.vocoder_dir}") + # should be in svc inferencer + # self.target_singer = args.target_singer + # self.logger.info(f"Target singers: {args.target_singer}") + # self.trans_key = args.trans_key + # self.logger.info(f"Trans key: {args.trans_key}") + + os.makedirs(args.output_dir, exist_ok=True) + + # set random seed + with self.accelerator.main_process_first(): + start = time.monotonic_ns() + self._set_random_seed(self.cfg.train.random_seed) + end = time.monotonic_ns() + self.logger.debug( + f"Setting random seed done in {(end - start) / 1e6:.2f}ms" + ) + self.logger.debug(f"Random seed: {self.cfg.train.random_seed}") + + # setup data_loader + with self.accelerator.main_process_first(): + self.logger.info("Building dataset...") + start = time.monotonic_ns() + self.test_dataloader = self._build_dataloader() + end = time.monotonic_ns() + self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms") + + # setup model + with self.accelerator.main_process_first(): + self.logger.info("Building model...") + start = time.monotonic_ns() + self.model = self._build_model() + end = time.monotonic_ns() + # self.logger.debug(self.model) + self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms") + + # init with accelerate + self.logger.info("Initializing accelerate...") + start = time.monotonic_ns() + self.accelerator = accelerate.Accelerator() + self.model = self.accelerator.prepare(self.model) + end = time.monotonic_ns() + self.accelerator.wait_for_everyone() + self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms") + + with self.accelerator.main_process_first(): + self.logger.info("Loading checkpoint...") + start = time.monotonic_ns() + # TODO: Also, suppose only use latest one yet + self.__load_model(os.path.join(args.acoustics_dir, "checkpoint")) + end = time.monotonic_ns() + self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms") + + self.model.eval() + self.accelerator.wait_for_everyone() + + ### Abstract methods ### + @abstractmethod + def _build_test_dataset(self): + pass + + @abstractmethod + def _build_model(self): + pass + + @abstractmethod + @torch.inference_mode() + def _inference_each_batch(self, batch_data): + pass + + ### Abstract methods end ### + + @torch.inference_mode() + def inference(self): + for i, batch in enumerate(self.test_dataloader): + y_pred = self._inference_each_batch(batch).cpu() + mel_min, mel_max = self.test_dataset.target_mel_extrema + y_pred = (y_pred + 1.0) / 2.0 * (mel_max - mel_min + EPS) + mel_min + y_ls = y_pred.chunk(self.test_batch_size) + tgt_ls = batch["target_len"].cpu().chunk(self.test_batch_size) + j = 0 + for it, l in zip(y_ls, tgt_ls): + l = l.item() + it = it.squeeze(0)[:l] + uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"] + torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt")) + j += 1 + + vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir) + + res = synthesis( + cfg=vocoder_cfg, + vocoder_weight_file=vocoder_ckpt, + n_samples=None, + pred=[ + torch.load( + os.path.join(self.args.output_dir, "{}.pt".format(i["Uid"])) + ).numpy(force=True) + for i in self.test_dataset.metadata + ], + ) + + output_audio_files = [] + for it, wav in zip(self.test_dataset.metadata, res): + uid = it["Uid"] + file = os.path.join(self.args.output_dir, f"{uid}.wav") + output_audio_files.append(file) + + wav = wav.numpy(force=True) + save_audio( + file, + wav, + self.cfg.preprocess.sample_rate, + add_silence=False, + turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate), + ) + os.remove(os.path.join(self.args.output_dir, f"{uid}.pt")) + + return sorted(output_audio_files) + + # TODO: LEGACY CODE + def _build_dataloader(self): + datasets, collate = self._build_test_dataset() + self.test_dataset = datasets(self.args, self.cfg, self.infer_type) + self.test_collate = collate(self.cfg) + self.test_batch_size = min( + self.cfg.train.batch_size, len(self.test_dataset.metadata) + ) + test_dataloader = DataLoader( + self.test_dataset, + collate_fn=self.test_collate, + num_workers=1, + batch_size=self.test_batch_size, + shuffle=False, + ) + return test_dataloader + + def __load_model(self, checkpoint_dir: str = None, checkpoint_path: str = None): + r"""Load model from checkpoint. If checkpoint_path is None, it will + load the latest checkpoint in checkpoint_dir. If checkpoint_path is not + None, it will load the checkpoint specified by checkpoint_path. **Only use this + method after** ``accelerator.prepare()``. + """ + if checkpoint_path is None: + ls = [] + for i in Path(checkpoint_dir).iterdir(): + if re.match(r"epoch-\d+_step-\d+_loss-[\d.]+", str(i.stem)): + ls.append(i) + ls.sort( + key=lambda x: int(x.stem.split("_")[-3].split("-")[-1]), reverse=True + ) + checkpoint_path = ls[0] + else: + checkpoint_path = Path(checkpoint_path) + self.accelerator.load_state(str(checkpoint_path)) + # set epoch and step + self.epoch = int(checkpoint_path.stem.split("_")[-3].split("-")[-1]) + self.step = int(checkpoint_path.stem.split("_")[-2].split("-")[-1]) + return str(checkpoint_path) + + @staticmethod + def _set_random_seed(seed): + r"""Set random seed for all possible random modules.""" + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + + @staticmethod + def _parse_vocoder(vocoder_dir): + r"""Parse vocoder config""" + vocoder_dir = os.path.abspath(vocoder_dir) + ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")] + ckpt_list.sort(key=lambda x: int(x.stem), reverse=True) + ckpt_path = str(ckpt_list[0]) + vocoder_cfg = load_config( + os.path.join(vocoder_dir, "args.json"), lowercase=True + ) + return vocoder_cfg, ckpt_path + + @staticmethod + def __count_parameters(model): + return sum(p.numel() for p in model.parameters()) + + def __dump_cfg(self, path): + os.makedirs(os.path.dirname(path), exist_ok=True) + json5.dump( + self.cfg, + open(path, "w"), + indent=4, + sort_keys=True, + ensure_ascii=False, + quote_keys=True, + ) diff --git a/models/base/new_trainer.py b/models/base/new_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5d013d2bc2f2e47e5c7646cac5c63cc88c04486b --- /dev/null +++ b/models/base/new_trainer.py @@ -0,0 +1,722 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import random +import shutil +import time +from abc import abstractmethod +from pathlib import Path + +import accelerate +import json5 +import numpy as np +import torch +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration +from torch.utils.data import ConcatDataset, DataLoader +from tqdm import tqdm + +from models.base.base_sampler import build_samplers +from optimizer.optimizers import NoamLR + + +class BaseTrainer(object): + r"""The base trainer for all tasks. Any trainer should inherit from this class.""" + + def __init__(self, args=None, cfg=None): + super().__init__() + + self.args = args + self.cfg = cfg + + cfg.exp_name = args.exp_name + + # init with accelerate + self._init_accelerator() + self.accelerator.wait_for_everyone() + + # Use accelerate logger for distributed training + with self.accelerator.main_process_first(): + self.logger = get_logger(args.exp_name, log_level=args.log_level) + + # Log some info + self.logger.info("=" * 56) + self.logger.info("||\t\t" + "New training process started." + "\t\t||") + self.logger.info("=" * 56) + self.logger.info("\n") + self.logger.debug(f"Using {args.log_level.upper()} logging level.") + self.logger.info(f"Experiment name: {args.exp_name}") + self.logger.info(f"Experiment directory: {self.exp_dir}") + self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint") + if self.accelerator.is_main_process: + os.makedirs(self.checkpoint_dir, exist_ok=True) + self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}") + + # init counts + self.batch_count: int = 0 + self.step: int = 0 + self.epoch: int = 0 + self.max_epoch = ( + self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf") + ) + self.logger.info( + "Max epoch: {}".format( + self.max_epoch if self.max_epoch < float("inf") else "Unlimited" + ) + ) + + # Check values + if self.accelerator.is_main_process: + self.__check_basic_configs() + # Set runtime configs + self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride + self.checkpoints_path = [ + [] for _ in range(len(self.save_checkpoint_stride)) + ] + self.keep_last = [ + i if i > 0 else float("inf") for i in self.cfg.train.keep_last + ] + self.run_eval = self.cfg.train.run_eval + + # set random seed + with self.accelerator.main_process_first(): + start = time.monotonic_ns() + self._set_random_seed(self.cfg.train.random_seed) + end = time.monotonic_ns() + self.logger.debug( + f"Setting random seed done in {(end - start) / 1e6:.2f}ms" + ) + self.logger.debug(f"Random seed: {self.cfg.train.random_seed}") + + # setup data_loader + with self.accelerator.main_process_first(): + self.logger.info("Building dataset...") + start = time.monotonic_ns() + self.train_dataloader, self.valid_dataloader = self._build_dataloader() + end = time.monotonic_ns() + self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms") + + # setup model + with self.accelerator.main_process_first(): + self.logger.info("Building model...") + start = time.monotonic_ns() + self.model = self._build_model() + end = time.monotonic_ns() + self.logger.debug(self.model) + self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms") + self.logger.info( + f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M" + ) + # optimizer & scheduler + with self.accelerator.main_process_first(): + self.logger.info("Building optimizer and scheduler...") + start = time.monotonic_ns() + self.optimizer = self.__build_optimizer() + self.scheduler = self.__build_scheduler() + end = time.monotonic_ns() + self.logger.info( + f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms" + ) + + # accelerate prepare + self.logger.info("Initializing accelerate...") + start = time.monotonic_ns() + ( + self.train_dataloader, + self.valid_dataloader, + self.model, + self.optimizer, + self.scheduler, + ) = self.accelerator.prepare( + self.train_dataloader, + self.valid_dataloader, + self.model, + self.optimizer, + self.scheduler, + ) + end = time.monotonic_ns() + self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms") + + # create criterion + with self.accelerator.main_process_first(): + self.logger.info("Building criterion...") + start = time.monotonic_ns() + self.criterion = self._build_criterion() + end = time.monotonic_ns() + self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms") + + # Resume or Finetune + with self.accelerator.main_process_first(): + if args.resume: + ## Automatically resume according to the current exprimental name + self.logger.info("Resuming from {}...".format(self.checkpoint_dir)) + start = time.monotonic_ns() + ckpt_path = self.__load_model( + checkpoint_dir=self.checkpoint_dir, resume_type=args.resume_type + ) + end = time.monotonic_ns() + self.logger.info( + f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms" + ) + self.checkpoints_path = json.load( + open(os.path.join(ckpt_path, "ckpts.json"), "r") + ) + elif args.resume_from_ckpt_path and args.resume_from_ckpt_path != "": + ## Resume from the given checkpoint path + if not os.path.exists(args.resume_from_ckpt_path): + raise ValueError( + "[Error] The resumed checkpoint path {} don't exist.".format( + args.resume_from_ckpt_path + ) + ) + + self.logger.info( + "Resuming from {}...".format(args.resume_from_ckpt_path) + ) + start = time.monotonic_ns() + ckpt_path = self.__load_model( + checkpoint_path=args.resume_from_ckpt_path, + resume_type=args.resume_type, + ) + end = time.monotonic_ns() + self.logger.info( + f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms" + ) + + # save config file path + self.config_save_path = os.path.join(self.exp_dir, "args.json") + + ### Following are abstract methods that should be implemented in child classes ### + @abstractmethod + def _build_dataset(self): + r"""Build dataset for model training/validating/evaluating.""" + pass + + @staticmethod + @abstractmethod + def _build_criterion(): + r"""Build criterion function for model loss calculation.""" + pass + + @abstractmethod + def _build_model(self): + r"""Build model for training/validating/evaluating.""" + pass + + @abstractmethod + def _forward_step(self, batch): + r"""One forward step of the neural network. This abstract method is trying to + unify ``_train_step`` and ``_valid_step`` and avoid redundant implementation. + However, for special case that using different forward step pattern for + training and validating, you could just override this method with ``pass`` and + implement ``_train_step`` and ``_valid_step`` separately. + """ + pass + + @abstractmethod + def _save_auxiliary_states(self): + r"""To save some auxiliary states when saving model's ckpt""" + pass + + ### Abstract methods end ### + + ### THIS IS MAIN ENTRY ### + def train_loop(self): + r"""Training loop. The public entry of training process.""" + # Wait everyone to prepare before we move on + self.accelerator.wait_for_everyone() + # dump config file + if self.accelerator.is_main_process: + self.__dump_cfg(self.config_save_path) + self.model.train() + self.optimizer.zero_grad() + # Wait to ensure good to go + self.accelerator.wait_for_everyone() + while self.epoch < self.max_epoch: + self.logger.info("\n") + self.logger.info("-" * 32) + self.logger.info("Epoch {}: ".format(self.epoch)) + + ### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict) + ### It's inconvenient for the model with multiple losses + # Do training & validating epoch + train_loss = self._train_epoch() + self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss)) + valid_loss = self._valid_epoch() + self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss)) + self.accelerator.log( + {"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss}, + step=self.epoch, + ) + + self.accelerator.wait_for_everyone() + # TODO: what is scheduler? + self.scheduler.step(valid_loss) # FIXME: use epoch track correct? + + # Check if hit save_checkpoint_stride and run_eval + run_eval = False + if self.accelerator.is_main_process: + save_checkpoint = False + hit_dix = [] + for i, num in enumerate(self.save_checkpoint_stride): + if self.epoch % num == 0: + save_checkpoint = True + hit_dix.append(i) + run_eval |= self.run_eval[i] + + self.accelerator.wait_for_everyone() + if self.accelerator.is_main_process and save_checkpoint: + path = os.path.join( + self.checkpoint_dir, + "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( + self.epoch, self.step, train_loss + ), + ) + self.tmp_checkpoint_save_path = path + self.accelerator.save_state(path) + print(f"save checkpoint in {path}") + json.dump( + self.checkpoints_path, + open(os.path.join(path, "ckpts.json"), "w"), + ensure_ascii=False, + indent=4, + ) + self._save_auxiliary_states() + + # Remove old checkpoints + to_remove = [] + for idx in hit_dix: + self.checkpoints_path[idx].append(path) + while len(self.checkpoints_path[idx]) > self.keep_last[idx]: + to_remove.append((idx, self.checkpoints_path[idx].pop(0))) + + # Search conflicts + total = set() + for i in self.checkpoints_path: + total |= set(i) + do_remove = set() + for idx, path in to_remove[::-1]: + if path in total: + self.checkpoints_path[idx].insert(0, path) + else: + do_remove.add(path) + + # Remove old checkpoints + for path in do_remove: + shutil.rmtree(path, ignore_errors=True) + self.logger.debug(f"Remove old checkpoint: {path}") + + self.accelerator.wait_for_everyone() + if run_eval: + # TODO: run evaluation + pass + + # Update info for each epoch + self.epoch += 1 + + # Finish training and save final checkpoint + self.accelerator.wait_for_everyone() + if self.accelerator.is_main_process: + self.accelerator.save_state( + os.path.join( + self.checkpoint_dir, + "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( + self.epoch, self.step, valid_loss + ), + ) + ) + self._save_auxiliary_states() + + self.accelerator.end_training() + + ### Following are methods that can be used directly in child classes ### + def _train_epoch(self): + r"""Training epoch. Should return average loss of a batch (sample) over + one epoch. See ``train_loop`` for usage. + """ + self.model.train() + epoch_sum_loss: float = 0.0 + epoch_step: int = 0 + for batch in tqdm( + self.train_dataloader, + desc=f"Training Epoch {self.epoch}", + unit="batch", + colour="GREEN", + leave=False, + dynamic_ncols=True, + smoothing=0.04, + disable=not self.accelerator.is_main_process, + ): + # Do training step and BP + with self.accelerator.accumulate(self.model): + loss = self._train_step(batch) + self.accelerator.backward(loss) + self.optimizer.step() + self.optimizer.zero_grad() + self.batch_count += 1 + + # Update info for each step + # TODO: step means BP counts or batch counts? + if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: + epoch_sum_loss += loss + self.accelerator.log( + { + "Step/Train Loss": loss, + "Step/Learning Rate": self.optimizer.param_groups[0]["lr"], + }, + step=self.step, + ) + self.step += 1 + epoch_step += 1 + + self.accelerator.wait_for_everyone() + return ( + epoch_sum_loss + / len(self.train_dataloader) + * self.cfg.train.gradient_accumulation_step + ) + + @torch.inference_mode() + def _valid_epoch(self): + r"""Testing epoch. Should return average loss of a batch (sample) over + one epoch. See ``train_loop`` for usage. + """ + self.model.eval() + epoch_sum_loss = 0.0 + for batch in tqdm( + self.valid_dataloader, + desc=f"Validating Epoch {self.epoch}", + unit="batch", + colour="GREEN", + leave=False, + dynamic_ncols=True, + smoothing=0.04, + disable=not self.accelerator.is_main_process, + ): + batch_loss = self._valid_step(batch) + epoch_sum_loss += batch_loss.item() + + self.accelerator.wait_for_everyone() + return epoch_sum_loss / len(self.valid_dataloader) + + def _train_step(self, batch): + r"""Training forward step. Should return average loss of a sample over + one batch. Provoke ``_forward_step`` is recommended except for special case. + See ``_train_epoch`` for usage. + """ + return self._forward_step(batch) + + @torch.inference_mode() + def _valid_step(self, batch): + r"""Testing forward step. Should return average loss of a sample over + one batch. Provoke ``_forward_step`` is recommended except for special case. + See ``_test_epoch`` for usage. + """ + return self._forward_step(batch) + + def __load_model( + self, + checkpoint_dir: str = None, + checkpoint_path: str = None, + resume_type: str = "", + ): + r"""Load model from checkpoint. If checkpoint_path is None, it will + load the latest checkpoint in checkpoint_dir. If checkpoint_path is not + None, it will load the checkpoint specified by checkpoint_path. **Only use this + method after** ``accelerator.prepare()``. + """ + if checkpoint_path is None: + ls = [str(i) for i in Path(checkpoint_dir).glob("*")] + ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True) + checkpoint_path = ls[0] + self.logger.info("Resume from {}...".format(checkpoint_path)) + + if resume_type in ["resume", ""]: + # Load all the things, including model weights, optimizer, scheduler, and random states. + self.accelerator.load_state(input_dir=checkpoint_path) + + # set epoch and step + self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1 + self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1 + + elif resume_type == "finetune": + # Load only the model weights + accelerate.load_checkpoint_and_dispatch( + self.accelerator.unwrap_model(self.model), + os.path.join(checkpoint_path, "pytorch_model.bin"), + ) + self.logger.info("Load model weights for finetune...") + + else: + raise ValueError("Resume_type must be `resume` or `finetune`.") + + return checkpoint_path + + # TODO: LEGACY CODE + def _build_dataloader(self): + Dataset, Collator = self._build_dataset() + + # build dataset instance for each dataset and combine them by ConcatDataset + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=False) + datasets_list.append(subdataset) + train_dataset = ConcatDataset(datasets_list) + train_collate = Collator(self.cfg) + _, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train") + self.logger.debug(f"train batch_sampler: {list(batch_sampler)}") + self.logger.debug(f"length: {train_dataset.cumulative_sizes}") + # TODO: use config instead of (sampler, shuffle, drop_last, batch_size) + train_loader = DataLoader( + train_dataset, + collate_fn=train_collate, + batch_sampler=batch_sampler, + num_workers=self.cfg.train.dataloader.num_worker, + pin_memory=self.cfg.train.dataloader.pin_memory, + ) + + # Build valid dataloader + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=True) + datasets_list.append(subdataset) + valid_dataset = ConcatDataset(datasets_list) + valid_collate = Collator(self.cfg) + _, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid") + self.logger.debug(f"valid batch_sampler: {list(batch_sampler)}") + self.logger.debug(f"length: {valid_dataset.cumulative_sizes}") + valid_loader = DataLoader( + valid_dataset, + collate_fn=valid_collate, + batch_sampler=batch_sampler, + num_workers=self.cfg.train.dataloader.num_worker, + pin_memory=self.cfg.train.dataloader.pin_memory, + ) + return train_loader, valid_loader + + @staticmethod + def _set_random_seed(seed): + r"""Set random seed for all possible random modules.""" + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + + def _check_nan(self, loss, y_pred, y_gt): + if torch.any(torch.isnan(loss)): + self.logger.fatal("Fatal Error: Training is down since loss has Nan!") + self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True) + if torch.any(torch.isnan(y_pred)): + self.logger.error( + f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True + ) + else: + self.logger.debug( + f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True + ) + if torch.any(torch.isnan(y_gt)): + self.logger.error( + f"y_gt has Nan: {torch.any(torch.isnan(y_gt))}", in_order=True + ) + else: + self.logger.debug( + f"y_gt has nan: {torch.any(torch.isnan(y_gt))}", in_order=True + ) + if torch.any(torch.isnan(y_pred)): + self.logger.error(f"y_pred: {y_pred}", in_order=True) + else: + self.logger.debug(f"y_pred: {y_pred}", in_order=True) + if torch.any(torch.isnan(y_gt)): + self.logger.error(f"y_gt: {y_gt}", in_order=True) + else: + self.logger.debug(f"y_gt: {y_gt}", in_order=True) + + # TODO: still OK to save tracking? + self.accelerator.end_training() + raise RuntimeError("Loss has Nan! See log for more info.") + + ### Protected methods end ### + + ## Following are private methods ## + ## !!! These are inconvenient for GAN-based model training. It'd be better to move these to svc_trainer.py if needed. + def __build_optimizer(self): + r"""Build optimizer for model.""" + # Make case-insensitive matching + if self.cfg.train.optimizer.lower() == "adadelta": + optimizer = torch.optim.Adadelta( + self.model.parameters(), **self.cfg.train.adadelta + ) + self.logger.info("Using Adadelta optimizer.") + elif self.cfg.train.optimizer.lower() == "adagrad": + optimizer = torch.optim.Adagrad( + self.model.parameters(), **self.cfg.train.adagrad + ) + self.logger.info("Using Adagrad optimizer.") + elif self.cfg.train.optimizer.lower() == "adam": + optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam) + self.logger.info("Using Adam optimizer.") + elif self.cfg.train.optimizer.lower() == "adamw": + optimizer = torch.optim.AdamW( + self.model.parameters(), **self.cfg.train.adamw + ) + elif self.cfg.train.optimizer.lower() == "sparseadam": + optimizer = torch.optim.SparseAdam( + self.model.parameters(), **self.cfg.train.sparseadam + ) + elif self.cfg.train.optimizer.lower() == "adamax": + optimizer = torch.optim.Adamax( + self.model.parameters(), **self.cfg.train.adamax + ) + elif self.cfg.train.optimizer.lower() == "asgd": + optimizer = torch.optim.ASGD(self.model.parameters(), **self.cfg.train.asgd) + elif self.cfg.train.optimizer.lower() == "lbfgs": + optimizer = torch.optim.LBFGS( + self.model.parameters(), **self.cfg.train.lbfgs + ) + elif self.cfg.train.optimizer.lower() == "nadam": + optimizer = torch.optim.NAdam( + self.model.parameters(), **self.cfg.train.nadam + ) + elif self.cfg.train.optimizer.lower() == "radam": + optimizer = torch.optim.RAdam( + self.model.parameters(), **self.cfg.train.radam + ) + elif self.cfg.train.optimizer.lower() == "rmsprop": + optimizer = torch.optim.RMSprop( + self.model.parameters(), **self.cfg.train.rmsprop + ) + elif self.cfg.train.optimizer.lower() == "rprop": + optimizer = torch.optim.Rprop( + self.model.parameters(), **self.cfg.train.rprop + ) + elif self.cfg.train.optimizer.lower() == "sgd": + optimizer = torch.optim.SGD(self.model.parameters(), **self.cfg.train.sgd) + else: + raise NotImplementedError( + f"Optimizer {self.cfg.train.optimizer} not supported yet!" + ) + return optimizer + + def __build_scheduler(self): + r"""Build scheduler for optimizer.""" + # Make case-insensitive matching + if self.cfg.train.scheduler.lower() == "lambdalr": + scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, **self.cfg.train.lambdalr + ) + elif self.cfg.train.scheduler.lower() == "multiplicativelr": + scheduler = torch.optim.lr_scheduler.MultiplicativeLR( + self.optimizer, **self.cfg.train.multiplicativelr + ) + elif self.cfg.train.scheduler.lower() == "steplr": + scheduler = torch.optim.lr_scheduler.StepLR( + self.optimizer, **self.cfg.train.steplr + ) + elif self.cfg.train.scheduler.lower() == "multisteplr": + scheduler = torch.optim.lr_scheduler.MultiStepLR( + self.optimizer, **self.cfg.train.multisteplr + ) + elif self.cfg.train.scheduler.lower() == "constantlr": + scheduler = torch.optim.lr_scheduler.ConstantLR( + self.optimizer, **self.cfg.train.constantlr + ) + elif self.cfg.train.scheduler.lower() == "linearlr": + scheduler = torch.optim.lr_scheduler.LinearLR( + self.optimizer, **self.cfg.train.linearlr + ) + elif self.cfg.train.scheduler.lower() == "exponentiallr": + scheduler = torch.optim.lr_scheduler.ExponentialLR( + self.optimizer, **self.cfg.train.exponentiallr + ) + elif self.cfg.train.scheduler.lower() == "polynomiallr": + scheduler = torch.optim.lr_scheduler.PolynomialLR( + self.optimizer, **self.cfg.train.polynomiallr + ) + elif self.cfg.train.scheduler.lower() == "cosineannealinglr": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + self.optimizer, **self.cfg.train.cosineannealinglr + ) + elif self.cfg.train.scheduler.lower() == "sequentiallr": + scheduler = torch.optim.lr_scheduler.SequentialLR( + self.optimizer, **self.cfg.train.sequentiallr + ) + elif self.cfg.train.scheduler.lower() == "reducelronplateau": + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer, **self.cfg.train.reducelronplateau + ) + elif self.cfg.train.scheduler.lower() == "cycliclr": + scheduler = torch.optim.lr_scheduler.CyclicLR( + self.optimizer, **self.cfg.train.cycliclr + ) + elif self.cfg.train.scheduler.lower() == "onecyclelr": + scheduler = torch.optim.lr_scheduler.OneCycleLR( + self.optimizer, **self.cfg.train.onecyclelr + ) + elif self.cfg.train.scheduler.lower() == "cosineannearingwarmrestarts": + scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + self.optimizer, **self.cfg.train.cosineannearingwarmrestarts + ) + elif self.cfg.train.scheduler.lower() == "noamlr": + scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler) + else: + raise NotImplementedError( + f"Scheduler {self.cfg.train.scheduler} not supported yet!" + ) + return scheduler + + def _init_accelerator(self): + self.exp_dir = os.path.join( + os.path.abspath(self.cfg.log_dir), self.args.exp_name + ) + project_config = ProjectConfiguration( + project_dir=self.exp_dir, + logging_dir=os.path.join(self.exp_dir, "log"), + ) + self.accelerator = accelerate.Accelerator( + gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step, + log_with=self.cfg.train.tracker, + project_config=project_config, + ) + if self.accelerator.is_main_process: + os.makedirs(project_config.project_dir, exist_ok=True) + os.makedirs(project_config.logging_dir, exist_ok=True) + with self.accelerator.main_process_first(): + self.accelerator.init_trackers(self.args.exp_name) + + def __check_basic_configs(self): + if self.cfg.train.gradient_accumulation_step <= 0: + self.logger.fatal("Invalid gradient_accumulation_step value!") + self.logger.error( + f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive." + ) + self.accelerator.end_training() + raise ValueError( + f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive." + ) + # TODO: check other values + + @staticmethod + def __count_parameters(model): + model_param = 0.0 + if isinstance(model, dict): + for key, value in model.items(): + model_param += sum(p.numel() for p in model[key].parameters()) + else: + model_param = sum(p.numel() for p in model.parameters()) + return model_param + + def __dump_cfg(self, path): + os.makedirs(os.path.dirname(path), exist_ok=True) + json5.dump( + self.cfg, + open(path, "w"), + indent=4, + sort_keys=True, + ensure_ascii=False, + quote_keys=True, + ) + + ### Private methods end ### diff --git a/models/svc/__init__.py b/models/svc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/svc/base/__init__.py b/models/svc/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38c2b1686db550b3b9892b8bc6e594cd847aafd1 --- /dev/null +++ b/models/svc/base/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .svc_inference import SVCInference +from .svc_trainer import SVCTrainer diff --git a/models/svc/base/svc_dataset.py b/models/svc/base/svc_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9a66c03c85ef5448f08a94718d395b519af9af74 --- /dev/null +++ b/models/svc/base/svc_dataset.py @@ -0,0 +1,437 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import random +import torch +from torch.nn.utils.rnn import pad_sequence +import json +import os +import numpy as np +from utils.data_utils import * +from processors.acoustic_extractor import cal_normalized_mel, load_mel_extrema +from processors.content_extractor import ( + ContentvecExtractor, + WhisperExtractor, + WenetExtractor, +) +from models.base.base_dataset import ( + BaseCollator, + BaseDataset, +) +from models.base.new_dataset import BaseTestDataset + +EPS = 1.0e-12 + + +class SVCDataset(BaseDataset): + def __init__(self, cfg, dataset, is_valid=False): + BaseDataset.__init__(self, cfg, dataset, is_valid=is_valid) + + cfg = self.cfg + + if cfg.model.condition_encoder.use_whisper: + self.whisper_aligner = WhisperExtractor(self.cfg) + self.utt2whisper_path = load_content_feature_path( + self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir + ) + + if cfg.model.condition_encoder.use_contentvec: + self.contentvec_aligner = ContentvecExtractor(self.cfg) + self.utt2contentVec_path = load_content_feature_path( + self.metadata, + cfg.preprocess.processed_dir, + cfg.preprocess.contentvec_dir, + ) + + if cfg.model.condition_encoder.use_mert: + self.utt2mert_path = load_content_feature_path( + self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir + ) + if cfg.model.condition_encoder.use_wenet: + self.wenet_aligner = WenetExtractor(self.cfg) + self.utt2wenet_path = load_content_feature_path( + self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir + ) + + def __getitem__(self, index): + single_feature = BaseDataset.__getitem__(self, index) + + utt_info = self.metadata[index] + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + if self.cfg.model.condition_encoder.use_whisper: + assert "target_len" in single_feature.keys() + aligned_whisper_feat = self.whisper_aligner.offline_align( + np.load(self.utt2whisper_path[utt]), single_feature["target_len"] + ) + single_feature["whisper_feat"] = aligned_whisper_feat + + if self.cfg.model.condition_encoder.use_contentvec: + assert "target_len" in single_feature.keys() + aligned_contentvec = self.contentvec_aligner.offline_align( + np.load(self.utt2contentVec_path[utt]), single_feature["target_len"] + ) + single_feature["contentvec_feat"] = aligned_contentvec + + if self.cfg.model.condition_encoder.use_mert: + assert "target_len" in single_feature.keys() + aligned_mert_feat = align_content_feature_length( + np.load(self.utt2mert_path[utt]), + single_feature["target_len"], + source_hop=self.cfg.preprocess.mert_hop_size, + ) + single_feature["mert_feat"] = aligned_mert_feat + + if self.cfg.model.condition_encoder.use_wenet: + assert "target_len" in single_feature.keys() + aligned_wenet_feat = self.wenet_aligner.offline_align( + np.load(self.utt2wenet_path[utt]), single_feature["target_len"] + ) + single_feature["wenet_feat"] = aligned_wenet_feat + + # print(single_feature.keys()) + # for k, v in single_feature.items(): + # if type(v) in [torch.Tensor, np.ndarray]: + # print(k, v.shape) + # else: + # print(k, v) + # exit() + + return self.clip_if_too_long(single_feature) + + def __len__(self): + return len(self.metadata) + + def random_select(self, feature_seq_len, max_seq_len, ending_ts=2812): + """ + ending_ts: to avoid invalid whisper features for over 30s audios + 2812 = 30 * 24000 // 256 + """ + ts = max(feature_seq_len - max_seq_len, 0) + ts = min(ts, ending_ts - max_seq_len) + + start = random.randint(0, ts) + end = start + max_seq_len + return start, end + + def clip_if_too_long(self, sample, max_seq_len=512): + """ + sample : + { + 'spk_id': (1,), + 'target_len': int + 'mel': (seq_len, dim), + 'frame_pitch': (seq_len,) + 'frame_energy': (seq_len,) + 'content_vector_feat': (seq_len, dim) + } + """ + + if sample["target_len"] <= max_seq_len: + return sample + + start, end = self.random_select(sample["target_len"], max_seq_len) + sample["target_len"] = end - start + + for k in sample.keys(): + if k == "audio": + # audio should be clipped in hop_size scale + sample[k] = sample[k][ + start + * self.cfg.preprocess.hop_size : end + * self.cfg.preprocess.hop_size + ] + elif k == "audio_len": + sample[k] = (end - start) * self.cfg.preprocess.hop_size + elif k not in ["spk_id", "target_len"]: + sample[k] = sample[k][start:end] + + return sample + + +class SVCCollator(BaseCollator): + """Zero-pads model inputs and targets based on number of frames per step""" + + def __init__(self, cfg): + BaseCollator.__init__(self, cfg) + + def __call__(self, batch): + parsed_batch_features = BaseCollator.__call__(self, batch) + return parsed_batch_features + + +class SVCTestDataset(BaseTestDataset): + def __init__(self, args, cfg, infer_type): + BaseTestDataset.__init__(self, args, cfg, infer_type) + self.metadata = self.get_metadata() + + target_singer = args.target_singer + self.cfg = cfg + self.trans_key = args.trans_key + assert type(target_singer) == str + + self.target_singer = target_singer.split("_")[-1] + self.target_dataset = target_singer.replace( + "_{}".format(self.target_singer), "" + ) + if cfg.preprocess.mel_min_max_norm: + self.target_mel_extrema = load_mel_extrema( + cfg.preprocess, self.target_dataset + ) + self.target_mel_extrema = torch.as_tensor( + self.target_mel_extrema[0] + ), torch.as_tensor(self.target_mel_extrema[1]) + + ######### Load source acoustic features ######### + if cfg.preprocess.use_spkid: + spk2id_path = os.path.join(args.acoustics_dir, cfg.preprocess.spk2id) + # utt2sp_path = os.path.join(self.data_root, cfg.preprocess.utt2spk) + + with open(spk2id_path, "r") as f: + self.spk2id = json.load(f) + # print("self.spk2id", self.spk2id) + + if cfg.preprocess.use_uv: + self.utt2uv_path = { + f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join( + cfg.preprocess.processed_dir, + utt_info["Dataset"], + cfg.preprocess.uv_dir, + utt_info["Uid"] + ".npy", + ) + for utt_info in self.metadata + } + + if cfg.preprocess.use_frame_pitch: + self.utt2frame_pitch_path = { + f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join( + cfg.preprocess.processed_dir, + utt_info["Dataset"], + cfg.preprocess.pitch_dir, + utt_info["Uid"] + ".npy", + ) + for utt_info in self.metadata + } + + # Target F0 median + target_f0_statistics_path = os.path.join( + cfg.preprocess.processed_dir, + self.target_dataset, + cfg.preprocess.pitch_dir, + "statistics.json", + ) + self.target_pitch_median = json.load(open(target_f0_statistics_path, "r"))[ + f"{self.target_dataset}_{self.target_singer}" + ]["voiced_positions"]["median"] + + # Source F0 median (if infer from file) + if infer_type == "from_file": + source_audio_name = cfg.inference.source_audio_name + source_f0_statistics_path = os.path.join( + cfg.preprocess.processed_dir, + source_audio_name, + cfg.preprocess.pitch_dir, + "statistics.json", + ) + self.source_pitch_median = json.load( + open(source_f0_statistics_path, "r") + )[f"{source_audio_name}_{source_audio_name}"]["voiced_positions"][ + "median" + ] + else: + self.source_pitch_median = None + + if cfg.preprocess.use_frame_energy: + self.utt2frame_energy_path = { + f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join( + cfg.preprocess.processed_dir, + utt_info["Dataset"], + cfg.preprocess.energy_dir, + utt_info["Uid"] + ".npy", + ) + for utt_info in self.metadata + } + + if cfg.preprocess.use_mel: + self.utt2mel_path = { + f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join( + cfg.preprocess.processed_dir, + utt_info["Dataset"], + cfg.preprocess.mel_dir, + utt_info["Uid"] + ".npy", + ) + for utt_info in self.metadata + } + + ######### Load source content features' path ######### + if cfg.model.condition_encoder.use_whisper: + self.whisper_aligner = WhisperExtractor(cfg) + self.utt2whisper_path = load_content_feature_path( + self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir + ) + + if cfg.model.condition_encoder.use_contentvec: + self.contentvec_aligner = ContentvecExtractor(cfg) + self.utt2contentVec_path = load_content_feature_path( + self.metadata, + cfg.preprocess.processed_dir, + cfg.preprocess.contentvec_dir, + ) + + if cfg.model.condition_encoder.use_mert: + self.utt2mert_path = load_content_feature_path( + self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir + ) + if cfg.model.condition_encoder.use_wenet: + self.wenet_aligner = WenetExtractor(cfg) + self.utt2wenet_path = load_content_feature_path( + self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir + ) + + def __getitem__(self, index): + single_feature = {} + + utt_info = self.metadata[index] + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + source_dataset = self.metadata[index]["Dataset"] + + if self.cfg.preprocess.use_spkid: + single_feature["spk_id"] = np.array( + [self.spk2id[f"{self.target_dataset}_{self.target_singer}"]], + dtype=np.int32, + ) + + ######### Get Acoustic Features Item ######### + if self.cfg.preprocess.use_mel: + mel = np.load(self.utt2mel_path[utt]) + assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T] + if self.cfg.preprocess.use_min_max_norm_mel: + # mel norm + mel = cal_normalized_mel(mel, source_dataset, self.cfg.preprocess) + + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = mel.shape[1] + single_feature["mel"] = mel.T # [T, n_mels] + + if self.cfg.preprocess.use_frame_pitch: + frame_pitch_path = self.utt2frame_pitch_path[utt] + frame_pitch = np.load(frame_pitch_path) + + if self.trans_key: + try: + self.trans_key = int(self.trans_key) + except: + pass + if type(self.trans_key) == int: + frame_pitch = transpose_key(frame_pitch, self.trans_key) + elif self.trans_key: + assert self.target_singer + + frame_pitch = pitch_shift_to_target( + frame_pitch, self.target_pitch_median, self.source_pitch_median + ) + + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = len(frame_pitch) + aligned_frame_pitch = align_length( + frame_pitch, single_feature["target_len"] + ) + single_feature["frame_pitch"] = aligned_frame_pitch + + if self.cfg.preprocess.use_uv: + frame_uv_path = self.utt2uv_path[utt] + frame_uv = np.load(frame_uv_path) + aligned_frame_uv = align_length(frame_uv, single_feature["target_len"]) + aligned_frame_uv = [ + 0 if frame_uv else 1 for frame_uv in aligned_frame_uv + ] + aligned_frame_uv = np.array(aligned_frame_uv) + single_feature["frame_uv"] = aligned_frame_uv + + if self.cfg.preprocess.use_frame_energy: + frame_energy_path = self.utt2frame_energy_path[utt] + frame_energy = np.load(frame_energy_path) + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = len(frame_energy) + aligned_frame_energy = align_length( + frame_energy, single_feature["target_len"] + ) + single_feature["frame_energy"] = aligned_frame_energy + + ######### Get Content Features Item ######### + if self.cfg.model.condition_encoder.use_whisper: + assert "target_len" in single_feature.keys() + aligned_whisper_feat = self.whisper_aligner.offline_align( + np.load(self.utt2whisper_path[utt]), single_feature["target_len"] + ) + single_feature["whisper_feat"] = aligned_whisper_feat + + if self.cfg.model.condition_encoder.use_contentvec: + assert "target_len" in single_feature.keys() + aligned_contentvec = self.contentvec_aligner.offline_align( + np.load(self.utt2contentVec_path[utt]), single_feature["target_len"] + ) + single_feature["contentvec_feat"] = aligned_contentvec + + if self.cfg.model.condition_encoder.use_mert: + assert "target_len" in single_feature.keys() + aligned_mert_feat = align_content_feature_length( + np.load(self.utt2mert_path[utt]), + single_feature["target_len"], + source_hop=self.cfg.preprocess.mert_hop_size, + ) + single_feature["mert_feat"] = aligned_mert_feat + + if self.cfg.model.condition_encoder.use_wenet: + assert "target_len" in single_feature.keys() + aligned_wenet_feat = self.wenet_aligner.offline_align( + np.load(self.utt2wenet_path[utt]), single_feature["target_len"] + ) + single_feature["wenet_feat"] = aligned_wenet_feat + + return single_feature + + def __len__(self): + return len(self.metadata) + + +class SVCTestCollator: + """Zero-pads model inputs and targets based on number of frames per step""" + + def __init__(self, cfg): + self.cfg = cfg + + def __call__(self, batch): + packed_batch_features = dict() + + # mel: [b, T, n_mels] + # frame_pitch, frame_energy: [1, T] + # target_len: [1] + # spk_id: [b, 1] + # mask: [b, T, 1] + + for key in batch[0].keys(): + if key == "target_len": + packed_batch_features["target_len"] = torch.LongTensor( + [b["target_len"] for b in batch] + ) + masks = [ + torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch + ] + packed_batch_features["mask"] = pad_sequence( + masks, batch_first=True, padding_value=0 + ) + else: + values = [torch.from_numpy(b[key]) for b in batch] + packed_batch_features[key] = pad_sequence( + values, batch_first=True, padding_value=0 + ) + + return packed_batch_features diff --git a/models/svc/base/svc_inference.py b/models/svc/base/svc_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..52f88d5d915e1616292c03927b4f51557351f58b --- /dev/null +++ b/models/svc/base/svc_inference.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from models.base.new_inference import BaseInference +from models.svc.base.svc_dataset import SVCTestCollator, SVCTestDataset + + +class SVCInference(BaseInference): + def __init__(self, args=None, cfg=None, infer_type="from_dataset"): + BaseInference.__init__(self, args, cfg, infer_type) + + def _build_test_dataset(self): + return SVCTestDataset, SVCTestCollator diff --git a/models/svc/base/svc_trainer.py b/models/svc/base/svc_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a093a86712bb7ccfa786a6c18dd1683ffc013c --- /dev/null +++ b/models/svc/base/svc_trainer.py @@ -0,0 +1,111 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os + +import torch +import torch.nn as nn + +from models.base.new_trainer import BaseTrainer +from models.svc.base.svc_dataset import SVCCollator, SVCDataset + + +class SVCTrainer(BaseTrainer): + r"""The base trainer for all SVC models. It inherits from BaseTrainer and implements + ``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this + class, and implement ``_build_model``, ``_forward_step``. + """ + + def __init__(self, args=None, cfg=None): + self.args = args + self.cfg = cfg + + self._init_accelerator() + + # Only for SVC tasks + with self.accelerator.main_process_first(): + self.singers = self._build_singer_lut() + + # Super init + BaseTrainer.__init__(self, args, cfg) + + # Only for SVC tasks + self.task_type = "SVC" + self.logger.info("Task type: {}".format(self.task_type)) + + ### Following are methods only for SVC tasks ### + # TODO: LEGACY CODE, NEED TO BE REFACTORED + def _build_dataset(self): + return SVCDataset, SVCCollator + + @staticmethod + def _build_criterion(): + criterion = nn.MSELoss(reduction="none") + return criterion + + @staticmethod + def _compute_loss(criterion, y_pred, y_gt, loss_mask): + """ + Args: + criterion: MSELoss(reduction='none') + y_pred, y_gt: (bs, seq_len, D) + loss_mask: (bs, seq_len, 1) + Returns: + loss: Tensor of shape [] + """ + + # (bs, seq_len, D) + loss = criterion(y_pred, y_gt) + # expand loss_mask to (bs, seq_len, D) + loss_mask = loss_mask.repeat(1, 1, loss.shape[-1]) + + loss = torch.sum(loss * loss_mask) / torch.sum(loss_mask) + return loss + + def _save_auxiliary_states(self): + """ + To save the singer's look-up table in the checkpoint saving path + """ + with open( + os.path.join(self.tmp_checkpoint_save_path, self.cfg.preprocess.spk2id), "w" + ) as f: + json.dump(self.singers, f, indent=4, ensure_ascii=False) + + def _build_singer_lut(self): + resumed_singer_path = None + if self.args.resume_from_ckpt_path and self.args.resume_from_ckpt_path != "": + resumed_singer_path = os.path.join( + self.args.resume_from_ckpt_path, self.cfg.preprocess.spk2id + ) + if os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)): + resumed_singer_path = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id) + + if resumed_singer_path: + with open(resumed_singer_path, "r") as f: + singers = json.load(f) + else: + singers = dict() + + for dataset in self.cfg.dataset: + singer_lut_path = os.path.join( + self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id + ) + with open(singer_lut_path, "r") as singer_lut_path: + singer_lut = json.load(singer_lut_path) + for singer in singer_lut.keys(): + if singer not in singers: + singers[singer] = len(singers) + + with open( + os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w" + ) as singer_file: + json.dump(singers, singer_file, indent=4, ensure_ascii=False) + print( + "singers have been dumped to {}".format( + os.path.join(self.exp_dir, self.cfg.preprocess.spk2id) + ) + ) + return singers diff --git a/models/svc/comosvc/__init__.py b/models/svc/comosvc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..19f1cb162e95d8a992002beaa0c0d8bada9cddd5 --- /dev/null +++ b/models/svc/comosvc/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/models/svc/comosvc/comosvc.py b/models/svc/comosvc/comosvc.py new file mode 100644 index 0000000000000000000000000000000000000000..6cecd7a3f40f3a78f0df06ef2340159d321d6117 --- /dev/null +++ b/models/svc/comosvc/comosvc.py @@ -0,0 +1,377 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Adapted from https://github.com/zhenye234/CoMoSpeech""" + +import torch +import torch.nn as nn +import copy +import numpy as np +import math +from tqdm.auto import tqdm + +from utils.ssim import SSIM + +from models.svc.transformer.conformer import Conformer, BaseModule +from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper +from models.svc.comosvc.utils import slice_segments, rand_ids_segments + + +class Consistency(nn.Module): + def __init__(self, cfg, distill=False): + super().__init__() + self.cfg = cfg + # self.denoise_fn = GradLogPEstimator2d(96) + self.denoise_fn = DiffusionWrapper(self.cfg) + self.cfg = cfg.model.comosvc + self.teacher = not distill + self.P_mean = self.cfg.P_mean + self.P_std = self.cfg.P_std + self.sigma_data = self.cfg.sigma_data + self.sigma_min = self.cfg.sigma_min + self.sigma_max = self.cfg.sigma_max + self.rho = self.cfg.rho + self.N = self.cfg.n_timesteps + self.ssim_loss = SSIM() + + # Time step discretization + step_indices = torch.arange(self.N) + # karras boundaries formula + t_steps = ( + self.sigma_min ** (1 / self.rho) + + step_indices + / (self.N - 1) + * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho)) + ) ** self.rho + self.t_steps = torch.cat( + [torch.zeros_like(t_steps[:1]), self.round_sigma(t_steps)] + ) + + def init_consistency_training(self): + self.denoise_fn_ema = copy.deepcopy(self.denoise_fn) + self.denoise_fn_pretrained = copy.deepcopy(self.denoise_fn) + + def EDMPrecond(self, x, sigma, cond, denoise_fn, mask, spk=None): + """ + karras diffusion reverse process + + Args: + x: noisy mel-spectrogram [B x n_mel x L] + sigma: noise level [B x 1 x 1] + cond: output of conformer encoder [B x n_mel x L] + denoise_fn: denoiser neural network e.g. DilatedCNN + mask: mask of padded frames [B x n_mel x L] + + Returns: + denoised mel-spectrogram [B x n_mel x L] + """ + sigma = sigma.reshape(-1, 1, 1) + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.log() / 4 + + x_in = c_in * x + x_in = x_in.transpose(1, 2) + x = x.transpose(1, 2) + cond = cond.transpose(1, 2) + F_x = denoise_fn(x_in, c_noise.squeeze(), cond) + # F_x = denoise_fn((c_in * x), mask, cond, c_noise.flatten()) + D_x = c_skip * x + c_out * (F_x) + D_x = D_x.transpose(1, 2) + return D_x + + def EDMLoss(self, x_start, cond, mask): + """ + compute loss for EDM model + + Args: + x_start: ground truth mel-spectrogram [B x n_mel x L] + cond: output of conformer encoder [B x n_mel x L] + mask: mask of padded frames [B x n_mel x L] + """ + rnd_normal = torch.randn([x_start.shape[0], 1, 1], device=x_start.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + # follow Grad-TTS, start from Gaussian noise with mean cond and std I + noise = (torch.randn_like(x_start) + cond) * sigma + D_yn = self.EDMPrecond(x_start + noise, sigma, cond, self.denoise_fn, mask) + loss = weight * ((D_yn - x_start) ** 2) + loss = torch.sum(loss * mask) / torch.sum(mask) + return loss + + def round_sigma(self, sigma): + return torch.as_tensor(sigma) + + def edm_sampler( + self, + latents, + cond, + nonpadding, + num_steps=50, + sigma_min=0.002, + sigma_max=80, + rho=7, + S_churn=0, + S_min=0, + S_max=float("inf"), + S_noise=1, + # S_churn=40 ,S_min=0.05,S_max=50,S_noise=1.003,# S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, + # S_churn=30 ,S_min=0.01,S_max=30,S_noise=1.007, + # S_churn=30 ,S_min=0.01,S_max=1,S_noise=1.007, + # S_churn=80 ,S_min=0.05,S_max=50,S_noise=1.003, + ): + """ + karras diffusion sampler + + Args: + latents: noisy mel-spectrogram [B x n_mel x L] + cond: output of conformer encoder [B x n_mel x L] + nonpadding: mask of padded frames [B x n_mel x L] + num_steps: number of steps for diffusion inference + + Returns: + denoised mel-spectrogram [B x n_mel x L] + """ + # Time step discretization. + step_indices = torch.arange(num_steps, device=latents.device) + + num_steps = num_steps + 1 + t_steps = ( + sigma_max ** (1 / rho) + + step_indices + / (num_steps - 1) + * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) + ) ** rho + t_steps = torch.cat([self.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) + + # Main sampling loop. + x_next = latents * t_steps[0] + # wrap in tqdm for progress bar + bar = tqdm(enumerate(zip(t_steps[:-1], t_steps[1:]))) + for i, (t_cur, t_next) in bar: + x_cur = x_next + # Increase noise temporarily. + gamma = ( + min(S_churn / num_steps, np.sqrt(2) - 1) + if S_min <= t_cur <= S_max + else 0 + ) + t_hat = self.round_sigma(t_cur + gamma * t_cur) + t = torch.zeros((x_cur.shape[0], 1, 1), device=x_cur.device) + t[:, 0, 0] = t_hat + t_hat = t + x_hat = x_cur + ( + t_hat**2 - t_cur**2 + ).sqrt() * S_noise * torch.randn_like(x_cur) + # Euler step. + denoised = self.EDMPrecond(x_hat, t_hat, cond, self.denoise_fn, nonpadding) + d_cur = (x_hat - denoised) / t_hat + x_next = x_hat + (t_next - t_hat) * d_cur + + return x_next + + def CTLoss_D(self, y, cond, mask): + """ + compute loss for consistency distillation + + Args: + y: ground truth mel-spectrogram [B x n_mel x L] + cond: output of conformer encoder [B x n_mel x L] + mask: mask of padded frames [B x n_mel x L] + """ + with torch.no_grad(): + mu = 0.95 + for p, ema_p in zip( + self.denoise_fn.parameters(), self.denoise_fn_ema.parameters() + ): + ema_p.mul_(mu).add_(p, alpha=1 - mu) + + n = torch.randint(1, self.N, (y.shape[0],)) + z = torch.randn_like(y) + cond + + tn_1 = self.t_steps[n + 1].reshape(-1, 1, 1).to(y.device) + f_theta = self.EDMPrecond(y + tn_1 * z, tn_1, cond, self.denoise_fn, mask) + + with torch.no_grad(): + tn = self.t_steps[n].reshape(-1, 1, 1).to(y.device) + + # euler step + x_hat = y + tn_1 * z + denoised = self.EDMPrecond( + x_hat, tn_1, cond, self.denoise_fn_pretrained, mask + ) + d_cur = (x_hat - denoised) / tn_1 + y_tn = x_hat + (tn - tn_1) * d_cur + + f_theta_ema = self.EDMPrecond(y_tn, tn, cond, self.denoise_fn_ema, mask) + + # loss = (f_theta - f_theta_ema.detach()) ** 2 + # loss = torch.sum(loss * mask) / torch.sum(mask) + loss = self.ssim_loss(f_theta, f_theta_ema.detach()) + loss = torch.sum(loss * mask) / torch.sum(mask) + + return loss + + def get_t_steps(self, N): + N = N + 1 + step_indices = torch.arange(N) # , device=latents.device) + t_steps = ( + self.sigma_min ** (1 / self.rho) + + step_indices + / (N - 1) + * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho)) + ) ** self.rho + + return t_steps.flip(0) + + def CT_sampler(self, latents, cond, nonpadding, t_steps=1): + """ + consistency distillation sampler + + Args: + latents: noisy mel-spectrogram [B x n_mel x L] + cond: output of conformer encoder [B x n_mel x L] + nonpadding: mask of padded frames [B x n_mel x L] + t_steps: number of steps for diffusion inference + + Returns: + denoised mel-spectrogram [B x n_mel x L] + """ + # one-step + if t_steps == 1: + t_steps = [80] + # multi-step + else: + t_steps = self.get_t_steps(t_steps) + + t_steps = torch.as_tensor(t_steps).to(latents.device) + latents = latents * t_steps[0] + _t = torch.zeros((latents.shape[0], 1, 1), device=latents.device) + _t[:, 0, 0] = t_steps + x = self.EDMPrecond(latents, _t, cond, self.denoise_fn_ema, nonpadding) + + for t in t_steps[1:-1]: + z = torch.randn_like(x) + cond + x_tn = x + (t**2 - self.sigma_min**2).sqrt() * z + _t = torch.zeros((x.shape[0], 1, 1), device=x.device) + _t[:, 0, 0] = t + t = _t + print(t) + x = self.EDMPrecond(x_tn, t, cond, self.denoise_fn_ema, nonpadding) + return x + + def forward(self, x, nonpadding, cond, t_steps=1, infer=False): + """ + calculate loss or sample mel-spectrogram + + Args: + x: + training: ground truth mel-spectrogram [B x n_mel x L] + inference: output of encoder [B x n_mel x L] + """ + if self.teacher: # teacher model -- karras diffusion + if not infer: + loss = self.EDMLoss(x, cond, nonpadding) + return loss + else: + shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2]) + x = torch.randn(shape, device=x.device) + cond + x = self.edm_sampler(x, cond, nonpadding, t_steps) + + return x + else: # Consistency distillation + if not infer: + loss = self.CTLoss_D(x, cond, nonpadding) + return loss + + else: + shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2]) + x = torch.randn(shape, device=x.device) + cond + x = self.CT_sampler(x, cond, nonpadding, t_steps=1) + + return x + + +class ComoSVC(BaseModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.cfg.model.comosvc.n_mel = self.cfg.preprocess.n_mel + self.distill = self.cfg.model.comosvc.distill + self.encoder = Conformer(self.cfg.model.comosvc) + self.decoder = Consistency(self.cfg, distill=self.distill) + self.ssim_loss = SSIM() + + @torch.no_grad() + def forward(self, x_mask, x, n_timesteps, temperature=1.0): + """ + Generates mel-spectrogram from pitch, content vector, energy. Returns: + 1. encoder outputs (from conformer) + 2. decoder outputs (from diffusion-based decoder) + + Args: + x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel] + x : output of encoder framework. [B x L x d_condition] + n_timesteps : number of steps to use for reverse diffusion in decoder. + temperature : controls variance of terminal distribution. + """ + + # Get encoder_outputs `mu_x` + mu_x = self.encoder(x, x_mask) + encoder_outputs = mu_x + + mu_x = mu_x.transpose(1, 2) + x_mask = x_mask.transpose(1, 2) + + # Generate sample by performing reverse dynamics + decoder_outputs = self.decoder( + mu_x, x_mask, mu_x, t_steps=n_timesteps, infer=True + ) + decoder_outputs = decoder_outputs.transpose(1, 2) + return encoder_outputs, decoder_outputs + + def compute_loss(self, x_mask, x, mel, out_size=None, skip_diff=False): + """ + Computes 2 losses: + 1. prior loss: loss between mel-spectrogram and encoder outputs. + 2. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder. + + Args: + x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel] + x : output of encoder framework. [B x L x d_condition] + mel : ground truth mel-spectrogram. [B x L x n_mel] + """ + + mu_x = self.encoder(x, x_mask) + # prior loss + prior_loss = torch.sum( + 0.5 * ((mel - mu_x) ** 2 + math.log(2 * math.pi)) * x_mask + ) + prior_loss = prior_loss / (torch.sum(x_mask) * self.cfg.model.comosvc.n_mel) + # ssim loss + ssim_loss = self.ssim_loss(mu_x, mel) + ssim_loss = torch.sum(ssim_loss * x_mask) / torch.sum(x_mask) + + x_mask = x_mask.transpose(1, 2) + mu_x = mu_x.transpose(1, 2) + mel = mel.transpose(1, 2) + if not self.distill and skip_diff: + diff_loss = prior_loss.clone() + diff_loss.fill_(0) + + # Cut a small segment of mel-spectrogram in order to increase batch size + else: + if self.distill: + mu_y = mu_x.detach() + else: + mu_y = mu_x + mask_y = x_mask + + diff_loss = self.decoder(mel, mask_y, mu_y, infer=False) + + return ssim_loss, prior_loss, diff_loss diff --git a/models/svc/comosvc/comosvc_inference.py b/models/svc/comosvc/comosvc_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..2783ec7e468c367c7d2f5f8988ed1f7e272d4cb7 --- /dev/null +++ b/models/svc/comosvc/comosvc_inference.py @@ -0,0 +1,39 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from models.svc.base import SVCInference +from modules.encoder.condition_encoder import ConditionEncoder +from models.svc.comosvc.comosvc import ComoSVC + + +class ComoSVCInference(SVCInference): + def __init__(self, args, cfg, infer_type="from_dataset"): + SVCInference.__init__(self, args, cfg, infer_type) + + def _build_model(self): + # TODO: sort out the config + self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min + self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max + self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) + self.acoustic_mapper = ComoSVC(self.cfg) + if self.cfg.model.comosvc.distill: + self.acoustic_mapper.decoder.init_consistency_training() + model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) + return model + + def _inference_each_batch(self, batch_data): + device = self.accelerator.device + for k, v in batch_data.items(): + batch_data[k] = v.to(device) + + cond = self.condition_encoder(batch_data) + mask = batch_data["mask"] + encoder_pred, decoder_pred = self.acoustic_mapper( + mask, cond, self.cfg.inference.comosvc.inference_steps + ) + + return decoder_pred diff --git a/models/svc/comosvc/comosvc_trainer.py b/models/svc/comosvc/comosvc_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5ba49fd4539b8ae351137a85595ff9cfba1f4677 --- /dev/null +++ b/models/svc/comosvc/comosvc_trainer.py @@ -0,0 +1,295 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import os +import json5 +from collections import OrderedDict +from tqdm import tqdm +import json +import shutil + +from models.svc.base import SVCTrainer +from modules.encoder.condition_encoder import ConditionEncoder +from models.svc.comosvc.comosvc import ComoSVC + + +class ComoSVCTrainer(SVCTrainer): + r"""The base trainer for all diffusion models. It inherits from SVCTrainer and + implements ``_build_model`` and ``_forward_step`` methods. + """ + + def __init__(self, args=None, cfg=None): + SVCTrainer.__init__(self, args, cfg) + self.distill = cfg.model.comosvc.distill + self.skip_diff = True + if self.distill: # and args.resume is None: + self.teacher_model_path = cfg.model.teacher_model_path + self.teacher_state_dict = self._load_teacher_state_dict() + self._load_teacher_model(self.teacher_state_dict) + self.acoustic_mapper.decoder.init_consistency_training() + + ### Following are methods only for comoSVC models ### + def _load_teacher_state_dict(self): + self.checkpoint_file = self.teacher_model_path + print("Load teacher acoustic model from {}".format(self.checkpoint_file)) + raw_state_dict = torch.load(self.checkpoint_file) # , map_location=self.device) + return raw_state_dict + + def _load_teacher_model(self, state_dict): + raw_dict = state_dict + clean_dict = OrderedDict() + for k, v in raw_dict.items(): + if k.startswith("module."): + clean_dict[k[7:]] = v + else: + clean_dict[k] = v + self.model.load_state_dict(clean_dict) + + def _build_model(self): + r"""Build the model for training. This function is called in ``__init__`` function.""" + + # TODO: sort out the config + self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min + self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max + self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) + self.acoustic_mapper = ComoSVC(self.cfg) + model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) + return model + + def _forward_step(self, batch): + r"""Forward step for training and inference. This function is called + in ``_train_step`` & ``_test_step`` function. + """ + loss = {} + mask = batch["mask"] + mel_input = batch["mel"] + cond = self.condition_encoder(batch) + if self.distill: + cond = cond.detach() + self.skip_diff = True if self.step < self.cfg.train.fast_steps else False + ssim_loss, prior_loss, diff_loss = self.acoustic_mapper.compute_loss( + mask, cond, mel_input, skip_diff=self.skip_diff + ) + if self.distill: + loss["distil_loss"] = diff_loss + else: + loss["ssim_loss_encoder"] = ssim_loss + loss["prior_loss_encoder"] = prior_loss + loss["diffusion_loss_decoder"] = diff_loss + + return loss + + def _train_epoch(self): + r"""Training epoch. Should return average loss of a batch (sample) over + one epoch. See ``train_loop`` for usage. + """ + self.model.train() + epoch_sum_loss: float = 0.0 + epoch_step: int = 0 + for batch in tqdm( + self.train_dataloader, + desc=f"Training Epoch {self.epoch}", + unit="batch", + colour="GREEN", + leave=False, + dynamic_ncols=True, + smoothing=0.04, + disable=not self.accelerator.is_main_process, + ): + # Do training step and BP + with self.accelerator.accumulate(self.model): + loss = self._train_step(batch) + total_loss = 0 + for k, v in loss.items(): + total_loss += v + self.accelerator.backward(total_loss) + enc_grad_norm = torch.nn.utils.clip_grad_norm_( + self.acoustic_mapper.encoder.parameters(), max_norm=1 + ) + dec_grad_norm = torch.nn.utils.clip_grad_norm_( + self.acoustic_mapper.decoder.parameters(), max_norm=1 + ) + self.optimizer.step() + self.optimizer.zero_grad() + self.batch_count += 1 + + # Update info for each step + # TODO: step means BP counts or batch counts? + if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: + epoch_sum_loss += total_loss + log_info = {} + for k, v in loss.items(): + key = "Step/Train Loss/{}".format(k) + log_info[key] = v + log_info["Step/Learning Rate"]: self.optimizer.param_groups[0]["lr"] + self.accelerator.log( + log_info, + step=self.step, + ) + self.step += 1 + epoch_step += 1 + + self.accelerator.wait_for_everyone() + return ( + epoch_sum_loss + / len(self.train_dataloader) + * self.cfg.train.gradient_accumulation_step, + loss, + ) + + def train_loop(self): + r"""Training loop. The public entry of training process.""" + # Wait everyone to prepare before we move on + self.accelerator.wait_for_everyone() + # dump config file + if self.accelerator.is_main_process: + self.__dump_cfg(self.config_save_path) + self.model.train() + self.optimizer.zero_grad() + # Wait to ensure good to go + self.accelerator.wait_for_everyone() + while self.epoch < self.max_epoch: + self.logger.info("\n") + self.logger.info("-" * 32) + self.logger.info("Epoch {}: ".format(self.epoch)) + + ### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict) + ### It's inconvenient for the model with multiple losses + # Do training & validating epoch + train_loss, loss = self._train_epoch() + self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss)) + for k, v in loss.items(): + self.logger.info(" |- Train/Loss/{}: {:.6f}".format(k, v)) + valid_loss = self._valid_epoch() + self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss)) + self.accelerator.log( + {"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss}, + step=self.epoch, + ) + + self.accelerator.wait_for_everyone() + # TODO: what is scheduler? + self.scheduler.step(valid_loss) # FIXME: use epoch track correct? + + # Check if hit save_checkpoint_stride and run_eval + run_eval = False + if self.accelerator.is_main_process: + save_checkpoint = False + hit_dix = [] + for i, num in enumerate(self.save_checkpoint_stride): + if self.epoch % num == 0: + save_checkpoint = True + hit_dix.append(i) + run_eval |= self.run_eval[i] + + self.accelerator.wait_for_everyone() + if ( + self.accelerator.is_main_process + and save_checkpoint + and (self.distill or not self.skip_diff) + ): + path = os.path.join( + self.checkpoint_dir, + "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( + self.epoch, self.step, train_loss + ), + ) + self.accelerator.save_state(path) + json.dump( + self.checkpoints_path, + open(os.path.join(path, "ckpts.json"), "w"), + ensure_ascii=False, + indent=4, + ) + + # Remove old checkpoints + to_remove = [] + for idx in hit_dix: + self.checkpoints_path[idx].append(path) + while len(self.checkpoints_path[idx]) > self.keep_last[idx]: + to_remove.append((idx, self.checkpoints_path[idx].pop(0))) + + # Search conflicts + total = set() + for i in self.checkpoints_path: + total |= set(i) + do_remove = set() + for idx, path in to_remove[::-1]: + if path in total: + self.checkpoints_path[idx].insert(0, path) + else: + do_remove.add(path) + + # Remove old checkpoints + for path in do_remove: + shutil.rmtree(path, ignore_errors=True) + self.logger.debug(f"Remove old checkpoint: {path}") + + self.accelerator.wait_for_everyone() + if run_eval: + # TODO: run evaluation + pass + + # Update info for each epoch + self.epoch += 1 + + # Finish training and save final checkpoint + self.accelerator.wait_for_everyone() + if self.accelerator.is_main_process: + self.accelerator.save_state( + os.path.join( + self.checkpoint_dir, + "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( + self.epoch, self.step, valid_loss + ), + ) + ) + self.accelerator.end_training() + + @torch.inference_mode() + def _valid_epoch(self): + r"""Testing epoch. Should return average loss of a batch (sample) over + one epoch. See ``train_loop`` for usage. + """ + self.model.eval() + epoch_sum_loss = 0.0 + for batch in tqdm( + self.valid_dataloader, + desc=f"Validating Epoch {self.epoch}", + unit="batch", + colour="GREEN", + leave=False, + dynamic_ncols=True, + smoothing=0.04, + disable=not self.accelerator.is_main_process, + ): + batch_loss = self._valid_step(batch) + for k, v in batch_loss.items(): + epoch_sum_loss += v + + self.accelerator.wait_for_everyone() + return epoch_sum_loss / len(self.valid_dataloader) + + @staticmethod + def __count_parameters(model): + model_param = 0.0 + if isinstance(model, dict): + for key, value in model.items(): + model_param += sum(p.numel() for p in model[key].parameters()) + else: + model_param = sum(p.numel() for p in model.parameters()) + return model_param + + def __dump_cfg(self, path): + os.makedirs(os.path.dirname(path), exist_ok=True) + json5.dump( + self.cfg, + open(path, "w"), + indent=4, + sort_keys=True, + ensure_ascii=False, + quote_keys=True, + ) diff --git a/models/svc/comosvc/utils.py b/models/svc/comosvc/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3f576f9a237d0a22ddfdb160122b906da9bcf889 --- /dev/null +++ b/models/svc/comosvc/utils.py @@ -0,0 +1,31 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def slice_segments(x, ids_str, segment_size=200): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_ids_segments(lengths, segment_size=200): + b = lengths.shape[0] + ids_str_max = lengths - segment_size + ids_str = (torch.rand([b]).to(device=lengths.device) * ids_str_max).to( + dtype=torch.long + ) + return ids_str + + +def fix_len_compatibility(length, num_downsamplings_in_unet=2): + while True: + if length % (2**num_downsamplings_in_unet) == 0: + return length + length += 1 diff --git a/models/svc/diffusion/__init__.py b/models/svc/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/svc/diffusion/diffusion_inference.py b/models/svc/diffusion/diffusion_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..69b7c899180fb080576c161d2184fba111c69b2a --- /dev/null +++ b/models/svc/diffusion/diffusion_inference.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler + +from models.svc.base import SVCInference +from models.svc.diffusion.diffusion_inference_pipeline import DiffusionInferencePipeline +from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper +from modules.encoder.condition_encoder import ConditionEncoder + + +class DiffusionInference(SVCInference): + def __init__(self, args=None, cfg=None, infer_type="from_dataset"): + SVCInference.__init__(self, args, cfg, infer_type) + + settings = { + **cfg.model.diffusion.scheduler_settings, + **cfg.inference.diffusion.scheduler_settings, + } + settings.pop("num_inference_timesteps") + + if cfg.inference.diffusion.scheduler.lower() == "ddpm": + self.scheduler = DDPMScheduler(**settings) + self.logger.info("Using DDPM scheduler.") + elif cfg.inference.diffusion.scheduler.lower() == "ddim": + self.scheduler = DDIMScheduler(**settings) + self.logger.info("Using DDIM scheduler.") + elif cfg.inference.diffusion.scheduler.lower() == "pndm": + self.scheduler = PNDMScheduler(**settings) + self.logger.info("Using PNDM scheduler.") + else: + raise NotImplementedError( + "Unsupported scheduler type: {}".format( + cfg.inference.diffusion.scheduler.lower() + ) + ) + + self.pipeline = DiffusionInferencePipeline( + self.model[1], + self.scheduler, + cfg.inference.diffusion.scheduler_settings.num_inference_timesteps, + ) + + def _build_model(self): + self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min + self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max + self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) + self.acoustic_mapper = DiffusionWrapper(self.cfg) + model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) + return model + + def _inference_each_batch(self, batch_data): + device = self.accelerator.device + for k, v in batch_data.items(): + batch_data[k] = v.to(device) + + conditioner = self.model[0](batch_data) + noise = torch.randn_like(batch_data["mel"], device=device) + y_pred = self.pipeline(noise, conditioner) + return y_pred diff --git a/models/svc/diffusion/diffusion_inference_pipeline.py b/models/svc/diffusion/diffusion_inference_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..e2461aada99179ac17a2aaffebdb24864af1f5ee --- /dev/null +++ b/models/svc/diffusion/diffusion_inference_pipeline.py @@ -0,0 +1,47 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from diffusers import DiffusionPipeline + + +class DiffusionInferencePipeline(DiffusionPipeline): + def __init__(self, network, scheduler, num_inference_timesteps=1000): + super().__init__() + + self.register_modules(network=network, scheduler=scheduler) + self.num_inference_timesteps = num_inference_timesteps + + @torch.inference_mode() + def __call__( + self, + initial_noise: torch.Tensor, + conditioner: torch.Tensor = None, + ): + r""" + Args: + initial_noise: The initial noise to be denoised. + conditioner:The conditioner. + n_inference_steps: The number of denoising steps. More denoising steps + usually lead to a higher quality at the expense of slower inference. + """ + + mel = initial_noise + batch_size = mel.size(0) + self.scheduler.set_timesteps(self.num_inference_timesteps) + + for t in self.progress_bar(self.scheduler.timesteps): + timestep = torch.full((batch_size,), t, device=mel.device, dtype=torch.long) + + # 1. predict noise model_output + model_output = self.network(mel, timestep, conditioner) + + # 2. denoise, compute previous step: x_t -> x_t-1 + mel = self.scheduler.step(model_output, t, mel).prev_sample + + # 3. clamp + mel = mel.clamp(-1.0, 1.0) + + return mel diff --git a/models/svc/diffusion/diffusion_trainer.py b/models/svc/diffusion/diffusion_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..6f5aeb56a825f84c57bb1d2ba9a5ff5a32d5f486 --- /dev/null +++ b/models/svc/diffusion/diffusion_trainer.py @@ -0,0 +1,88 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from diffusers import DDPMScheduler + +from models.svc.base import SVCTrainer +from modules.encoder.condition_encoder import ConditionEncoder +from .diffusion_wrapper import DiffusionWrapper + + +class DiffusionTrainer(SVCTrainer): + r"""The base trainer for all diffusion models. It inherits from SVCTrainer and + implements ``_build_model`` and ``_forward_step`` methods. + """ + + def __init__(self, args=None, cfg=None): + SVCTrainer.__init__(self, args, cfg) + + # Only for SVC tasks using diffusion + self.noise_scheduler = DDPMScheduler( + **self.cfg.model.diffusion.scheduler_settings, + ) + self.diffusion_timesteps = ( + self.cfg.model.diffusion.scheduler_settings.num_train_timesteps + ) + + ### Following are methods only for diffusion models ### + def _build_model(self): + r"""Build the model for training. This function is called in ``__init__`` function.""" + + # TODO: sort out the config + self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min + self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max + self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) + self.acoustic_mapper = DiffusionWrapper(self.cfg) + model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) + + num_of_params_encoder = self.count_parameters(self.condition_encoder) + num_of_params_am = self.count_parameters(self.acoustic_mapper) + num_of_params = num_of_params_encoder + num_of_params_am + log = "Diffusion Model's Parameters: #Encoder is {:.2f}M, #Diffusion is {:.2f}M. The total is {:.2f}M".format( + num_of_params_encoder / 1e6, num_of_params_am / 1e6, num_of_params / 1e6 + ) + self.logger.info(log) + + return model + + def count_parameters(self, model): + model_param = 0.0 + if isinstance(model, dict): + for key, value in model.items(): + model_param += sum(p.numel() for p in model[key].parameters()) + else: + model_param = sum(p.numel() for p in model.parameters()) + return model_param + + def _forward_step(self, batch): + r"""Forward step for training and inference. This function is called + in ``_train_step`` & ``_test_step`` function. + """ + + device = self.accelerator.device + + mel_input = batch["mel"] + noise = torch.randn_like(mel_input, device=device, dtype=torch.float32) + batch_size = mel_input.size(0) + timesteps = torch.randint( + 0, + self.diffusion_timesteps, + (batch_size,), + device=device, + dtype=torch.long, + ) + + noisy_mel = self.noise_scheduler.add_noise(mel_input, noise, timesteps) + conditioner = self.condition_encoder(batch) + + y_pred = self.acoustic_mapper(noisy_mel, timesteps, conditioner) + + # TODO: Predict noise or gt should be configurable + loss = self._compute_loss(self.criterion, y_pred, noise, batch["mask"]) + self._check_nan(loss, y_pred, noise) + + # FIXME: Clarify that we should not divide it with batch size here + return loss diff --git a/models/svc/diffusion/diffusion_wrapper.py b/models/svc/diffusion/diffusion_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..ef66c2b6b85ceb8fe7a2cf9b53c62edc6b3ef6bc --- /dev/null +++ b/models/svc/diffusion/diffusion_wrapper.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn + +from modules.diffusion import BiDilConv +from modules.encoder.position_encoder import PositionEncoder + + +class DiffusionWrapper(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.cfg = cfg + self.diff_cfg = cfg.model.diffusion + + self.diff_encoder = PositionEncoder( + d_raw_emb=self.diff_cfg.step_encoder.dim_raw_embedding, + d_out=self.diff_cfg.bidilconv.base_channel, + d_mlp=self.diff_cfg.step_encoder.dim_hidden_layer, + activation_function=self.diff_cfg.step_encoder.activation, + n_layer=self.diff_cfg.step_encoder.num_layer, + max_period=self.diff_cfg.step_encoder.max_period, + ) + + # FIXME: Only support BiDilConv now for debug + if self.diff_cfg.model_type.lower() == "bidilconv": + self.neural_network = BiDilConv( + input_channel=self.cfg.preprocess.n_mel, **self.diff_cfg.bidilconv + ) + else: + raise ValueError( + f"Unsupported diffusion model type: {self.diff_cfg.model_type}" + ) + + def forward(self, x, t, c): + """ + Args: + x: [N, T, mel_band] of mel spectrogram + t: Diffusion time step with shape of [N] + c: [N, T, conditioner_size] of conditioner + + Returns: + [N, T, mel_band] of mel spectrogram + """ + + assert ( + x.size()[:-1] == c.size()[:-1] + ), "x mismatch with c, got \n x: {} \n c: {}".format(x.size(), c.size()) + assert x.size(0) == t.size( + 0 + ), "x mismatch with t, got \n x: {} \n t: {}".format(x.size(), t.size()) + assert t.dim() == 1, "t must be 1D tensor, got {}".format(t.dim()) + + N, T, mel_band = x.size() + + x = x.transpose(1, 2).contiguous() # [N, mel_band, T] + c = c.transpose(1, 2).contiguous() # [N, conditioner_size, T] + t = self.diff_encoder(t).contiguous() # [N, base_channel] + + h = self.neural_network(x, t, c) + h = h.transpose(1, 2).contiguous() # [N, T, mel_band] + + assert h.size() == ( + N, + T, + mel_band, + ), "h mismatch with input x, got \n h: {} \n x: {}".format( + h.size(), (N, T, mel_band) + ) + return h diff --git a/models/svc/transformer/__init__.py b/models/svc/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/svc/transformer/conformer.py b/models/svc/transformer/conformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5e48019cfc17d5f3825ce989f4852cec55fe1daa --- /dev/null +++ b/models/svc/transformer/conformer.py @@ -0,0 +1,405 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import torch +import numpy as np +import torch.nn as nn +from utils.util import convert_pad_shape + + +class BaseModule(torch.nn.Module): + def __init__(self): + super(BaseModule, self).__init__() + + @property + def nparams(self): + """ + Returns number of trainable parameters of the module. + """ + num_params = 0 + for name, param in self.named_parameters(): + if param.requires_grad: + num_params += np.prod(param.detach().cpu().numpy().shape) + return num_params + + def relocate_input(self, x: list): + """ + Relocates provided tensors to the same device set for the module. + """ + device = next(self.parameters()).device + for i in range(len(x)): + if isinstance(x[i], torch.Tensor) and x[i].device != device: + x[i] = x[i].to(device) + return x + + +class LayerNorm(BaseModule): + def __init__(self, channels, eps=1e-4): + super(LayerNorm, self).__init__() + self.channels = channels + self.eps = eps + + self.gamma = torch.nn.Parameter(torch.ones(channels)) + self.beta = torch.nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + n_dims = len(x.shape) + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean) ** 2, 1, keepdim=True) + + x = (x - mean) * torch.rsqrt(variance + self.eps) + + shape = [1, -1] + [1] * (n_dims - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x + + +class ConvReluNorm(BaseModule): + def __init__( + self, + in_channels, + hidden_channels, + out_channels, + kernel_size, + n_layers, + p_dropout, + eps=1e-5, + ): + super(ConvReluNorm, self).__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + self.eps = eps + + self.conv_layers = torch.nn.ModuleList() + self.conv_layers.append( + torch.nn.Conv1d( + in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + ) + self.relu_drop = torch.nn.Sequential( + torch.nn.ReLU(), torch.nn.Dropout(p_dropout) + ) + for _ in range(n_layers - 1): + self.conv_layers.append( + torch.nn.Conv1d( + hidden_channels, + hidden_channels, + kernel_size, + padding=kernel_size // 2, + ) + ) + self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.instance_norm(x, x_mask) + x = self.relu_drop(x) + x = self.proj(x) + return x * x_mask + + def instance_norm(self, x, mask, return_mean_std=False): + mean, std = self.calc_mean_std(x, mask) + x = (x - mean) / std + if return_mean_std: + return x, mean, std + else: + return x + + def calc_mean_std(self, x, mask=None): + x = x * mask + B, C = x.shape[:2] + mn = x.view(B, C, -1).mean(-1) + sd = (x.view(B, C, -1).var(-1) + self.eps).sqrt() + mn = mn.view(B, C, *((len(x.shape) - 2) * [1])) + sd = sd.view(B, C, *((len(x.shape) - 2) * [1])) + return mn, sd + + +class MultiHeadAttention(BaseModule): + def __init__( + self, + channels, + out_channels, + n_heads, + window_size=None, + heads_share=True, + p_dropout=0.0, + proximal_bias=False, + proximal_init=False, + ): + super(MultiHeadAttention, self).__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.window_size = window_size + self.heads_share = heads_share + self.proximal_bias = proximal_bias + self.p_dropout = p_dropout + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = torch.nn.Conv1d(channels, channels, 1) + self.conv_k = torch.nn.Conv1d(channels, channels, 1) + self.conv_v = torch.nn.Conv1d(channels, channels, 1) + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = torch.nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + self.emb_rel_v = torch.nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) + self.drop = torch.nn.Dropout(p_dropout) + + torch.nn.init.xavier_uniform_(self.conv_q.weight) + torch.nn.init.xavier_uniform_(self.conv_k.weight) + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + torch.nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) + if self.window_size is not None: + assert ( + t_s == t_t + ), "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings) + rel_logits = self._relative_position_to_absolute_position(rel_logits) + scores_local = rel_logits / math.sqrt(self.k_channels) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to( + device=scores.device, dtype=scores.dtype + ) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + p_attn = torch.nn.functional.softmax(scores, dim=-1) + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings( + self.emb_rel_v, t_s + ) + output = output + self._matmul_with_relative_values( + relative_weights, value_relative_embeddings + ) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = torch.nn.functional.pad( + relative_embeddings, + convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + ) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[ + :, slice_start_position:slice_end_position + ] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + batch, heads, length, _ = x.size() + x = torch.nn.functional.pad( + x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]) + ) + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = torch.nn.functional.pad( + x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) + ) + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ + :, :, :length, length - 1 : + ] + return x_final + + def _absolute_position_to_relative_position(self, x): + batch, heads, length, _ = x.size() + x = torch.nn.functional.pad( + x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) + ) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + x_flat = torch.nn.functional.pad( + x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]) + ) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(BaseModule): + def __init__( + self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0 + ): + super(FFN, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.conv_1 = torch.nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.conv_2 = torch.nn.Conv1d( + filter_channels, out_channels, kernel_size, padding=kernel_size // 2 + ) + self.drop = torch.nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + return x * x_mask + + +class Encoder(BaseModule): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads=2, + n_layers=6, + kernel_size=3, + p_dropout=0.1, + window_size=4, + **kwargs + ): + super(Encoder, self).__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = torch.nn.Dropout(p_dropout) + self.attn_layers = torch.nn.ModuleList() + self.norm_layers_1 = torch.nn.ModuleList() + self.ffn_layers = torch.nn.ModuleList() + self.norm_layers_2 = torch.nn.ModuleList() + for _ in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + window_size=window_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.n_layers): + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class Conformer(BaseModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.n_heads = self.cfg.n_heads + self.n_layers = self.cfg.n_layers + self.hidden_channels = self.cfg.input_dim + self.filter_channels = self.cfg.filter_channels + self.output_dim = self.cfg.output_dim + self.dropout = self.cfg.dropout + + self.conformer_encoder = Encoder( + self.hidden_channels, + self.filter_channels, + n_heads=self.n_heads, + n_layers=self.n_layers, + kernel_size=3, + p_dropout=self.dropout, + window_size=4, + ) + self.projection = nn.Conv1d(self.hidden_channels, self.output_dim, 1) + + def forward(self, x, x_mask): + """ + Args: + x: (N, seq_len, input_dim) + Returns: + output: (N, seq_len, output_dim) + """ + # (N, seq_len, d_model) + x = x.transpose(1, 2) + x_mask = x_mask.transpose(1, 2) + output = self.conformer_encoder(x, x_mask) + # (N, seq_len, output_dim) + output = self.projection(output) + output = output.transpose(1, 2) + return output diff --git a/models/svc/transformer/transformer.py b/models/svc/transformer/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..fd3cdb6c2d0fc93534d005b9f67a3058c9185c60 --- /dev/null +++ b/models/svc/transformer/transformer.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import torch +import torch.nn as nn +from torch.nn import TransformerEncoder, TransformerEncoderLayer + + +class Transformer(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + + dropout = self.cfg.dropout + nhead = self.cfg.n_heads + nlayers = self.cfg.n_layers + input_dim = self.cfg.input_dim + output_dim = self.cfg.output_dim + + d_model = input_dim + self.pos_encoder = PositionalEncoding(d_model, dropout) + encoder_layers = TransformerEncoderLayer( + d_model, nhead, dropout=dropout, batch_first=True + ) + self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) + + self.output_mlp = nn.Linear(d_model, output_dim) + + def forward(self, x, mask=None): + """ + Args: + x: (N, seq_len, input_dim) + Returns: + output: (N, seq_len, output_dim) + """ + # (N, seq_len, d_model) + src = self.pos_encoder(x) + # model_stats["pos_embedding"] = x + # (N, seq_len, d_model) + output = self.transformer_encoder(src) + # (N, seq_len, output_dim) + output = self.output_mlp(output) + return output + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) + ) + + # Assume that x is (seq_len, N, d) + # pe = torch.zeros(max_len, 1, d_model) + # pe[:, 0, 0::2] = torch.sin(position * div_term) + # pe[:, 0, 1::2] = torch.cos(position * div_term) + + # Assume that x in (N, seq_len, d) + pe = torch.zeros(1, max_len, d_model) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + + self.register_buffer("pe", pe) + + def forward(self, x): + """ + Args: + x: Tensor, shape [N, seq_len, d] + """ + # Old: Assume that x is (seq_len, N, d), and self.pe is (max_len, 1, d_model) + # x = x + self.pe[: x.size(0)] + + # Now: self.pe is (1, max_len, d) + x = x + self.pe[:, : x.size(1), :] + + return self.dropout(x) diff --git a/models/svc/transformer/transformer_inference.py b/models/svc/transformer/transformer_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..f6299c532aec6cb9283ee87ee9f0142f0b5c981b --- /dev/null +++ b/models/svc/transformer/transformer_inference.py @@ -0,0 +1,45 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import time +import numpy as np +import torch +from tqdm import tqdm +import torch.nn as nn +from collections import OrderedDict + +from models.svc.base import SVCInference +from modules.encoder.condition_encoder import ConditionEncoder +from models.svc.transformer.transformer import Transformer +from models.svc.transformer.conformer import Conformer + + +class TransformerInference(SVCInference): + def __init__(self, args=None, cfg=None, infer_type="from_dataset"): + SVCInference.__init__(self, args, cfg, infer_type) + + def _build_model(self): + self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min + self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max + self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) + if self.cfg.model.transformer.type == "transformer": + self.acoustic_mapper = Transformer(self.cfg.model.transformer) + elif self.cfg.model.transformer.type == "conformer": + self.acoustic_mapper = Conformer(self.cfg.model.transformer) + else: + raise NotImplementedError + model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) + return model + + def _inference_each_batch(self, batch_data): + device = self.accelerator.device + for k, v in batch_data.items(): + batch_data[k] = v.to(device) + + condition = self.condition_encoder(batch_data) + y_pred = self.acoustic_mapper(condition, batch_data["mask"]) + + return y_pred diff --git a/models/svc/transformer/transformer_trainer.py b/models/svc/transformer/transformer_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..3633078475d26e708280bc354f091bb9ab01ae45 --- /dev/null +++ b/models/svc/transformer/transformer_trainer.py @@ -0,0 +1,52 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from models.svc.base import SVCTrainer +from modules.encoder.condition_encoder import ConditionEncoder +from models.svc.transformer.transformer import Transformer +from models.svc.transformer.conformer import Conformer +from utils.ssim import SSIM + + +class TransformerTrainer(SVCTrainer): + def __init__(self, args, cfg): + SVCTrainer.__init__(self, args, cfg) + self.ssim_loss = SSIM() + + def _build_model(self): + self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min + self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max + self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) + if self.cfg.model.transformer.type == "transformer": + self.acoustic_mapper = Transformer(self.cfg.model.transformer) + elif self.cfg.model.transformer.type == "conformer": + self.acoustic_mapper = Conformer(self.cfg.model.transformer) + else: + raise NotImplementedError + model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) + return model + + def _forward_step(self, batch): + total_loss = 0 + device = self.accelerator.device + mel = batch["mel"] + mask = batch["mask"] + + condition = self.condition_encoder(batch) + mel_pred = self.acoustic_mapper(condition, mask) + + l1_loss = torch.sum(torch.abs(mel_pred - mel) * batch["mask"]) / torch.sum( + batch["mask"] + ) + self._check_nan(l1_loss, mel_pred, mel) + total_loss += l1_loss + ssim_loss = self.ssim_loss(mel_pred, mel) + ssim_loss = torch.sum(ssim_loss * batch["mask"]) / torch.sum(batch["mask"]) + self._check_nan(ssim_loss, mel_pred, mel) + total_loss += ssim_loss + + return total_loss diff --git a/models/svc/vits/__init__.py b/models/svc/vits/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/svc/vits/vits.py b/models/svc/vits/vits.py new file mode 100644 index 0000000000000000000000000000000000000000..346b4abe3a4fb9f5fbb0a48224cf665653c39cd5 --- /dev/null +++ b/models/svc/vits/vits.py @@ -0,0 +1,271 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This code is modified from https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/models.py +import copy +import torch +from torch import nn +from torch.nn import functional as F + +from utils.util import * +from utils.f0 import f0_to_coarse + +from modules.transformer.attentions import Encoder +from models.tts.vits.vits import ResidualCouplingBlock, PosteriorEncoder +from models.vocoders.gan.generator.bigvgan import BigVGAN +from models.vocoders.gan.generator.hifigan import HiFiGAN +from models.vocoders.gan.generator.nsfhifigan import NSFHiFiGAN +from models.vocoders.gan.generator.melgan import MelGAN +from models.vocoders.gan.generator.apnet import APNet +from modules.encoder.condition_encoder import ConditionEncoder + + +def slice_pitch_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, idx_str:idx_end] + return ret + + +def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + ret_pitch = slice_pitch_segments(pitch, ids_str, segment_size) + return ret, ret_pitch, ids_str + + +class ContentEncoder(nn.Module): + def __init__( + self, + out_channels, + hidden_channels, + kernel_size, + n_layers, + gin_channels=0, + filter_channels=None, + n_heads=None, + p_dropout=None, + ): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.f0_emb = nn.Embedding(256, hidden_channels) + + self.enc_ = Encoder( + hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) + + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + # condition_encoder ver. + def forward(self, x, x_mask, noice_scale=1): + x = self.enc_(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs) * noice_scale) * x_mask + + return z, m, logs, x_mask + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__(self, spec_channels, segment_size, cfg): + super().__init__() + self.spec_channels = spec_channels + self.segment_size = segment_size + self.cfg = cfg + self.inter_channels = cfg.model.vits.inter_channels + self.hidden_channels = cfg.model.vits.hidden_channels + self.filter_channels = cfg.model.vits.filter_channels + self.n_heads = cfg.model.vits.n_heads + self.n_layers = cfg.model.vits.n_layers + self.kernel_size = cfg.model.vits.kernel_size + self.p_dropout = cfg.model.vits.p_dropout + self.ssl_dim = cfg.model.vits.ssl_dim + self.n_flow_layer = cfg.model.vits.n_flow_layer + self.gin_channels = cfg.model.vits.gin_channels + self.n_speakers = cfg.model.vits.n_speakers + + # f0 + self.n_bins = cfg.preprocess.pitch_bin + self.f0_min = cfg.preprocess.f0_min + self.f0_max = cfg.preprocess.f0_max + + # TODO: sort out the config + self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min + self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max + self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) + + self.emb_g = nn.Embedding(self.n_speakers, self.gin_channels) + + self.enc_p = ContentEncoder( + self.inter_channels, + self.hidden_channels, + filter_channels=self.filter_channels, + n_heads=self.n_heads, + n_layers=self.n_layers, + kernel_size=self.kernel_size, + p_dropout=self.p_dropout, + ) + + assert cfg.model.generator in [ + "bigvgan", + "hifigan", + "melgan", + "nsfhifigan", + "apnet", + ] + self.dec_name = cfg.model.generator + temp_cfg = copy.deepcopy(cfg) + temp_cfg.preprocess.n_mel = self.inter_channels + if cfg.model.generator == "bigvgan": + temp_cfg.model.bigvgan = cfg.model.generator_config.bigvgan + self.dec = BigVGAN(temp_cfg) + elif cfg.model.generator == "hifigan": + temp_cfg.model.hifigan = cfg.model.generator_config.hifigan + self.dec = HiFiGAN(temp_cfg) + elif cfg.model.generator == "melgan": + temp_cfg.model.melgan = cfg.model.generator_config.melgan + self.dec = MelGAN(temp_cfg) + elif cfg.model.generator == "nsfhifigan": + temp_cfg.model.nsfhifigan = cfg.model.generator_config.nsfhifigan + self.dec = NSFHiFiGAN(temp_cfg) # TODO: nsf need f0 + elif cfg.model.generator == "apnet": + temp_cfg.model.apnet = cfg.model.generator_config.apnet + self.dec = APNet(temp_cfg) + + self.enc_q = PosteriorEncoder( + self.spec_channels, + self.inter_channels, + self.hidden_channels, + 5, + 1, + 16, + gin_channels=self.gin_channels, + ) + + self.flow = ResidualCouplingBlock( + self.inter_channels, + self.hidden_channels, + 5, + 1, + self.n_flow_layer, + gin_channels=self.gin_channels, + ) + + def forward(self, data): + """VitsSVC forward function. + + Args: + data (dict): condition data & audio data, including: + B: batch size, T: target length + { + "spk_id": [B, singer_table_size] + "target_len": [B] + "mask": [B, T, 1] + "mel": [B, T, n_mel] + "linear": [B, T, n_fft // 2 + 1] + "frame_pitch": [B, T] + "frame_uv": [B, T] + "audio": [B, audio_len] + "audio_len": [B] + "contentvec_feat": [B, T, contentvec_dim] + "whisper_feat": [B, T, whisper_dim] + ... + } + """ + + # TODO: elegantly handle the dimensions + c = data["contentvec_feat"].transpose(1, 2) + spec = data["linear"].transpose(1, 2) + + g = data["spk_id"] + g = self.emb_g(g).transpose(1, 2) + + c_lengths = data["target_len"] + spec_lengths = data["target_len"] + f0 = data["frame_pitch"] + + x_mask = torch.unsqueeze(sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype) + # condition_encoder ver. + x = self.condition_encoder(data).transpose(1, 2) + + # prior encoder + z_ptemp, m_p, logs_p, _ = self.enc_p(x, x_mask) + # posterior encoder + z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g) + + # flow + z_p = self.flow(z, spec_mask, g=g) + z_slice, pitch_slice, ids_slice = rand_slice_segments_with_pitch( + z, f0, spec_lengths, self.segment_size + ) + + if self.dec_name == "nsfhifigan": + o = self.dec(z_slice, f0=f0.float()) + elif self.dec_name == "apnet": + _, _, _, _, o = self.dec(z_slice) + else: + o = self.dec(z_slice) + + outputs = { + "y_hat": o, + "ids_slice": ids_slice, + "x_mask": x_mask, + "z_mask": data["mask"].transpose(1, 2), + "z": z, + "z_p": z_p, + "m_p": m_p, + "logs_p": logs_p, + "m_q": m_q, + "logs_q": logs_q, + } + return outputs + + @torch.no_grad() + def infer(self, data, noise_scale=0.35, seed=52468): + # c, f0, uv, g + c = data["contentvec_feat"].transpose(1, 2) + f0 = data["frame_pitch"] + g = data["spk_id"] + + if c.device == torch.device("cuda"): + torch.cuda.manual_seed_all(seed) + else: + torch.manual_seed(seed) + + c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) + + if g.dim() == 1: + g = g.unsqueeze(0) + g = self.emb_g(g).transpose(1, 2) + + x_mask = torch.unsqueeze(sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype) + # condition_encoder ver. + x = self.condition_encoder(data).transpose(1, 2) + + z_p, m_p, logs_p, c_mask = self.enc_p(x, x_mask, noice_scale=noise_scale) + z = self.flow(z_p, c_mask, g=g, reverse=True) + + if self.dec_name == "nsfhifigan": + o = self.dec(z * c_mask, f0=f0) + elif self.dec_name == "apnet": + _, _, _, _, o = self.dec(z * c_mask) + else: + o = self.dec(z * c_mask) + return o, f0 diff --git a/models/svc/vits/vits_inference.py b/models/svc/vits/vits_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..3180f8c8e45a4db61448c2e402d23e87026f9a37 --- /dev/null +++ b/models/svc/vits/vits_inference.py @@ -0,0 +1,84 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import time +import numpy as np +from tqdm import tqdm +import torch + +from models.svc.base import SVCInference +from models.svc.vits.vits import SynthesizerTrn + +from models.svc.base.svc_dataset import SVCTestDataset, SVCTestCollator +from utils.io import save_audio +from utils.audio_slicer import is_silence + + +class VitsInference(SVCInference): + def __init__(self, args=None, cfg=None, infer_type="from_dataset"): + SVCInference.__init__(self, args, cfg) + + def _build_model(self): + net_g = SynthesizerTrn( + self.cfg.preprocess.n_fft // 2 + 1, + self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size, + self.cfg, + ) + self.model = net_g + return net_g + + def build_save_dir(self, dataset, speaker): + save_dir = os.path.join( + self.args.output_dir, + "svc_am_step-{}_{}".format(self.am_restore_step, self.args.mode), + ) + if dataset is not None: + save_dir = os.path.join(save_dir, "data_{}".format(dataset)) + if speaker != -1: + save_dir = os.path.join( + save_dir, + "spk_{}".format(speaker), + ) + os.makedirs(save_dir, exist_ok=True) + print("Saving to ", save_dir) + return save_dir + + @torch.inference_mode() + def inference(self): + res = [] + for i, batch in enumerate(self.test_dataloader): + pred_audio_list = self._inference_each_batch(batch) + for it, wav in zip(self.test_dataset.metadata, pred_audio_list): + uid = it["Uid"] + file = os.path.join(self.args.output_dir, f"{uid}.wav") + + wav = wav.numpy(force=True) + save_audio( + file, + wav, + self.cfg.preprocess.sample_rate, + add_silence=False, + turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate), + ) + res.append(file) + return res + + def _inference_each_batch(self, batch_data, noise_scale=0.667): + device = self.accelerator.device + pred_res = [] + self.model.eval() + with torch.no_grad(): + # Put the data to device + # device = self.accelerator.device + for k, v in batch_data.items(): + batch_data[k] = v.to(device) + + audios, f0 = self.model.infer(batch_data, noise_scale=noise_scale) + + pred_res.extend(audios) + + return pred_res diff --git a/models/svc/vits/vits_trainer.py b/models/svc/vits/vits_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..618fd223cf5fe979c8164b3bba2cba9beec0b390 --- /dev/null +++ b/models/svc/vits/vits_trainer.py @@ -0,0 +1,483 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.optim.lr_scheduler import ExponentialLR +from tqdm import tqdm + +# from models.svc.base import SVCTrainer +from models.svc.base.svc_dataset import SVCCollator, SVCDataset +from models.svc.vits.vits import * +from models.tts.base import TTSTrainer + +from utils.mel import mel_spectrogram_torch +import json + +from models.vocoders.gan.discriminator.mpd import ( + MultiPeriodDiscriminator_vits as MultiPeriodDiscriminator, +) + + +class VitsSVCTrainer(TTSTrainer): + def __init__(self, args, cfg): + self.args = args + self.cfg = cfg + self._init_accelerator() + # Only for SVC tasks + with self.accelerator.main_process_first(): + self.singers = self._build_singer_lut() + TTSTrainer.__init__(self, args, cfg) + + def _build_model(self): + net_g = SynthesizerTrn( + self.cfg.preprocess.n_fft // 2 + 1, + self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size, + # directly use cfg + self.cfg, + ) + net_d = MultiPeriodDiscriminator(self.cfg.model.vits.use_spectral_norm) + model = {"generator": net_g, "discriminator": net_d} + + return model + + def _build_dataset(self): + return SVCDataset, SVCCollator + + def _build_optimizer(self): + optimizer_g = torch.optim.AdamW( + self.model["generator"].parameters(), + self.cfg.train.learning_rate, + betas=self.cfg.train.AdamW.betas, + eps=self.cfg.train.AdamW.eps, + ) + optimizer_d = torch.optim.AdamW( + self.model["discriminator"].parameters(), + self.cfg.train.learning_rate, + betas=self.cfg.train.AdamW.betas, + eps=self.cfg.train.AdamW.eps, + ) + optimizer = {"optimizer_g": optimizer_g, "optimizer_d": optimizer_d} + + return optimizer + + def _build_scheduler(self): + scheduler_g = ExponentialLR( + self.optimizer["optimizer_g"], + gamma=self.cfg.train.lr_decay, + last_epoch=self.epoch - 1, + ) + scheduler_d = ExponentialLR( + self.optimizer["optimizer_d"], + gamma=self.cfg.train.lr_decay, + last_epoch=self.epoch - 1, + ) + + scheduler = {"scheduler_g": scheduler_g, "scheduler_d": scheduler_d} + return scheduler + + def _build_criterion(self): + class GeneratorLoss(nn.Module): + def __init__(self, cfg): + super(GeneratorLoss, self).__init__() + self.cfg = cfg + self.l1_loss = nn.L1Loss() + + def generator_loss(self, disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + def feature_loss(self, fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + def kl_loss(self, z_p, logs_q, m_p, logs_p, z_mask): + """ + z_p, logs_q: [b, h, t_t] + m_p, logs_p: [b, h, t_t] + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + l = kl / torch.sum(z_mask) + return l + + def forward( + self, + outputs_g, + outputs_d, + y_mel, + y_hat_mel, + ): + loss_g = {} + + # mel loss + loss_mel = self.l1_loss(y_mel, y_hat_mel) * self.cfg.train.c_mel + loss_g["loss_mel"] = loss_mel + + # kl loss + loss_kl = ( + self.kl_loss( + outputs_g["z_p"], + outputs_g["logs_q"], + outputs_g["m_p"], + outputs_g["logs_p"], + outputs_g["z_mask"], + ) + * self.cfg.train.c_kl + ) + loss_g["loss_kl"] = loss_kl + + # feature loss + loss_fm = self.feature_loss(outputs_d["fmap_rs"], outputs_d["fmap_gs"]) + loss_g["loss_fm"] = loss_fm + + # gan loss + loss_gen, losses_gen = self.generator_loss(outputs_d["y_d_hat_g"]) + loss_g["loss_gen"] = loss_gen + loss_g["loss_gen_all"] = loss_mel + loss_kl + loss_fm + loss_gen + + return loss_g + + class DiscriminatorLoss(nn.Module): + def __init__(self, cfg): + super(DiscriminatorLoss, self).__init__() + self.cfg = cfg + self.l1Loss = torch.nn.L1Loss(reduction="mean") + + def __call__(self, disc_real_outputs, disc_generated_outputs): + loss_d = {} + + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + dr = dr.float() + dg = dg.float() + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + loss_d["loss_disc_all"] = loss + + return loss_d + + criterion = { + "generator": GeneratorLoss(self.cfg), + "discriminator": DiscriminatorLoss(self.cfg), + } + return criterion + + # Keep legacy unchanged + def write_summary( + self, + losses, + stats, + images={}, + audios={}, + audio_sampling_rate=24000, + tag="train", + ): + for key, value in losses.items(): + self.sw.add_scalar(tag + "/" + key, value, self.step) + self.sw.add_scalar( + "learning_rate", + self.optimizer["optimizer_g"].param_groups[0]["lr"], + self.step, + ) + + if len(images) != 0: + for key, value in images.items(): + self.sw.add_image(key, value, self.global_step, batchformats="HWC") + if len(audios) != 0: + for key, value in audios.items(): + self.sw.add_audio(key, value, self.global_step, audio_sampling_rate) + + def write_valid_summary( + self, losses, stats, images={}, audios={}, audio_sampling_rate=24000, tag="val" + ): + for key, value in losses.items(): + self.sw.add_scalar(tag + "/" + key, value, self.step) + + if len(images) != 0: + for key, value in images.items(): + self.sw.add_image(key, value, self.global_step, batchformats="HWC") + if len(audios) != 0: + for key, value in audios.items(): + self.sw.add_audio(key, value, self.global_step, audio_sampling_rate) + + def _get_state_dict(self): + state_dict = { + "generator": self.model["generator"].state_dict(), + "discriminator": self.model["discriminator"].state_dict(), + "optimizer_g": self.optimizer["optimizer_g"].state_dict(), + "optimizer_d": self.optimizer["optimizer_d"].state_dict(), + "scheduler_g": self.scheduler["scheduler_g"].state_dict(), + "scheduler_d": self.scheduler["scheduler_d"].state_dict(), + "step": self.step, + "epoch": self.epoch, + "batch_size": self.cfg.train.batch_size, + } + return state_dict + + def get_state_dict(self): + state_dict = { + "generator": self.model["generator"].state_dict(), + "discriminator": self.model["discriminator"].state_dict(), + "optimizer_g": self.optimizer["optimizer_g"].state_dict(), + "optimizer_d": self.optimizer["optimizer_d"].state_dict(), + "scheduler_g": self.scheduler["scheduler_g"].state_dict(), + "scheduler_d": self.scheduler["scheduler_d"].state_dict(), + "step": self.step, + "epoch": self.epoch, + "batch_size": self.cfg.train.batch_size, + } + return state_dict + + def load_model(self, checkpoint): + self.step = checkpoint["step"] + self.epoch = checkpoint["epoch"] + self.model["generator"].load_state_dict(checkpoint["generator"]) + self.model["discriminator"].load_state_dict(checkpoint["discriminator"]) + self.optimizer["optimizer_g"].load_state_dict(checkpoint["optimizer_g"]) + self.optimizer["optimizer_d"].load_state_dict(checkpoint["optimizer_d"]) + self.scheduler["scheduler_g"].load_state_dict(checkpoint["scheduler_g"]) + self.scheduler["scheduler_d"].load_state_dict(checkpoint["scheduler_d"]) + + @torch.inference_mode() + def _valid_step(self, batch): + r"""Testing forward step. Should return average loss of a sample over + one batch. Provoke ``_forward_step`` is recommended except for special case. + See ``_test_epoch`` for usage. + """ + + valid_losses = {} + total_loss = 0 + valid_stats = {} + + # Discriminator + # Generator output + outputs_g = self.model["generator"](batch) + + y_mel = slice_segments( + batch["mel"].transpose(1, 2), + outputs_g["ids_slice"], + self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size, + ) + y_hat_mel = mel_spectrogram_torch( + outputs_g["y_hat"].squeeze(1), self.cfg.preprocess + ) + y = slice_segments( + batch["audio"].unsqueeze(1), + outputs_g["ids_slice"] * self.cfg.preprocess.hop_size, + self.cfg.preprocess.segment_size, + ) + + # Discriminator output + outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach()) + ## Discriminator loss + loss_d = self.criterion["discriminator"]( + outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"] + ) + valid_losses.update(loss_d) + + ## Generator + outputs_d = self.model["discriminator"](y, outputs_g["y_hat"]) + loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel) + valid_losses.update(loss_g) + + for item in valid_losses: + valid_losses[item] = valid_losses[item].item() + + total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"] + + return ( + total_loss.item(), + valid_losses, + valid_stats, + ) + + def _train_step(self, batch): + r"""Forward step for training and inference. This function is called + in ``_train_step`` & ``_test_step`` function. + """ + + train_losses = {} + total_loss = 0 + training_stats = {} + + ## Train Discriminator + # Generator output + outputs_g = self.model["generator"](batch) + + y_mel = slice_segments( + batch["mel"].transpose(1, 2), + outputs_g["ids_slice"], + self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size, + ) + y_hat_mel = mel_spectrogram_torch( + outputs_g["y_hat"].squeeze(1), self.cfg.preprocess + ) + + y = slice_segments( + # [1, 168418] -> [1, 1, 168418] + batch["audio"].unsqueeze(1), + outputs_g["ids_slice"] * self.cfg.preprocess.hop_size, + self.cfg.preprocess.segment_size, + ) + + # Discriminator output + outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach()) + # Discriminator loss + loss_d = self.criterion["discriminator"]( + outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"] + ) + train_losses.update(loss_d) + + # BP and Grad Updated + self.optimizer["optimizer_d"].zero_grad() + self.accelerator.backward(loss_d["loss_disc_all"]) + self.optimizer["optimizer_d"].step() + + ## Train Generator + outputs_d = self.model["discriminator"](y, outputs_g["y_hat"]) + loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel) + train_losses.update(loss_g) + + # BP and Grad Updated + self.optimizer["optimizer_g"].zero_grad() + self.accelerator.backward(loss_g["loss_gen_all"]) + self.optimizer["optimizer_g"].step() + + for item in train_losses: + train_losses[item] = train_losses[item].item() + + total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"] + + return ( + total_loss.item(), + train_losses, + training_stats, + ) + + def _train_epoch(self): + r"""Training epoch. Should return average loss of a batch (sample) over + one epoch. See ``train_loop`` for usage. + """ + epoch_sum_loss: float = 0.0 + epoch_losses: dict = {} + epoch_step: int = 0 + for batch in tqdm( + self.train_dataloader, + desc=f"Training Epoch {self.epoch}", + unit="batch", + colour="GREEN", + leave=False, + dynamic_ncols=True, + smoothing=0.04, + disable=not self.accelerator.is_main_process, + ): + # Do training step and BP + with self.accelerator.accumulate(self.model): + total_loss, train_losses, training_stats = self._train_step(batch) + self.batch_count += 1 + + # Update info for each step + if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: + epoch_sum_loss += total_loss + for key, value in train_losses.items(): + if key not in epoch_losses.keys(): + epoch_losses[key] = value + else: + epoch_losses[key] += value + + self.accelerator.log( + { + "Step/Generator Loss": train_losses["loss_gen_all"], + "Step/Discriminator Loss": train_losses["loss_disc_all"], + "Step/Generator Learning Rate": self.optimizer[ + "optimizer_d" + ].param_groups[0]["lr"], + "Step/Discriminator Learning Rate": self.optimizer[ + "optimizer_g" + ].param_groups[0]["lr"], + }, + step=self.step, + ) + self.step += 1 + epoch_step += 1 + + self.accelerator.wait_for_everyone() + + epoch_sum_loss = ( + epoch_sum_loss + / len(self.train_dataloader) + * self.cfg.train.gradient_accumulation_step + ) + + for key in epoch_losses.keys(): + epoch_losses[key] = ( + epoch_losses[key] + / len(self.train_dataloader) + * self.cfg.train.gradient_accumulation_step + ) + + return epoch_sum_loss, epoch_losses + + def _build_singer_lut(self): + resumed_singer_path = None + if self.args.resume_from_ckpt_path and self.args.resume_from_ckpt_path != "": + resumed_singer_path = os.path.join( + self.args.resume_from_ckpt_path, self.cfg.preprocess.spk2id + ) + if os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)): + resumed_singer_path = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id) + + if resumed_singer_path: + with open(resumed_singer_path, "r") as f: + singers = json.load(f) + else: + singers = dict() + + for dataset in self.cfg.dataset: + singer_lut_path = os.path.join( + self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id + ) + with open(singer_lut_path, "r") as singer_lut_path: + singer_lut = json.load(singer_lut_path) + for singer in singer_lut.keys(): + if singer not in singers: + singers[singer] = len(singers) + + with open( + os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w" + ) as singer_file: + json.dump(singers, singer_file, indent=4, ensure_ascii=False) + print( + "singers have been dumped to {}".format( + os.path.join(self.exp_dir, self.cfg.preprocess.spk2id) + ) + ) + return singers diff --git a/models/tta/autoencoder/__init__.py b/models/tta/autoencoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/tta/autoencoder/autoencoder.py b/models/tta/autoencoder/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d97bfdc959d41bb8edd38993c052b7ad07d8cd92 --- /dev/null +++ b/models/tta/autoencoder/autoencoder.py @@ -0,0 +1,405 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modules.distributions.distributions import DiagonalGaussianDistribution + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample2d(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Upsample1d(Upsample2d): + def __init__(self, in_channels, with_conv): + super().__init__(in_channels, with_conv) + if self.with_conv: + self.conv = torch.nn.Conv1d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + +class Downsample2d(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + self.pad = (0, 1, 0, 1) + else: + self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2) + + def forward(self, x): + if self.with_conv: # bp: check self.avgpool and self.pad + x = torch.nn.functional.pad(x, self.pad, mode="constant", value=0) + x = self.conv(x) + else: + x = self.avg_pool(x) + return x + + +class Downsample1d(Downsample2d): + def __init__(self, in_channels, with_conv): + super().__init__(in_channels, with_conv) + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + # TODO: can we replace it just with conv2d with padding 1? + self.conv = torch.nn.Conv1d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + self.pad = (1, 1) + else: + self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2) + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class ResnetBlock1d(ResnetBlock): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512 + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + conv_shortcut=conv_shortcut, + dropout=dropout, + ) + + self.conv1 = torch.nn.Conv1d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.conv2 = torch.nn.Conv1d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv1d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv1d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + +class Encoder2d(nn.Module): + def __init__( + self, + *, + ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + dropout=0.0, + resamp_with_conv=True, + in_channels, + z_channels, + double_z=True, + **ignore_kwargs + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, dropout=dropout + ) + ) + block_in = block_out + down = nn.Module() + down.block = block + if i_level != self.num_resolutions - 1: + down.downsample = Downsample2d(block_in, resamp_with_conv) + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, dropout=dropout + ) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, dropout=dropout + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +# TODO: Encoder1d +class Encoder1d(Encoder2d): + ... + + +class Decoder2d(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + dropout=0.0, + resamp_with_conv=True, + in_channels, + z_channels, + give_pre_end=False, + **ignorekwargs + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + # self.z_shape = (1,z_channels,curr_res,curr_res) + # print("Working with z of shape {} = {} dimensions.".format( + # self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, dropout=dropout + ) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, dropout=dropout + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, dropout=dropout + ) + ) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample2d(block_in, resamp_with_conv) + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z): + self.last_z_shape = z.shape + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +# TODO: decoder1d +class Decoder1d(Decoder2d): + ... + + +class AutoencoderKL(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.encoder = Encoder2d( + ch=cfg.ch, + ch_mult=cfg.ch_mult, + num_res_blocks=cfg.num_res_blocks, + in_channels=cfg.in_channels, + z_channels=cfg.z_channels, + double_z=cfg.double_z, + ) + self.decoder = Decoder2d( + ch=cfg.ch, + ch_mult=cfg.ch_mult, + num_res_blocks=cfg.num_res_blocks, + out_ch=cfg.out_ch, + z_channels=cfg.z_channels, + in_channels=None, + ) + assert self.cfg.double_z + + self.quant_conv = torch.nn.Conv2d(2 * cfg.z_channels, 2 * cfg.z_channels, 1) + self.post_quant_conv = torch.nn.Conv2d(cfg.z_channels, cfg.z_channels, 1) + self.embed_dim = cfg.z_channels + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_last_layer(self): + return self.decoder.conv_out.weight diff --git a/models/tta/autoencoder/autoencoder_dataset.py b/models/tta/autoencoder/autoencoder_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..75828eea1933d45c84a2395e9599c5ae5d7f597e --- /dev/null +++ b/models/tta/autoencoder/autoencoder_dataset.py @@ -0,0 +1,114 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import random +import torch +from torch.nn.utils.rnn import pad_sequence +from utils.data_utils import * +from models.base.base_dataset import ( + BaseCollator, + BaseDataset, + BaseTestDataset, + BaseTestCollator, +) +import librosa + + +class AutoencoderKLDataset(BaseDataset): + def __init__(self, cfg, dataset, is_valid=False): + BaseDataset.__init__(self, cfg, dataset, is_valid=is_valid) + + cfg = self.cfg + + # utt2melspec + if cfg.preprocess.use_melspec: + self.utt2melspec_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2melspec_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.melspec_dir, + uid + ".npy", + ) + + # utt2wav + if cfg.preprocess.use_wav: + self.utt2wav_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2wav_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.wav_dir, + uid + ".wav", + ) + + def __getitem__(self, index): + # melspec: (n_mels, T) + # wav: (T,) + + single_feature = BaseDataset.__getitem__(self, index) + + utt_info = self.metadata[index] + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + if self.cfg.preprocess.use_melspec: + single_feature["melspec"] = np.load(self.utt2melspec_path[utt]) + + if self.cfg.preprocess.use_wav: + wav, sr = librosa.load( + self.utt2wav_path[utt], sr=16000 + ) # hard coding for 16KHz... + single_feature["wav"] = wav + + return single_feature + + def __len__(self): + return len(self.metadata) + + def __len__(self): + return len(self.metadata) + + +class AutoencoderKLCollator(BaseCollator): + def __init__(self, cfg): + BaseCollator.__init__(self, cfg) + + def __call__(self, batch): + # mel: (B, n_mels, T) + # wav (option): (B, T) + + packed_batch_features = dict() + + for key in batch[0].keys(): + if key == "melspec": + packed_batch_features["melspec"] = torch.from_numpy( + np.array([b["melspec"][:, :624] for b in batch]) + ) + + if key == "wav": + values = [torch.from_numpy(b[key]) for b in batch] + packed_batch_features[key] = pad_sequence( + values, batch_first=True, padding_value=0 + ) + + return packed_batch_features + + +class AutoencoderKLTestDataset(BaseTestDataset): + ... + + +class AutoencoderKLTestCollator(BaseTestCollator): + ... diff --git a/models/tta/autoencoder/autoencoder_loss.py b/models/tta/autoencoder/autoencoder_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..5916aa36ebdfba6f2608514767f3e3761b57269f --- /dev/null +++ b/models/tta/autoencoder/autoencoder_loss.py @@ -0,0 +1,305 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import functools +import torch.nn.functional as F + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(F.softplus(-logits_real)) + torch.mean(F.softplus(logits_fake)) + ) + return d_loss + + +def adopt_weight(weight, global_step, threshold=0, value=0.0): + if global_step < threshold: + weight = value + return weight + + +class ActNorm(nn.Module): + def __init__( + self, num_features, logdet=False, affine=True, allow_reverse_init=False + ): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:, :, None, None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height * width * torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:, :, None, None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if ( + type(norm_layer) == functools.partial + ): # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True), + ] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=2, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=1, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + ] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) + + +class AutoencoderLossWithDiscriminator(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.kl_weight = cfg.kl_weight + self.logvar = nn.Parameter(torch.ones(size=()) * cfg.logvar_init) + + self.discriminator = NLayerDiscriminator( + input_nc=cfg.disc_in_channels, + n_layers=cfg.disc_num_layers, + use_actnorm=cfg.use_actnorm, + ).apply(weights_init) + + self.discriminator_iter_start = cfg.disc_start + self.discriminator_weight = cfg.disc_weight + self.disc_factor = cfg.disc_factor + self.disc_loss = hinge_d_loss + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer): + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp( + d_weight, self.cfg.min_adapt_d_weight, self.cfg.max_adapt_d_weight + ).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward( + self, + inputs, + reconstructions, + posteriors, + optimizer_idx, + global_step, + last_layer, + split="train", + weights=None, + ): + rec_loss = torch.abs( + inputs.contiguous() - reconstructions.contiguous() + ) # l1 loss + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights * nll_loss + # weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + weighted_nll_loss = torch.mean(weighted_nll_loss) + # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + nll_loss = torch.mean(nll_loss) + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + # ? kl_loss = torch.mean(kl_loss) + + # now the GAN part + if optimizer_idx == 0: + logits_fake = self.discriminator(reconstructions.contiguous()) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight( + nll_loss, g_loss, last_layer=last_layer + ) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight( + self.disc_factor, global_step, threshold=self.discriminator_iter_start + ) + + total_loss = ( + weighted_nll_loss + + self.kl_weight * kl_loss + + d_weight * disc_factor * g_loss + ) + + return { + "loss": total_loss, + "kl_loss": kl_loss, + "rec_loss": rec_loss.mean(), + "nll_loss": nll_loss, + "g_loss": g_loss, + "d_weight": d_weight, + "disc_factor": torch.tensor(disc_factor), + } + + if optimizer_idx == 1: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + + disc_factor = adopt_weight( + self.disc_factor, global_step, threshold=self.discriminator_iter_start + ) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + return { + "d_loss": d_loss, + "logits_real": logits_real.mean(), + "logits_fake": logits_fake.mean(), + } diff --git a/models/tta/autoencoder/autoencoder_trainer.py b/models/tta/autoencoder/autoencoder_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..1faf02fa26c9bdc69faf2344fc2f722336d68a71 --- /dev/null +++ b/models/tta/autoencoder/autoencoder_trainer.py @@ -0,0 +1,187 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from models.base.base_trainer import BaseTrainer +from models.tta.autoencoder.autoencoder_dataset import ( + AutoencoderKLDataset, + AutoencoderKLCollator, +) +from models.tta.autoencoder.autoencoder import AutoencoderKL +from models.tta.autoencoder.autoencoder_loss import AutoencoderLossWithDiscriminator +from torch.optim import Adam, AdamW +from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch.nn import MSELoss, L1Loss +import torch.nn.functional as F +from torch.utils.data import ConcatDataset, DataLoader + + +class AutoencoderKLTrainer(BaseTrainer): + def __init__(self, args, cfg): + BaseTrainer.__init__(self, args, cfg) + self.cfg = cfg + self.save_config_file() + + def build_dataset(self): + return AutoencoderKLDataset, AutoencoderKLCollator + + def build_optimizer(self): + opt_ae = torch.optim.AdamW(self.model.parameters(), **self.cfg.train.adam) + opt_disc = torch.optim.AdamW( + self.criterion.discriminator.parameters(), **self.cfg.train.adam + ) + optimizer = {"opt_ae": opt_ae, "opt_disc": opt_disc} + return optimizer + + def build_data_loader(self): + Dataset, Collator = self.build_dataset() + # build dataset instance for each dataset and combine them by ConcatDataset + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=False) + datasets_list.append(subdataset) + train_dataset = ConcatDataset(datasets_list) + + train_collate = Collator(self.cfg) + + # use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size) + train_loader = DataLoader( + train_dataset, + collate_fn=train_collate, + num_workers=self.args.num_workers, + batch_size=self.cfg.train.batch_size, + pin_memory=False, + ) + if not self.cfg.train.ddp or self.args.local_rank == 0: + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=True) + datasets_list.append(subdataset) + valid_dataset = ConcatDataset(datasets_list) + valid_collate = Collator(self.cfg) + + valid_loader = DataLoader( + valid_dataset, + collate_fn=valid_collate, + num_workers=1, + batch_size=self.cfg.train.batch_size, + ) + else: + raise NotImplementedError("DDP is not supported yet.") + # valid_loader = None + data_loader = {"train": train_loader, "valid": valid_loader} + return data_loader + + # TODO: check it... + def build_scheduler(self): + return None + # return ReduceLROnPlateau(self.optimizer["opt_ae"], **self.cfg.train.lronPlateau) + + def write_summary(self, losses, stats): + for key, value in losses.items(): + self.sw.add_scalar(key, value, self.step) + + def write_valid_summary(self, losses, stats): + for key, value in losses.items(): + self.sw.add_scalar(key, value, self.step) + + def build_criterion(self): + return AutoencoderLossWithDiscriminator(self.cfg.model.loss) + + def get_state_dict(self): + if self.scheduler != None: + state_dict = { + "model": self.model.state_dict(), + "optimizer_ae": self.optimizer["opt_ae"].state_dict(), + "optimizer_disc": self.optimizer["opt_disc"].state_dict(), + "scheduler": self.scheduler.state_dict(), + "step": self.step, + "epoch": self.epoch, + "batch_size": self.cfg.train.batch_size, + } + else: + state_dict = { + "model": self.model.state_dict(), + "optimizer_ae": self.optimizer["opt_ae"].state_dict(), + "optimizer_disc": self.optimizer["opt_disc"].state_dict(), + "step": self.step, + "epoch": self.epoch, + "batch_size": self.cfg.train.batch_size, + } + return state_dict + + def load_model(self, checkpoint): + self.step = checkpoint["step"] + self.epoch = checkpoint["epoch"] + + self.model.load_state_dict(checkpoint["model"]) + self.optimizer["opt_ae"].load_state_dict(checkpoint["optimizer_ae"]) + self.optimizer["opt_disc"].load_state_dict(checkpoint["optimizer_disc"]) + if self.scheduler != None: + self.scheduler.load_state_dict(checkpoint["scheduler"]) + + def build_model(self): + self.model = AutoencoderKL(self.cfg.model.autoencoderkl) + return self.model + + # TODO: train step + def train_step(self, data): + global_step = self.step + optimizer_idx = global_step % 2 + + train_losses = {} + total_loss = 0 + train_states = {} + + inputs = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T) + reconstructions, posterior = self.model(inputs) + # train_stats.update(stat) + + train_losses = self.criterion( + inputs=inputs, + reconstructions=reconstructions, + posteriors=posterior, + optimizer_idx=optimizer_idx, + global_step=global_step, + last_layer=self.model.get_last_layer(), + split="train", + ) + + if optimizer_idx == 0: + total_loss = train_losses["loss"] + self.optimizer["opt_ae"].zero_grad() + total_loss.backward() + self.optimizer["opt_ae"].step() + + else: + total_loss = train_losses["d_loss"] + self.optimizer["opt_disc"].zero_grad() + total_loss.backward() + self.optimizer["opt_disc"].step() + + for item in train_losses: + train_losses[item] = train_losses[item].item() + + return train_losses, train_states, total_loss.item() + + # TODO: eval step + @torch.no_grad() + def eval_step(self, data, index): + valid_loss = {} + total_valid_loss = 0 + valid_stats = {} + + inputs = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T) + reconstructions, posterior = self.model(inputs) + + loss = F.l1_loss(inputs, reconstructions) + valid_loss["loss"] = loss + + total_valid_loss += loss + + for item in valid_loss: + valid_loss[item] = valid_loss[item].item() + + return valid_loss, valid_stats, total_valid_loss.item() diff --git a/models/tta/ldm/__init__.py b/models/tta/ldm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/tta/ldm/attention.py b/models/tta/ldm/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..8e241fc8be6456a1c3bc3b8d8efe648ea4e42740 --- /dev/null +++ b/models/tta/ldm/attention.py @@ -0,0 +1,329 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, "b j -> (b h) () j", h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum("b i j, b j d -> b i d", attn, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + ): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint( + self._forward, (x, context), self.parameters(), self.checkpoint + ) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__( + self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None + ): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim + ) + for d in range(depth) + ] + ) + + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + return x + x_in diff --git a/models/tta/ldm/audioldm.py b/models/tta/ldm/audioldm.py new file mode 100644 index 0000000000000000000000000000000000000000..bf43d07ee6a793bde6acbd330b8606c50da26df9 --- /dev/null +++ b/models/tta/ldm/audioldm.py @@ -0,0 +1,926 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from einops import repeat + +from models.tta.ldm.attention import SpatialTransformer + +# from attention import SpatialTransformer + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += torch.DoubleTensor([matmul_ops]) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) # [N x (H * C) x T] + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum( + "bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length) + ) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + "Learned 2x upsampling without padding" + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d( + self.channels, self.out_channels, kernel_size=ks, stride=2 + ) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if use_spatial_transformer: + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." + + if context_dim is not None: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + from omegaconf.listconfig import ListConfig + + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + nn.Linear(model_channels, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + # print(h.shape, hs[-1].shape) + if h.shape != hs[-1].shape: + if h.shape[-1] > hs[-1].shape[-1]: + h = h[:, :, :, : hs[-1].shape[-1]] + if h.shape[-2] > hs[-1].shape[-2]: + h = h[:, :, : hs[-1].shape[-2], :] + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + # print(h.shape) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class AudioLDM(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.unet = UNetModel( + image_size=cfg.image_size, + in_channels=cfg.in_channels, + out_channels=cfg.out_channels, + model_channels=cfg.model_channels, + attention_resolutions=cfg.attention_resolutions, + num_res_blocks=cfg.num_res_blocks, + channel_mult=cfg.channel_mult, + num_heads=cfg.num_heads, + use_spatial_transformer=cfg.use_spatial_transformer, + transformer_depth=cfg.transformer_depth, + context_dim=cfg.context_dim, + use_checkpoint=cfg.use_checkpoint, + legacy=cfg.legacy, + ) + + def forward(self, x, timesteps=None, context=None, y=None): + x = self.unet(x=x, timesteps=timesteps, context=context, y=y) + return x diff --git a/models/tta/ldm/audioldm_dataset.py b/models/tta/ldm/audioldm_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..344eeb51a164118d16b1651963e364b02135d7de --- /dev/null +++ b/models/tta/ldm/audioldm_dataset.py @@ -0,0 +1,153 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import random +import torch +from torch.nn.utils.rnn import pad_sequence +from utils.data_utils import * + + +from models.base.base_dataset import ( + BaseCollator, + BaseDataset, + BaseTestDataset, + BaseTestCollator, +) +import librosa + +from transformers import AutoTokenizer + + +class AudioLDMDataset(BaseDataset): + def __init__(self, cfg, dataset, is_valid=False): + BaseDataset.__init__(self, cfg, dataset, is_valid=is_valid) + + self.cfg = cfg + + # utt2melspec + if cfg.preprocess.use_melspec: + self.utt2melspec_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2melspec_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.melspec_dir, + uid + ".npy", + ) + + # utt2wav + if cfg.preprocess.use_wav: + self.utt2wav_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2wav_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.wav_dir, + uid + ".wav", + ) + + # utt2caption + if cfg.preprocess.use_caption: + self.utt2caption = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2caption[utt] = utt_info["Caption"] + + def __getitem__(self, index): + # melspec: (n_mels, T) + # wav: (T,) + + single_feature = BaseDataset.__getitem__(self, index) + + utt_info = self.metadata[index] + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + if self.cfg.preprocess.use_melspec: + single_feature["melspec"] = np.load(self.utt2melspec_path[utt]) + + if self.cfg.preprocess.use_wav: + wav, sr = librosa.load( + self.utt2wav_path[utt], sr=16000 + ) # hard coding for 16KHz... + single_feature["wav"] = wav + + if self.cfg.preprocess.use_caption: + cond_mask = np.random.choice( + [1, 0], + p=[ + self.cfg.preprocess.cond_mask_prob, + 1 - self.cfg.preprocess.cond_mask_prob, + ], + ) # (0.1, 0.9) + if cond_mask: + single_feature["caption"] = "" + else: + single_feature["caption"] = self.utt2caption[utt] + + return single_feature + + def __len__(self): + return len(self.metadata) + + +class AudioLDMCollator(BaseCollator): + def __init__(self, cfg): + BaseCollator.__init__(self, cfg) + + self.tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512) + + def __call__(self, batch): + # mel: (B, n_mels, T) + # wav (option): (B, T) + # text_input_ids: (B, L) + # text_attention_mask: (B, L) + + packed_batch_features = dict() + + for key in batch[0].keys(): + if key == "melspec": + packed_batch_features["melspec"] = torch.from_numpy( + np.array([b["melspec"][:, :624] for b in batch]) + ) + + if key == "wav": + values = [torch.from_numpy(b[key]) for b in batch] + packed_batch_features[key] = pad_sequence( + values, batch_first=True, padding_value=0 + ) + + if key == "caption": + captions = [b[key] for b in batch] + text_input = self.tokenizer( + captions, return_tensors="pt", truncation=True, padding="longest" + ) + text_input_ids = text_input["input_ids"] + text_attention_mask = text_input["attention_mask"] + + packed_batch_features["text_input_ids"] = text_input_ids + packed_batch_features["text_attention_mask"] = text_attention_mask + + return packed_batch_features + + +class AudioLDMTestDataset(BaseTestDataset): + ... + + +class AudioLDMTestCollator(BaseTestCollator): + ... diff --git a/models/tta/ldm/audioldm_inference.py b/models/tta/ldm/audioldm_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..a37b40639aeef4d4a8ade1324171e7be11009d8d --- /dev/null +++ b/models/tta/ldm/audioldm_inference.py @@ -0,0 +1,193 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import time +import numpy as np +import torch +from tqdm import tqdm +import torch.nn as nn +from collections import OrderedDict +import json + +from models.tta.autoencoder.autoencoder import AutoencoderKL +from models.tta.ldm.inference_utils.vocoder import Generator +from models.tta.ldm.audioldm import AudioLDM +from transformers import T5EncoderModel, AutoTokenizer +from diffusers import PNDMScheduler + +import matplotlib.pyplot as plt +from scipy.io.wavfile import write + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +class AudioLDMInference: + def __init__(self, args, cfg): + self.cfg = cfg + self.args = args + + self.build_autoencoderkl() + self.build_textencoder() + + self.model = self.build_model() + self.load_state_dict() + + self.build_vocoder() + + self.out_path = self.args.output_dir + self.out_mel_path = os.path.join(self.out_path, "mel") + self.out_wav_path = os.path.join(self.out_path, "wav") + os.makedirs(self.out_mel_path, exist_ok=True) + os.makedirs(self.out_wav_path, exist_ok=True) + + def build_autoencoderkl(self): + self.autoencoderkl = AutoencoderKL(self.cfg.model.autoencoderkl) + self.autoencoder_path = self.cfg.model.autoencoder_path + checkpoint = torch.load(self.autoencoder_path, map_location="cpu") + self.autoencoderkl.load_state_dict(checkpoint["model"]) + self.autoencoderkl.cuda(self.args.local_rank) + self.autoencoderkl.requires_grad_(requires_grad=False) + self.autoencoderkl.eval() + + def build_textencoder(self): + self.tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512) + self.text_encoder = T5EncoderModel.from_pretrained("t5-base") + self.text_encoder.cuda(self.args.local_rank) + self.text_encoder.requires_grad_(requires_grad=False) + self.text_encoder.eval() + + def build_vocoder(self): + config_file = os.path.join(self.args.vocoder_config_path) + with open(config_file) as f: + data = f.read() + json_config = json.loads(data) + h = AttrDict(json_config) + self.vocoder = Generator(h).to(self.args.local_rank) + checkpoint_dict = torch.load( + self.args.vocoder_path, map_location=self.args.local_rank + ) + self.vocoder.load_state_dict(checkpoint_dict["generator"]) + + def build_model(self): + self.model = AudioLDM(self.cfg.model.audioldm) + return self.model + + def load_state_dict(self): + self.checkpoint_path = self.args.checkpoint_path + checkpoint = torch.load(self.checkpoint_path, map_location="cpu") + self.model.load_state_dict(checkpoint["model"]) + self.model.cuda(self.args.local_rank) + + def get_text_embedding(self): + text = self.args.text + + prompt = [text] + + text_input = self.tokenizer( + prompt, + max_length=self.tokenizer.model_max_length, + truncation=True, + padding="do_not_pad", + return_tensors="pt", + ) + text_embeddings = self.text_encoder( + text_input.input_ids.to(self.args.local_rank) + )[0] + + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * 1, padding="max_length", max_length=max_length, return_tensors="pt" + ) + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(self.args.local_rank) + )[0] + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def inference(self): + text_embeddings = self.get_text_embedding() + print(text_embeddings.shape) + + num_steps = self.args.num_steps + guidance_scale = self.args.guidance_scale + + noise_scheduler = PNDMScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + skip_prk_steps=True, + set_alpha_to_one=False, + steps_offset=1, + prediction_type="epsilon", + ) + + noise_scheduler.set_timesteps(num_steps) + + latents = torch.randn( + ( + 1, + self.cfg.model.autoencoderkl.z_channels, + 80 // (2 ** (len(self.cfg.model.autoencoderkl.ch_mult) - 1)), + 624 // (2 ** (len(self.cfg.model.autoencoderkl.ch_mult) - 1)), + ) + ).to(self.args.local_rank) + + self.model.eval() + for t in tqdm(noise_scheduler.timesteps): + t = t.to(self.args.local_rank) + + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + latent_model_input = torch.cat([latents] * 2) + + latent_model_input = noise_scheduler.scale_model_input( + latent_model_input, timestep=t + ) + # print(latent_model_input.shape) + + # predict the noise residual + with torch.no_grad(): + noise_pred = self.model( + latent_model_input, torch.cat([t.unsqueeze(0)] * 2), text_embeddings + ) + + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = noise_scheduler.step(noise_pred, t, latents).prev_sample + # print(latents.shape) + + latents_out = latents + print(latents_out.shape) + + with torch.no_grad(): + mel_out = self.autoencoderkl.decode(latents_out) + print(mel_out.shape) + + melspec = mel_out[0, 0].cpu().detach().numpy() + plt.imsave(os.path.join(self.out_mel_path, self.args.text + ".png"), melspec) + + self.vocoder.eval() + self.vocoder.remove_weight_norm() + with torch.no_grad(): + melspec = np.expand_dims(melspec, 0) + melspec = torch.FloatTensor(melspec).to(self.args.local_rank) + + y = self.vocoder(melspec) + audio = y.squeeze() + audio = audio * 32768.0 + audio = audio.cpu().numpy().astype("int16") + + write(os.path.join(self.out_wav_path, self.args.text + ".wav"), 16000, audio) diff --git a/models/tta/ldm/audioldm_trainer.py b/models/tta/ldm/audioldm_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..bd4a241867b139f0ce314b9c78d053cf711a83df --- /dev/null +++ b/models/tta/ldm/audioldm_trainer.py @@ -0,0 +1,251 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from models.base.base_trainer import BaseTrainer +from diffusers import DDPMScheduler +from models.tta.ldm.audioldm_dataset import AudioLDMDataset, AudioLDMCollator +from models.tta.autoencoder.autoencoder import AutoencoderKL +from models.tta.ldm.audioldm import AudioLDM, UNetModel +import torch +import torch.nn as nn +from torch.nn import MSELoss, L1Loss +import torch.nn.functional as F +from torch.utils.data import ConcatDataset, DataLoader + +from transformers import T5EncoderModel +from diffusers import DDPMScheduler + + +class AudioLDMTrainer(BaseTrainer): + def __init__(self, args, cfg): + BaseTrainer.__init__(self, args, cfg) + self.cfg = cfg + + self.build_autoencoderkl() + self.build_textencoder() + self.nosie_scheduler = self.build_noise_scheduler() + + self.save_config_file() + + def build_autoencoderkl(self): + self.autoencoderkl = AutoencoderKL(self.cfg.model.autoencoderkl) + self.autoencoder_path = self.cfg.model.autoencoder_path + checkpoint = torch.load(self.autoencoder_path, map_location="cpu") + self.autoencoderkl.load_state_dict(checkpoint["model"]) + self.autoencoderkl.cuda(self.args.local_rank) + self.autoencoderkl.requires_grad_(requires_grad=False) + self.autoencoderkl.eval() + + def build_textencoder(self): + self.text_encoder = T5EncoderModel.from_pretrained("t5-base") + self.text_encoder.cuda(self.args.local_rank) + self.text_encoder.requires_grad_(requires_grad=False) + self.text_encoder.eval() + + def build_noise_scheduler(self): + nosie_scheduler = DDPMScheduler( + num_train_timesteps=self.cfg.model.noise_scheduler.num_train_timesteps, + beta_start=self.cfg.model.noise_scheduler.beta_start, + beta_end=self.cfg.model.noise_scheduler.beta_end, + beta_schedule=self.cfg.model.noise_scheduler.beta_schedule, + clip_sample=self.cfg.model.noise_scheduler.clip_sample, + # steps_offset=self.cfg.model.noise_scheduler.steps_offset, + # set_alpha_to_one=self.cfg.model.noise_scheduler.set_alpha_to_one, + # skip_prk_steps=self.cfg.model.noise_scheduler.skip_prk_steps, + prediction_type=self.cfg.model.noise_scheduler.prediction_type, + ) + return nosie_scheduler + + def build_dataset(self): + return AudioLDMDataset, AudioLDMCollator + + def build_data_loader(self): + Dataset, Collator = self.build_dataset() + # build dataset instance for each dataset and combine them by ConcatDataset + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=False) + datasets_list.append(subdataset) + train_dataset = ConcatDataset(datasets_list) + + train_collate = Collator(self.cfg) + + # use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size) + train_loader = DataLoader( + train_dataset, + collate_fn=train_collate, + num_workers=self.args.num_workers, + batch_size=self.cfg.train.batch_size, + pin_memory=False, + ) + if not self.cfg.train.ddp or self.args.local_rank == 0: + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=True) + datasets_list.append(subdataset) + valid_dataset = ConcatDataset(datasets_list) + valid_collate = Collator(self.cfg) + + valid_loader = DataLoader( + valid_dataset, + collate_fn=valid_collate, + num_workers=1, + batch_size=self.cfg.train.batch_size, + ) + else: + raise NotImplementedError("DDP is not supported yet.") + # valid_loader = None + data_loader = {"train": train_loader, "valid": valid_loader} + return data_loader + + def build_optimizer(self): + optimizer = torch.optim.AdamW(self.model.parameters(), **self.cfg.train.adam) + return optimizer + + # TODO: check it... + def build_scheduler(self): + return None + # return ReduceLROnPlateau(self.optimizer["opt_ae"], **self.cfg.train.lronPlateau) + + def write_summary(self, losses, stats): + for key, value in losses.items(): + self.sw.add_scalar(key, value, self.step) + + def write_valid_summary(self, losses, stats): + for key, value in losses.items(): + self.sw.add_scalar(key, value, self.step) + + def build_criterion(self): + criterion = nn.MSELoss(reduction="mean") + return criterion + + def get_state_dict(self): + if self.scheduler != None: + state_dict = { + "model": self.model.state_dict(), + "optimizer": self.optimizer.state_dict(), + "scheduler": self.scheduler.state_dict(), + "step": self.step, + "epoch": self.epoch, + "batch_size": self.cfg.train.batch_size, + } + else: + state_dict = { + "model": self.model.state_dict(), + "optimizer": self.optimizer.state_dict(), + "step": self.step, + "epoch": self.epoch, + "batch_size": self.cfg.train.batch_size, + } + return state_dict + + def load_model(self, checkpoint): + self.step = checkpoint["step"] + self.epoch = checkpoint["epoch"] + + self.model.load_state_dict(checkpoint["model"]) + self.optimizer.load_state_dict(checkpoint["optimizer"]) + if self.scheduler != None: + self.scheduler.load_state_dict(checkpoint["scheduler"]) + + def build_model(self): + self.model = AudioLDM(self.cfg.model.audioldm) + return self.model + + @torch.no_grad() + def mel_to_latent(self, melspec): + posterior = self.autoencoderkl.encode(melspec) + latent = posterior.sample() # (B, 4, 5, 78) + return latent + + @torch.no_grad() + def get_text_embedding(self, text_input_ids, text_attention_mask): + text_embedding = self.text_encoder( + input_ids=text_input_ids, attention_mask=text_attention_mask + ).last_hidden_state + return text_embedding # (B, T, 768) + + def train_step(self, data): + train_losses = {} + total_loss = 0 + train_stats = {} + + melspec = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T) + latents = self.mel_to_latent(melspec) + + text_embedding = self.get_text_embedding( + data["text_input_ids"], data["text_attention_mask"] + ) + + noise = torch.randn_like(latents).float() + + bsz = latents.shape[0] + timesteps = torch.randint( + 0, + self.cfg.model.noise_scheduler.num_train_timesteps, + (bsz,), + device=latents.device, + ) + timesteps = timesteps.long() + + with torch.no_grad(): + noisy_latents = self.nosie_scheduler.add_noise(latents, noise, timesteps) + + model_pred = self.model( + noisy_latents, timesteps=timesteps, context=text_embedding + ) + + loss = self.criterion(model_pred, noise) + + train_losses["loss"] = loss + total_loss += loss + + self.optimizer.zero_grad() + total_loss.backward() + self.optimizer.step() + + for item in train_losses: + train_losses[item] = train_losses[item].item() + + return train_losses, train_stats, total_loss.item() + + # TODO: eval step + @torch.no_grad() + def eval_step(self, data, index): + valid_loss = {} + total_valid_loss = 0 + valid_stats = {} + + melspec = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T) + latents = self.mel_to_latent(melspec) + + text_embedding = self.get_text_embedding( + data["text_input_ids"], data["text_attention_mask"] + ) + + noise = torch.randn_like(latents).float() + + bsz = latents.shape[0] + timesteps = torch.randint( + 0, + self.cfg.model.noise_scheduler.num_train_timesteps, + (bsz,), + device=latents.device, + ) + timesteps = timesteps.long() + + noisy_latents = self.nosie_scheduler.add_noise(latents, noise, timesteps) + + model_pred = self.model(noisy_latents, timesteps, text_embedding) + + loss = self.criterion(model_pred, noise) + valid_loss["loss"] = loss + + total_valid_loss += loss + + for item in valid_loss: + valid_loss[item] = valid_loss[item].item() + + return valid_loss, valid_stats, total_valid_loss.item() diff --git a/models/tta/ldm/inference_utils/utils.py b/models/tta/ldm/inference_utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fd61c0c262be55a9b98f12c1ce1043eeddfcc739 --- /dev/null +++ b/models/tta/ldm/inference_utils/utils.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import glob +import os +import matplotlib +import torch +from torch.nn.utils import weight_norm + +matplotlib.use("Agg") +import matplotlib.pylab as plt + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print("Saving checkpoint to {}".format(filepath)) + torch.save(obj, filepath) + print("Complete.") + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + "????????") + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] diff --git a/models/tta/ldm/inference_utils/vocoder.py b/models/tta/ldm/inference_utils/vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..19e17c1e2b3e20154305180705ccbf8b5e49c346 --- /dev/null +++ b/models/tta/ldm/inference_utils/vocoder.py @@ -0,0 +1,408 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from models.tta.ldm.inference_utils.utils import get_padding, init_weights + +LRELU_SLOPE = 0.1 + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm( + Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3) + ) + resblock = ResBlock1 if h.resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ] + ) + self.meanpools = nn.ModuleList( + [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/models/tts/naturalspeech2/ns2_dataset.py b/models/tts/naturalspeech2/ns2_dataset.py index 009176ce692f74921353492efe7f2207ab345b40..eea17ffa943e1d163f545d829c327f8e7decec1d 100644 --- a/models/tts/naturalspeech2/ns2_dataset.py +++ b/models/tts/naturalspeech2/ns2_dataset.py @@ -21,13 +21,11 @@ class NS2Dataset(torch.utils.data.Dataset): assert isinstance(dataset, str) processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset) - # for example: /home/v-detaixin/LibriTTS/processed_data; train-full meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file # train.json self.metafile_path = os.path.join(processed_data_dir, meta_file) - # /home/v-detaixin/LibriTTS/processed_data/train-full/train.json self.metadata = self.get_metadata() diff --git a/models/vocoders/autoregressive/autoregressive_vocoder_dataset.py b/models/vocoders/autoregressive/autoregressive_vocoder_dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/autoregressive/autoregressive_vocoder_inference.py b/models/vocoders/autoregressive/autoregressive_vocoder_inference.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/autoregressive/autoregressive_vocoder_trainer.py b/models/vocoders/autoregressive/autoregressive_vocoder_trainer.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/autoregressive/wavenet/conv.py b/models/vocoders/autoregressive/wavenet/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..a095aad5d7203f6e5fb5a4d585b894e34dbe63c7 --- /dev/null +++ b/models/vocoders/autoregressive/wavenet/conv.py @@ -0,0 +1,66 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn +from torch.nn import functional as F + + +class Conv1d(nn.Conv1d): + """Extended nn.Conv1d for incremental dilated convolutions""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.clear_buffer() + self._linearized_weight = None + self.register_backward_hook(self._clear_linearized_weight) + + def incremental_forward(self, input): + # input (B, T, C) + # run forward pre hooks + for hook in self._forward_pre_hooks.values(): + hook(self, input) + + # reshape weight + weight = self._get_linearized_weight() + kw = self.kernel_size[0] + dilation = self.dilation[0] + + bsz = input.size(0) + if kw > 1: + input = input.data + if self.input_buffer is None: + self.input_buffer = input.new( + bsz, kw + (kw - 1) * (dilation - 1), input.size(2) + ) + self.input_buffer.zero_() + else: + # shift buffer + self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone() + # append next input + self.input_buffer[:, -1, :] = input[:, -1, :] + input = self.input_buffer + if dilation > 1: + input = input[:, 0::dilation, :].contiguous() + output = F.linear(input.view(bsz, -1), weight, self.bias) + return output.view(bsz, 1, -1) + + def clear_buffer(self): + self.input_buffer = None + + def _get_linearized_weight(self): + if self._linearized_weight is None: + kw = self.kernel_size[0] + # nn.Conv1d + if self.weight.size() == (self.out_channels, self.in_channels, kw): + weight = self.weight.transpose(1, 2).contiguous() + else: + # fairseq.modules.conv_tbc.ConvTBC + weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous() + assert weight.size() == (self.out_channels, kw, self.in_channels) + self._linearized_weight = weight.view(self.out_channels, -1) + return self._linearized_weight + + def _clear_linearized_weight(self, *args): + self._linearized_weight = None diff --git a/models/vocoders/autoregressive/wavenet/modules.py b/models/vocoders/autoregressive/wavenet/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..13d51e52a50af3bc1f7fe9627aeae8d2b1b28b7d --- /dev/null +++ b/models/vocoders/autoregressive/wavenet/modules.py @@ -0,0 +1,152 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import math + +from torch import nn +from torch.nn import functional as F + +from .conv import Conv1d as conv_Conv1d + + +def Conv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs): + m = conv_Conv1d(in_channels, out_channels, kernel_size, **kwargs) + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + return nn.utils.weight_norm(m) + + +def Conv1d1x1(in_channels, out_channels, bias=True): + return Conv1d( + in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias + ) + + +def _conv1x1_forward(conv, x, is_incremental): + if is_incremental: + x = conv.incremental_forward(x) + else: + x = conv(x) + return x + + +class ResidualConv1dGLU(nn.Module): + """Residual dilated conv1d + Gated linear unit + + Args: + residual_channels (int): Residual input / output channels + gate_channels (int): Gated activation channels. + kernel_size (int): Kernel size of convolution layers. + skip_out_channels (int): Skip connection channels. If None, set to same + as ``residual_channels``. + cin_channels (int): Local conditioning channels. If negative value is + set, local conditioning is disabled. + dropout (float): Dropout probability. + padding (int): Padding for convolution layers. If None, proper padding + is computed depends on dilation and kernel_size. + dilation (int): Dilation factor. + """ + + def __init__( + self, + residual_channels, + gate_channels, + kernel_size, + skip_out_channels=None, + cin_channels=-1, + dropout=1 - 0.95, + padding=None, + dilation=1, + causal=True, + bias=True, + *args, + **kwargs, + ): + super(ResidualConv1dGLU, self).__init__() + self.dropout = dropout + + if skip_out_channels is None: + skip_out_channels = residual_channels + if padding is None: + # no future time stamps available + if causal: + padding = (kernel_size - 1) * dilation + else: + padding = (kernel_size - 1) // 2 * dilation + self.causal = causal + + self.conv = Conv1d( + residual_channels, + gate_channels, + kernel_size, + padding=padding, + dilation=dilation, + bias=bias, + *args, + **kwargs, + ) + + # mel conditioning + self.conv1x1c = Conv1d1x1(cin_channels, gate_channels, bias=False) + + gate_out_channels = gate_channels // 2 + self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias) + self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_out_channels, bias=bias) + + def forward(self, x, c=None): + return self._forward(x, c, False) + + def incremental_forward(self, x, c=None): + return self._forward(x, c, True) + + def clear_buffer(self): + for c in [ + self.conv, + self.conv1x1_out, + self.conv1x1_skip, + self.conv1x1c, + ]: + if c is not None: + c.clear_buffer() + + def _forward(self, x, c, is_incremental): + """Forward + + Args: + x (Tensor): B x C x T + c (Tensor): B x C x T, Mel conditioning features + Returns: + Tensor: output + """ + residual = x + x = F.dropout(x, p=self.dropout, training=self.training) + if is_incremental: + splitdim = -1 + x = self.conv.incremental_forward(x) + else: + splitdim = 1 + x = self.conv(x) + # remove future time steps + x = x[:, :, : residual.size(-1)] if self.causal else x + + a, b = x.split(x.size(splitdim) // 2, dim=splitdim) + + assert self.conv1x1c is not None + c = _conv1x1_forward(self.conv1x1c, c, is_incremental) + ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) + a, b = a + ca, b + cb + + x = torch.tanh(a) * torch.sigmoid(b) + + # For skip connection + s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental) + + # For residual connection + x = _conv1x1_forward(self.conv1x1_out, x, is_incremental) + + x = (x + residual) * math.sqrt(0.5) + return x, s diff --git a/models/vocoders/autoregressive/wavenet/upsample.py b/models/vocoders/autoregressive/wavenet/upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..b664302cd56545f1709a4f1874ebadd8e9375a9c --- /dev/null +++ b/models/vocoders/autoregressive/wavenet/upsample.py @@ -0,0 +1,109 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +import numpy as np + +from torch import nn +from torch.nn import functional as F + + +class Stretch2d(nn.Module): + def __init__(self, x_scale, y_scale, mode="nearest"): + super(Stretch2d, self).__init__() + self.x_scale = x_scale + self.y_scale = y_scale + self.mode = mode + + def forward(self, x): + return F.interpolate( + x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode + ) + + +def _get_activation(upsample_activation): + nonlinear = getattr(nn, upsample_activation) + return nonlinear + + +class UpsampleNetwork(nn.Module): + def __init__( + self, + upsample_scales, + upsample_activation="none", + upsample_activation_params={}, + mode="nearest", + freq_axis_kernel_size=1, + cin_pad=0, + cin_channels=128, + ): + super(UpsampleNetwork, self).__init__() + self.up_layers = nn.ModuleList() + total_scale = np.prod(upsample_scales) + self.indent = cin_pad * total_scale + for scale in upsample_scales: + freq_axis_padding = (freq_axis_kernel_size - 1) // 2 + k_size = (freq_axis_kernel_size, scale * 2 + 1) + padding = (freq_axis_padding, scale) + stretch = Stretch2d(scale, 1, mode) + conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False) + conv.weight.data.fill_(1.0 / np.prod(k_size)) + conv = nn.utils.weight_norm(conv) + self.up_layers.append(stretch) + self.up_layers.append(conv) + if upsample_activation != "none": + nonlinear = _get_activation(upsample_activation) + self.up_layers.append(nonlinear(**upsample_activation_params)) + + def forward(self, c): + """ + Args: + c : B x C x T + """ + + # B x 1 x C x T + c = c.unsqueeze(1) + for f in self.up_layers: + c = f(c) + # B x C x T + c = c.squeeze(1) + + if self.indent > 0: + c = c[:, :, self.indent : -self.indent] + return c + + +class ConvInUpsampleNetwork(nn.Module): + def __init__( + self, + upsample_scales, + upsample_activation="none", + upsample_activation_params={}, + mode="nearest", + freq_axis_kernel_size=1, + cin_pad=0, + cin_channels=128, + ): + super(ConvInUpsampleNetwork, self).__init__() + # To capture wide-context information in conditional features + # meaningless if cin_pad == 0 + ks = 2 * cin_pad + 1 + self.conv_in = nn.Conv1d( + cin_channels, cin_channels, kernel_size=ks, padding=cin_pad, bias=False + ) + self.upsample = UpsampleNetwork( + upsample_scales, + upsample_activation, + upsample_activation_params, + mode, + freq_axis_kernel_size, + cin_pad=cin_pad, + cin_channels=cin_channels, + ) + + def forward(self, c): + c_up = self.upsample(self.conv_in(c)) + return c_up diff --git a/models/vocoders/autoregressive/wavenet/wavenet.py b/models/vocoders/autoregressive/wavenet/wavenet.py new file mode 100644 index 0000000000000000000000000000000000000000..d63f22c2600fd0f83e5bdf339ebb121b3d2f35e6 --- /dev/null +++ b/models/vocoders/autoregressive/wavenet/wavenet.py @@ -0,0 +1,170 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from torch import nn +from torch.nn import functional as F + +from .modules import Conv1d1x1, ResidualConv1dGLU +from .upsample import ConvInUpsampleNetwork + + +def receptive_field_size( + total_layers, num_cycles, kernel_size, dilation=lambda x: 2**x +): + """Compute receptive field size + + Args: + total_layers (int): total layers + num_cycles (int): cycles + kernel_size (int): kernel size + dilation (lambda): lambda to compute dilation factor. ``lambda x : 1`` + to disable dilated convolution. + + Returns: + int: receptive field size in sample + + """ + assert total_layers % num_cycles == 0 + + layers_per_cycle = total_layers // num_cycles + dilations = [dilation(i % layers_per_cycle) for i in range(total_layers)] + return (kernel_size - 1) * sum(dilations) + 1 + + +class WaveNet(nn.Module): + """The WaveNet model that supports local and global conditioning. + + Args: + out_channels (int): Output channels. If input_type is mu-law quantized + one-hot vecror. this must equal to the quantize channels. Other wise + num_mixtures x 3 (pi, mu, log_scale). + layers (int): Number of total layers + stacks (int): Number of dilation cycles + residual_channels (int): Residual input / output channels + gate_channels (int): Gated activation channels. + skip_out_channels (int): Skip connection channels. + kernel_size (int): Kernel size of convolution layers. + dropout (float): Dropout probability. + input_dim (int): Number of mel-spec dimension. + upsample_scales (list): List of upsample scale. + ``np.prod(upsample_scales)`` must equal to hop size. Used only if + upsample_conditional_features is enabled. + freq_axis_kernel_size (int): Freq-axis kernel_size for transposed + convolution layers for upsampling. If you only care about time-axis + upsampling, set this to 1. + scalar_input (Bool): If True, scalar input ([-1, 1]) is expected, otherwise + quantized one-hot vector is expected.. + """ + + def __init__(self, cfg): + super(WaveNet, self).__init__() + self.cfg = cfg + self.scalar_input = self.cfg.VOCODER.SCALAR_INPUT + self.out_channels = self.cfg.VOCODER.OUT_CHANNELS + self.cin_channels = self.cfg.VOCODER.INPUT_DIM + self.residual_channels = self.cfg.VOCODER.RESIDUAL_CHANNELS + self.layers = self.cfg.VOCODER.LAYERS + self.stacks = self.cfg.VOCODER.STACKS + self.gate_channels = self.cfg.VOCODER.GATE_CHANNELS + self.kernel_size = self.cfg.VOCODER.KERNEL_SIZE + self.skip_out_channels = self.cfg.VOCODER.SKIP_OUT_CHANNELS + self.dropout = self.cfg.VOCODER.DROPOUT + self.upsample_scales = self.cfg.VOCODER.UPSAMPLE_SCALES + self.mel_frame_pad = self.cfg.VOCODER.MEL_FRAME_PAD + + assert self.layers % self.stacks == 0 + + layers_per_stack = self.layers // self.stacks + if self.scalar_input: + self.first_conv = Conv1d1x1(1, self.residual_channels) + else: + self.first_conv = Conv1d1x1(self.out_channels, self.residual_channels) + + self.conv_layers = nn.ModuleList() + for layer in range(self.layers): + dilation = 2 ** (layer % layers_per_stack) + conv = ResidualConv1dGLU( + self.residual_channels, + self.gate_channels, + kernel_size=self.kernel_size, + skip_out_channels=self.skip_out_channels, + bias=True, + dilation=dilation, + dropout=self.dropout, + cin_channels=self.cin_channels, + ) + self.conv_layers.append(conv) + + self.last_conv_layers = nn.ModuleList( + [ + nn.ReLU(inplace=True), + Conv1d1x1(self.skip_out_channels, self.skip_out_channels), + nn.ReLU(inplace=True), + Conv1d1x1(self.skip_out_channels, self.out_channels), + ] + ) + + self.upsample_net = ConvInUpsampleNetwork( + upsample_scales=self.upsample_scales, + cin_pad=self.mel_frame_pad, + cin_channels=self.cin_channels, + ) + + self.receptive_field = receptive_field_size( + self.layers, self.stacks, self.kernel_size + ) + + def forward(self, x, mel, softmax=False): + """Forward step + + Args: + x (Tensor): One-hot encoded audio signal, shape (B x C x T) + mel (Tensor): Local conditioning features, + shape (B x cin_channels x T) + softmax (bool): Whether applies softmax or not. + + Returns: + Tensor: output, shape B x out_channels x T + """ + B, _, T = x.size() + + mel = self.upsample_net(mel) + assert mel.shape[-1] == x.shape[-1] + + x = self.first_conv(x) + skips = 0 + for f in self.conv_layers: + x, h = f(x, mel) + skips += h + skips *= math.sqrt(1.0 / len(self.conv_layers)) + + x = skips + for f in self.last_conv_layers: + x = f(x) + + x = F.softmax(x, dim=1) if softmax else x + + return x + + def clear_buffer(self): + self.first_conv.clear_buffer() + for f in self.conv_layers: + f.clear_buffer() + for f in self.last_conv_layers: + try: + f.clear_buffer() + except AttributeError: + pass + + def make_generation_fast_(self): + def remove_weight_norm(m): + try: + nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(remove_weight_norm) diff --git a/models/vocoders/autoregressive/wavernn/wavernn.py b/models/vocoders/autoregressive/wavernn/wavernn.py new file mode 100644 index 0000000000000000000000000000000000000000..c7475fa8fe8b4575bf714e615349582ff98bbc27 --- /dev/null +++ b/models/vocoders/autoregressive/wavernn/wavernn.py @@ -0,0 +1,188 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np + + +class ResBlock(nn.Module): + def __init__(self, dims): + super().__init__() + self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) + self.batch_norm1 = nn.BatchNorm1d(dims) + self.batch_norm2 = nn.BatchNorm1d(dims) + + def forward(self, x): + residual = x + x = self.conv1(x) + x = self.batch_norm1(x) + x = F.relu(x) + x = self.conv2(x) + x = self.batch_norm2(x) + x = x + residual + return x + + +class MelResNet(nn.Module): + def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad): + super().__init__() + kernel_size = pad * 2 + 1 + self.conv_in = nn.Conv1d( + in_dims, compute_dims, kernel_size=kernel_size, bias=False + ) + self.batch_norm = nn.BatchNorm1d(compute_dims) + self.layers = nn.ModuleList() + for i in range(res_blocks): + self.layers.append(ResBlock(compute_dims)) + self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1) + + def forward(self, x): + x = self.conv_in(x) + x = self.batch_norm(x) + x = F.relu(x) + for f in self.layers: + x = f(x) + x = self.conv_out(x) + return x + + +class Stretch2d(nn.Module): + def __init__(self, x_scale, y_scale): + super().__init__() + self.x_scale = x_scale + self.y_scale = y_scale + + def forward(self, x): + b, c, h, w = x.size() + x = x.unsqueeze(-1).unsqueeze(3) + x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale) + return x.view(b, c, h * self.y_scale, w * self.x_scale) + + +class UpsampleNetwork(nn.Module): + def __init__( + self, feat_dims, upsample_scales, compute_dims, res_blocks, res_out_dims, pad + ): + super().__init__() + total_scale = np.cumproduct(upsample_scales)[-1] + self.indent = pad * total_scale + self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad) + self.resnet_stretch = Stretch2d(total_scale, 1) + self.up_layers = nn.ModuleList() + for scale in upsample_scales: + kernel_size = (1, scale * 2 + 1) + padding = (0, scale) + stretch = Stretch2d(scale, 1) + conv = nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) + conv.weight.data.fill_(1.0 / kernel_size[1]) + self.up_layers.append(stretch) + self.up_layers.append(conv) + + def forward(self, m): + aux = self.resnet(m).unsqueeze(1) + aux = self.resnet_stretch(aux) + aux = aux.squeeze(1) + m = m.unsqueeze(1) + for f in self.up_layers: + m = f(m) + m = m.squeeze(1)[:, :, self.indent : -self.indent] + return m.transpose(1, 2), aux.transpose(1, 2) + + +class WaveRNN(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.cfg = cfg + self.pad = self.cfg.VOCODER.MEL_FRAME_PAD + + if self.cfg.VOCODER.MODE == "mu_law_quantize": + self.n_classes = 2**self.cfg.VOCODER.BITS + elif self.cfg.VOCODER.MODE == "mu_law" or self.cfg.VOCODER: + self.n_classes = 30 + + self._to_flatten = [] + + self.rnn_dims = self.cfg.VOCODER.RNN_DIMS + self.aux_dims = self.cfg.VOCODER.RES_OUT_DIMS // 4 + self.hop_length = self.cfg.VOCODER.HOP_LENGTH + self.fc_dims = self.cfg.VOCODER.FC_DIMS + self.upsample_factors = self.cfg.VOCODER.UPSAMPLE_FACTORS + self.feat_dims = self.cfg.VOCODER.INPUT_DIM + self.compute_dims = self.cfg.VOCODER.COMPUTE_DIMS + self.res_out_dims = self.cfg.VOCODER.RES_OUT_DIMS + self.res_blocks = self.cfg.VOCODER.RES_BLOCKS + + self.upsample = UpsampleNetwork( + self.feat_dims, + self.upsample_factors, + self.compute_dims, + self.res_blocks, + self.res_out_dims, + self.pad, + ) + self.I = nn.Linear(self.feat_dims + self.aux_dims + 1, self.rnn_dims) + + self.rnn1 = nn.GRU(self.rnn_dims, self.rnn_dims, batch_first=True) + self.rnn2 = nn.GRU( + self.rnn_dims + self.aux_dims, self.rnn_dims, batch_first=True + ) + self._to_flatten += [self.rnn1, self.rnn2] + + self.fc1 = nn.Linear(self.rnn_dims + self.aux_dims, self.fc_dims) + self.fc2 = nn.Linear(self.fc_dims + self.aux_dims, self.fc_dims) + self.fc3 = nn.Linear(self.fc_dims, self.n_classes) + + self.num_params() + + self._flatten_parameters() + + def forward(self, x, mels): + device = next(self.parameters()).device + + self._flatten_parameters() + + batch_size = x.size(0) + h1 = torch.zeros(1, batch_size, self.rnn_dims, device=device) + h2 = torch.zeros(1, batch_size, self.rnn_dims, device=device) + mels, aux = self.upsample(mels) + + aux_idx = [self.aux_dims * i for i in range(5)] + a1 = aux[:, :, aux_idx[0] : aux_idx[1]] + a2 = aux[:, :, aux_idx[1] : aux_idx[2]] + a3 = aux[:, :, aux_idx[2] : aux_idx[3]] + a4 = aux[:, :, aux_idx[3] : aux_idx[4]] + + x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2) + x = self.I(x) + res = x + x, _ = self.rnn1(x, h1) + + x = x + res + res = x + x = torch.cat([x, a2], dim=2) + x, _ = self.rnn2(x, h2) + + x = x + res + x = torch.cat([x, a3], dim=2) + x = F.relu(self.fc1(x)) + + x = torch.cat([x, a4], dim=2) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + def num_params(self, print_out=True): + parameters = filter(lambda p: p.requires_grad, self.parameters()) + parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 + if print_out: + print("Trainable Parameters: %.3fM" % parameters) + return parameters + + def _flatten_parameters(self): + [m.flatten_parameters() for m in self._to_flatten] diff --git a/models/vocoders/diffusion/diffusion_vocoder_dataset.py b/models/vocoders/diffusion/diffusion_vocoder_dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/diffusion/diffusion_vocoder_inference.py b/models/vocoders/diffusion/diffusion_vocoder_inference.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/diffusion/diffusion_vocoder_trainer.py b/models/vocoders/diffusion/diffusion_vocoder_trainer.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/diffusion/diffwave/diffwave.py b/models/vocoders/diffusion/diffwave/diffwave.py new file mode 100644 index 0000000000000000000000000000000000000000..c9379b0b622c6da8a754f2cc87fd7723eacfa995 --- /dev/null +++ b/models/vocoders/diffusion/diffwave/diffwave.py @@ -0,0 +1,173 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from math import sqrt + + +Linear = nn.Linear +ConvTranspose2d = nn.ConvTranspose2d + + +def Conv1d(*args, **kwargs): + layer = nn.Conv1d(*args, **kwargs) + nn.init.kaiming_normal_(layer.weight) + return layer + + +@torch.jit.script +def silu(x): + return x * torch.sigmoid(x) + + +class DiffusionEmbedding(nn.Module): + def __init__(self, max_steps): + super().__init__() + self.register_buffer( + "embedding", self._build_embedding(max_steps), persistent=False + ) + self.projection1 = Linear(128, 512) + self.projection2 = Linear(512, 512) + + def forward(self, diffusion_step): + if diffusion_step.dtype in [torch.int32, torch.int64]: + x = self.embedding[diffusion_step] + else: + x = self._lerp_embedding(diffusion_step) + x = self.projection1(x) + x = silu(x) + x = self.projection2(x) + x = silu(x) + return x + + def _lerp_embedding(self, t): + low_idx = torch.floor(t).long() + high_idx = torch.ceil(t).long() + low = self.embedding[low_idx] + high = self.embedding[high_idx] + return low + (high - low) * (t - low_idx) + + def _build_embedding(self, max_steps): + steps = torch.arange(max_steps).unsqueeze(1) # [T,1] + dims = torch.arange(64).unsqueeze(0) # [1,64] + table = steps * 10.0 ** (dims * 4.0 / 63.0) # [T,64] + table = torch.cat([torch.sin(table), torch.cos(table)], dim=1) + return table + + +class SpectrogramUpsampler(nn.Module): + def __init__(self, upsample_factors): + super().__init__() + self.conv1 = ConvTranspose2d( + 1, + 1, + [3, upsample_factors[0] * 2], + stride=[1, upsample_factors[0]], + padding=[1, upsample_factors[0] // 2], + ) + self.conv2 = ConvTranspose2d( + 1, + 1, + [3, upsample_factors[1] * 2], + stride=[1, upsample_factors[1]], + padding=[1, upsample_factors[1] // 2], + ) + + def forward(self, x): + x = torch.unsqueeze(x, 1) + x = self.conv1(x) + x = F.leaky_relu(x, 0.4) + x = self.conv2(x) + x = F.leaky_relu(x, 0.4) + x = torch.squeeze(x, 1) + return x + + +class ResidualBlock(nn.Module): + def __init__(self, n_mels, residual_channels, dilation): + super().__init__() + self.dilated_conv = Conv1d( + residual_channels, + 2 * residual_channels, + 3, + padding=dilation, + dilation=dilation, + ) + self.diffusion_projection = Linear(512, residual_channels) + + self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1) + + self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) + + def forward(self, x, diffusion_step, conditioner): + diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) + y = x + diffusion_step + + conditioner = self.conditioner_projection(conditioner) + y = self.dilated_conv(y) + conditioner + + gate, filter = torch.chunk(y, 2, dim=1) + y = torch.sigmoid(gate) * torch.tanh(filter) + + y = self.output_projection(y) + residual, skip = torch.chunk(y, 2, dim=1) + return (x + residual) / sqrt(2.0), skip + + +class DiffWave(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.cfg.VOCODER.NOISE_SCHEDULE = np.linspace( + self.cfg.VOCODER.NOISE_SCHEDULE_FACTORS[0], + self.cfg.VOCODER.NOISE_SCHEDULE_FACTORS[1], + self.cfg.VOCODER.NOISE_SCHEDULE_FACTORS[2], + ).tolist() + self.input_projection = Conv1d(1, self.cfg.VOCODER.RESIDUAL_CHANNELS, 1) + self.diffusion_embedding = DiffusionEmbedding( + len(self.cfg.VOCODER.NOISE_SCHEDULE) + ) + self.spectrogram_upsampler = SpectrogramUpsampler( + self.cfg.VOCODER.UPSAMPLE_FACTORS + ) + + self.residual_layers = nn.ModuleList( + [ + ResidualBlock( + self.cfg.VOCODER.INPUT_DIM, + self.cfg.VOCODER.RESIDUAL_CHANNELS, + 2 ** (i % self.cfg.VOCODER.DILATION_CYCLE_LENGTH), + ) + for i in range(self.cfg.VOCODER.RESIDUAL_LAYERS) + ] + ) + self.skip_projection = Conv1d( + self.cfg.VOCODER.RESIDUAL_CHANNELS, self.cfg.VOCODER.RESIDUAL_CHANNELS, 1 + ) + self.output_projection = Conv1d(self.cfg.VOCODER.RESIDUAL_CHANNELS, 1, 1) + nn.init.zeros_(self.output_projection.weight) + + def forward(self, audio, diffusion_step, spectrogram): + x = audio.unsqueeze(1) + x = self.input_projection(x) + x = F.relu(x) + + diffusion_step = self.diffusion_embedding(diffusion_step) + spectrogram = self.spectrogram_upsampler(spectrogram) + + skip = None + for layer in self.residual_layers: + x, skip_connection = layer(x, diffusion_step, spectrogram) + skip = skip_connection if skip is None else skip_connection + skip + + x = skip / sqrt(len(self.residual_layers)) + x = self.skip_projection(x) + x = F.relu(x) + x = self.output_projection(x) + return x diff --git a/models/vocoders/dsp/world/world.py b/models/vocoders/dsp/world/world.py new file mode 100644 index 0000000000000000000000000000000000000000..59f28e8e896f883fe6ce243dfb7f254e78fd09c6 --- /dev/null +++ b/models/vocoders/dsp/world/world.py @@ -0,0 +1,183 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# 1. Extract WORLD features including F0, AP, SP +# 2. Transform between SP and MCEP +import torchaudio +import pyworld as pw +import numpy as np +import torch +import diffsptk +import os +from tqdm import tqdm +import pickle +import json +import re +import torchaudio + +from cuhkszsvc.configs.config_parse import get_wav_path, get_wav_file_path +from utils.io import has_existed + + +def get_mcep_params(fs): + """Hyperparameters of transformation between SP and MCEP + + Reference: + https://github.com/CSTR-Edinburgh/merlin/blob/master/misc/scripts/vocoder/world_v2/copy_synthesis.sh + + """ + if fs in [44100, 48000]: + fft_size = 2048 + alpha = 0.77 + if fs in [16000]: + fft_size = 1024 + alpha = 0.58 + return fft_size, alpha + + +def extract_world_features(wave_file, fs, frameshift): + # waveform: (1, seq) + waveform, sample_rate = torchaudio.load(wave_file) + if sample_rate != fs: + waveform = torchaudio.functional.resample( + waveform, orig_freq=sample_rate, new_freq=fs + ) + # x: (seq,) + x = np.array(torch.clamp(waveform[0], -1.0, 1.0), dtype=np.double) + + _f0, t = pw.dio(x, fs, frame_period=frameshift) # raw pitch extractor + f0 = pw.stonemask(x, _f0, t, fs) # pitch refinement + sp = pw.cheaptrick(x, f0, t, fs) # extract smoothed spectrogram + ap = pw.d4c(x, f0, t, fs) # extract aperiodicity + + return f0, sp, ap, fs + + +def sp2mcep(x, mcsize, fs): + fft_size, alpha = get_mcep_params(fs) + x = torch.as_tensor(x, dtype=torch.float) + + tmp = diffsptk.ScalarOperation("SquareRoot")(x) + tmp = diffsptk.ScalarOperation("Multiplication", 32768.0)(tmp) + mgc = diffsptk.MelCepstralAnalysis( + cep_order=mcsize - 1, fft_length=fft_size, alpha=alpha, n_iter=1 + )(tmp) + return mgc.numpy() + + +def mcep2sp(x, mcsize, fs): + fft_size, alpha = get_mcep_params(fs) + x = torch.as_tensor(x, dtype=torch.float) + + tmp = diffsptk.MelGeneralizedCepstrumToSpectrum( + alpha=alpha, + cep_order=mcsize - 1, + fft_length=fft_size, + )(x) + tmp = diffsptk.ScalarOperation("Division", 32768.0)(tmp) + sp = diffsptk.ScalarOperation("Power", 2)(tmp) + return sp.double().numpy() + + +def extract_mcep_features_of_dataset( + output_path, dataset_path, dataset, mcsize, fs, frameshift, splits=None +): + output_dir = os.path.join(output_path, dataset, "mcep/{}".format(fs)) + + if not splits: + splits = ["train", "test"] if dataset != "m4singer" else ["test"] + + for dataset_type in splits: + print("-" * 20) + print("Dataset: {}, {}".format(dataset, dataset_type)) + + output_file = os.path.join(output_dir, "{}.pkl".format(dataset_type)) + if has_existed(output_file): + continue + + # Extract SP features + print("\nExtracting SP featuers...") + sp_features = get_world_features_of_dataset( + output_path, dataset_path, dataset, dataset_type, fs, frameshift + ) + + # SP to MCEP + print("\nTransform SP to MCEP...") + mcep_features = [sp2mcep(sp, mcsize=mcsize, fs=fs) for sp in tqdm(sp_features)] + + # Save + os.makedirs(output_dir, exist_ok=True) + with open(output_file, "wb") as f: + pickle.dump(mcep_features, f) + + +def get_world_features_of_dataset( + output_path, + dataset_path, + dataset, + dataset_type, + fs, + frameshift, + save_sp_feature=False, +): + data_dir = os.path.join(output_path, dataset) + wave_dir = get_wav_path(dataset_path, dataset) + + # Dataset + dataset_file = os.path.join(data_dir, "{}.json".format(dataset_type)) + if not os.path.exists(dataset_file): + print("File {} has not existed.".format(dataset_file)) + return None + + with open(dataset_file, "r") as f: + datasets = json.load(f) + + # Save dir + f0_dir = os.path.join(output_path, dataset, "f0") + os.makedirs(f0_dir, exist_ok=True) + + # Extract + f0_features = [] + sp_features = [] + for utt in tqdm(datasets): + wave_file = get_wav_file_path(dataset, wave_dir, utt) + f0, sp, _, _ = extract_world_features(wave_file, fs, frameshift) + + sp_features.append(sp) + f0_features.append(f0) + + # Save sp + if save_sp_feature: + sp_dir = os.path.join(output_path, dataset, "sp") + os.makedirs(sp_dir, exist_ok=True) + with open(os.path.join(sp_dir, "{}.pkl".format(dataset_type)), "wb") as f: + pickle.dump(sp_features, f) + + # F0 statistics + f0_statistics_file = os.path.join(f0_dir, "{}_f0.pkl".format(dataset_type)) + f0_statistics(f0_features, f0_statistics_file) + + return sp_features + + +def f0_statistics(f0_features, path): + print("\nF0 statistics...") + + total_f0 = [] + for f0 in tqdm(f0_features): + total_f0 += [f for f in f0 if f != 0] + + mean = sum(total_f0) / len(total_f0) + print("Min = {}, Max = {}, Mean = {}".format(min(total_f0), max(total_f0), mean)) + + with open(path, "wb") as f: + pickle.dump([mean, total_f0], f) + + +def world_synthesis(f0, sp, ap, fs, frameshift): + y = pw.synthesize( + f0, sp, ap, fs, frame_period=frameshift + ) # synthesize an utterance using the parameters + return y diff --git a/models/vocoders/flow/flow_vocoder_dataset.py b/models/vocoders/flow/flow_vocoder_dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/flow/flow_vocoder_inference.py b/models/vocoders/flow/flow_vocoder_inference.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/flow/flow_vocoder_trainer.py b/models/vocoders/flow/flow_vocoder_trainer.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/flow/waveglow/waveglow.py b/models/vocoders/flow/waveglow/waveglow.py new file mode 100644 index 0000000000000000000000000000000000000000..13e2a1bf8f5e3c3d47a031ceec87e4ff111cd5fe --- /dev/null +++ b/models/vocoders/flow/waveglow/waveglow.py @@ -0,0 +1,249 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.autograd import Variable +import torch.nn.functional as F + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +class Invertible1x1Conv(torch.nn.Module): + """ + The layer outputs both the convolution, and the log determinant + of its weight matrix. If reverse=True it does convolution with + inverse + """ + + def __init__(self, c): + super(Invertible1x1Conv, self).__init__() + self.conv = torch.nn.Conv1d( + c, c, kernel_size=1, stride=1, padding=0, bias=False + ) + + # Sample a random orthonormal matrix to initialize weights + W = torch.linalg.qr(torch.FloatTensor(c, c).normal_())[0] + + # Ensure determinant is 1.0 not -1.0 + if torch.det(W) < 0: + W[:, 0] = -1 * W[:, 0] + W = W.view(c, c, 1) + self.conv.weight.data = W + + def forward(self, z, reverse=False): + # shape + batch_size, group_size, n_of_groups = z.size() + + W = self.conv.weight.squeeze() + + if reverse: + if not hasattr(self, "W_inverse"): + # Reverse computation + W_inverse = W.float().inverse() + W_inverse = Variable(W_inverse[..., None]) + if z.type() == "torch.cuda.HalfTensor": + W_inverse = W_inverse.half() + self.W_inverse = W_inverse + z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) + return z + else: + # Forward computation + log_det_W = batch_size * n_of_groups * torch.logdet(W) + z = self.conv(z) + return z, log_det_W + + +class WN(torch.nn.Module): + """ + This is the WaveNet like layer for the affine coupling. The primary difference + from WaveNet is the convolutions need not be causal. There is also no dilation + size reset. The dilation only doubles on each layer + """ + + def __init__( + self, n_in_channels, n_mel_channels, n_layers, n_channels, kernel_size + ): + super(WN, self).__init__() + assert kernel_size % 2 == 1 + assert n_channels % 2 == 0 + self.n_layers = n_layers + self.n_channels = n_channels + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + + start = torch.nn.Conv1d(n_in_channels, n_channels, 1) + start = torch.nn.utils.weight_norm(start, name="weight") + self.start = start + + # Initializing last layer to 0 makes the affine coupling layers + # do nothing at first. This helps with training stability + end = torch.nn.Conv1d(n_channels, 2 * n_in_channels, 1) + end.weight.data.zero_() + end.bias.data.zero_() + self.end = end + + cond_layer = torch.nn.Conv1d(n_mel_channels, 2 * n_channels * n_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") + + for i in range(n_layers): + dilation = 2**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d( + n_channels, + 2 * n_channels, + kernel_size, + dilation=dilation, + padding=padding, + ) + in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * n_channels + else: + res_skip_channels = n_channels + res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + + def forward(self, forward_input): + audio, spect = forward_input + audio = self.start(audio) + output = torch.zeros_like(audio) + n_channels_tensor = torch.IntTensor([self.n_channels]) + + spect = self.cond_layer(spect) + + for i in range(self.n_layers): + spect_offset = i * 2 * self.n_channels + acts = fused_add_tanh_sigmoid_multiply( + self.in_layers[i](audio), + spect[:, spect_offset : spect_offset + 2 * self.n_channels, :], + n_channels_tensor, + ) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + audio = audio + res_skip_acts[:, : self.n_channels, :] + output = output + res_skip_acts[:, self.n_channels :, :] + else: + output = output + res_skip_acts + + return self.end(output) + + +class WaveGlow(torch.nn.Module): + def __init__(self, cfg): + super(WaveGlow, self).__init__() + + self.cfg = cfg + + self.upsample = torch.nn.ConvTranspose1d( + self.cfg.VOCODER.INPUT_DIM, + self.cfg.VOCODER.INPUT_DIM, + 1024, + stride=256, + ) + assert self.cfg.VOCODER.N_GROUP % 2 == 0 + self.n_flows = self.cfg.VOCODER.N_FLOWS + self.n_group = self.cfg.VOCODER.N_GROUP + self.n_early_every = self.cfg.VOCODER.N_EARLY_EVERY + self.n_early_size = self.cfg.VOCODER.N_EARLY_SIZE + self.WN = torch.nn.ModuleList() + self.convinv = torch.nn.ModuleList() + + n_half = int(self.cfg.VOCODER.N_GROUP / 2) + + # Set up layers with the right sizes based on how many dimensions + # have been output already + n_remaining_channels = self.cfg.VOCODER.N_GROUP + for k in range(self.cfg.VOCODER.N_FLOWS): + if k % self.n_early_every == 0 and k > 0: + n_half = n_half - int(self.n_early_size / 2) + n_remaining_channels = n_remaining_channels - self.n_early_size + self.convinv.append(Invertible1x1Conv(n_remaining_channels)) + self.WN.append( + WN( + n_half, + self.cfg.VOCODER.INPUT_DIM * self.cfg.VOCODER.N_GROUP, + self.cfg.VOCODER.N_LAYERS, + self.cfg.VOCODER.N_CHANNELS, + self.cfg.VOCODER.KERNEL_SIZE, + ) + ) + self.n_remaining_channels = n_remaining_channels # Useful during inference + + def forward(self, forward_input): + """ + forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames + forward_input[1] = audio: batch x time + """ + spect, audio = forward_input + + # Upsample spectrogram to size of audio + spect = self.upsample(spect) + assert spect.size(2) >= audio.size(1) + if spect.size(2) > audio.size(1): + spect = spect[:, :, : audio.size(1)] + + spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) + spect = ( + spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) + ) + + audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) + output_audio = [] + log_s_list = [] + log_det_W_list = [] + + for k in range(self.n_flows): + if k % self.n_early_every == 0 and k > 0: + output_audio.append(audio[:, : self.n_early_size, :]) + audio = audio[:, self.n_early_size :, :] + + audio, log_det_W = self.convinv[k](audio) + log_det_W_list.append(log_det_W) + + n_half = int(audio.size(1) / 2) + audio_0 = audio[:, :n_half, :] + audio_1 = audio[:, n_half:, :] + + output = self.WN[k]((audio_0, spect)) + log_s = output[:, n_half:, :] + b = output[:, :n_half, :] + audio_1 = torch.exp(log_s) * audio_1 + b + log_s_list.append(log_s) + + audio = torch.cat([audio_0, audio_1], 1) + + output_audio.append(audio) + return torch.cat(output_audio, 1), log_s_list, log_det_W_list + + @staticmethod + def remove_weightnorm(model): + waveglow = model + for WN in waveglow.WN: + WN.start = torch.nn.utils.remove_weight_norm(WN.start) + WN.in_layers = remove(WN.in_layers) + WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer) + WN.res_skip_layers = remove(WN.res_skip_layers) + return waveglow + + +def remove(conv_list): + new_conv_list = torch.nn.ModuleList() + for old_conv in conv_list: + old_conv = torch.nn.utils.remove_weight_norm(old_conv) + new_conv_list.append(old_conv) + return new_conv_list diff --git a/models/vocoders/gan/discriminator/mpd.py b/models/vocoders/gan/discriminator/mpd.py new file mode 100644 index 0000000000000000000000000000000000000000..f28711d18847a106a998cab90871fe6303a4fd08 --- /dev/null +++ b/models/vocoders/gan/discriminator/mpd.py @@ -0,0 +1,269 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv2d, Conv1d +from torch.nn.utils import weight_norm, spectral_norm +from torch import nn +from modules.vocoder_blocks import * + +LRELU_SLOPE = 0.1 + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, cfg, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.d_mult = cfg.model.mpd.discriminator_channel_mult_factor + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + int(32 * self.d_mult), + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + int(32 * self.d_mult), + int(128 * self.d_mult), + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + int(128 * self.d_mult), + int(512 * self.d_mult), + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + int(512 * self.d_mult), + int(1024 * self.d_mult), + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + int(1024 * self.d_mult), + int(1024 * self.d_mult), + (kernel_size, 1), + (stride, 1), + padding=(2, 0), + ) + ), + ] + ) + self.conv_post = norm_f( + Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)) + ) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, cfg): + super(MultiPeriodDiscriminator, self).__init__() + self.mpd_reshapes = cfg.model.mpd.mpd_reshapes + print("mpd_reshapes: {}".format(self.mpd_reshapes)) + discriminators = [ + DiscriminatorP(cfg, rs, use_spectral_norm=cfg.model.mpd.use_spectral_norm) + for rs in self.mpd_reshapes + ] + self.discriminators = nn.ModuleList(discriminators) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +# TODO: merge with DiscriminatorP (lmxue, yicheng) +class DiscriminatorP_vits(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP_vits, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 1024, + 1024, + (kernel_size, 1), + 1, + padding=(get_padding(kernel_size, 1), 0), + ) + ), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +# TODO: merge with MultiPeriodDiscriminator (lmxue, yicheng) +class MultiPeriodDiscriminator_vits(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator_vits, self).__init__() + periods = [2, 3, 5, 7, 11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [ + DiscriminatorP_vits(i, use_spectral_norm=use_spectral_norm) for i in periods + ] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + outputs = { + "y_d_hat_r": y_d_rs, + "y_d_hat_g": y_d_gs, + "fmap_rs": fmap_rs, + "fmap_gs": fmap_gs, + } + + return outputs diff --git a/models/vocoders/gan/discriminator/mrd.py b/models/vocoders/gan/discriminator/mrd.py new file mode 100644 index 0000000000000000000000000000000000000000..38ee80bfbf82b6aa63c80dbc2c6ffed8cb50a924 --- /dev/null +++ b/models/vocoders/gan/discriminator/mrd.py @@ -0,0 +1,160 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from torch import nn + +LRELU_SLOPE = 0.1 + + +# This code is a refined MRD adopted from BigVGAN under the MIT License +# https://github.com/NVIDIA/BigVGAN + + +class DiscriminatorR(nn.Module): + def __init__(self, cfg, resolution): + super().__init__() + + self.resolution = resolution + assert ( + len(self.resolution) == 3 + ), "MRD layer requires list with len=3, got {}".format(self.resolution) + self.lrelu_slope = LRELU_SLOPE + + norm_f = ( + weight_norm if cfg.model.mrd.use_spectral_norm == False else spectral_norm + ) + if cfg.model.mrd.mrd_override: + print( + "INFO: overriding MRD use_spectral_norm as {}".format( + cfg.model.mrd.mrd_use_spectral_norm + ) + ) + norm_f = ( + weight_norm + if cfg.model.mrd.mrd_use_spectral_norm == False + else spectral_norm + ) + self.d_mult = cfg.model.mrd.discriminator_channel_mult_factor + if cfg.model.mrd.mrd_override: + print( + "INFO: overriding mrd channel multiplier as {}".format( + cfg.model.mrd.mrd_channel_mult + ) + ) + self.d_mult = cfg.model.mrd.mrd_channel_mult + + self.convs = nn.ModuleList( + [ + norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))), + norm_f( + nn.Conv2d( + int(32 * self.d_mult), + int(32 * self.d_mult), + (3, 9), + stride=(1, 2), + padding=(1, 4), + ) + ), + norm_f( + nn.Conv2d( + int(32 * self.d_mult), + int(32 * self.d_mult), + (3, 9), + stride=(1, 2), + padding=(1, 4), + ) + ), + norm_f( + nn.Conv2d( + int(32 * self.d_mult), + int(32 * self.d_mult), + (3, 9), + stride=(1, 2), + padding=(1, 4), + ) + ), + norm_f( + nn.Conv2d( + int(32 * self.d_mult), + int(32 * self.d_mult), + (3, 3), + padding=(1, 1), + ) + ), + ] + ) + self.conv_post = norm_f( + nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)) + ) + + def forward(self, x): + fmap = [] + + x = self.spectrogram(x) + x = x.unsqueeze(1) + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, self.lrelu_slope) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + def spectrogram(self, x): + n_fft, hop_length, win_length = self.resolution + x = F.pad( + x, + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + x = x.squeeze(1) + x = torch.stft( + x, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + center=False, + return_complex=True, + ) + x = torch.view_as_real(x) # [B, F, TT, 2] + mag = torch.norm(x, p=2, dim=-1) # [B, F, TT] + + return mag + + +class MultiResolutionDiscriminator(nn.Module): + def __init__(self, cfg, debug=False): + super().__init__() + self.resolutions = cfg.model.mrd.resolutions + assert ( + len(self.resolutions) == 3 + ), "MRD requires list of list with len=3, each element having a list with len=3. got {}".format( + self.resolutions + ) + self.discriminators = nn.ModuleList( + [DiscriminatorR(cfg, resolution) for resolution in self.resolutions] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(x=y) + y_d_g, fmap_g = d(x=y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs diff --git a/models/vocoders/gan/discriminator/msd.py b/models/vocoders/gan/discriminator/msd.py new file mode 100644 index 0000000000000000000000000000000000000000..4c1556aea581878dcbe10f7a3bdebc33a4972e2c --- /dev/null +++ b/models/vocoders/gan/discriminator/msd.py @@ -0,0 +1,88 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, AvgPool1d +from torch.nn.utils import weight_norm, spectral_norm +from torch import nn +from modules.vocoder_blocks import * + + +LRELU_SLOPE = 0.1 + + +class DiscriminatorS(nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(nn.Module): + def __init__(self, cfg): + super(MultiScaleDiscriminator, self).__init__() + + self.cfg = cfg + + self.discriminators = nn.ModuleList( + [ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ] + ) + + self.meanpools = nn.ModuleList( + [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs diff --git a/models/vocoders/gan/discriminator/mssbcqtd.py b/models/vocoders/gan/discriminator/mssbcqtd.py new file mode 100644 index 0000000000000000000000000000000000000000..213de5441754944a360707e99a3734ad035d9077 --- /dev/null +++ b/models/vocoders/gan/discriminator/mssbcqtd.py @@ -0,0 +1,182 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch import nn +from modules.vocoder_blocks import * + +from einops import rearrange +import torchaudio.transforms as T + +from nnAudio import features + +LRELU_SLOPE = 0.1 + + +class DiscriminatorCQT(nn.Module): + def __init__(self, cfg, hop_length, n_octaves, bins_per_octave): + super(DiscriminatorCQT, self).__init__() + self.cfg = cfg + + self.filters = cfg.model.mssbcqtd.filters + self.max_filters = cfg.model.mssbcqtd.max_filters + self.filters_scale = cfg.model.mssbcqtd.filters_scale + self.kernel_size = (3, 9) + self.dilations = cfg.model.mssbcqtd.dilations + self.stride = (1, 2) + + self.in_channels = cfg.model.mssbcqtd.in_channels + self.out_channels = cfg.model.mssbcqtd.out_channels + self.fs = cfg.preprocess.sample_rate + self.hop_length = hop_length + self.n_octaves = n_octaves + self.bins_per_octave = bins_per_octave + + self.cqt_transform = features.cqt.CQT2010v2( + sr=self.fs * 2, + hop_length=self.hop_length, + n_bins=self.bins_per_octave * self.n_octaves, + bins_per_octave=self.bins_per_octave, + output_format="Complex", + pad_mode="constant", + ) + + self.conv_pres = nn.ModuleList() + for i in range(self.n_octaves): + self.conv_pres.append( + NormConv2d( + self.in_channels * 2, + self.in_channels * 2, + kernel_size=self.kernel_size, + padding=get_2d_padding(self.kernel_size), + ) + ) + + self.convs = nn.ModuleList() + + self.convs.append( + NormConv2d( + self.in_channels * 2, + self.filters, + kernel_size=self.kernel_size, + padding=get_2d_padding(self.kernel_size), + ) + ) + + in_chs = min(self.filters_scale * self.filters, self.max_filters) + for i, dilation in enumerate(self.dilations): + out_chs = min( + (self.filters_scale ** (i + 1)) * self.filters, self.max_filters + ) + self.convs.append( + NormConv2d( + in_chs, + out_chs, + kernel_size=self.kernel_size, + stride=self.stride, + dilation=(dilation, 1), + padding=get_2d_padding(self.kernel_size, (dilation, 1)), + norm="weight_norm", + ) + ) + in_chs = out_chs + out_chs = min( + (self.filters_scale ** (len(self.dilations) + 1)) * self.filters, + self.max_filters, + ) + self.convs.append( + NormConv2d( + in_chs, + out_chs, + kernel_size=(self.kernel_size[0], self.kernel_size[0]), + padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])), + norm="weight_norm", + ) + ) + + self.conv_post = NormConv2d( + out_chs, + self.out_channels, + kernel_size=(self.kernel_size[0], self.kernel_size[0]), + padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])), + norm="weight_norm", + ) + + self.activation = torch.nn.LeakyReLU(negative_slope=LRELU_SLOPE) + self.resample = T.Resample(orig_freq=self.fs, new_freq=self.fs * 2) + + def forward(self, x): + fmap = [] + + x = self.resample(x) + + z = self.cqt_transform(x) + + z_amplitude = z[:, :, :, 0].unsqueeze(1) + z_phase = z[:, :, :, 1].unsqueeze(1) + + z = torch.cat([z_amplitude, z_phase], dim=1) + z = rearrange(z, "b c w t -> b c t w") + + latent_z = [] + for i in range(self.n_octaves): + latent_z.append( + self.conv_pres[i]( + z[ + :, + :, + :, + i * self.bins_per_octave : (i + 1) * self.bins_per_octave, + ] + ) + ) + latent_z = torch.cat(latent_z, dim=-1) + + for i, l in enumerate(self.convs): + latent_z = l(latent_z) + + latent_z = self.activation(latent_z) + fmap.append(latent_z) + + latent_z = self.conv_post(latent_z) + + return latent_z, fmap + + +class MultiScaleSubbandCQTDiscriminator(nn.Module): + def __init__(self, cfg): + super(MultiScaleSubbandCQTDiscriminator, self).__init__() + + self.cfg = cfg + + self.discriminators = nn.ModuleList( + [ + DiscriminatorCQT( + cfg, + hop_length=cfg.model.mssbcqtd.hop_lengths[i], + n_octaves=cfg.model.mssbcqtd.n_octaves[i], + bins_per_octave=cfg.model.mssbcqtd.bins_per_octaves[i], + ) + for i in range(len(cfg.model.mssbcqtd.hop_lengths)) + ] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for disc in self.discriminators: + y_d_r, fmap_r = disc(y) + y_d_g, fmap_g = disc(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs diff --git a/models/vocoders/gan/discriminator/msstftd.py b/models/vocoders/gan/discriminator/msstftd.py new file mode 100644 index 0000000000000000000000000000000000000000..83dedb78848d2d73ac667e7a191f05de1ed7bf21 --- /dev/null +++ b/models/vocoders/gan/discriminator/msstftd.py @@ -0,0 +1,226 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This code is adopted from META's Encodec under MIT License +# https://github.com/facebookresearch/encodec + +"""MS-STFT discriminator, provided here for reference.""" + +import typing as tp + +import torchaudio +import torch +from torch import nn +from einops import rearrange + +from modules.vocoder_blocks import * + + +FeatureMapType = tp.List[torch.Tensor] +LogitsType = torch.Tensor +DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]] + + +def get_2d_padding( + kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1) +): + return ( + ((kernel_size[0] - 1) * dilation[0]) // 2, + ((kernel_size[1] - 1) * dilation[1]) // 2, + ) + + +class DiscriminatorSTFT(nn.Module): + """STFT sub-discriminator. + Args: + filters (int): Number of filters in convolutions + in_channels (int): Number of input channels. Default: 1 + out_channels (int): Number of output channels. Default: 1 + n_fft (int): Size of FFT for each scale. Default: 1024 + hop_length (int): Length of hop between STFT windows for each scale. Default: 256 + kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)`` + stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)`` + dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]`` + win_length (int): Window size for each scale. Default: 1024 + normalized (bool): Whether to normalize by magnitude after stft. Default: True + norm (str): Normalization method. Default: `'weight_norm'` + activation (str): Activation function. Default: `'LeakyReLU'` + activation_params (dict): Parameters to provide to the activation function. + growth (int): Growth factor for the filters. Default: 1 + """ + + def __init__( + self, + filters: int, + in_channels: int = 1, + out_channels: int = 1, + n_fft: int = 1024, + hop_length: int = 256, + win_length: int = 1024, + max_filters: int = 1024, + filters_scale: int = 1, + kernel_size: tp.Tuple[int, int] = (3, 9), + dilations: tp.List = [1, 2, 4], + stride: tp.Tuple[int, int] = (1, 2), + normalized: bool = True, + norm: str = "weight_norm", + activation: str = "LeakyReLU", + activation_params: dict = {"negative_slope": 0.2}, + ): + super().__init__() + assert len(kernel_size) == 2 + assert len(stride) == 2 + self.filters = filters + self.in_channels = in_channels + self.out_channels = out_channels + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.normalized = normalized + self.activation = getattr(torch.nn, activation)(**activation_params) + self.spec_transform = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window_fn=torch.hann_window, + normalized=self.normalized, + center=False, + pad_mode=None, + power=None, + ) + spec_channels = 2 * self.in_channels + self.convs = nn.ModuleList() + self.convs.append( + NormConv2d( + spec_channels, + self.filters, + kernel_size=kernel_size, + padding=get_2d_padding(kernel_size), + ) + ) + in_chs = min(filters_scale * self.filters, max_filters) + for i, dilation in enumerate(dilations): + out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters) + self.convs.append( + NormConv2d( + in_chs, + out_chs, + kernel_size=kernel_size, + stride=stride, + dilation=(dilation, 1), + padding=get_2d_padding(kernel_size, (dilation, 1)), + norm=norm, + ) + ) + in_chs = out_chs + out_chs = min( + (filters_scale ** (len(dilations) + 1)) * self.filters, max_filters + ) + self.convs.append( + NormConv2d( + in_chs, + out_chs, + kernel_size=(kernel_size[0], kernel_size[0]), + padding=get_2d_padding((kernel_size[0], kernel_size[0])), + norm=norm, + ) + ) + self.conv_post = NormConv2d( + out_chs, + self.out_channels, + kernel_size=(kernel_size[0], kernel_size[0]), + padding=get_2d_padding((kernel_size[0], kernel_size[0])), + norm=norm, + ) + + def forward(self, x: torch.Tensor): + """Discriminator STFT Module is the sub module of MultiScaleSTFTDiscriminator. + + Args: + x (torch.Tensor): input tensor of shape [B, 1, Time] + + Returns: + z: z is the output of the last convolutional layer of shape + fmap: fmap is the list of feature maps of every convolutional layer of shape + """ + fmap = [] + z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] + z = torch.cat([z.real, z.imag], dim=1) + z = rearrange(z, "b c w t -> b c t w") + for i, layer in enumerate(self.convs): + z = layer(z) + + z = self.activation(z) + fmap.append(z) + z = self.conv_post(z) + return z, fmap + + +class MultiScaleSTFTDiscriminator(nn.Module): + """Multi-Scale STFT (MS-STFT) discriminator. + Args: + filters (int): Number of filters in convolutions + in_channels (int): Number of input channels. Default: 1 + out_channels (int): Number of output channels. Default: 1 + n_ffts (Sequence[int]): Size of FFT for each scale + hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale + win_lengths (Sequence[int]): Window size for each scale + **kwargs: additional args for STFTDiscriminator + """ + + def __init__( + self, + cfg, + in_channels: int = 1, + out_channels: int = 1, + n_ffts: tp.List[int] = [1024, 2048, 512], + hop_lengths: tp.List[int] = [256, 512, 256], + win_lengths: tp.List[int] = [1024, 2048, 512], + **kwargs, + ): + self.cfg = cfg + super().__init__() + assert len(n_ffts) == len(hop_lengths) == len(win_lengths) + self.discriminators = nn.ModuleList( + [ + DiscriminatorSTFT( + filters=self.cfg.model.msstftd.filters, + in_channels=in_channels, + out_channels=out_channels, + n_fft=n_ffts[i], + win_length=win_lengths[i], + hop_length=hop_lengths[i], + **kwargs, + ) + for i in range(len(n_ffts)) + ] + ) + self.num_discriminators = len(self.discriminators) + + def forward(self, y, y_hat) -> DiscriminatorOutput: + """Multi-Scale STFT (MS-STFT) discriminator. + + Args: + x (torch.Tensor): input waveform + + Returns: + logits: list of every discriminator's output + fmaps: list of every discriminator's feature maps, + each feature maps is a list of Discriminator STFT's every layer + """ + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for disc in self.discriminators: + y_d_r, fmap_r = disc(y) + y_d_g, fmap_g = disc(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs diff --git a/models/vocoders/gan/gan_vocoder_dataset.py b/models/vocoders/gan/gan_vocoder_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5bf87c371647a44fb5bcae33701eda65616e5fd7 --- /dev/null +++ b/models/vocoders/gan/gan_vocoder_dataset.py @@ -0,0 +1,205 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import random + +import numpy as np + +from torch.nn import functional as F + +from torch.nn.utils.rnn import pad_sequence +from utils.data_utils import * +from models.vocoders.vocoder_dataset import VocoderDataset + + +class GANVocoderDataset(VocoderDataset): + def __init__(self, cfg, dataset, is_valid=False): + """ + Args: + cfg: config + dataset: dataset name + is_valid: whether to use train or valid dataset + """ + super().__init__(cfg, dataset, is_valid) + + eval_index = random.randint(0, len(self.metadata) - 1) + eval_utt_info = self.metadata[eval_index] + eval_utt = "{}_{}".format(eval_utt_info["Dataset"], eval_utt_info["Uid"]) + self.eval_audio = np.load(self.utt2audio_path[eval_utt]) + if cfg.preprocess.use_mel: + self.eval_mel = np.load(self.utt2mel_path[eval_utt]) + if cfg.preprocess.use_frame_pitch: + self.eval_pitch = np.load(self.utt2frame_pitch_path[eval_utt]) + + def __getitem__(self, index): + utt_info = self.metadata[index] + + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + single_feature = dict() + + if self.cfg.preprocess.use_mel: + mel = np.load(self.utt2mel_path[utt]) + assert mel.shape[0] == self.cfg.preprocess.n_mel + + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = mel.shape[1] + + if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame: + mel = np.pad( + mel, + ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])), + mode="constant", + ) + else: + if "start" not in single_feature.keys(): + start = random.randint( + 0, mel.shape[-1] - self.cfg.preprocess.cut_mel_frame + ) + end = start + self.cfg.preprocess.cut_mel_frame + single_feature["start"] = start + single_feature["end"] = end + mel = mel[:, single_feature["start"] : single_feature["end"]] + single_feature["mel"] = mel + + if self.cfg.preprocess.use_frame_pitch: + frame_pitch = np.load(self.utt2frame_pitch_path[utt]) + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = len(frame_pitch) + aligned_frame_pitch = align_length( + frame_pitch, single_feature["target_len"] + ) + + if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame: + aligned_frame_pitch = np.pad( + aligned_frame_pitch, + ( + ( + 0, + self.cfg.preprocess.cut_mel_frame + * self.cfg.preprocess.hop_size + - audio.shape[-1], + ) + ), + mode="constant", + ) + else: + if "start" not in single_feature.keys(): + start = random.randint( + 0, + aligned_frame_pitch.shape[-1] + - self.cfg.preprocess.cut_mel_frame, + ) + end = start + self.cfg.preprocess.cut_mel_frame + single_feature["start"] = start + single_feature["end"] = end + aligned_frame_pitch = aligned_frame_pitch[ + single_feature["start"] : single_feature["end"] + ] + single_feature["frame_pitch"] = aligned_frame_pitch + + if self.cfg.preprocess.use_audio: + audio = np.load(self.utt2audio_path[utt]) + + assert "target_len" in single_feature.keys() + + if ( + audio.shape[-1] + <= self.cfg.preprocess.cut_mel_frame * self.cfg.preprocess.hop_size + ): + audio = np.pad( + audio, + ( + ( + 0, + self.cfg.preprocess.cut_mel_frame + * self.cfg.preprocess.hop_size + - audio.shape[-1], + ) + ), + mode="constant", + ) + else: + if "start" not in single_feature.keys(): + audio = audio[ + 0 : self.cfg.preprocess.cut_mel_frame + * self.cfg.preprocess.hop_size + ] + else: + audio = audio[ + single_feature["start"] + * self.cfg.preprocess.hop_size : single_feature["end"] + * self.cfg.preprocess.hop_size, + ] + single_feature["audio"] = audio + + if self.cfg.preprocess.use_amplitude_phase: + logamp = np.load(self.utt2logamp_path[utt]) + pha = np.load(self.utt2pha_path[utt]) + rea = np.load(self.utt2rea_path[utt]) + imag = np.load(self.utt2imag_path[utt]) + + assert "target_len" in single_feature.keys() + + if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame: + logamp = np.pad( + logamp, + ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])), + mode="constant", + ) + pha = np.pad( + pha, + ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])), + mode="constant", + ) + rea = np.pad( + rea, + ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])), + mode="constant", + ) + imag = np.pad( + imag, + ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])), + mode="constant", + ) + else: + logamp = logamp[:, single_feature["start"] : single_feature["end"]] + pha = pha[:, single_feature["start"] : single_feature["end"]] + rea = rea[:, single_feature["start"] : single_feature["end"]] + imag = imag[:, single_feature["start"] : single_feature["end"]] + single_feature["logamp"] = logamp + single_feature["pha"] = pha + single_feature["rea"] = rea + single_feature["imag"] = imag + + return single_feature + + +class GANVocoderCollator(object): + """Zero-pads model inputs and targets based on number of frames per step""" + + def __init__(self, cfg): + self.cfg = cfg + + def __call__(self, batch): + packed_batch_features = dict() + + # mel: [b, n_mels, frame] + # frame_pitch: [b, frame] + # audios: [b, frame * hop_size] + + for key in batch[0].keys(): + if key in ["target_len", "start", "end"]: + continue + else: + values = [torch.from_numpy(b[key]) for b in batch] + packed_batch_features[key] = pad_sequence( + values, batch_first=True, padding_value=0 + ) + + return packed_batch_features diff --git a/models/vocoders/gan/gan_vocoder_inference.py b/models/vocoders/gan/gan_vocoder_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..cb69631662dedf4fc73a29f675f0a4bc361b03ec --- /dev/null +++ b/models/vocoders/gan/gan_vocoder_inference.py @@ -0,0 +1,110 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from utils.util import pad_mels_to_tensors, pad_f0_to_tensors + + +def vocoder_inference(cfg, model, mels, f0s=None, device=None, fast_inference=False): + """Inference the vocoder + Args: + mels: A tensor of mel-specs with the shape (batch_size, num_mels, frames) + Returns: + audios: A tensor of audios with the shape (batch_size, seq_len) + """ + model.eval() + + with torch.no_grad(): + mels = mels.to(device) + if f0s != None: + f0s = f0s.to(device) + + if f0s == None and not cfg.preprocess.extract_amplitude_phase: + output = model.forward(mels) + elif cfg.preprocess.extract_amplitude_phase: + ( + _, + _, + _, + _, + output, + ) = model.forward(mels) + else: + output = model.forward(mels, f0s) + + return output.squeeze(1).detach().cpu() + + +def synthesis_audios(cfg, model, mels, f0s=None, batch_size=None, fast_inference=False): + """Inference the vocoder + Args: + mels: A list of mel-specs + Returns: + audios: A list of audios + """ + # Get the device + device = next(model.parameters()).device + + audios = [] + + # Pad the given list into tensors + mel_batches, mel_frames = pad_mels_to_tensors(mels, batch_size) + if f0s != None: + f0_batches = pad_f0_to_tensors(f0s, batch_size) + + if f0s == None: + for mel_batch, mel_frame in zip(mel_batches, mel_frames): + for i in range(mel_batch.shape[0]): + mel = mel_batch[i] + frame = mel_frame[i] + audio = vocoder_inference( + cfg, + model, + mel.unsqueeze(0), + device=device, + fast_inference=fast_inference, + ).squeeze(0) + + # # Apply fade_out to make the sound more natural + # fade_out = torch.linspace( + # 1, 0, steps=20 * model.cfg.preprocess.hop_size + # ).cpu() + + # calculate the audio length + audio_length = frame * model.cfg.preprocess.hop_size + audio = audio[:audio_length] + + # audio[-20 * model.cfg.preprocess.hop_size :] *= fade_out + + audios.append(audio) + else: + for mel_batch, f0_batch, mel_frame in zip(mel_batches, f0_batches, mel_frames): + for i in range(mel_batch.shape[0]): + mel = mel_batch[i] + f0 = f0_batch[i] + frame = mel_frame[i] + audio = vocoder_inference( + cfg, + model, + mel.unsqueeze(0), + f0s=f0.unsqueeze(0), + device=device, + fast_inference=fast_inference, + ).squeeze(0) + + # # Apply fade_out to make the sound more natural + # fade_out = torch.linspace( + # 1, 0, steps=20 * model.cfg.preprocess.hop_size + # ).cpu() + + # calculate the audio length + audio_length = frame * model.cfg.preprocess.hop_length + audio = audio[:audio_length] + + # audio[-20 * model.cfg.preprocess.hop_size :] *= fade_out + + audios.append(audio) + return audios diff --git a/models/vocoders/gan/gan_vocoder_trainer.py b/models/vocoders/gan/gan_vocoder_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..2fb9c8f03a7de14d0162bfd671d33b76890293a5 --- /dev/null +++ b/models/vocoders/gan/gan_vocoder_trainer.py @@ -0,0 +1,1112 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys +import time +import torch +import json +import itertools +import accelerate +import torch.distributed as dist +import torch.nn.functional as F +from tqdm import tqdm +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.utils.tensorboard import SummaryWriter + +from torch.optim import AdamW +from torch.optim.lr_scheduler import ExponentialLR + +from librosa.filters import mel as librosa_mel_fn + +from accelerate.logging import get_logger +from pathlib import Path + +from utils.io import save_audio +from utils.data_utils import * +from utils.util import ( + Logger, + ValueWindow, + remove_older_ckpt, + set_all_random_seed, + save_config, +) +from utils.mel import extract_mel_features +from models.vocoders.vocoder_trainer import VocoderTrainer +from models.vocoders.gan.gan_vocoder_dataset import ( + GANVocoderDataset, + GANVocoderCollator, +) + +from models.vocoders.gan.generator.bigvgan import BigVGAN +from models.vocoders.gan.generator.hifigan import HiFiGAN +from models.vocoders.gan.generator.melgan import MelGAN +from models.vocoders.gan.generator.nsfhifigan import NSFHiFiGAN +from models.vocoders.gan.generator.apnet import APNet + +from models.vocoders.gan.discriminator.mpd import MultiPeriodDiscriminator +from models.vocoders.gan.discriminator.mrd import MultiResolutionDiscriminator +from models.vocoders.gan.discriminator.mssbcqtd import MultiScaleSubbandCQTDiscriminator +from models.vocoders.gan.discriminator.msd import MultiScaleDiscriminator +from models.vocoders.gan.discriminator.msstftd import MultiScaleSTFTDiscriminator + +from models.vocoders.gan.gan_vocoder_inference import vocoder_inference + +supported_generators = { + "bigvgan": BigVGAN, + "hifigan": HiFiGAN, + "melgan": MelGAN, + "nsfhifigan": NSFHiFiGAN, + "apnet": APNet, +} + +supported_discriminators = { + "mpd": MultiPeriodDiscriminator, + "msd": MultiScaleDiscriminator, + "mrd": MultiResolutionDiscriminator, + "msstftd": MultiScaleSTFTDiscriminator, + "mssbcqtd": MultiScaleSubbandCQTDiscriminator, +} + + +class GANVocoderTrainer(VocoderTrainer): + def __init__(self, args, cfg): + super().__init__() + + self.args = args + self.cfg = cfg + + cfg.exp_name = args.exp_name + + # Init accelerator + self._init_accelerator() + self.accelerator.wait_for_everyone() + + # Init logger + with self.accelerator.main_process_first(): + self.logger = get_logger(args.exp_name, log_level=args.log_level) + + self.logger.info("=" * 56) + self.logger.info("||\t\t" + "New training process started." + "\t\t||") + self.logger.info("=" * 56) + self.logger.info("\n") + self.logger.debug(f"Using {args.log_level.upper()} logging level.") + self.logger.info(f"Experiment name: {args.exp_name}") + self.logger.info(f"Experiment directory: {self.exp_dir}") + self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint") + if self.accelerator.is_main_process: + os.makedirs(self.checkpoint_dir, exist_ok=True) + self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}") + + # Init training status + self.batch_count: int = 0 + self.step: int = 0 + self.epoch: int = 0 + + self.max_epoch = ( + self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf") + ) + self.logger.info( + "Max epoch: {}".format( + self.max_epoch if self.max_epoch < float("inf") else "Unlimited" + ) + ) + + # Check potential erorrs + if self.accelerator.is_main_process: + self._check_basic_configs() + self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride + self.checkpoints_path = [ + [] for _ in range(len(self.save_checkpoint_stride)) + ] + self.run_eval = self.cfg.train.run_eval + + # Set random seed + with self.accelerator.main_process_first(): + start = time.monotonic_ns() + self._set_random_seed(self.cfg.train.random_seed) + end = time.monotonic_ns() + self.logger.debug( + f"Setting random seed done in {(end - start) / 1e6:.2f}ms" + ) + self.logger.debug(f"Random seed: {self.cfg.train.random_seed}") + + # Build dataloader + with self.accelerator.main_process_first(): + self.logger.info("Building dataset...") + start = time.monotonic_ns() + self.train_dataloader, self.valid_dataloader = self._build_dataloader() + end = time.monotonic_ns() + self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms") + + # Build model + with self.accelerator.main_process_first(): + self.logger.info("Building model...") + start = time.monotonic_ns() + self.generator, self.discriminators = self._build_model() + end = time.monotonic_ns() + self.logger.debug(self.generator) + for _, discriminator in self.discriminators.items(): + self.logger.debug(discriminator) + self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms") + self.logger.info(f"Model parameters: {self._count_parameters()/1e6:.2f}M") + + # Build optimizers and schedulers + with self.accelerator.main_process_first(): + self.logger.info("Building optimizer and scheduler...") + start = time.monotonic_ns() + ( + self.generator_optimizer, + self.discriminator_optimizer, + ) = self._build_optimizer() + ( + self.generator_scheduler, + self.discriminator_scheduler, + ) = self._build_scheduler() + end = time.monotonic_ns() + self.logger.info( + f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms" + ) + + # Accelerator preparing + self.logger.info("Initializing accelerate...") + start = time.monotonic_ns() + ( + self.train_dataloader, + self.valid_dataloader, + self.generator, + self.generator_optimizer, + self.discriminator_optimizer, + self.generator_scheduler, + self.discriminator_scheduler, + ) = self.accelerator.prepare( + self.train_dataloader, + self.valid_dataloader, + self.generator, + self.generator_optimizer, + self.discriminator_optimizer, + self.generator_scheduler, + self.discriminator_scheduler, + ) + for key, discriminator in self.discriminators.items(): + self.discriminators[key] = self.accelerator.prepare_model(discriminator) + end = time.monotonic_ns() + self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms") + + # Build criterions + with self.accelerator.main_process_first(): + self.logger.info("Building criterion...") + start = time.monotonic_ns() + self.criterions = self._build_criterion() + end = time.monotonic_ns() + self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms") + + # Resume checkpoints + with self.accelerator.main_process_first(): + if args.resume_type: + self.logger.info("Resuming from checkpoint...") + start = time.monotonic_ns() + ckpt_path = Path(args.checkpoint) + if self._is_valid_pattern(ckpt_path.parts[-1]): + ckpt_path = self._load_model( + None, args.checkpoint, args.resume_type + ) + else: + ckpt_path = self._load_model( + args.checkpoint, resume_type=args.resume_type + ) + end = time.monotonic_ns() + self.logger.info( + f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms" + ) + self.checkpoints_path = json.load( + open(os.path.join(ckpt_path, "ckpts.json"), "r") + ) + + self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint") + if self.accelerator.is_main_process: + os.makedirs(self.checkpoint_dir, exist_ok=True) + self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}") + + # Save config + self.config_save_path = os.path.join(self.exp_dir, "args.json") + + def _build_dataset(self): + return GANVocoderDataset, GANVocoderCollator + + def _build_criterion(self): + class feature_criterion(torch.nn.Module): + def __init__(self, cfg): + super(feature_criterion, self).__init__() + self.cfg = cfg + self.l1Loss = torch.nn.L1Loss(reduction="mean") + self.l2Loss = torch.nn.MSELoss(reduction="mean") + self.relu = torch.nn.ReLU() + + def __call__(self, fmap_r, fmap_g): + loss = 0 + + if self.cfg.model.generator in [ + "hifigan", + "nsfhifigan", + "bigvgan", + "apnet", + ]: + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + loss = loss * 2 + elif self.cfg.model.generator in ["melgan"]: + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += self.l1Loss(rl, gl) + + loss = loss * 10 + elif self.cfg.model.generator in ["codec"]: + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss = loss + self.l1Loss(rl, gl) / torch.mean( + torch.abs(rl) + ) + + KL_scale = len(fmap_r) * len(fmap_r[0]) + + loss = 3 * loss / KL_scale + else: + raise NotImplementedError + + return loss + + class discriminator_criterion(torch.nn.Module): + def __init__(self, cfg): + super(discriminator_criterion, self).__init__() + self.cfg = cfg + self.l1Loss = torch.nn.L1Loss(reduction="mean") + self.l2Loss = torch.nn.MSELoss(reduction="mean") + self.relu = torch.nn.ReLU() + + def __call__(self, disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + + if self.cfg.model.generator in [ + "hifigan", + "nsfhifigan", + "bigvgan", + "apnet", + ]: + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + elif self.cfg.model.generator in ["melgan"]: + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean(self.relu(1 - dr)) + g_loss = torch.mean(self.relu(1 + dg)) + loss = loss + r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + elif self.cfg.model.generator in ["codec"]: + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean(self.relu(1 - dr)) + g_loss = torch.mean(self.relu(1 + dg)) + loss = loss + r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + loss = loss / len(disc_real_outputs) + else: + raise NotImplementedError + + return loss, r_losses, g_losses + + class generator_criterion(torch.nn.Module): + def __init__(self, cfg): + super(generator_criterion, self).__init__() + self.cfg = cfg + self.l1Loss = torch.nn.L1Loss(reduction="mean") + self.l2Loss = torch.nn.MSELoss(reduction="mean") + self.relu = torch.nn.ReLU() + + def __call__(self, disc_outputs): + loss = 0 + gen_losses = [] + + if self.cfg.model.generator in [ + "hifigan", + "nsfhifigan", + "bigvgan", + "apnet", + ]: + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + elif self.cfg.model.generator in ["melgan"]: + for dg in disc_outputs: + l = -torch.mean(dg) + gen_losses.append(l) + loss += l + elif self.cfg.model.generator in ["codec"]: + for dg in disc_outputs: + l = torch.mean(self.relu(1 - dg)) / len(disc_outputs) + gen_losses.append(l) + loss += l + else: + raise NotImplementedError + + return loss, gen_losses + + class mel_criterion(torch.nn.Module): + def __init__(self, cfg): + super(mel_criterion, self).__init__() + self.cfg = cfg + self.l1Loss = torch.nn.L1Loss(reduction="mean") + self.l2Loss = torch.nn.MSELoss(reduction="mean") + self.relu = torch.nn.ReLU() + + def __call__(self, y_gt, y_pred): + loss = 0 + + if self.cfg.model.generator in [ + "hifigan", + "nsfhifigan", + "bigvgan", + "melgan", + "codec", + "apnet", + ]: + y_gt_mel = extract_mel_features(y_gt, self.cfg.preprocess) + y_pred_mel = extract_mel_features( + y_pred.squeeze(1), self.cfg.preprocess + ) + + loss = self.l1Loss(y_gt_mel, y_pred_mel) * 45 + else: + raise NotImplementedError + + return loss + + class wav_criterion(torch.nn.Module): + def __init__(self, cfg): + super(wav_criterion, self).__init__() + self.cfg = cfg + self.l1Loss = torch.nn.L1Loss(reduction="mean") + self.l2Loss = torch.nn.MSELoss(reduction="mean") + self.relu = torch.nn.ReLU() + + def __call__(self, y_gt, y_pred): + loss = 0 + + if self.cfg.model.generator in [ + "hifigan", + "nsfhifigan", + "bigvgan", + "apnet", + ]: + loss = self.l2Loss(y_gt, y_pred.squeeze(1)) * 100 + elif self.cfg.model.generator in ["melgan"]: + loss = self.l1Loss(y_gt, y_pred.squeeze(1)) / 10 + elif self.cfg.model.generator in ["codec"]: + loss = self.l1Loss(y_gt, y_pred.squeeze(1)) + self.l2Loss( + y_gt, y_pred.squeeze(1) + ) + loss /= 10 + else: + raise NotImplementedError + + return loss + + class phase_criterion(torch.nn.Module): + def __init__(self, cfg): + super(phase_criterion, self).__init__() + self.cfg = cfg + self.l1Loss = torch.nn.L1Loss(reduction="mean") + self.l2Loss = torch.nn.MSELoss(reduction="mean") + self.relu = torch.nn.ReLU() + + def __call__(self, phase_gt, phase_pred): + n_fft = self.cfg.preprocess.n_fft + frames = phase_gt.size()[-1] + + GD_matrix = ( + torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=1) + - torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=2) + - torch.eye(n_fft // 2 + 1) + ) + GD_matrix = GD_matrix.to(phase_pred.device) + + GD_r = torch.matmul(phase_gt.permute(0, 2, 1), GD_matrix) + GD_g = torch.matmul(phase_pred.permute(0, 2, 1), GD_matrix) + + PTD_matrix = ( + torch.triu(torch.ones(frames, frames), diagonal=1) + - torch.triu(torch.ones(frames, frames), diagonal=2) + - torch.eye(frames) + ) + PTD_matrix = PTD_matrix.to(phase_pred.device) + + PTD_r = torch.matmul(phase_gt, PTD_matrix) + PTD_g = torch.matmul(phase_pred, PTD_matrix) + + IP_loss = torch.mean(-torch.cos(phase_gt - phase_pred)) + GD_loss = torch.mean(-torch.cos(GD_r - GD_g)) + PTD_loss = torch.mean(-torch.cos(PTD_r - PTD_g)) + + return 100 * (IP_loss + GD_loss + PTD_loss) + + class amplitude_criterion(torch.nn.Module): + def __init__(self, cfg): + super(amplitude_criterion, self).__init__() + self.cfg = cfg + self.l1Loss = torch.nn.L1Loss(reduction="mean") + self.l2Loss = torch.nn.MSELoss(reduction="mean") + self.relu = torch.nn.ReLU() + + def __call__(self, log_amplitude_gt, log_amplitude_pred): + amplitude_loss = self.l2Loss(log_amplitude_gt, log_amplitude_pred) + + return 45 * amplitude_loss + + class consistency_criterion(torch.nn.Module): + def __init__(self, cfg): + super(consistency_criterion, self).__init__() + self.cfg = cfg + self.l1Loss = torch.nn.L1Loss(reduction="mean") + self.l2Loss = torch.nn.MSELoss(reduction="mean") + self.relu = torch.nn.ReLU() + + def __call__( + self, + rea_gt, + rea_pred, + rea_pred_final, + imag_gt, + imag_pred, + imag_pred_final, + ): + C_loss = torch.mean( + torch.mean( + (rea_pred - rea_pred_final) ** 2 + + (imag_pred - imag_pred_final) ** 2, + (1, 2), + ) + ) + + L_R = self.l1Loss(rea_gt, rea_pred) + L_I = self.l1Loss(imag_gt, imag_pred) + + return 20 * (C_loss + 2.25 * (L_R + L_I)) + + criterions = dict() + for key in self.cfg.train.criterions: + if key == "feature": + criterions["feature"] = feature_criterion(self.cfg) + elif key == "discriminator": + criterions["discriminator"] = discriminator_criterion(self.cfg) + elif key == "generator": + criterions["generator"] = generator_criterion(self.cfg) + elif key == "mel": + criterions["mel"] = mel_criterion(self.cfg) + elif key == "wav": + criterions["wav"] = wav_criterion(self.cfg) + elif key == "phase": + criterions["phase"] = phase_criterion(self.cfg) + elif key == "amplitude": + criterions["amplitude"] = amplitude_criterion(self.cfg) + elif key == "consistency": + criterions["consistency"] = consistency_criterion(self.cfg) + else: + raise NotImplementedError + + return criterions + + def _build_model(self): + generator = supported_generators[self.cfg.model.generator](self.cfg) + discriminators = dict() + for key in self.cfg.model.discriminators: + discriminators[key] = supported_discriminators[key](self.cfg) + + return generator, discriminators + + def _build_optimizer(self): + optimizer_params_generator = [dict(params=self.generator.parameters())] + generator_optimizer = AdamW( + optimizer_params_generator, + lr=self.cfg.train.adamw.lr, + betas=(self.cfg.train.adamw.adam_b1, self.cfg.train.adamw.adam_b2), + ) + + optimizer_params_discriminator = [] + for discriminator in self.discriminators.keys(): + optimizer_params_discriminator.append( + dict(params=self.discriminators[discriminator].parameters()) + ) + discriminator_optimizer = AdamW( + optimizer_params_discriminator, + lr=self.cfg.train.adamw.lr, + betas=(self.cfg.train.adamw.adam_b1, self.cfg.train.adamw.adam_b2), + ) + + return generator_optimizer, discriminator_optimizer + + def _build_scheduler(self): + discriminator_scheduler = ExponentialLR( + self.discriminator_optimizer, + gamma=self.cfg.train.exponential_lr.lr_decay, + last_epoch=self.epoch - 1, + ) + + generator_scheduler = ExponentialLR( + self.generator_optimizer, + gamma=self.cfg.train.exponential_lr.lr_decay, + last_epoch=self.epoch - 1, + ) + + return generator_scheduler, discriminator_scheduler + + def train_loop(self): + """Training process""" + self.accelerator.wait_for_everyone() + + # Dump config + if self.accelerator.is_main_process: + self._dump_cfg(self.config_save_path) + self.generator.train() + for key in self.discriminators.keys(): + self.discriminators[key].train() + self.generator_optimizer.zero_grad() + self.discriminator_optimizer.zero_grad() + + # Sync and start training + self.accelerator.wait_for_everyone() + while self.epoch < self.max_epoch: + self.logger.info("\n") + self.logger.info("-" * 32) + self.logger.info("Epoch {}: ".format(self.epoch)) + + # Train and Validate + train_total_loss, train_losses = self._train_epoch() + for key, loss in train_losses.items(): + self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss)) + self.accelerator.log( + {"Epoch/Train {} Loss".format(key): loss}, + step=self.epoch, + ) + valid_total_loss, valid_losses = self._valid_epoch() + for key, loss in valid_losses.items(): + self.logger.info(" |- Valid/{} Loss: {:.6f}".format(key, loss)) + self.accelerator.log( + {"Epoch/Valid {} Loss".format(key): loss}, + step=self.epoch, + ) + self.accelerator.log( + { + "Epoch/Train Total Loss": train_total_loss, + "Epoch/Valid Total Loss": valid_total_loss, + }, + step=self.epoch, + ) + + # Update scheduler + self.accelerator.wait_for_everyone() + self.generator_scheduler.step() + self.discriminator_scheduler.step() + + # Check save checkpoint interval + run_eval = False + if self.accelerator.is_main_process: + save_checkpoint = False + for i, num in enumerate(self.save_checkpoint_stride): + if self.epoch % num == 0: + save_checkpoint = True + run_eval |= self.run_eval[i] + + # Save checkpoints + self.accelerator.wait_for_everyone() + if self.accelerator.is_main_process and save_checkpoint: + path = os.path.join( + self.checkpoint_dir, + "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( + self.epoch, self.step, valid_total_loss + ), + ) + self.accelerator.save_state(path) + json.dump( + self.checkpoints_path, + open(os.path.join(path, "ckpts.json"), "w"), + ensure_ascii=False, + indent=4, + ) + + # Save eval audios + self.accelerator.wait_for_everyone() + if self.accelerator.is_main_process and run_eval: + for i in range(len(self.valid_dataloader.dataset.eval_audios)): + if self.cfg.preprocess.use_frame_pitch: + eval_audio = self._inference( + self.valid_dataloader.dataset.eval_mels[i], + eval_pitch=self.valid_dataloader.dataset.eval_pitchs[i], + use_pitch=True, + ) + else: + eval_audio = self._inference( + self.valid_dataloader.dataset.eval_mels[i] + ) + path = os.path.join( + self.checkpoint_dir, + "epoch-{:04d}_step-{:07d}_loss-{:.6f}_eval_audio_{}.wav".format( + self.epoch, + self.step, + valid_total_loss, + self.valid_dataloader.dataset.eval_dataset_names[i], + ), + ) + path_gt = os.path.join( + self.checkpoint_dir, + "epoch-{:04d}_step-{:07d}_loss-{:.6f}_eval_audio_{}_gt.wav".format( + self.epoch, + self.step, + valid_total_loss, + self.valid_dataloader.dataset.eval_dataset_names[i], + ), + ) + save_audio(path, eval_audio, self.cfg.preprocess.sample_rate) + save_audio( + path_gt, + self.valid_dataloader.dataset.eval_audios[i], + self.cfg.preprocess.sample_rate, + ) + + self.accelerator.wait_for_everyone() + + self.epoch += 1 + + # Finish training + self.accelerator.wait_for_everyone() + path = os.path.join( + self.checkpoint_dir, + "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( + self.epoch, self.step, valid_total_loss + ), + ) + self.accelerator.save_state(path) + + def _train_epoch(self): + """Training epoch. Should return average loss of a batch (sample) over + one epoch. See ``train_loop`` for usage. + """ + self.generator.train() + for key, _ in self.discriminators.items(): + self.discriminators[key].train() + + epoch_losses: dict = {} + epoch_total_loss: int = 0 + + for batch in tqdm( + self.train_dataloader, + desc=f"Training Epoch {self.epoch}", + unit="batch", + colour="GREEN", + leave=False, + dynamic_ncols=True, + smoothing=0.04, + disable=not self.accelerator.is_main_process, + ): + # Get losses + total_loss, losses = self._train_step(batch) + self.batch_count += 1 + + # Log info + if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: + self.accelerator.log( + { + "Step/Generator Learning Rate": self.generator_optimizer.param_groups[ + 0 + ][ + "lr" + ], + "Step/Discriminator Learning Rate": self.discriminator_optimizer.param_groups[ + 0 + ][ + "lr" + ], + }, + step=self.step, + ) + for key, _ in losses.items(): + self.accelerator.log( + { + "Step/Train {} Loss".format(key): losses[key], + }, + step=self.step, + ) + + if not epoch_losses: + epoch_losses = losses + else: + for key, value in losses.items(): + epoch_losses[key] += value + epoch_total_loss += total_loss + self.step += 1 + + # Get and log total losses + self.accelerator.wait_for_everyone() + epoch_total_loss = ( + epoch_total_loss + / len(self.train_dataloader) + * self.cfg.train.gradient_accumulation_step + ) + for key in epoch_losses.keys(): + epoch_losses[key] = ( + epoch_losses[key] + / len(self.train_dataloader) + * self.cfg.train.gradient_accumulation_step + ) + return epoch_total_loss, epoch_losses + + def _train_step(self, data): + """Training forward step. Should return average loss of a sample over + one batch. Provoke ``_forward_step`` is recommended except for special case. + See ``_train_epoch`` for usage. + """ + # Init losses + train_losses = {} + total_loss = 0 + + generator_losses = {} + generator_total_loss = 0 + discriminator_losses = {} + discriminator_total_loss = 0 + + # Use input feature to get predictions + mel_input = data["mel"] + audio_gt = data["audio"] + + if self.cfg.preprocess.extract_amplitude_phase: + logamp_gt = data["logamp"] + pha_gt = data["pha"] + rea_gt = data["rea"] + imag_gt = data["imag"] + + if self.cfg.preprocess.use_frame_pitch: + pitch_input = data["frame_pitch"] + + if self.cfg.preprocess.use_frame_pitch: + pitch_input = pitch_input.float() + audio_pred = self.generator.forward(mel_input, pitch_input) + elif self.cfg.preprocess.extract_amplitude_phase: + ( + logamp_pred, + pha_pred, + rea_pred, + imag_pred, + audio_pred, + ) = self.generator.forward(mel_input) + from utils.mel import amplitude_phase_spectrum + + _, _, rea_pred_final, imag_pred_final = amplitude_phase_spectrum( + audio_pred.squeeze(1), self.cfg.preprocess + ) + else: + audio_pred = self.generator.forward(mel_input) + + # Calculate and BP Discriminator losses + self.discriminator_optimizer.zero_grad() + for key, _ in self.discriminators.items(): + y_r, y_g, _, _ = self.discriminators[key].forward( + audio_gt.unsqueeze(1), audio_pred.detach() + ) + ( + discriminator_losses["{}_discriminator".format(key)], + _, + _, + ) = self.criterions["discriminator"](y_r, y_g) + discriminator_total_loss += discriminator_losses[ + "{}_discriminator".format(key) + ] + + self.accelerator.backward(discriminator_total_loss) + self.discriminator_optimizer.step() + + # Calculate and BP Generator losses + self.generator_optimizer.zero_grad() + for key, _ in self.discriminators.items(): + y_r, y_g, f_r, f_g = self.discriminators[key].forward( + audio_gt.unsqueeze(1), audio_pred + ) + generator_losses["{}_feature".format(key)] = self.criterions["feature"]( + f_r, f_g + ) + generator_losses["{}_generator".format(key)], _ = self.criterions[ + "generator" + ](y_g) + generator_total_loss += generator_losses["{}_feature".format(key)] + generator_total_loss += generator_losses["{}_generator".format(key)] + + if "mel" in self.criterions.keys(): + generator_losses["mel"] = self.criterions["mel"](audio_gt, audio_pred) + generator_total_loss += generator_losses["mel"] + + if "wav" in self.criterions.keys(): + generator_losses["wav"] = self.criterions["wav"](audio_gt, audio_pred) + generator_total_loss += generator_losses["wav"] + + if "amplitude" in self.criterions.keys(): + generator_losses["amplitude"] = self.criterions["amplitude"]( + logamp_gt, logamp_pred + ) + generator_total_loss += generator_losses["amplitude"] + + if "phase" in self.criterions.keys(): + generator_losses["phase"] = self.criterions["phase"](pha_gt, pha_pred) + generator_total_loss += generator_losses["phase"] + + if "consistency" in self.criterions.keys(): + generator_losses["consistency"] = self.criterions["consistency"]( + rea_gt, rea_pred, rea_pred_final, imag_gt, imag_pred, imag_pred_final + ) + generator_total_loss += generator_losses["consistency"] + + self.accelerator.backward(generator_total_loss) + self.generator_optimizer.step() + + # Get the total losses + total_loss = discriminator_total_loss + generator_total_loss + train_losses.update(discriminator_losses) + train_losses.update(generator_losses) + + for key, _ in train_losses.items(): + train_losses[key] = train_losses[key].item() + + return total_loss.item(), train_losses + + def _valid_epoch(self): + """Testing epoch. Should return average loss of a batch (sample) over + one epoch. See ``train_loop`` for usage. + """ + self.generator.eval() + for key, _ in self.discriminators.items(): + self.discriminators[key].eval() + + epoch_losses: dict = {} + epoch_total_loss: int = 0 + + for batch in tqdm( + self.valid_dataloader, + desc=f"Validating Epoch {self.epoch}", + unit="batch", + colour="GREEN", + leave=False, + dynamic_ncols=True, + smoothing=0.04, + disable=not self.accelerator.is_main_process, + ): + # Get losses + total_loss, losses = self._valid_step(batch) + + # Log info + for key, _ in losses.items(): + self.accelerator.log( + { + "Step/Valid {} Loss".format(key): losses[key], + }, + step=self.step, + ) + + if not epoch_losses: + epoch_losses = losses + else: + for key, value in losses.items(): + epoch_losses[key] += value + epoch_total_loss += total_loss + + # Get and log total losses + self.accelerator.wait_for_everyone() + epoch_total_loss = epoch_total_loss / len(self.valid_dataloader) + for key in epoch_losses.keys(): + epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader) + return epoch_total_loss, epoch_losses + + def _valid_step(self, data): + """Testing forward step. Should return average loss of a sample over + one batch. Provoke ``_forward_step`` is recommended except for special case. + See ``_test_epoch`` for usage. + """ + # Init losses + valid_losses = {} + total_loss = 0 + + generator_losses = {} + generator_total_loss = 0 + discriminator_losses = {} + discriminator_total_loss = 0 + + # Use feature inputs to get the predicted audio + mel_input = data["mel"] + audio_gt = data["audio"] + + if self.cfg.preprocess.extract_amplitude_phase: + logamp_gt = data["logamp"] + pha_gt = data["pha"] + rea_gt = data["rea"] + imag_gt = data["imag"] + + if self.cfg.preprocess.use_frame_pitch: + pitch_input = data["frame_pitch"] + + if self.cfg.preprocess.use_frame_pitch: + pitch_input = pitch_input.float() + audio_pred = self.generator.forward(mel_input, pitch_input) + elif self.cfg.preprocess.extract_amplitude_phase: + ( + logamp_pred, + pha_pred, + rea_pred, + imag_pred, + audio_pred, + ) = self.generator.forward(mel_input) + from utils.mel import amplitude_phase_spectrum + + _, _, rea_pred_final, imag_pred_final = amplitude_phase_spectrum( + audio_pred.squeeze(1), self.cfg.preprocess + ) + else: + audio_pred = self.generator.forward(mel_input) + + # Get Discriminator losses + for key, _ in self.discriminators.items(): + y_r, y_g, _, _ = self.discriminators[key].forward( + audio_gt.unsqueeze(1), audio_pred + ) + ( + discriminator_losses["{}_discriminator".format(key)], + _, + _, + ) = self.criterions["discriminator"](y_r, y_g) + discriminator_total_loss += discriminator_losses[ + "{}_discriminator".format(key) + ] + + for key, _ in self.discriminators.items(): + y_r, y_g, f_r, f_g = self.discriminators[key].forward( + audio_gt.unsqueeze(1), audio_pred + ) + generator_losses["{}_feature".format(key)] = self.criterions["feature"]( + f_r, f_g + ) + generator_losses["{}_generator".format(key)], _ = self.criterions[ + "generator" + ](y_g) + generator_total_loss += generator_losses["{}_feature".format(key)] + generator_total_loss += generator_losses["{}_generator".format(key)] + + if "mel" in self.criterions.keys(): + generator_losses["mel"] = self.criterions["mel"](audio_gt, audio_pred) + generator_total_loss += generator_losses["mel"] + if "mel" in self.criterions.keys(): + generator_losses["mel"] = self.criterions["mel"](audio_gt, audio_pred) + generator_total_loss += generator_losses["mel"] + + if "wav" in self.criterions.keys(): + generator_losses["wav"] = self.criterions["wav"](audio_gt, audio_pred) + generator_total_loss += generator_losses["wav"] + if "wav" in self.criterions.keys(): + generator_losses["wav"] = self.criterions["wav"](audio_gt, audio_pred) + generator_total_loss += generator_losses["wav"] + + if "amplitude" in self.criterions.keys(): + generator_losses["amplitude"] = self.criterions["amplitude"]( + logamp_gt, logamp_pred + ) + generator_total_loss += generator_losses["amplitude"] + + if "phase" in self.criterions.keys(): + generator_losses["phase"] = self.criterions["phase"](pha_gt, pha_pred) + generator_total_loss += generator_losses["phase"] + + if "consistency" in self.criterions.keys(): + generator_losses["consistency"] = self.criterions["consistency"]( + rea_gt, + rea_pred, + rea_pred_final, + imag_gt, + imag_pred, + imag_pred_final, + ) + generator_total_loss += generator_losses["consistency"] + + total_loss = discriminator_total_loss + generator_total_loss + valid_losses.update(discriminator_losses) + valid_losses.update(generator_losses) + + for item in valid_losses: + valid_losses[item] = valid_losses[item].item() + for item in valid_losses: + valid_losses[item] = valid_losses[item].item() + + return total_loss.item(), valid_losses + return total_loss.item(), valid_losses + + def _inference(self, eval_mel, eval_pitch=None, use_pitch=False): + """Inference during training for test audios.""" + if use_pitch: + eval_pitch = align_length(eval_pitch, eval_mel.shape[1]) + eval_audio = vocoder_inference( + self.cfg, + self.generator, + torch.from_numpy(eval_mel).unsqueeze(0), + f0s=torch.from_numpy(eval_pitch).unsqueeze(0).float(), + device=next(self.generator.parameters()).device, + ).squeeze(0) + else: + eval_audio = vocoder_inference( + self.cfg, + self.generator, + torch.from_numpy(eval_mel).unsqueeze(0), + device=next(self.generator.parameters()).device, + ).squeeze(0) + return eval_audio + + def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"): + """Load model from checkpoint. If checkpoint_path is None, it will + load the latest checkpoint in checkpoint_dir. If checkpoint_path is not + None, it will load the checkpoint specified by checkpoint_path. **Only use this + method after** ``accelerator.prepare()``. + """ + if checkpoint_path is None: + ls = [str(i) for i in Path(checkpoint_dir).glob("*")] + ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True) + checkpoint_path = ls[0] + if resume_type == "resume": + self.accelerator.load_state(checkpoint_path) + self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1 + self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1 + elif resume_type == "finetune": + accelerate.load_checkpoint_and_dispatch( + self.accelerator.unwrap_model(self.generator), + os.path.join(checkpoint_path, "pytorch_model.bin"), + ) + for key, _ in self.discriminators.items(): + accelerate.load_checkpoint_and_dispatch( + self.accelerator.unwrap_model(self.discriminators[key]), + os.path.join(checkpoint_path, "pytorch_model.bin"), + ) + self.logger.info("Load model weights for finetune SUCCESS!") + else: + raise ValueError("Unsupported resume type: {}".format(resume_type)) + return checkpoint_path + + def _count_parameters(self): + result = sum(p.numel() for p in self.generator.parameters()) + for _, discriminator in self.discriminators.items(): + result += sum(p.numel() for p in discriminator.parameters()) + return result diff --git a/models/vocoders/gan/generator/__init__.py b/models/vocoders/gan/generator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/gan/generator/apnet.py b/models/vocoders/gan/generator/apnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9d529bbda7dd89857df9c54f1e60e873a5c9fc48 --- /dev/null +++ b/models/vocoders/gan/generator/apnet.py @@ -0,0 +1,399 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, spectral_norm +from modules.vocoder_blocks import * + +LRELU_SLOPE = 0.1 + + +class ISTFT(nn.Module): + """ + Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with + windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. + See issue: https://github.com/pytorch/pytorch/issues/62323 + Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. + The NOLA constraint is met as we trim padded samples anyway. + + Args: + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames. + win_length (int): The size of window frame and STFT filter. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__( + self, + n_fft: int, + hop_length: int, + win_length: int, + padding: str = "same", + ): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + + def forward(self, spec: torch.Tensor, window) -> torch.Tensor: + """ + Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. + + Args: + spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, + N is the number of frequency bins, and T is the number of time frames. + + Returns: + Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. + """ + if self.padding == "center": + # Fallback to pytorch native implementation + return torch.istft( + spec, + self.n_fft, + self.hop_length, + self.win_length, + window, + center=True, + ) + elif self.padding == "same": + pad = (self.win_length - self.hop_length) // 2 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + assert spec.dim() == 3, "Expected a 3D tensor as input" + B, N, T = spec.shape + + # Inverse FFT + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") + ifft = ifft * window[None, :, None] + + # Overlap and Add + output_size = (T - 1) * self.hop_length + self.win_length + y = torch.nn.functional.fold( + ifft, + output_size=(1, output_size), + kernel_size=(1, self.win_length), + stride=(1, self.hop_length), + )[:, 0, 0, pad:-pad] + + # Window envelope + window_sq = window.square().expand(1, T, -1).transpose(1, 2) + window_envelope = torch.nn.functional.fold( + window_sq, + output_size=(1, output_size), + kernel_size=(1, self.win_length), + stride=(1, self.hop_length), + ).squeeze()[pad:-pad] + + # Normalize + assert (window_envelope > 1e-11).all() + y = y / window_envelope + + return y + + +# The ASP and PSP Module are adopted from APNet under the MIT License +# https://github.com/YangAi520/APNet/blob/main/models.py + + +class ASPResBlock(torch.nn.Module): + def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ASPResBlock, self).__init__() + self.cfg = cfg + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + +class PSPResBlock(torch.nn.Module): + def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3, 5)): + super(PSPResBlock, self).__init__() + self.cfg = cfg + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + +class APNet(torch.nn.Module): + def __init__(self, cfg): + super(APNet, self).__init__() + self.cfg = cfg + self.ASP_num_kernels = len(cfg.model.apnet.ASP_resblock_kernel_sizes) + self.PSP_num_kernels = len(cfg.model.apnet.PSP_resblock_kernel_sizes) + + self.ASP_input_conv = weight_norm( + Conv1d( + cfg.preprocess.n_mel, + cfg.model.apnet.ASP_channel, + cfg.model.apnet.ASP_input_conv_kernel_size, + 1, + padding=get_padding(cfg.model.apnet.ASP_input_conv_kernel_size, 1), + ) + ) + self.PSP_input_conv = weight_norm( + Conv1d( + cfg.preprocess.n_mel, + cfg.model.apnet.PSP_channel, + cfg.model.apnet.PSP_input_conv_kernel_size, + 1, + padding=get_padding(cfg.model.apnet.PSP_input_conv_kernel_size, 1), + ) + ) + + self.ASP_ResNet = nn.ModuleList() + for j, (k, d) in enumerate( + zip( + cfg.model.apnet.ASP_resblock_kernel_sizes, + cfg.model.apnet.ASP_resblock_dilation_sizes, + ) + ): + self.ASP_ResNet.append(ASPResBlock(cfg, cfg.model.apnet.ASP_channel, k, d)) + + self.PSP_ResNet = nn.ModuleList() + for j, (k, d) in enumerate( + zip( + cfg.model.apnet.PSP_resblock_kernel_sizes, + cfg.model.apnet.PSP_resblock_dilation_sizes, + ) + ): + self.PSP_ResNet.append(PSPResBlock(cfg, cfg.model.apnet.PSP_channel, k, d)) + + self.ASP_output_conv = weight_norm( + Conv1d( + cfg.model.apnet.ASP_channel, + cfg.preprocess.n_fft // 2 + 1, + cfg.model.apnet.ASP_output_conv_kernel_size, + 1, + padding=get_padding(cfg.model.apnet.ASP_output_conv_kernel_size, 1), + ) + ) + self.PSP_output_R_conv = weight_norm( + Conv1d( + cfg.model.apnet.PSP_channel, + cfg.preprocess.n_fft // 2 + 1, + cfg.model.apnet.PSP_output_R_conv_kernel_size, + 1, + padding=get_padding(cfg.model.apnet.PSP_output_R_conv_kernel_size, 1), + ) + ) + self.PSP_output_I_conv = weight_norm( + Conv1d( + cfg.model.apnet.PSP_channel, + cfg.preprocess.n_fft // 2 + 1, + cfg.model.apnet.PSP_output_I_conv_kernel_size, + 1, + padding=get_padding(cfg.model.apnet.PSP_output_I_conv_kernel_size, 1), + ) + ) + + self.iSTFT = ISTFT( + self.cfg.preprocess.n_fft, + hop_length=self.cfg.preprocess.hop_size, + win_length=self.cfg.preprocess.win_size, + ) + + self.ASP_output_conv.apply(init_weights) + self.PSP_output_R_conv.apply(init_weights) + self.PSP_output_I_conv.apply(init_weights) + + def forward(self, mel): + logamp = self.ASP_input_conv(mel) + logamps = None + for j in range(self.ASP_num_kernels): + if logamps is None: + logamps = self.ASP_ResNet[j](logamp) + else: + logamps += self.ASP_ResNet[j](logamp) + logamp = logamps / self.ASP_num_kernels + logamp = F.leaky_relu(logamp) + logamp = self.ASP_output_conv(logamp) + + pha = self.PSP_input_conv(mel) + phas = None + for j in range(self.PSP_num_kernels): + if phas is None: + phas = self.PSP_ResNet[j](pha) + else: + phas += self.PSP_ResNet[j](pha) + pha = phas / self.PSP_num_kernels + pha = F.leaky_relu(pha) + R = self.PSP_output_R_conv(pha) + I = self.PSP_output_I_conv(pha) + + pha = torch.atan2(I, R) + + rea = torch.exp(logamp) * torch.cos(pha) + imag = torch.exp(logamp) * torch.sin(pha) + + spec = torch.cat((rea.unsqueeze(-1), imag.unsqueeze(-1)), -1) + + spec = torch.view_as_complex(spec) + + audio = self.iSTFT.forward( + spec, torch.hann_window(self.cfg.preprocess.win_size).to(mel.device) + ) + + return logamp, pha, rea, imag, audio.unsqueeze(1) diff --git a/models/vocoders/gan/generator/bigvgan.py b/models/vocoders/gan/generator/bigvgan.py new file mode 100644 index 0000000000000000000000000000000000000000..c7658d31d59efe613aee3e7bf8089e91c8f484af --- /dev/null +++ b/models/vocoders/gan/generator/bigvgan.py @@ -0,0 +1,341 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +import torch.nn as nn + +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm + +from modules.vocoder_blocks import * +from modules.activation_functions import * +from modules.anti_aliasing import * + +LRELU_SLOPE = 0.1 + +# The AMPBlock Module is adopted from BigVGAN under the MIT License +# https://github.com/NVIDIA/BigVGAN + + +class AMPBlock1(torch.nn.Module): + def __init__( + self, cfg, channels, kernel_size=3, dilation=(1, 3, 5), activation=None + ): + super(AMPBlock1, self).__init__() + self.cfg = cfg + + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + self.num_layers = len(self.convs1) + len( + self.convs2 + ) # total number of conv layers + + if ( + activation == "snake" + ): # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList( + [ + Activation1d( + activation=Snake( + channels, alpha_logscale=cfg.model.bigvgan.snake_logscale + ) + ) + for _ in range(self.num_layers) + ] + ) + elif ( + activation == "snakebeta" + ): # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList( + [ + Activation1d( + activation=SnakeBeta( + channels, alpha_logscale=cfg.model.bigvgan.snake_logscale + ) + ) + for _ in range(self.num_layers) + ] + ) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class AMPBlock2(torch.nn.Module): + def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3), activation=None): + super(AMPBlock2, self).__init__() + self.cfg = cfg + + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + self.num_layers = len(self.convs) # total number of conv layers + + if ( + activation == "snake" + ): # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList( + [ + Activation1d( + activation=Snake( + channels, alpha_logscale=cfg.model.bigvgan.snake_logscale + ) + ) + for _ in range(self.num_layers) + ] + ) + elif ( + activation == "snakebeta" + ): # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList( + [ + Activation1d( + activation=SnakeBeta( + channels, alpha_logscale=cfg.model.bigvgan.snake_logscale + ) + ) + for _ in range(self.num_layers) + ] + ) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + for c, a in zip(self.convs, self.activations): + xt = a(x) + xt = c(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class BigVGAN(torch.nn.Module): + def __init__(self, cfg): + super(BigVGAN, self).__init__() + self.cfg = cfg + + self.num_kernels = len(cfg.model.bigvgan.resblock_kernel_sizes) + self.num_upsamples = len(cfg.model.bigvgan.upsample_rates) + + # Conv pre to boost channels + self.conv_pre = weight_norm( + Conv1d( + cfg.preprocess.n_mel, + cfg.model.bigvgan.upsample_initial_channel, + 7, + 1, + padding=3, + ) + ) + + resblock = AMPBlock1 if cfg.model.bigvgan.resblock == "1" else AMPBlock2 + + # Upsamplers + self.ups = nn.ModuleList() + for i, (u, k) in enumerate( + zip( + cfg.model.bigvgan.upsample_rates, + cfg.model.bigvgan.upsample_kernel_sizes, + ) + ): + self.ups.append( + nn.ModuleList( + [ + weight_norm( + ConvTranspose1d( + cfg.model.bigvgan.upsample_initial_channel // (2**i), + cfg.model.bigvgan.upsample_initial_channel + // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ] + ) + ) + + # Res Blocks with AMP and Anti-aliasing + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = cfg.model.bigvgan.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip( + cfg.model.bigvgan.resblock_kernel_sizes, + cfg.model.bigvgan.resblock_dilation_sizes, + ) + ): + self.resblocks.append( + resblock(cfg, ch, k, d, activation=cfg.model.bigvgan.activation) + ) + + # Conv post for result + if cfg.model.bigvgan.activation == "snake": + activation_post = Snake(ch, alpha_logscale=cfg.model.bigvgan.snake_logscale) + self.activation_post = Activation1d(activation=activation_post) + elif cfg.model.bigvgan.activation == "snakebeta": + activation_post = SnakeBeta( + ch, alpha_logscale=cfg.model.bigvgan.snake_logscale + ) + self.activation_post = Activation1d(activation=activation_post) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + + # Weight Norm + for i in range(len(self.ups)): + self.ups[i].apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + for i_up in range(len(self.ups[i])): + x = self.ups[i][i_up](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + x = self.activation_post(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + for l_i in l: + remove_weight_norm(l_i) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) diff --git a/models/vocoders/gan/generator/hifigan.py b/models/vocoders/gan/generator/hifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..9f5f32498f5eb6441db787b0ae204a1eeff36aa3 --- /dev/null +++ b/models/vocoders/gan/generator/hifigan.py @@ -0,0 +1,449 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm +from modules.vocoder_blocks import * + + +LRELU_SLOPE = 0.1 + + +class ResBlock1(torch.nn.Module): + def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.cfg = cfg + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.cfg = cfg + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class HiFiGAN(torch.nn.Module): + def __init__(self, cfg): + super(HiFiGAN, self).__init__() + self.cfg = cfg + self.num_kernels = len(self.cfg.model.hifigan.resblock_kernel_sizes) + self.num_upsamples = len(self.cfg.model.hifigan.upsample_rates) + self.conv_pre = weight_norm( + Conv1d( + cfg.preprocess.n_mel, + self.cfg.model.hifigan.upsample_initial_channel, + 7, + 1, + padding=3, + ) + ) + resblock = ResBlock1 if self.cfg.model.hifigan.resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate( + zip( + self.cfg.model.hifigan.upsample_rates, + self.cfg.model.hifigan.upsample_kernel_sizes, + ) + ): + self.ups.append( + weight_norm( + ConvTranspose1d( + self.cfg.model.hifigan.upsample_initial_channel // (2**i), + self.cfg.model.hifigan.upsample_initial_channel + // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = self.cfg.model.hifigan.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip( + self.cfg.model.hifigan.resblock_kernel_sizes, + self.cfg.model.hifigan.resblock_dilation_sizes, + ) + ): + self.resblocks.append(resblock(self.cfg, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +# todo: merge with ResBlock1 (lmxue, yicheng) +class ResBlock1_vits(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1_vits, self).__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x, x_mask=None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c2(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +# todo: merge with ResBlock2 (lmxue, yicheng) +class ResBlock2_vits(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2_vits, self).__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +# todo: merge with HiFiGAN (lmxue, yicheng) +class HiFiGAN_vits(torch.nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=0, + ): + super(HiFiGAN_vits, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d( + initial_channel, upsample_initial_channel, 7, 1, padding=3 + ) + resblock = ResBlock1_vits if resblock == "1" else ResBlock2_vits + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() diff --git a/models/vocoders/gan/generator/melgan.py b/models/vocoders/gan/generator/melgan.py new file mode 100644 index 0000000000000000000000000000000000000000..d13c5fe6c0844ff7d753ed14324a92b34cc7798c --- /dev/null +++ b/models/vocoders/gan/generator/melgan.py @@ -0,0 +1,104 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from torch.nn.utils import weight_norm + +# This code is adopted from MelGAN under the MIT License +# https://github.com/descriptinc/melgan-neurips + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find("BatchNorm2d") != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +class ResnetBlock(nn.Module): + def __init__(self, dim, dilation=1): + super().__init__() + self.block = nn.Sequential( + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(dilation), + WNConv1d(dim, dim, kernel_size=3, dilation=dilation), + nn.LeakyReLU(0.2), + WNConv1d(dim, dim, kernel_size=1), + ) + self.shortcut = WNConv1d(dim, dim, kernel_size=1) + + def forward(self, x): + return self.shortcut(x) + self.block(x) + + +class MelGAN(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.cfg = cfg + + self.hop_length = np.prod(self.cfg.model.melgan.ratios) + mult = int(2 ** len(self.cfg.model.melgan.ratios)) + + model = [ + nn.ReflectionPad1d(3), + WNConv1d( + self.cfg.preprocess.n_mel, + mult * self.cfg.model.melgan.ngf, + kernel_size=7, + padding=0, + ), + ] + + # Upsample to raw audio scale + for i, r in enumerate(self.cfg.model.melgan.ratios): + model += [ + nn.LeakyReLU(0.2), + WNConvTranspose1d( + mult * self.cfg.model.melgan.ngf, + mult * self.cfg.model.melgan.ngf // 2, + kernel_size=r * 2, + stride=r, + padding=r // 2 + r % 2, + output_padding=r % 2, + ), + ] + + for j in range(self.cfg.model.melgan.n_residual_layers): + model += [ + ResnetBlock(mult * self.cfg.model.melgan.ngf // 2, dilation=3**j) + ] + + mult //= 2 + + model += [ + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(3), + WNConv1d(self.cfg.model.melgan.ngf, 1, kernel_size=7, padding=0), + nn.Tanh(), + ] + + self.model = nn.Sequential(*model) + self.apply(weights_init) + + def forward(self, x): + return self.model(x) diff --git a/models/vocoders/gan/generator/nsfhifigan.py b/models/vocoders/gan/generator/nsfhifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..2db7f6d88b09525decc444a14115e6a63e548485 --- /dev/null +++ b/models/vocoders/gan/generator/nsfhifigan.py @@ -0,0 +1,283 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm + +from modules.neural_source_filter import * +from modules.vocoder_blocks import * + + +LRELU_SLOPE = 0.1 + + +class ResBlock1(nn.Module): + def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.cfg = cfg + + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(nn.Module): + def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock1, self).__init__() + self.cfg = cfg + + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +# This NSF Module is adopted from Xin Wang's NSF under the MIT License +# https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts + + +class SourceModuleHnNSF(nn.Module): + def __init__( + self, fs, harmonic_num=0, amp=0.1, noise_std=0.003, voiced_threshold=0 + ): + super(SourceModuleHnNSF, self).__init__() + + self.amp = amp + self.noise_std = noise_std + self.l_sin_gen = SineGen(fs, harmonic_num, amp, noise_std, voiced_threshold) + + self.l_linear = nn.Linear(harmonic_num + 1, 1) + self.l_tanh = nn.Tanh() + + def forward(self, x, upp): + sine_wavs, uv, _ = self.l_sin_gen(x, upp) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + return sine_merge + + +class NSFHiFiGAN(nn.Module): + def __init__(self, cfg): + super(NSFHiFiGAN, self).__init__() + + self.cfg = cfg + self.num_kernels = len(self.cfg.model.nsfhifigan.resblock_kernel_sizes) + self.num_upsamples = len(self.cfg.model.nsfhifigan.upsample_rates) + self.m_source = SourceModuleHnNSF( + fs=self.cfg.preprocess.sample_rate, + harmonic_num=self.cfg.model.nsfhifigan.harmonic_num, + ) + self.noise_convs = nn.ModuleList() + self.conv_pre = weight_norm( + Conv1d( + self.cfg.preprocess.n_mel, + self.cfg.model.nsfhifigan.upsample_initial_channel, + 7, + 1, + padding=3, + ) + ) + + resblock = ResBlock1 if self.cfg.model.nsfhifigan.resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate( + zip( + self.cfg.model.nsfhifigan.upsample_rates, + self.cfg.model.nsfhifigan.upsample_kernel_sizes, + ) + ): + c_cur = self.cfg.model.nsfhifigan.upsample_initial_channel // (2 ** (i + 1)) + self.ups.append( + weight_norm( + ConvTranspose1d( + self.cfg.model.nsfhifigan.upsample_initial_channel // (2**i), + self.cfg.model.nsfhifigan.upsample_initial_channel + // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + if i + 1 < len(self.cfg.model.nsfhifigan.upsample_rates): + stride_f0 = int( + np.prod(self.cfg.model.nsfhifigan.upsample_rates[i + 1 :]) + ) + self.noise_convs.append( + Conv1d( + 1, + c_cur, + kernel_size=stride_f0 * 2, + stride=stride_f0, + padding=stride_f0 // 2, + ) + ) + else: + self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) + + self.resblocks = nn.ModuleList() + ch = self.cfg.model.nsfhifigan.upsample_initial_channel + for i in range(len(self.ups)): + ch //= 2 + for j, (k, d) in enumerate( + zip( + self.cfg.model.nsfhifigan.resblock_kernel_sizes, + self.cfg.model.nsfhifigan.resblock_dilation_sizes, + ) + ): + self.resblocks.append(resblock(cfg, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.upp = int(np.prod(self.cfg.model.nsfhifigan.upsample_rates)) + + def forward(self, x, f0): + har_source = self.m_source(f0, self.upp).transpose(1, 2) + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + x_source = self.noise_convs[i](har_source) + + length = min(x.shape[-1], x_source.shape[-1]) + x = x[:, :, :length] + x_source = x[:, :, :length] + + x = x + x_source + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x diff --git a/models/vocoders/gan/generator/sifigan.py b/models/vocoders/gan/generator/sifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/vocoder_dataset.py b/models/vocoders/vocoder_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7df17b97ba7a4f770f01971324126eca4a2db272 --- /dev/null +++ b/models/vocoders/vocoder_dataset.py @@ -0,0 +1,264 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Iterable +import torch +import numpy as np +import torch.utils.data +from torch.nn.utils.rnn import pad_sequence +from utils.data_utils import * +from torch.utils.data import ConcatDataset, Dataset + + +class VocoderDataset(torch.utils.data.Dataset): + def __init__(self, cfg, dataset, is_valid=False): + """ + Args: + cfg: config + dataset: dataset name + is_valid: whether to use train or valid dataset + """ + assert isinstance(dataset, str) + + processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset) + + meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file + self.metafile_path = os.path.join(processed_data_dir, meta_file) + self.metadata = self.get_metadata() + + self.data_root = processed_data_dir + self.cfg = cfg + + if cfg.preprocess.use_audio: + self.utt2audio_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2audio_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.audio_dir, + uid + ".npy", + ) + elif cfg.preprocess.use_label: + self.utt2label_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2label_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.label_dir, + uid + ".npy", + ) + elif cfg.preprocess.use_one_hot: + self.utt2one_hot_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2one_hot_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.one_hot_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_mel: + self.utt2mel_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2mel_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.mel_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_frame_pitch: + self.utt2frame_pitch_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2frame_pitch_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.pitch_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_uv: + self.utt2uv_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + self.utt2uv_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.uv_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_amplitude_phase: + self.utt2logamp_path = {} + self.utt2pha_path = {} + self.utt2rea_path = {} + self.utt2imag_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + self.utt2logamp_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.log_amplitude_dir, + uid + ".npy", + ) + self.utt2pha_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.phase_dir, + uid + ".npy", + ) + self.utt2rea_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.real_dir, + uid + ".npy", + ) + self.utt2imag_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.imaginary_dir, + uid + ".npy", + ) + + def __getitem__(self, index): + utt_info = self.metadata[index] + + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + single_feature = dict() + + if self.cfg.preprocess.use_mel: + mel = np.load(self.utt2mel_path[utt]) + assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T] + + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = mel.shape[1] + + single_feature["mel"] = mel + + if self.cfg.preprocess.use_frame_pitch: + frame_pitch = np.load(self.utt2frame_pitch_path[utt]) + + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = len(frame_pitch) + + aligned_frame_pitch = align_length( + frame_pitch, single_feature["target_len"] + ) + + single_feature["frame_pitch"] = aligned_frame_pitch + + if self.cfg.preprocess.use_audio: + audio = np.load(self.utt2audio_path[utt]) + + single_feature["audio"] = audio + + return single_feature + + def get_metadata(self): + with open(self.metafile_path, "r", encoding="utf-8") as f: + metadata = json.load(f) + + return metadata + + def get_dataset_name(self): + return self.metadata[0]["Dataset"] + + def __len__(self): + return len(self.metadata) + + +class VocoderConcatDataset(ConcatDataset): + def __init__(self, datasets: Iterable[Dataset], full_audio_inference=False): + """Concatenate a series of datasets with their random inference audio merged.""" + super().__init__(datasets) + + self.cfg = self.datasets[0].cfg + + self.metadata = [] + + # Merge metadata + for dataset in self.datasets: + self.metadata += dataset.metadata + + # Merge random inference features + if full_audio_inference: + self.eval_audios = [] + self.eval_dataset_names = [] + if self.cfg.preprocess.use_mel: + self.eval_mels = [] + if self.cfg.preprocess.use_frame_pitch: + self.eval_pitchs = [] + for dataset in self.datasets: + self.eval_audios.append(dataset.eval_audio) + self.eval_dataset_names.append(dataset.get_dataset_name()) + if self.cfg.preprocess.use_mel: + self.eval_mels.append(dataset.eval_mel) + if self.cfg.preprocess.use_frame_pitch: + self.eval_pitchs.append(dataset.eval_pitch) + + +class VocoderCollator(object): + """Zero-pads model inputs and targets based on number of frames per step""" + + def __init__(self, cfg): + self.cfg = cfg + + def __call__(self, batch): + packed_batch_features = dict() + + # mel: [b, n_mels, frame] + # frame_pitch: [b, frame] + # audios: [b, frame * hop_size] + + for key in batch[0].keys(): + if key == "target_len": + packed_batch_features["target_len"] = torch.LongTensor( + [b["target_len"] for b in batch] + ) + masks = [ + torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch + ] + packed_batch_features["mask"] = pad_sequence( + masks, batch_first=True, padding_value=0 + ) + elif key == "mel": + values = [torch.from_numpy(b[key]).T for b in batch] + packed_batch_features[key] = pad_sequence( + values, batch_first=True, padding_value=0 + ) + else: + values = [torch.from_numpy(b[key]) for b in batch] + packed_batch_features[key] = pad_sequence( + values, batch_first=True, padding_value=0 + ) + + return packed_batch_features diff --git a/models/vocoders/vocoder_inference.py b/models/vocoders/vocoder_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..dfd09ee6aa44c544c51a62c0e014cca4260cc6a8 --- /dev/null +++ b/models/vocoders/vocoder_inference.py @@ -0,0 +1,488 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import torch +import json +import json5 +import time +import accelerate +import random +import numpy as np +import shutil + +from pathlib import Path +from tqdm import tqdm +from glob import glob +from accelerate.logging import get_logger +from torch.utils.data import DataLoader + +from models.vocoders.vocoder_dataset import ( + VocoderDataset, + VocoderCollator, + VocoderConcatDataset, +) + +from models.vocoders.gan.generator import bigvgan, hifigan, melgan, nsfhifigan, apnet +from models.vocoders.flow.waveglow import waveglow +from models.vocoders.diffusion.diffwave import diffwave +from models.vocoders.autoregressive.wavenet import wavenet +from models.vocoders.autoregressive.wavernn import wavernn +from models.vocoders.gan import gan_vocoder_inference +from utils.io import save_audio + +_vocoders = { + "diffwave": diffwave.DiffWave, + "wavernn": wavernn.WaveRNN, + "wavenet": wavenet.WaveNet, + "waveglow": waveglow.WaveGlow, + "nsfhifigan": nsfhifigan.NSFHiFiGAN, + "bigvgan": bigvgan.BigVGAN, + "hifigan": hifigan.HiFiGAN, + "melgan": melgan.MelGAN, + "apnet": apnet.APNet, +} + +_vocoder_infer_funcs = { + # "world": world_inference.synthesis_audios, + # "wavernn": wavernn_inference.synthesis_audios, + # "wavenet": wavenet_inference.synthesis_audios, + # "diffwave": diffwave_inference.synthesis_audios, + "nsfhifigan": gan_vocoder_inference.synthesis_audios, + "bigvgan": gan_vocoder_inference.synthesis_audios, + "melgan": gan_vocoder_inference.synthesis_audios, + "hifigan": gan_vocoder_inference.synthesis_audios, + "apnet": gan_vocoder_inference.synthesis_audios, +} + + +class VocoderInference(object): + def __init__(self, args=None, cfg=None, infer_type="from_dataset"): + super().__init__() + + start = time.monotonic_ns() + self.args = args + self.cfg = cfg + self.infer_type = infer_type + + # Init accelerator + self.accelerator = accelerate.Accelerator() + self.accelerator.wait_for_everyone() + + # Get logger + with self.accelerator.main_process_first(): + self.logger = get_logger("inference", log_level=args.log_level) + + # Log some info + self.logger.info("=" * 56) + self.logger.info("||\t\t" + "New inference process started." + "\t\t||") + self.logger.info("=" * 56) + self.logger.info("\n") + + self.vocoder_dir = args.vocoder_dir + self.logger.debug(f"Vocoder dir: {args.vocoder_dir}") + + os.makedirs(args.output_dir, exist_ok=True) + if os.path.exists(os.path.join(args.output_dir, "pred")): + shutil.rmtree(os.path.join(args.output_dir, "pred")) + if os.path.exists(os.path.join(args.output_dir, "gt")): + shutil.rmtree(os.path.join(args.output_dir, "gt")) + os.makedirs(os.path.join(args.output_dir, "pred"), exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "gt"), exist_ok=True) + + # Set random seed + with self.accelerator.main_process_first(): + start = time.monotonic_ns() + self._set_random_seed(self.cfg.train.random_seed) + end = time.monotonic_ns() + self.logger.debug( + f"Setting random seed done in {(end - start) / 1e6:.2f}ms" + ) + self.logger.debug(f"Random seed: {self.cfg.train.random_seed}") + + # Setup inference mode + if self.infer_type == "infer_from_dataset": + self.cfg.dataset = self.args.infer_datasets + elif self.infer_type == "infer_from_feature": + self._build_tmp_dataset_from_feature() + self.cfg.dataset = ["tmp"] + elif self.infer_type == "infer_from_audio": + self._build_tmp_dataset_from_audio() + self.cfg.dataset = ["tmp"] + + # Setup data loader + with self.accelerator.main_process_first(): + self.logger.info("Building dataset...") + start = time.monotonic_ns() + self.test_dataloader = self._build_dataloader() + end = time.monotonic_ns() + self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms") + + # Build model + with self.accelerator.main_process_first(): + self.logger.info("Building model...") + start = time.monotonic_ns() + self.model = self._build_model() + end = time.monotonic_ns() + self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms") + + # Init with accelerate + self.logger.info("Initializing accelerate...") + start = time.monotonic_ns() + self.accelerator = accelerate.Accelerator() + (self.model, self.test_dataloader) = self.accelerator.prepare( + self.model, self.test_dataloader + ) + end = time.monotonic_ns() + self.accelerator.wait_for_everyone() + self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms") + + with self.accelerator.main_process_first(): + self.logger.info("Loading checkpoint...") + start = time.monotonic_ns() + if os.path.isdir(args.vocoder_dir): + if os.path.isdir(os.path.join(args.vocoder_dir, "checkpoint")): + self._load_model(os.path.join(args.vocoder_dir, "checkpoint")) + else: + self._load_model(os.path.join(args.vocoder_dir)) + else: + self._load_model(os.path.join(args.vocoder_dir)) + end = time.monotonic_ns() + self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms") + + self.model.eval() + self.accelerator.wait_for_everyone() + + def _build_tmp_dataset_from_feature(self): + if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")): + shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) + + utts = [] + mels = glob(os.path.join(self.args.feature_folder, "mels", "*.npy")) + for i, mel in enumerate(mels): + uid = mel.split("/")[-1].split(".")[0] + utt = {"Dataset": "tmp", "Uid": uid, "index": i} + utts.append(utt) + + os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) + with open( + os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w" + ) as f: + json.dump(utts, f) + + meta_info = {"dataset": "tmp", "test": {"size": len(utts)}} + + with open( + os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"), + "w", + ) as f: + json.dump(meta_info, f) + + features = glob(os.path.join(self.args.feature_folder, "*")) + for feature in features: + feature_name = feature.split("/")[-1] + if os.path.isfile(feature): + continue + shutil.copytree( + os.path.join(self.args.feature_folder, feature_name), + os.path.join(self.cfg.preprocess.processed_dir, "tmp", feature_name), + ) + + def _build_tmp_dataset_from_audio(self): + if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")): + shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) + + utts = [] + audios = glob(os.path.join(self.args.audio_folder, "*")) + for i, audio in enumerate(audios): + uid = audio.split("/")[-1].split(".")[0] + utt = {"Dataset": "tmp", "Uid": uid, "index": i, "Path": audio} + utts.append(utt) + + os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) + with open( + os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w" + ) as f: + json.dump(utts, f) + + meta_info = {"dataset": "tmp", "test": {"size": len(utts)}} + + with open( + os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"), + "w", + ) as f: + json.dump(meta_info, f) + + from processors import acoustic_extractor + + acoustic_extractor.extract_utt_acoustic_features_serial( + utts, os.path.join(self.cfg.preprocess.processed_dir, "tmp"), self.cfg + ) + + def _build_test_dataset(self): + return VocoderDataset, VocoderCollator + + def _build_model(self): + model = _vocoders[self.cfg.model.generator](self.cfg) + return model + + def _build_dataloader(self): + """Build dataloader which merges a series of datasets.""" + Dataset, Collator = self._build_test_dataset() + + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=True) + datasets_list.append(subdataset) + test_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=False) + test_collate = Collator(self.cfg) + test_batch_size = min(self.cfg.inference.batch_size, len(test_dataset)) + test_dataloader = DataLoader( + test_dataset, + collate_fn=test_collate, + num_workers=1, + batch_size=test_batch_size, + shuffle=False, + ) + self.test_batch_size = test_batch_size + self.test_dataset = test_dataset + return test_dataloader + + def _load_model(self, checkpoint_dir, from_multi_gpu=False): + """Load model from checkpoint. If a folder is given, it will + load the latest checkpoint in checkpoint_dir. If a path is given + it will load the checkpoint specified by checkpoint_path. + **Only use this method after** ``accelerator.prepare()``. + """ + if os.path.isdir(checkpoint_dir): + if "epoch" in checkpoint_dir and "step" in checkpoint_dir: + checkpoint_path = checkpoint_dir + else: + # Load the latest accelerator state dicts + ls = [ + str(i) + for i in Path(checkpoint_dir).glob("*") + if not "audio" in str(i) + ] + ls.sort( + key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True + ) + checkpoint_path = ls[0] + accelerate.load_checkpoint_and_dispatch( + self.accelerator.unwrap_model(self.model), + os.path.join(checkpoint_path, "pytorch_model.bin"), + ) + return str(checkpoint_path) + else: + # Load old .pt checkpoints + if self.cfg.model.generator in [ + "bigvgan", + "hifigan", + "melgan", + "nsfhifigan", + ]: + ckpt = torch.load( + checkpoint_dir, + map_location=torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu"), + ) + if from_multi_gpu: + pretrained_generator_dict = ckpt["generator_state_dict"] + generator_dict = self.model.state_dict() + + new_generator_dict = { + k.split("module.")[-1]: v + for k, v in pretrained_generator_dict.items() + if ( + k.split("module.")[-1] in generator_dict + and v.shape == generator_dict[k.split("module.")[-1]].shape + ) + } + + generator_dict.update(new_generator_dict) + + self.model.load_state_dict(generator_dict) + else: + self.model.load_state_dict(ckpt["generator_state_dict"]) + else: + self.model.load_state_dict(torch.load(checkpoint_dir)["state_dict"]) + return str(checkpoint_dir) + + def inference(self): + """Inference via batches""" + for i, batch in tqdm(enumerate(self.test_dataloader)): + if self.cfg.preprocess.use_frame_pitch: + audio_pred = self.model.forward( + batch["mel"].transpose(-1, -2), batch["frame_pitch"].float() + ).cpu() + elif self.cfg.preprocess.extract_amplitude_phase: + audio_pred = self.model.forward(batch["mel"].transpose(-1, -2))[-1] + else: + audio_pred = ( + self.model.forward(batch["mel"].transpose(-1, -2)).detach().cpu() + ) + audio_ls = audio_pred.chunk(self.test_batch_size) + audio_gt_ls = batch["audio"].cpu().chunk(self.test_batch_size) + length_ls = batch["target_len"].cpu().chunk(self.test_batch_size) + j = 0 + for it, it_gt, l in zip(audio_ls, audio_gt_ls, length_ls): + l = l.item() + it = it.squeeze(0).squeeze(0)[: l * self.cfg.preprocess.hop_size] + it_gt = it_gt.squeeze(0)[: l * self.cfg.preprocess.hop_size] + uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"] + save_audio( + os.path.join(self.args.output_dir, "pred", "{}.wav").format(uid), + it, + self.cfg.preprocess.sample_rate, + ) + save_audio( + os.path.join(self.args.output_dir, "gt", "{}.wav").format(uid), + it_gt, + self.cfg.preprocess.sample_rate, + ) + j += 1 + + if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")): + shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) + + def _set_random_seed(self, seed): + """Set random seed for all possible random modules.""" + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + + def _count_parameters(self, model): + return sum(p.numel() for p in model.parameters()) + + def _dump_cfg(self, path): + os.makedirs(os.path.dirname(path), exist_ok=True) + json5.dump( + self.cfg, + open(path, "w"), + indent=4, + sort_keys=True, + ensure_ascii=False, + quote_keys=True, + ) + + +def load_nnvocoder( + cfg, + vocoder_name, + weights_file, + from_multi_gpu=False, +): + """Load the specified vocoder. + cfg: the vocoder config filer. + weights_file: a folder or a .pt path. + from_multi_gpu: automatically remove the "module" string in state dicts if "True". + """ + print("Loading Vocoder from Weights file: {}".format(weights_file)) + + # Build model + model = _vocoders[vocoder_name](cfg) + if not os.path.isdir(weights_file): + # Load from .pt file + if vocoder_name in ["bigvgan", "hifigan", "melgan", "nsfhifigan"]: + ckpt = torch.load( + weights_file, + map_location=torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu"), + ) + if from_multi_gpu: + pretrained_generator_dict = ckpt["generator_state_dict"] + generator_dict = model.state_dict() + + new_generator_dict = { + k.split("module.")[-1]: v + for k, v in pretrained_generator_dict.items() + if ( + k.split("module.")[-1] in generator_dict + and v.shape == generator_dict[k.split("module.")[-1]].shape + ) + } + + generator_dict.update(new_generator_dict) + + model.load_state_dict(generator_dict) + else: + model.load_state_dict(ckpt["generator_state_dict"]) + else: + model.load_state_dict(torch.load(weights_file)["state_dict"]) + else: + # Load from accelerator state dict + weights_file = os.path.join(weights_file, "checkpoint") + ls = [str(i) for i in Path(weights_file).glob("*") if not "audio" in str(i)] + ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True) + checkpoint_path = ls[0] + accelerator = accelerate.Accelerator() + model = accelerator.prepare(model) + accelerator.load_state(checkpoint_path) + + if torch.cuda.is_available(): + model = model.cuda() + + model = model.eval() + return model + + +def tensorize(data, device, n_samples): + """ + data: a list of numpy array + """ + assert type(data) == list + if n_samples: + data = data[:n_samples] + data = [torch.as_tensor(x, device=device) for x in data] + return data + + +def synthesis( + cfg, + vocoder_weight_file, + n_samples, + pred, + f0s=None, + batch_size=64, + fast_inference=False, +): + """Synthesis audios from a given vocoder and series of given features. + cfg: vocoder config. + vocoder_weight_file: a folder of accelerator state dict or a path to the .pt file. + pred: a list of numpy arrays. [(seq_len1, acoustic_features_dim), (seq_len2, acoustic_features_dim), ...] + """ + + vocoder_name = cfg.model.generator + + print("Synthesis audios using {} vocoder...".format(vocoder_name)) + + ###### TODO: World Vocoder Refactor ###### + # if vocoder_name == "world": + # world_inference.synthesis_audios( + # cfg, dataset_name, split, n_samples, pred, save_dir, tag + # ) + # return + + # ====== Loading neural vocoder model ====== + vocoder = load_nnvocoder( + cfg, vocoder_name, weights_file=vocoder_weight_file, from_multi_gpu=True + ) + device = next(vocoder.parameters()).device + + # ====== Inference for predicted acoustic features ====== + # pred: (frame_len, n_mels) -> (n_mels, frame_len) + mels_pred = tensorize([p.T for p in pred], device, n_samples) + print("For predicted mels, #sample = {}...".format(len(mels_pred))) + audios_pred = _vocoder_infer_funcs[vocoder_name]( + cfg, + vocoder, + mels_pred, + f0s=f0s, + batch_size=batch_size, + fast_inference=fast_inference, + ) + return audios_pred diff --git a/models/vocoders/vocoder_sampler.py b/models/vocoders/vocoder_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..9d29f88a291dcf7386cadaeae0d990c8e76ebf98 --- /dev/null +++ b/models/vocoders/vocoder_sampler.py @@ -0,0 +1,126 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import random + +from torch.utils.data import ConcatDataset, Dataset +from torch.utils.data.sampler import ( + BatchSampler, + RandomSampler, + Sampler, + SequentialSampler, +) + + +class ScheduledSampler(Sampler): + """A sampler that samples data from a given concat-dataset. + + Args: + concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets + batch_size (int): batch size + holistic_shuffle (bool): whether to shuffle the whole dataset or not + logger (logging.Logger): logger to print warning message + + Usage: + For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True: + >>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]]))) + [3, 4, 5, 0, 1, 2, 6, 7, 8] + """ + + def __init__( + self, concat_dataset, batch_size, holistic_shuffle, logger=None, type="train" + ): + if not isinstance(concat_dataset, ConcatDataset): + raise ValueError( + "concat_dataset must be an instance of ConcatDataset, but got {}".format( + type(concat_dataset) + ) + ) + if not isinstance(batch_size, int): + raise ValueError( + "batch_size must be an integer, but got {}".format(type(batch_size)) + ) + if not isinstance(holistic_shuffle, bool): + raise ValueError( + "holistic_shuffle must be a boolean, but got {}".format( + type(holistic_shuffle) + ) + ) + + self.concat_dataset = concat_dataset + self.batch_size = batch_size + self.holistic_shuffle = holistic_shuffle + + affected_dataset_name = [] + affected_dataset_len = [] + for dataset in concat_dataset.datasets: + dataset_len = len(dataset) + dataset_name = dataset.get_dataset_name() + if dataset_len < batch_size: + affected_dataset_name.append(dataset_name) + affected_dataset_len.append(dataset_len) + + self.type = type + for dataset_name, dataset_len in zip( + affected_dataset_name, affected_dataset_len + ): + if not type == "valid": + logger.warning( + "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format( + type, dataset_name, dataset_len, batch_size + ) + ) + + def __len__(self): + # the number of batches with drop last + num_of_batches = sum( + [ + math.floor(len(dataset) / self.batch_size) + for dataset in self.concat_dataset.datasets + ] + ) + return num_of_batches * self.batch_size + + def __iter__(self): + iters = [] + for dataset in self.concat_dataset.datasets: + iters.append( + SequentialSampler(dataset).__iter__() + if self.holistic_shuffle + else RandomSampler(dataset).__iter__() + ) + init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1] + output_batches = [] + for dataset_idx in range(len(self.concat_dataset.datasets)): + cur_batch = [] + for idx in iters[dataset_idx]: + cur_batch.append(idx + init_indices[dataset_idx]) + if len(cur_batch) == self.batch_size: + output_batches.append(cur_batch) + cur_batch = [] + if self.type == "valid" and len(cur_batch) > 0: + output_batches.append(cur_batch) + cur_batch = [] + # force drop last in training + random.shuffle(output_batches) + output_indices = [item for sublist in output_batches for item in sublist] + return iter(output_indices) + + +def build_samplers(concat_dataset: Dataset, cfg, logger, type): + sampler = ScheduledSampler( + concat_dataset, + cfg.train.batch_size, + cfg.train.sampler.holistic_shuffle, + logger, + type, + ) + batch_sampler = BatchSampler( + sampler, + cfg.train.batch_size, + cfg.train.sampler.drop_last if not type == "valid" else False, + ) + return sampler, batch_sampler diff --git a/models/vocoders/vocoder_trainer.py b/models/vocoders/vocoder_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5821e735a64f07fcf9c782712670e24ce6a91c04 --- /dev/null +++ b/models/vocoders/vocoder_trainer.py @@ -0,0 +1,180 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import random +from pathlib import Path +import re + +import accelerate +import json5 +import numpy as np +import torch +from accelerate.utils import ProjectConfiguration +from torch.utils.data import DataLoader +from tqdm import tqdm + +from models.vocoders.vocoder_dataset import VocoderConcatDataset +from models.vocoders.vocoder_sampler import build_samplers + + +class VocoderTrainer: + def __init__(self): + super().__init__() + + def _init_accelerator(self): + """Initialize the accelerator components.""" + self.exp_dir = os.path.join( + os.path.abspath(self.cfg.log_dir), self.args.exp_name + ) + project_config = ProjectConfiguration( + project_dir=self.exp_dir, logging_dir=os.path.join(self.exp_dir, "log") + ) + self.accelerator = accelerate.Accelerator( + gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step, + log_with=self.cfg.train.tracker, + project_config=project_config, + ) + if self.accelerator.is_main_process: + os.makedirs(project_config.project_dir, exist_ok=True) + os.makedirs(project_config.logging_dir, exist_ok=True) + with self.accelerator.main_process_first(): + self.accelerator.init_trackers(self.args.exp_name) + + def _build_dataset(self): + pass + + def _build_criterion(self): + pass + + def _build_model(self): + pass + + def _build_dataloader(self): + """Build dataloader which merges a series of datasets.""" + # Build dataset instance for each dataset and combine them by ConcatDataset + Dataset, Collator = self._build_dataset() + + # Build train set + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=False) + datasets_list.append(subdataset) + train_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=True) + train_collate = Collator(self.cfg) + _, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train") + train_loader = DataLoader( + train_dataset, + collate_fn=train_collate, + batch_sampler=batch_sampler, + num_workers=self.cfg.train.dataloader.num_worker, + pin_memory=self.cfg.train.dataloader.pin_memory, + ) + + # Build test set + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=True) + datasets_list.append(subdataset) + valid_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=True) + valid_collate = Collator(self.cfg) + _, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "train") + valid_loader = DataLoader( + valid_dataset, + collate_fn=valid_collate, + batch_sampler=batch_sampler, + num_workers=self.cfg.train.dataloader.num_worker, + pin_memory=self.cfg.train.dataloader.pin_memory, + ) + return train_loader, valid_loader + + def _build_optimizer(self): + pass + + def _build_scheduler(self): + pass + + def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"): + """Load model from checkpoint. If a folder is given, it will + load the latest checkpoint in checkpoint_dir. If a path is given + it will load the checkpoint specified by checkpoint_path. + **Only use this method after** ``accelerator.prepare()``. + """ + if checkpoint_path is None: + ls = [str(i) for i in Path(checkpoint_dir).glob("*")] + ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True) + checkpoint_path = ls[0] + if resume_type == "resume": + self.accelerator.load_state(checkpoint_path) + elif resume_type == "finetune": + accelerate.load_checkpoint_and_dispatch( + self.accelerator.unwrap_model(self.model), + os.path.join(checkpoint_path, "pytorch_model.bin"), + ) + self.logger.info("Load model weights for finetune SUCCESS!") + else: + raise ValueError("Unsupported resume type: {}".format(resume_type)) + self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1 + self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1 + return checkpoint_path + + def train_loop(self): + pass + + def _train_epoch(self): + pass + + def _valid_epoch(self): + pass + + def _train_step(self): + pass + + def _valid_step(self): + pass + + def _inference(self): + pass + + def _set_random_seed(self, seed): + """Set random seed for all possible random modules.""" + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + + def _check_nan(self, loss): + if torch.any(torch.isnan(loss)): + self.logger.fatal("Fatal Error: NaN!") + self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True) + + def _check_basic_configs(self): + if self.cfg.train.gradient_accumulation_step <= 0: + self.logger.fatal("Invalid gradient_accumulation_step value!") + self.logger.error( + f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive." + ) + self.accelerator.end_training() + raise ValueError( + f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive." + ) + + def _count_parameters(self): + pass + + def _dump_cfg(self, path): + os.makedirs(os.path.dirname(path), exist_ok=True) + json5.dump( + self.cfg, + open(path, "w"), + indent=4, + sort_keys=True, + ensure_ascii=False, + quote_keys=True, + ) + + def _is_valid_pattern(self, directory_name): + directory_name = str(directory_name) + pattern = r"^epoch-\d{4}_step-\d{7}_loss-\d{1}\.\d{6}" + return re.match(pattern, directory_name) is not None