import json import os import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm from .env import AttrDict from .utils import get_padding, init_weights LRELU_SLOPE = 0.1 def load_model(model_path, device='cuda'): h = load_config(model_path) generator = Generator(h).to(device) cp_dict = torch.load(model_path, map_location=device) generator.load_state_dict(cp_dict['generator']) generator.eval() generator.remove_weight_norm() del cp_dict return generator, h def load_config(model_path): config_file = os.path.join(os.path.split(model_path)[0], 'config.json') with open(config_file) as f: data = f.read() json_config = json.loads(data) h = AttrDict(json_config) return h class ResBlock1(torch.nn.Module): def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): super(ResBlock1, self).__init__() self.h = h self.convs1 = nn.ModuleList([ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2]))) ]) self.convs1.apply(init_weights) self.convs2 = nn.ModuleList([ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))) ]) self.convs2.apply(init_weights) def forward(self, x): for c1, c2 in zip(self.convs1, self.convs2): xt = F.leaky_relu(x, LRELU_SLOPE) xt = c1(xt) xt = F.leaky_relu(xt, LRELU_SLOPE) xt = c2(xt) x = xt + x return x def remove_weight_norm(self): for l in self.convs1: remove_weight_norm(l) for l in self.convs2: remove_weight_norm(l) class ResBlock2(torch.nn.Module): def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): super(ResBlock2, self).__init__() self.h = h self.convs = nn.ModuleList([ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))) ]) self.convs.apply(init_weights) def forward(self, x): for c in self.convs: xt = F.leaky_relu(x, LRELU_SLOPE) xt = c(xt) x = xt + x return x def remove_weight_norm(self): for l in self.convs: remove_weight_norm(l) class SineGen(torch.nn.Module): """ Definition of sine generator SineGen(samp_rate, harmonic_num = 0, sine_amp = 0.1, noise_std = 0.003, voiced_threshold = 0, flag_for_pulse=False) samp_rate: sampling rate in Hz harmonic_num: number of harmonic overtones (default 0) sine_amp: amplitude of sine-wavefrom (default 0.1) noise_std: std of Gaussian noise (default 0.003) voiced_thoreshold: F0 threshold for U/V classification (default 0) flag_for_pulse: this SinGen is used inside PulseGen (default False) Note: when flag_for_pulse is True, the first time step of a voiced segment is always sin(np.pi) or cos(0) """ def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0): super(SineGen, self).__init__() self.sine_amp = sine_amp self.noise_std = noise_std self.harmonic_num = harmonic_num self.dim = self.harmonic_num + 1 self.sampling_rate = samp_rate self.voiced_threshold = voiced_threshold def _f02uv(self, f0): # generate uv signal uv = torch.ones_like(f0) uv = uv * (f0 > self.voiced_threshold) return uv @torch.no_grad() def forward(self, f0, upp): """ sine_tensor, uv = forward(f0) input F0: tensor(batchsize=1, length, dim=1) f0 for unvoiced steps should be 0 output sine_tensor: tensor(batchsize=1, length, dim) output uv: tensor(batchsize=1, length, 1) """ f0 = f0.unsqueeze(-1) fn = torch.multiply(f0, torch.arange(1, self.dim + 1, device=f0.device).reshape((1, 1, -1))) rad_values = (fn / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化 rand_ini = torch.rand(fn.shape[0], fn.shape[2], device=fn.device) rand_ini[:, 0] = 0 rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini is_half = rad_values.dtype is not torch.float32 tmp_over_one = torch.cumsum(rad_values.double(), 1) # % 1 #####%1意味着后面的cumsum无法再优化 if is_half: tmp_over_one = tmp_over_one.half() else: tmp_over_one = tmp_over_one.float() tmp_over_one *= upp tmp_over_one = F.interpolate( tmp_over_one.transpose(2, 1), scale_factor=upp, mode='linear', align_corners=True ).transpose(2, 1) rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1) tmp_over_one %= 1 tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 cumsum_shift = torch.zeros_like(rad_values) cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 rad_values = rad_values.double() cumsum_shift = cumsum_shift.double() sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi) if is_half: sine_waves = sine_waves.half() else: sine_waves = sine_waves.float() sine_waves = sine_waves * self.sine_amp uv = self._f02uv(f0) uv = F.interpolate(uv.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1) noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 noise = noise_amp * torch.randn_like(sine_waves) sine_waves = sine_waves * uv + noise return sine_waves, uv, noise class SourceModuleHnNSF(torch.nn.Module): """ SourceModule for hn-nsf SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0) sampling_rate: sampling_rate in Hz harmonic_num: number of harmonic above F0 (default: 0) sine_amp: amplitude of sine source signal (default: 0.1) add_noise_std: std of additive Gaussian noise (default: 0.003) note that amplitude of noise in unvoiced is decided by sine_amp voiced_threshold: threhold to set U/V given F0 (default: 0) Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) F0_sampled (batchsize, length, 1) Sine_source (batchsize, length, 1) noise_source (batchsize, length 1) uv (batchsize, length, 1) """ def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0): super(SourceModuleHnNSF, self).__init__() self.sine_amp = sine_amp self.noise_std = add_noise_std # to produce sine waveforms self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod) # to merge source harmonics into a single excitation self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) self.l_tanh = torch.nn.Tanh() def forward(self, x, upp): sine_wavs, uv, _ = self.l_sin_gen(x, upp) sine_merge = self.l_tanh(self.l_linear(sine_wavs)) return sine_merge class Generator(torch.nn.Module): def __init__(self, h): super(Generator, self).__init__() self.h = h self.num_kernels = len(h.resblock_kernel_sizes) self.num_upsamples = len(h.upsample_rates) self.m_source = SourceModuleHnNSF( sampling_rate=h.sampling_rate, harmonic_num=8 ) self.noise_convs = nn.ModuleList() self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)) resblock = ResBlock1 if h.resblock == '1' else ResBlock2 self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): c_cur = h.upsample_initial_channel // (2 ** (i + 1)) self.ups.append(weight_norm( ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2))) if i + 1 < len(h.upsample_rates): # stride_f0 = int(np.prod(h.upsample_rates[i + 1:])) self.noise_convs.append(Conv1d( 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2)) else: self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) self.resblocks = nn.ModuleList() ch = h.upsample_initial_channel for i in range(len(self.ups)): ch //= 2 for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): self.resblocks.append(resblock(h, ch, k, d)) self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) self.ups.apply(init_weights) self.conv_post.apply(init_weights) self.upp = int(np.prod(h.upsample_rates)) def forward(self, x, f0): har_source = self.m_source(f0, self.upp).transpose(1, 2) x = self.conv_pre(x) for i in range(self.num_upsamples): x = F.leaky_relu(x, LRELU_SLOPE) x = self.ups[i](x) x_source = self.noise_convs[i](har_source) x = x + x_source xs = None for j in range(self.num_kernels): if xs is None: xs = self.resblocks[i * self.num_kernels + j](x) else: xs += self.resblocks[i * self.num_kernels + j](x) x = xs / self.num_kernels x = F.leaky_relu(x) x = self.conv_post(x) x = torch.tanh(x) return x def remove_weight_norm(self): print('Removing weight norm...') for l in self.ups: remove_weight_norm(l) for l in self.resblocks: l.remove_weight_norm() remove_weight_norm(self.conv_pre) remove_weight_norm(self.conv_post) class DiscriminatorP(torch.nn.Module): def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): super(DiscriminatorP, self).__init__() self.period = period norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList([ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), ]) self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) def forward(self, x): fmap = [] # 1d to 2d b, c, t = x.shape if t % self.period != 0: # pad first n_pad = self.period - (t % self.period) x = F.pad(x, (0, n_pad), "reflect") t = t + n_pad x = x.view(b, c, t // self.period, self.period) for l in self.convs: x = l(x) x = F.leaky_relu(x, LRELU_SLOPE) fmap.append(x) x = self.conv_post(x) fmap.append(x) x = torch.flatten(x, 1, -1) return x, fmap class MultiPeriodDiscriminator(torch.nn.Module): def __init__(self, periods=None): super(MultiPeriodDiscriminator, self).__init__() self.periods = periods if periods is not None else [2, 3, 5, 7, 11] self.discriminators = nn.ModuleList() for period in self.periods: self.discriminators.append(DiscriminatorP(period)) def forward(self, y, y_hat): y_d_rs = [] y_d_gs = [] fmap_rs = [] fmap_gs = [] for i, d in enumerate(self.discriminators): y_d_r, fmap_r = d(y) y_d_g, fmap_g = d(y_hat) y_d_rs.append(y_d_r) fmap_rs.append(fmap_r) y_d_gs.append(y_d_g) fmap_gs.append(fmap_g) return y_d_rs, y_d_gs, fmap_rs, fmap_gs class DiscriminatorS(torch.nn.Module): def __init__(self, use_spectral_norm=False): super(DiscriminatorS, self).__init__() norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList([ norm_f(Conv1d(1, 128, 15, 1, padding=7)), norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), ]) self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) def forward(self, x): fmap = [] for l in self.convs: x = l(x) x = F.leaky_relu(x, LRELU_SLOPE) fmap.append(x) x = self.conv_post(x) fmap.append(x) x = torch.flatten(x, 1, -1) return x, fmap class MultiScaleDiscriminator(torch.nn.Module): def __init__(self): super(MultiScaleDiscriminator, self).__init__() self.discriminators = nn.ModuleList([ DiscriminatorS(use_spectral_norm=True), DiscriminatorS(), DiscriminatorS(), ]) self.meanpools = nn.ModuleList([ AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2) ]) def forward(self, y, y_hat): y_d_rs = [] y_d_gs = [] fmap_rs = [] fmap_gs = [] for i, d in enumerate(self.discriminators): if i != 0: y = self.meanpools[i - 1](y) y_hat = self.meanpools[i - 1](y_hat) y_d_r, fmap_r = d(y) y_d_g, fmap_g = d(y_hat) y_d_rs.append(y_d_r) fmap_rs.append(fmap_r) y_d_gs.append(y_d_g) fmap_gs.append(fmap_g) return y_d_rs, y_d_gs, fmap_rs, fmap_gs def feature_loss(fmap_r, fmap_g): loss = 0 for dr, dg in zip(fmap_r, fmap_g): for rl, gl in zip(dr, dg): loss += torch.mean(torch.abs(rl - gl)) return loss * 2 def discriminator_loss(disc_real_outputs, disc_generated_outputs): loss = 0 r_losses = [] g_losses = [] for dr, dg in zip(disc_real_outputs, disc_generated_outputs): r_loss = torch.mean((1 - dr) ** 2) g_loss = torch.mean(dg ** 2) loss += (r_loss + g_loss) r_losses.append(r_loss.item()) g_losses.append(g_loss.item()) return loss, r_losses, g_losses def generator_loss(disc_outputs): loss = 0 gen_losses = [] for dg in disc_outputs: l = torch.mean((1 - dg) ** 2) gen_losses.append(l) loss += l return loss, gen_losses