Spaces:
Runtime error
Runtime error
File size: 7,669 Bytes
64db264 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
#coding:utf-8
import os
import paddle
from paddle import nn
from munch import Munch
from starganv2vc_paddle.transforms import build_transforms
import paddle.nn.functional as F
import numpy as np
def compute_d_loss(nets, args, x_real, y_org, y_trg, z_trg=None, x_ref=None, use_r1_reg=True, use_adv_cls=False, use_con_reg=False):
args = Munch(args)
assert (z_trg is None) != (x_ref is None)
# with real audios
x_real.stop_gradient = False
out = nets.discriminator(x_real, y_org)
loss_real = adv_loss(out, 1)
# R1 regularizaition (https://arxiv.org/abs/1801.04406v4)
if use_r1_reg:
loss_reg = r1_reg(out, x_real)
else:
loss_reg = paddle.to_tensor([0.], dtype=paddle.float32)
# consistency regularization (bCR-GAN: https://arxiv.org/abs/2002.04724)
loss_con_reg = paddle.to_tensor([0.], dtype=paddle.float32)
if use_con_reg:
t = build_transforms()
out_aug = nets.discriminator(t(x_real).detach(), y_org)
loss_con_reg += F.smooth_l1_loss(out, out_aug)
# with fake audios
with paddle.no_grad():
if z_trg is not None:
s_trg = nets.mapping_network(z_trg, y_trg)
else: # x_ref is not None
s_trg = nets.style_encoder(x_ref, y_trg)
F0 = nets.f0_model.get_feature_GAN(x_real)
x_fake = nets.generator(x_real, s_trg, masks=None, F0=F0)
out = nets.discriminator(x_fake, y_trg)
loss_fake = adv_loss(out, 0)
if use_con_reg:
out_aug = nets.discriminator(t(x_fake).detach(), y_trg)
loss_con_reg += F.smooth_l1_loss(out, out_aug)
# adversarial classifier loss
if use_adv_cls:
out_de = nets.discriminator.classifier(x_fake)
loss_real_adv_cls = F.cross_entropy(out_de[y_org != y_trg], y_org[y_org != y_trg])
if use_con_reg:
out_de_aug = nets.discriminator.classifier(t(x_fake).detach())
loss_con_reg += F.smooth_l1_loss(out_de, out_de_aug)
else:
loss_real_adv_cls = paddle.zeros([1]).mean()
loss = loss_real + loss_fake + args.lambda_reg * loss_reg + \
args.lambda_adv_cls * loss_real_adv_cls + \
args.lambda_con_reg * loss_con_reg
return loss, Munch(real=loss_real.item(),
fake=loss_fake.item(),
reg=loss_reg.item(),
real_adv_cls=loss_real_adv_cls.item(),
con_reg=loss_con_reg.item())
def compute_g_loss(nets, args, x_real, y_org, y_trg, z_trgs=None, x_refs=None, use_adv_cls=False):
args = Munch(args)
assert (z_trgs is None) != (x_refs is None)
if z_trgs is not None:
z_trg, z_trg2 = z_trgs
if x_refs is not None:
x_ref, x_ref2 = x_refs
# compute style vectors
if z_trgs is not None:
s_trg = nets.mapping_network(z_trg, y_trg)
else:
s_trg = nets.style_encoder(x_ref, y_trg)
# compute ASR/F0 features (real)
with paddle.no_grad():
F0_real, GAN_F0_real, cyc_F0_real = nets.f0_model(x_real)
ASR_real = nets.asr_model.get_feature(x_real)
# adversarial loss
x_fake = nets.generator(x_real, s_trg, masks=None, F0=GAN_F0_real)
out = nets.discriminator(x_fake, y_trg)
loss_adv = adv_loss(out, 1)
# compute ASR/F0 features (fake)
F0_fake, GAN_F0_fake, _ = nets.f0_model(x_fake)
ASR_fake = nets.asr_model.get_feature(x_fake)
# norm consistency loss
x_fake_norm = log_norm(x_fake)
x_real_norm = log_norm(x_real)
loss_norm = ((paddle.nn.ReLU()(paddle.abs(x_fake_norm - x_real_norm) - args.norm_bias))**2).mean()
# F0 loss
loss_f0 = f0_loss(F0_fake, F0_real)
# style F0 loss (style initialization)
if x_refs is not None and args.lambda_f0_sty > 0 and not use_adv_cls:
F0_sty, _, _ = nets.f0_model(x_ref)
loss_f0_sty = F.l1_loss(compute_mean_f0(F0_fake), compute_mean_f0(F0_sty))
else:
loss_f0_sty = paddle.zeros([1]).mean()
# ASR loss
loss_asr = F.smooth_l1_loss(ASR_fake, ASR_real)
# style reconstruction loss
s_pred = nets.style_encoder(x_fake, y_trg)
loss_sty = paddle.mean(paddle.abs(s_pred - s_trg))
# diversity sensitive loss
if z_trgs is not None:
s_trg2 = nets.mapping_network(z_trg2, y_trg)
else:
s_trg2 = nets.style_encoder(x_ref2, y_trg)
x_fake2 = nets.generator(x_real, s_trg2, masks=None, F0=GAN_F0_real)
x_fake2 = x_fake2.detach()
_, GAN_F0_fake2, _ = nets.f0_model(x_fake2)
loss_ds = paddle.mean(paddle.abs(x_fake - x_fake2))
loss_ds += F.smooth_l1_loss(GAN_F0_fake, GAN_F0_fake2.detach())
# cycle-consistency loss
s_org = nets.style_encoder(x_real, y_org)
x_rec = nets.generator(x_fake, s_org, masks=None, F0=GAN_F0_fake)
loss_cyc = paddle.mean(paddle.abs(x_rec - x_real))
# F0 loss in cycle-consistency loss
if args.lambda_f0 > 0:
_, _, cyc_F0_rec = nets.f0_model(x_rec)
loss_cyc += F.smooth_l1_loss(cyc_F0_rec, cyc_F0_real)
if args.lambda_asr > 0:
ASR_recon = nets.asr_model.get_feature(x_rec)
loss_cyc += F.smooth_l1_loss(ASR_recon, ASR_real)
# adversarial classifier loss
if use_adv_cls:
out_de = nets.discriminator.classifier(x_fake)
loss_adv_cls = F.cross_entropy(out_de[y_org != y_trg], y_trg[y_org != y_trg])
else:
loss_adv_cls = paddle.zeros([1]).mean()
loss = args.lambda_adv * loss_adv + args.lambda_sty * loss_sty \
- args.lambda_ds * loss_ds + args.lambda_cyc * loss_cyc\
+ args.lambda_norm * loss_norm \
+ args.lambda_asr * loss_asr \
+ args.lambda_f0 * loss_f0 \
+ args.lambda_f0_sty * loss_f0_sty \
+ args.lambda_adv_cls * loss_adv_cls
return loss, Munch(adv=loss_adv.item(),
sty=loss_sty.item(),
ds=loss_ds.item(),
cyc=loss_cyc.item(),
norm=loss_norm.item(),
asr=loss_asr.item(),
f0=loss_f0.item(),
adv_cls=loss_adv_cls.item())
# for norm consistency loss
def log_norm(x, mean=-4, std=4, axis=2):
"""
normalized log mel -> mel -> norm -> log(norm)
"""
x = paddle.log(paddle.exp(x * std + mean).norm(axis=axis))
return x
# for adversarial loss
def adv_loss(logits, target):
assert target in [1, 0]
if len(logits.shape) > 1:
logits = logits.reshape([-1])
targets = paddle.full_like(logits, fill_value=target)
logits = logits.clip(min=-10, max=10) # prevent nan
loss = F.binary_cross_entropy_with_logits(logits, targets)
return loss
# for R1 regularization loss
def r1_reg(d_out, x_in):
# zero-centered gradient penalty for real images
batch_size = x_in.shape[0]
grad_dout = paddle.grad(
outputs=d_out.sum(), inputs=x_in,
create_graph=True, retain_graph=True, only_inputs=True
)[0]
grad_dout2 = grad_dout.pow(2)
assert(grad_dout2.shape == x_in.shape)
reg = 0.5 * grad_dout2.reshape((batch_size, -1)).sum(1).mean(0)
return reg
# for F0 consistency loss
def compute_mean_f0(f0):
f0_mean = f0.mean(-1)
f0_mean = f0_mean.expand((f0.shape[-1], f0_mean.shape[0])).transpose((1, 0)) # (B, M)
return f0_mean
def f0_loss(x_f0, y_f0):
"""
x.shape = (B, 1, M, L): predict
y.shape = (B, 1, M, L): target
"""
# compute the mean
x_mean = compute_mean_f0(x_f0)
y_mean = compute_mean_f0(y_f0)
loss = F.l1_loss(x_f0 / x_mean, y_f0 / y_mean)
return loss |