File size: 3,484 Bytes
14d1720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import tqdm
import torch
from utils.plotting import get_files
from scipy.io.wavfile import write
import numpy as np

MAX_WAV_VALUE = 32768.0

def validate(hp, args, generator, discriminator, valloader, stft_loss, criterion, writer, step):
    generator.eval()
    discriminator.eval()
    torch.backends.cudnn.benchmark = False

    loader = tqdm.tqdm(valloader, desc='Validation loop')
    loss_g_sum = 0.0
    loss_d_sum = 0.0
    for mel, audio in loader:
        mel = mel.cuda()
        audio = audio.cuda()    # B, 1, T torch.Size([1, 1, 212893])

        # generator
        fake_audio = generator(mel) # B, 1, T' torch.Size([1, 1, 212992])

        disc_fake = discriminator(fake_audio[:, :, :audio.size(2)]) # B, 1, T torch.Size([1, 1, 212893])
        disc_real = discriminator(audio)

        adv_loss =0.0
        loss_d_real = 0.0
        loss_d_fake = 0.0
        sc_loss, mag_loss = stft_loss(fake_audio[:, :, :audio.size(2)].squeeze(1), audio.squeeze(1))
        loss_g = sc_loss + mag_loss


        for (feats_fake, score_fake), (feats_real, score_real) in zip(disc_fake, disc_real):
            adv_loss += criterion(score_fake, torch.ones_like(score_fake))

            if hp.model.feat_loss :
                for feat_f, feat_r in zip(feats_fake, feats_real):
                    adv_loss += hp.model.feat_match * torch.mean(torch.abs(feat_f - feat_r))
            loss_d_real += criterion(score_real, torch.ones_like(score_real))
            loss_d_fake += criterion(score_fake, torch.zeros_like(score_fake))
        adv_loss = adv_loss / len(disc_fake)
        loss_d_real = loss_d_real / len(score_real)
        loss_d_fake = loss_d_fake / len(disc_fake)
        loss_g += hp.model.lambda_adv * adv_loss
        loss_d = loss_d_real + loss_d_fake
        loss_g_sum += loss_g.item()
        loss_d_sum += loss_d.item()

        loader.set_description("g %.04f d %.04f ad %.04f| step %d" % (loss_g, loss_d, adv_loss, step))

    loss_g_avg = loss_g_sum / len(valloader.dataset)
    loss_d_avg = loss_d_sum / len(valloader.dataset)

    audio = audio[0][0].cpu().detach().numpy()
    fake_audio = fake_audio[0][0].cpu().detach().numpy()

    writer.log_validation(loss_g_avg, loss_d_avg, adv_loss, generator, discriminator, audio, fake_audio, step)
    if hp.data.eval_path is not None:
        mel_filename = get_files(hp.data.eval_path , extension = '.npy')
        for j in range(0,len(mel_filename)):
            with torch.no_grad():
                mel = torch.from_numpy(np.load(mel_filename[j]))
                out_path = mel_filename[j].replace('.npy', f'{step}.wav')
                mel_name = mel_filename[j].split("/")[-1].split(".")[0]
                if len(mel.shape) == 2:
                    mel = mel.unsqueeze(0)
                mel = mel.cuda()
                gen_audio = generator.inference(mel)
                gen_audio = gen_audio.squeeze()
                gen_audio = gen_audio[:-(hp.audio.hop_length*10)]
                writer.log_evaluation(gen_audio.cpu().detach().numpy(), step, mel_name)
                gen_audio = MAX_WAV_VALUE * gen_audio
                gen_audio = gen_audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
                gen_audio = gen_audio.short()
                gen_audio = gen_audio.cpu().detach().numpy()

                write(out_path, hp.audio.sampling_rate, gen_audio)


    
    #add evalution code here

    torch.backends.cudnn.benchmark = True
    generator.train()
    discriminator.train()