wuxulong19950206
First model version
14d1720
import tqdm
import torch
from utils.plotting import get_files
from scipy.io.wavfile import write
import numpy as np
MAX_WAV_VALUE = 32768.0
def validate(hp, args, generator, discriminator, valloader, stft_loss, criterion, writer, step):
generator.eval()
discriminator.eval()
torch.backends.cudnn.benchmark = False
loader = tqdm.tqdm(valloader, desc='Validation loop')
loss_g_sum = 0.0
loss_d_sum = 0.0
for mel, audio in loader:
mel = mel.cuda()
audio = audio.cuda() # B, 1, T torch.Size([1, 1, 212893])
# generator
fake_audio = generator(mel) # B, 1, T' torch.Size([1, 1, 212992])
disc_fake = discriminator(fake_audio[:, :, :audio.size(2)]) # B, 1, T torch.Size([1, 1, 212893])
disc_real = discriminator(audio)
adv_loss =0.0
loss_d_real = 0.0
loss_d_fake = 0.0
sc_loss, mag_loss = stft_loss(fake_audio[:, :, :audio.size(2)].squeeze(1), audio.squeeze(1))
loss_g = sc_loss + mag_loss
for (feats_fake, score_fake), (feats_real, score_real) in zip(disc_fake, disc_real):
adv_loss += criterion(score_fake, torch.ones_like(score_fake))
if hp.model.feat_loss :
for feat_f, feat_r in zip(feats_fake, feats_real):
adv_loss += hp.model.feat_match * torch.mean(torch.abs(feat_f - feat_r))
loss_d_real += criterion(score_real, torch.ones_like(score_real))
loss_d_fake += criterion(score_fake, torch.zeros_like(score_fake))
adv_loss = adv_loss / len(disc_fake)
loss_d_real = loss_d_real / len(score_real)
loss_d_fake = loss_d_fake / len(disc_fake)
loss_g += hp.model.lambda_adv * adv_loss
loss_d = loss_d_real + loss_d_fake
loss_g_sum += loss_g.item()
loss_d_sum += loss_d.item()
loader.set_description("g %.04f d %.04f ad %.04f| step %d" % (loss_g, loss_d, adv_loss, step))
loss_g_avg = loss_g_sum / len(valloader.dataset)
loss_d_avg = loss_d_sum / len(valloader.dataset)
audio = audio[0][0].cpu().detach().numpy()
fake_audio = fake_audio[0][0].cpu().detach().numpy()
writer.log_validation(loss_g_avg, loss_d_avg, adv_loss, generator, discriminator, audio, fake_audio, step)
if hp.data.eval_path is not None:
mel_filename = get_files(hp.data.eval_path , extension = '.npy')
for j in range(0,len(mel_filename)):
with torch.no_grad():
mel = torch.from_numpy(np.load(mel_filename[j]))
out_path = mel_filename[j].replace('.npy', f'{step}.wav')
mel_name = mel_filename[j].split("/")[-1].split(".")[0]
if len(mel.shape) == 2:
mel = mel.unsqueeze(0)
mel = mel.cuda()
gen_audio = generator.inference(mel)
gen_audio = gen_audio.squeeze()
gen_audio = gen_audio[:-(hp.audio.hop_length*10)]
writer.log_evaluation(gen_audio.cpu().detach().numpy(), step, mel_name)
gen_audio = MAX_WAV_VALUE * gen_audio
gen_audio = gen_audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
gen_audio = gen_audio.short()
gen_audio = gen_audio.cpu().detach().numpy()
write(out_path, hp.audio.sampling_rate, gen_audio)
#add evalution code here
torch.backends.cudnn.benchmark = True
generator.train()
discriminator.train()