# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os import torch import json import json5 import time import accelerate import random import numpy as np import shutil from pathlib import Path from tqdm import tqdm from glob import glob from accelerate.logging import get_logger from torch.utils.data import DataLoader from models.vocoders.vocoder_dataset import ( VocoderDataset, VocoderCollator, VocoderConcatDataset, ) from models.vocoders.gan.generator import bigvgan, hifigan, melgan, nsfhifigan, apnet from models.vocoders.flow.waveglow import waveglow from models.vocoders.diffusion.diffwave import diffwave from models.vocoders.autoregressive.wavenet import wavenet from models.vocoders.autoregressive.wavernn import wavernn from models.vocoders.gan import gan_vocoder_inference from utils.io import save_audio _vocoders = { "diffwave": diffwave.DiffWave, "wavernn": wavernn.WaveRNN, "wavenet": wavenet.WaveNet, "waveglow": waveglow.WaveGlow, "nsfhifigan": nsfhifigan.NSFHiFiGAN, "bigvgan": bigvgan.BigVGAN, "hifigan": hifigan.HiFiGAN, "melgan": melgan.MelGAN, "apnet": apnet.APNet, } _vocoder_infer_funcs = { # "world": world_inference.synthesis_audios, # "wavernn": wavernn_inference.synthesis_audios, # "wavenet": wavenet_inference.synthesis_audios, # "diffwave": diffwave_inference.synthesis_audios, "nsfhifigan": gan_vocoder_inference.synthesis_audios, "bigvgan": gan_vocoder_inference.synthesis_audios, "melgan": gan_vocoder_inference.synthesis_audios, "hifigan": gan_vocoder_inference.synthesis_audios, "apnet": gan_vocoder_inference.synthesis_audios, } class VocoderInference(object): def __init__(self, args=None, cfg=None, infer_type="from_dataset"): super().__init__() start = time.monotonic_ns() self.args = args self.cfg = cfg self.infer_type = infer_type # Init accelerator self.accelerator = accelerate.Accelerator() self.accelerator.wait_for_everyone() # Get logger with self.accelerator.main_process_first(): self.logger = get_logger("inference", log_level=args.log_level) # Log some info self.logger.info("=" * 56) self.logger.info("||\t\t" + "New inference process started." + "\t\t||") self.logger.info("=" * 56) self.logger.info("\n") self.vocoder_dir = args.vocoder_dir self.logger.debug(f"Vocoder dir: {args.vocoder_dir}") os.makedirs(args.output_dir, exist_ok=True) if os.path.exists(os.path.join(args.output_dir, "pred")): shutil.rmtree(os.path.join(args.output_dir, "pred")) if os.path.exists(os.path.join(args.output_dir, "gt")): shutil.rmtree(os.path.join(args.output_dir, "gt")) os.makedirs(os.path.join(args.output_dir, "pred"), exist_ok=True) os.makedirs(os.path.join(args.output_dir, "gt"), exist_ok=True) # Set random seed with self.accelerator.main_process_first(): start = time.monotonic_ns() self._set_random_seed(self.cfg.train.random_seed) end = time.monotonic_ns() self.logger.debug( f"Setting random seed done in {(end - start) / 1e6:.2f}ms" ) self.logger.debug(f"Random seed: {self.cfg.train.random_seed}") # Setup inference mode if self.infer_type == "infer_from_dataset": self.cfg.dataset = self.args.infer_datasets elif self.infer_type == "infer_from_feature": self._build_tmp_dataset_from_feature() self.cfg.dataset = ["tmp"] elif self.infer_type == "infer_from_audio": self._build_tmp_dataset_from_audio() self.cfg.dataset = ["tmp"] # Setup data loader with self.accelerator.main_process_first(): self.logger.info("Building dataset...") start = time.monotonic_ns() self.test_dataloader = self._build_dataloader() end = time.monotonic_ns() self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms") # Build model with self.accelerator.main_process_first(): self.logger.info("Building model...") start = time.monotonic_ns() self.model = self._build_model() end = time.monotonic_ns() self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms") # Init with accelerate self.logger.info("Initializing accelerate...") start = time.monotonic_ns() self.accelerator = accelerate.Accelerator() (self.model, self.test_dataloader) = self.accelerator.prepare( self.model, self.test_dataloader ) end = time.monotonic_ns() self.accelerator.wait_for_everyone() self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms") with self.accelerator.main_process_first(): self.logger.info("Loading checkpoint...") start = time.monotonic_ns() if os.path.isdir(args.vocoder_dir): if os.path.isdir(os.path.join(args.vocoder_dir, "checkpoint")): self._load_model(os.path.join(args.vocoder_dir, "checkpoint")) else: self._load_model(os.path.join(args.vocoder_dir)) else: self._load_model(os.path.join(args.vocoder_dir)) end = time.monotonic_ns() self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms") self.model.eval() self.accelerator.wait_for_everyone() def _build_tmp_dataset_from_feature(self): if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")): shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) utts = [] mels = glob(os.path.join(self.args.feature_folder, "mels", "*.npy")) for i, mel in enumerate(mels): uid = mel.split("/")[-1].split(".")[0] utt = {"Dataset": "tmp", "Uid": uid, "index": i} utts.append(utt) os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) with open( os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w" ) as f: json.dump(utts, f) meta_info = {"dataset": "tmp", "test": {"size": len(utts)}} with open( os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"), "w", ) as f: json.dump(meta_info, f) features = glob(os.path.join(self.args.feature_folder, "*")) for feature in features: feature_name = feature.split("/")[-1] if os.path.isfile(feature): continue shutil.copytree( os.path.join(self.args.feature_folder, feature_name), os.path.join(self.cfg.preprocess.processed_dir, "tmp", feature_name), ) def _build_tmp_dataset_from_audio(self): if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")): shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) utts = [] audios = glob(os.path.join(self.args.audio_folder, "*")) for i, audio in enumerate(audios): uid = audio.split("/")[-1].split(".")[0] utt = {"Dataset": "tmp", "Uid": uid, "index": i, "Path": audio} utts.append(utt) os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) with open( os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w" ) as f: json.dump(utts, f) meta_info = {"dataset": "tmp", "test": {"size": len(utts)}} with open( os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"), "w", ) as f: json.dump(meta_info, f) from processors import acoustic_extractor acoustic_extractor.extract_utt_acoustic_features_serial( utts, os.path.join(self.cfg.preprocess.processed_dir, "tmp"), self.cfg ) def _build_test_dataset(self): return VocoderDataset, VocoderCollator def _build_model(self): model = _vocoders[self.cfg.model.generator](self.cfg) return model def _build_dataloader(self): """Build dataloader which merges a series of datasets.""" Dataset, Collator = self._build_test_dataset() datasets_list = [] for dataset in self.cfg.dataset: subdataset = Dataset(self.cfg, dataset, is_valid=True) datasets_list.append(subdataset) test_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=False) test_collate = Collator(self.cfg) test_batch_size = min(self.cfg.inference.batch_size, len(test_dataset)) test_dataloader = DataLoader( test_dataset, collate_fn=test_collate, num_workers=1, batch_size=test_batch_size, shuffle=False, ) self.test_batch_size = test_batch_size self.test_dataset = test_dataset return test_dataloader def _load_model(self, checkpoint_dir, from_multi_gpu=False): """Load model from checkpoint. If a folder is given, it will load the latest checkpoint in checkpoint_dir. If a path is given it will load the checkpoint specified by checkpoint_path. **Only use this method after** ``accelerator.prepare()``. """ if os.path.isdir(checkpoint_dir): if "epoch" in checkpoint_dir and "step" in checkpoint_dir: checkpoint_path = checkpoint_dir else: # Load the latest accelerator state dicts ls = [ str(i) for i in Path(checkpoint_dir).glob("*") if not "audio" in str(i) ] ls.sort( key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True ) checkpoint_path = ls[0] accelerate.load_checkpoint_and_dispatch( self.accelerator.unwrap_model(self.model), os.path.join(checkpoint_path, "pytorch_model.bin"), ) return str(checkpoint_path) else: # Load old .pt checkpoints if self.cfg.model.generator in [ "bigvgan", "hifigan", "melgan", "nsfhifigan", ]: ckpt = torch.load( checkpoint_dir, map_location=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), ) if from_multi_gpu: pretrained_generator_dict = ckpt["generator_state_dict"] generator_dict = self.model.state_dict() new_generator_dict = { k.split("module.")[-1]: v for k, v in pretrained_generator_dict.items() if ( k.split("module.")[-1] in generator_dict and v.shape == generator_dict[k.split("module.")[-1]].shape ) } generator_dict.update(new_generator_dict) self.model.load_state_dict(generator_dict) else: self.model.load_state_dict(ckpt["generator_state_dict"]) else: self.model.load_state_dict(torch.load(checkpoint_dir)["state_dict"]) return str(checkpoint_dir) def inference(self): """Inference via batches""" for i, batch in tqdm(enumerate(self.test_dataloader)): if self.cfg.preprocess.use_frame_pitch: audio_pred = self.model.forward( batch["mel"].transpose(-1, -2), batch["frame_pitch"].float() ).cpu() elif self.cfg.preprocess.extract_amplitude_phase: audio_pred = self.model.forward(batch["mel"].transpose(-1, -2))[-1] else: audio_pred = ( self.model.forward(batch["mel"].transpose(-1, -2)).detach().cpu() ) audio_ls = audio_pred.chunk(self.test_batch_size) audio_gt_ls = batch["audio"].cpu().chunk(self.test_batch_size) length_ls = batch["target_len"].cpu().chunk(self.test_batch_size) j = 0 for it, it_gt, l in zip(audio_ls, audio_gt_ls, length_ls): l = l.item() it = it.squeeze(0).squeeze(0)[: l * self.cfg.preprocess.hop_size] it_gt = it_gt.squeeze(0)[: l * self.cfg.preprocess.hop_size] uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"] save_audio( os.path.join(self.args.output_dir, "pred", "{}.wav").format(uid), it, self.cfg.preprocess.sample_rate, ) save_audio( os.path.join(self.args.output_dir, "gt", "{}.wav").format(uid), it_gt, self.cfg.preprocess.sample_rate, ) j += 1 if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")): shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) def _set_random_seed(self, seed): """Set random seed for all possible random modules.""" random.seed(seed) np.random.seed(seed) torch.random.manual_seed(seed) def _count_parameters(self, model): return sum(p.numel() for p in model.parameters()) def _dump_cfg(self, path): os.makedirs(os.path.dirname(path), exist_ok=True) json5.dump( self.cfg, open(path, "w"), indent=4, sort_keys=True, ensure_ascii=False, quote_keys=True, ) def load_nnvocoder( cfg, vocoder_name, weights_file, from_multi_gpu=False, ): """Load the specified vocoder. cfg: the vocoder config filer. weights_file: a folder or a .pt path. from_multi_gpu: automatically remove the "module" string in state dicts if "True". """ print("Loading Vocoder from Weights file: {}".format(weights_file)) # Build model model = _vocoders[vocoder_name](cfg) if not os.path.isdir(weights_file): # Load from .pt file if vocoder_name in ["bigvgan", "hifigan", "melgan", "nsfhifigan"]: ckpt = torch.load( weights_file, map_location=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), ) if from_multi_gpu: pretrained_generator_dict = ckpt["generator_state_dict"] generator_dict = model.state_dict() new_generator_dict = { k.split("module.")[-1]: v for k, v in pretrained_generator_dict.items() if ( k.split("module.")[-1] in generator_dict and v.shape == generator_dict[k.split("module.")[-1]].shape ) } generator_dict.update(new_generator_dict) model.load_state_dict(generator_dict) else: model.load_state_dict(ckpt["generator_state_dict"]) else: model.load_state_dict(torch.load(weights_file)["state_dict"]) else: # Load from accelerator state dict weights_file = os.path.join(weights_file, "checkpoint") ls = [str(i) for i in Path(weights_file).glob("*") if not "audio" in str(i)] ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True) checkpoint_path = ls[0] accelerator = accelerate.Accelerator() model = accelerator.prepare(model) accelerator.load_state(checkpoint_path) if torch.cuda.is_available(): model = model.cuda() model = model.eval() return model def tensorize(data, device, n_samples): """ data: a list of numpy array """ assert type(data) == list if n_samples: data = data[:n_samples] data = [torch.as_tensor(x, device=device) for x in data] return data def synthesis( cfg, vocoder_weight_file, n_samples, pred, f0s=None, batch_size=64, fast_inference=False, ): """Synthesis audios from a given vocoder and series of given features. cfg: vocoder config. vocoder_weight_file: a folder of accelerator state dict or a path to the .pt file. pred: a list of numpy arrays. [(seq_len1, acoustic_features_dim), (seq_len2, acoustic_features_dim), ...] """ vocoder_name = cfg.model.generator print("Synthesis audios using {} vocoder...".format(vocoder_name)) ###### TODO: World Vocoder Refactor ###### # if vocoder_name == "world": # world_inference.synthesis_audios( # cfg, dataset_name, split, n_samples, pred, save_dir, tag # ) # return # ====== Loading neural vocoder model ====== vocoder = load_nnvocoder( cfg, vocoder_name, weights_file=vocoder_weight_file, from_multi_gpu=True ) device = next(vocoder.parameters()).device # ====== Inference for predicted acoustic features ====== # pred: (frame_len, n_mels) -> (n_mels, frame_len) mels_pred = tensorize([p.T for p in pred], device, n_samples) print("For predicted mels, #sample = {}...".format(len(mels_pred))) audios_pred = _vocoder_infer_funcs[vocoder_name]( cfg, vocoder, mels_pred, f0s=f0s, batch_size=batch_size, fast_inference=fast_inference, ) return audios_pred