Spaces:
Sleeping
Sleeping
import tqdm | |
import torch | |
from utils.plotting import get_files | |
from scipy.io.wavfile import write | |
import numpy as np | |
MAX_WAV_VALUE = 32768.0 | |
def validate(hp, args, generator, discriminator, valloader, stft_loss, criterion, writer, step): | |
generator.eval() | |
discriminator.eval() | |
torch.backends.cudnn.benchmark = False | |
loader = tqdm.tqdm(valloader, desc='Validation loop') | |
loss_g_sum = 0.0 | |
loss_d_sum = 0.0 | |
for mel, audio in loader: | |
mel = mel.cuda() | |
audio = audio.cuda() # B, 1, T torch.Size([1, 1, 212893]) | |
# generator | |
fake_audio = generator(mel) # B, 1, T' torch.Size([1, 1, 212992]) | |
disc_fake = discriminator(fake_audio[:, :, :audio.size(2)]) # B, 1, T torch.Size([1, 1, 212893]) | |
disc_real = discriminator(audio) | |
adv_loss =0.0 | |
loss_d_real = 0.0 | |
loss_d_fake = 0.0 | |
sc_loss, mag_loss = stft_loss(fake_audio[:, :, :audio.size(2)].squeeze(1), audio.squeeze(1)) | |
loss_g = sc_loss + mag_loss | |
for (feats_fake, score_fake), (feats_real, score_real) in zip(disc_fake, disc_real): | |
adv_loss += criterion(score_fake, torch.ones_like(score_fake)) | |
if hp.model.feat_loss : | |
for feat_f, feat_r in zip(feats_fake, feats_real): | |
adv_loss += hp.model.feat_match * torch.mean(torch.abs(feat_f - feat_r)) | |
loss_d_real += criterion(score_real, torch.ones_like(score_real)) | |
loss_d_fake += criterion(score_fake, torch.zeros_like(score_fake)) | |
adv_loss = adv_loss / len(disc_fake) | |
loss_d_real = loss_d_real / len(score_real) | |
loss_d_fake = loss_d_fake / len(disc_fake) | |
loss_g += hp.model.lambda_adv * adv_loss | |
loss_d = loss_d_real + loss_d_fake | |
loss_g_sum += loss_g.item() | |
loss_d_sum += loss_d.item() | |
loader.set_description("g %.04f d %.04f ad %.04f| step %d" % (loss_g, loss_d, adv_loss, step)) | |
loss_g_avg = loss_g_sum / len(valloader.dataset) | |
loss_d_avg = loss_d_sum / len(valloader.dataset) | |
audio = audio[0][0].cpu().detach().numpy() | |
fake_audio = fake_audio[0][0].cpu().detach().numpy() | |
writer.log_validation(loss_g_avg, loss_d_avg, adv_loss, generator, discriminator, audio, fake_audio, step) | |
if hp.data.eval_path is not None: | |
mel_filename = get_files(hp.data.eval_path , extension = '.npy') | |
for j in range(0,len(mel_filename)): | |
with torch.no_grad(): | |
mel = torch.from_numpy(np.load(mel_filename[j])) | |
out_path = mel_filename[j].replace('.npy', f'{step}.wav') | |
mel_name = mel_filename[j].split("/")[-1].split(".")[0] | |
if len(mel.shape) == 2: | |
mel = mel.unsqueeze(0) | |
mel = mel.cuda() | |
gen_audio = generator.inference(mel) | |
gen_audio = gen_audio.squeeze() | |
gen_audio = gen_audio[:-(hp.audio.hop_length*10)] | |
writer.log_evaluation(gen_audio.cpu().detach().numpy(), step, mel_name) | |
gen_audio = MAX_WAV_VALUE * gen_audio | |
gen_audio = gen_audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1) | |
gen_audio = gen_audio.short() | |
gen_audio = gen_audio.cpu().detach().numpy() | |
write(out_path, hp.audio.sampling_rate, gen_audio) | |
#add evalution code here | |
torch.backends.cudnn.benchmark = True | |
generator.train() | |
discriminator.train() | |