yuancwang
init
b725c5a
raw
history blame
No virus
2.43 kB
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from models.tts.naturalspeech2.wavenet import WaveNet
class DiffusionFlow(nn.Module):
def __init__(self, cfg):
super().__init__()
self.diff_estimator = WaveNet(cfg.wavenet)
self.beta_min = cfg.beta_min
self.beta_max = cfg.beta_max
self.sigma = cfg.sigma
self.noise_factor = cfg.noise_factor
def forward(self, x, x_mask, cond, spk_query_emb, offset=1e-5):
"""
x: (B, 128, T)
x_mask: (B, T), mask is 0
cond: (B, T, 512)
spk_query_emb: (B, 32, 512)
"""
diffusion_step = torch.rand(
x.shape[0], dtype=x.dtype, device=x.device, requires_grad=False
)
diffusion_step = torch.clamp(diffusion_step, offset, 1.0 - offset)
xt, z = self.forward_diffusion(x0=x, diffusion_step=diffusion_step)
flow_pred = self.diff_estimator(
xt, x_mask, cond, diffusion_step, spk_query_emb
) # noise - x0_pred, noise_pred - x0
noise = z
x0_pred = noise - flow_pred
noise_pred = x + flow_pred
diff_out = {
"x0_pred": x0_pred,
"noise_pred": noise_pred,
"noise": noise,
"flow_pred": flow_pred,
}
return diff_out
@torch.no_grad()
def forward_diffusion(self, x0, diffusion_step):
"""
x0: (B, 128, T)
time_step: (B,)
"""
time_step = diffusion_step.unsqueeze(-1).unsqueeze(-1)
z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, requires_grad=False)
xt = (1 - time_step) * x0 + time_step * z
return xt, z
@torch.no_grad()
def cal_dxt(self, xt, x_mask, cond, spk_query_emb, diffusion_step, h):
flow_pred = self.diff_estimator(
xt, x_mask, cond, diffusion_step, spk_query_emb
) # z - x0 = x1 - x0
dxt = h * flow_pred
return dxt
@torch.no_grad()
def reverse_diffusion(self, z, x_mask, cond, n_timesteps, spk_query_emb):
h = 1.0 / n_timesteps
xt = z
for i in range(n_timesteps):
t = (1.0 - (i + 0.5) * h) * torch.ones(
z.shape[0], dtype=z.dtype, device=z.device
)
dxt = self.cal_dxt(xt, x_mask, cond, spk_query_emb, diffusion_step=t, h=h)
xt = xt - dxt
return xt