import logging import os import pathlib import random import sys from typing import Dict import lightning.pytorch as pl import matplotlib import numpy as np import torch.utils.data from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_only from matplotlib import pyplot as plt from torch import nn from torch.utils.data import Dataset from torchmetrics import Metric, MeanMetric import utils from models.univnet.univnet import UnivNet # from models.lvc_ddspgan.lvc_ddspgan import DDSPgan # from models.nsf_HiFigan.models import Generator, AttrDict, MultiScaleDiscriminator, MultiPeriodDiscriminator from modules.loss.univloss import univloss from modules.univ_D.discriminator import MultiPeriodDiscriminator, MultiResSpecDiscriminator from training.base_task_gan import GanBaseTask from utils.wav2mel import PitchAdjustableMelSpectrogram def spec_to_figure(spec, vmin=None, vmax=None): if isinstance(spec, torch.Tensor): spec = spec.cpu().numpy() fig = plt.figure(figsize=(12, 9),dpi=100) plt.pcolor(spec.T, vmin=vmin, vmax=vmax) plt.tight_layout() return fig class nsf_HiFigan_dataset(Dataset): def __init__(self, config: dict, data_dir, infer=False): super().__init__() self.config = config self.data_dir = data_dir if isinstance(data_dir, pathlib.Path) else pathlib.Path(data_dir) with open(self.data_dir, 'r', encoding='utf8') as f: fills = f.read().strip().split('\n') self.data_index = fills self.infer = infer self.volume_aug = self.config['volume_aug'] self.volume_aug_prob = self.config['volume_aug_prob'] if not infer else 0 def __getitem__(self, index): data_path = self.data_index[index] data = np.load(data_path) return {'f0':data['f0'],'spectrogram':data['mel'],'audio':data['audio']} def __len__(self): return len(self.data_index) def collater(self, minibatch): samples_per_frame = self.config['hop_size'] if self.infer: crop_mel_frames = 0 else: crop_mel_frames = self.config['crop_mel_frames'] for record in minibatch: # Filter out records that aren't long enough. if len(record['spectrogram']) < crop_mel_frames: del record['spectrogram'] del record['audio'] del record['f0'] continue start = random.randint(0, record['spectrogram'].shape[0] - 1 - crop_mel_frames) end = start + crop_mel_frames if self.infer: record['spectrogram'] = record['spectrogram'].T record['f0'] = record['f0'] else: record['spectrogram'] = record['spectrogram'][start:end].T record['f0'] = record['f0'][start:end] start *= samples_per_frame end *= samples_per_frame if self.infer: cty=(len(record['spectrogram'].T) * samples_per_frame) record['audio'] = record['audio'][:cty] record['audio'] = np.pad(record['audio'], ( 0, (len(record['spectrogram'].T) * samples_per_frame) - len(record['audio'])), mode='constant') pass else: # record['spectrogram'] = record['spectrogram'][start:end].T record['audio'] = record['audio'][start:end] record['audio'] = np.pad(record['audio'], (0, (end - start) - len(record['audio'])), mode='constant') if self.volume_aug: for record in minibatch: if random.random() < self.volume_aug_prob: audio = record['audio'] audio_mel = record['spectrogram'] max_amp = float(np.max(np.abs(audio))) + 1e-5 max_shift = min(3, np.log(1 / max_amp)) log_mel_shift = random.uniform(-3, max_shift) # audio *= (10 ** log_mel_shift) audio *= np.exp(log_mel_shift) audio_mel += log_mel_shift audio_mel = torch.clamp(torch.from_numpy(audio_mel), min=np.log(1e-5)).numpy() record['audio'] = audio record['spectrogram'] = audio_mel audio = np.stack([record['audio'] for record in minibatch if 'audio' in record]) spectrogram = np.stack([record['spectrogram'] for record in minibatch if 'spectrogram' in record]) f0 = np.stack([record['f0'] for record in minibatch if 'f0' in record]) return { 'audio': torch.from_numpy(audio).unsqueeze(1), 'mel': torch.from_numpy(spectrogram), 'f0': torch.from_numpy(f0), } class stftlog: def __init__(self, n_fft=2048, win_length=2048, hop_length=512, center=False,): self.hop_length=hop_length self.win_size=win_length self.n_fft = n_fft self.win_size = win_length self.center = center self.hann_window = {} def exc(self,y): hann_window_key = f"{y.device}" if hann_window_key not in self.hann_window: self.hann_window[hann_window_key] = torch.hann_window( self.win_size, device=y.device ) y = torch.nn.functional.pad( y.unsqueeze(1), ( int((self.win_size - self.hop_length) // 2), int((self.win_size - self.hop_length+1) // 2), ), mode="reflect", ) y = y.squeeze(1) spec = torch.stft( y, self.n_fft, hop_length=self.hop_length, win_length=self.win_size, window=self.hann_window[hann_window_key], center=self.center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True, ).abs() return spec class univnet_task(GanBaseTask): def __init__(self, config): super().__init__(config) self.TF = PitchAdjustableMelSpectrogram( f_min=0, f_max=None, n_mels=256,) self.logged_gt_wav = set() self.stft=stftlog() upmel = config['model_args'].get('upmel') self.upmel=upmel # if upmel is not None: # self.noisec=config['model_args']['cond_in_channels']*upmel # else: self.noisec = config['model_args']['cond_in_channels'] def build_dataset(self): self.train_dataset = nsf_HiFigan_dataset(config=self.config, data_dir=pathlib.Path(self.config['DataIndexPath']) / self.config[ 'train_set_name']) self.valid_dataset = nsf_HiFigan_dataset(config=self.config, data_dir=pathlib.Path(self.config['DataIndexPath']) / self.config[ 'valid_set_name'], infer=True) def build_model(self): # cfg=self.config['model_args'] # cfg.update({'sampling_rate':self.config['audio_sample_rate'],'num_mels':self.config['audio_num_mel_bins'],'hop_size':self.config['hop_size']}) # h=AttrDict(cfg) self.generator=UnivNet(self.config,use_weight_norm=self.config['model_args'].get('use_weight_norm',True)) self.discriminator=nn.ModuleDict({'mrd':MultiResSpecDiscriminator(fft_sizes=self.config['model_args'].get('mrd_fft_sizes',[1024, 2048, 512]), hop_sizes=self.config['model_args'].get('mrd_hop_sizes',[120, 240, 50]), win_lengths= self.config['model_args'].get('mrd_win_lengths',[600, 1200, 240]),), 'mpd':MultiPeriodDiscriminator(periods=self.config['model_args']['discriminator_periods'])}) def build_losses_and_metrics(self): self.mix_loss=univloss(self.config) def Gforward(self, sample, infer=False): """ steps: 1. run the full model 2. calculate losses if not infer """ mel=sample['mel'] if self.upmel is not None: x=torch.randn(mel.size()[0],self.noisec,mel.size()[-1]*self.upmel).to(mel) else: x = torch.randn(mel.size()[0], self.noisec, mel.size()[-1]).to(mel) wav=self.generator(x=x,c=mel, ) return {'audio':wav} def Dforward(self, Goutput): mrd_out,mrd_feature=self.discriminator['mrd'](Goutput) mpd_out,mpd_feature=self.discriminator['mpd'](Goutput) return (mrd_out,mrd_feature),(mpd_out,mpd_feature) # def _training_step(self, sample, batch_idx): # """ # :return: total loss: torch.Tensor, loss_log: dict, other_log: dict # # """ # # log_diet = {} # opt_g, opt_d = self.optimizers() # # forward generator start # Goutput = self.Gforward(sample=sample) #y_g_hat =Goutput # # forward generator start # # #forward discriminator start # # Dfake = self.Dforward(Goutput=Goutput['audio'].detach()) #y_g_hat =Goutput # Dtrue = self.Dforward(Goutput=sample['audio']) #y =sample['audio'] # Dloss, Dlog = self.mix_loss.Dloss(Dfake=Dfake, Dtrue=Dtrue) # log_diet.update(Dlog) # # forward discriminator end # #opt discriminator start # opt_d.zero_grad() #clean discriminator grad # self.manual_backward(Dloss) # opt_d.step() # # opt discriminator end # # opt generator start # GDfake = self.Dforward(Goutput=Goutput['audio']) # GDtrue=self.Dforward(Goutput=sample['audio']) # GDloss, GDlog = self.mix_loss.GDloss(GDfake=GDfake,GDtrue=GDtrue) # log_diet.update(GDlog) # Auxloss, Auxlog = self.mix_loss.Auxloss(Goutput=Goutput, sample=sample) # # log_diet.update(Auxlog) # Gloss=GDloss + Auxloss # # opt_g.zero_grad() #clean generator grad # self.manual_backward(Gloss) # opt_g.step() # # opt generator end # return log_diet def _validation_step(self, sample, batch_idx): wav=self.Gforward(sample)['audio'] with torch.no_grad(): # self.TF = self.TF.cpu() # mels = torch.log10(torch.clamp(self.TF(wav.squeeze(0).cpu().float()), min=1e-5)) # GTmels = torch.log10(torch.clamp(self.TF(sample['audio'].squeeze(0).cpu().float()), min=1e-5)) stfts=self.stft.exc(wav.squeeze(0).cpu().float()) Gstfts=self.stft.exc(sample['audio'].squeeze(0).cpu().float()) Gstfts_log10=torch.log10(torch.clamp(Gstfts, min=1e-7)) Gstfts_log = torch.log(torch.clamp(Gstfts, min=1e-7)) stfts_log10=torch.log10(torch.clamp(stfts, min=1e-7)) stfts_log= torch.log(torch.clamp(stfts, min=1e-7)) # self.plot_mel(batch_idx, GTmels.transpose(1,2), mels.transpose(1,2), name=f'diffmel_{batch_idx}') self.plot_mel(batch_idx, Gstfts_log10.transpose(1,2), stfts_log10.transpose(1,2), name=f'HIFImel_{batch_idx}/log10') # self.plot_mel(batch_idx, Gstfts_log.transpose(1, 2), stfts_log.transpose(1, 2), name=f'HIFImel_{batch_idx}/log') self.logger.experiment.add_audio(f'diff_{batch_idx}_', wav, sample_rate=self.config['audio_sample_rate'], global_step=self.global_step) if batch_idx not in self.logged_gt_wav: # gt_wav = self.vocoder.spec2wav(gt_mel, f0=f0) self.logger.experiment.add_audio(f'gt_{batch_idx}_', sample['audio'], sample_rate=self.config['audio_sample_rate'], global_step=self.global_step) self.logged_gt_wav.add(batch_idx) return {'l1loss':nn.L1Loss()(wav, sample['audio'])}, 1 def plot_mel(self, batch_idx, spec, spec_out, name=None): name = f'mel_{batch_idx}' if name is None else name vmin = self.config['mel_vmin'] vmax = self.config['mel_vmax'] spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1) self.logger.experiment.add_figure(name, spec_to_figure(spec_cat[0], vmin, vmax), self.global_step)