Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import numpy as np | |
from torch.nn import Conv1d | |
from torch.nn import ConvTranspose1d | |
from torch.nn.utils import weight_norm | |
from torch.nn.utils import remove_weight_norm | |
from .nsf import SourceModuleHnNSF | |
from .bigv import init_weights, SnakeBeta, AMPBlock | |
from .alias import Activation1d | |
class SpeakerAdapter(nn.Module): | |
def __init__(self, | |
speaker_dim, | |
adapter_dim, | |
epsilon=1e-5 | |
): | |
super(SpeakerAdapter, self).__init__() | |
self.speaker_dim = speaker_dim | |
self.adapter_dim = adapter_dim | |
self.epsilon = epsilon | |
self.W_scale = nn.Linear(self.speaker_dim, self.adapter_dim) | |
self.W_bias = nn.Linear(self.speaker_dim, self.adapter_dim) | |
self.reset_parameters() | |
def reset_parameters(self): | |
torch.nn.init.constant_(self.W_scale.weight, 0.0) | |
torch.nn.init.constant_(self.W_scale.bias, 1.0) | |
torch.nn.init.constant_(self.W_bias.weight, 0.0) | |
torch.nn.init.constant_(self.W_bias.bias, 0.0) | |
def forward(self, x, speaker_embedding): | |
x = x.transpose(1, -1) | |
mean = x.mean(dim=-1, keepdim=True) | |
var = ((x - mean) ** 2).mean(dim=-1, keepdim=True) | |
std = (var + self.epsilon).sqrt() | |
y = (x - mean) / std | |
scale = self.W_scale(speaker_embedding) | |
bias = self.W_bias(speaker_embedding) | |
y *= scale.unsqueeze(1) | |
y += bias.unsqueeze(1) | |
y = y.transpose(1, -1) | |
return y | |
class Generator(torch.nn.Module): | |
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks. | |
def __init__(self, hp): | |
super(Generator, self).__init__() | |
self.hp = hp | |
self.num_kernels = len(hp.gen.resblock_kernel_sizes) | |
self.num_upsamples = len(hp.gen.upsample_rates) | |
# speaker adaper, 256 should change by what speaker encoder you use | |
self.adapter = SpeakerAdapter(hp.vits.spk_dim, hp.gen.upsample_input) | |
# pre conv | |
self.conv_pre = nn.utils.weight_norm( | |
Conv1d(hp.gen.upsample_input, hp.gen.upsample_initial_channel, 7, 1, padding=3)) | |
# nsf | |
self.f0_upsamp = torch.nn.Upsample( | |
scale_factor=np.prod(hp.gen.upsample_rates)) | |
self.m_source = SourceModuleHnNSF() | |
self.noise_convs = nn.ModuleList() | |
# transposed conv-based upsamplers. does not apply anti-aliasing | |
self.ups = nn.ModuleList() | |
for i, (u, k) in enumerate(zip(hp.gen.upsample_rates, hp.gen.upsample_kernel_sizes)): | |
# print(f'ups: {i} {k}, {u}, {(k - u) // 2}') | |
# base | |
self.ups.append(nn.ModuleList([ | |
weight_norm(ConvTranspose1d(hp.gen.upsample_initial_channel // (2 ** i), | |
hp.gen.upsample_initial_channel // ( | |
2 ** (i + 1)), | |
k, u, padding=(k - u) // 2)) | |
])) | |
# nsf | |
if i + 1 < len(hp.gen.upsample_rates): | |
stride_f0 = np.prod(hp.gen.upsample_rates[i + 1:]) | |
stride_f0 = int(stride_f0) | |
self.noise_convs.append( | |
Conv1d( | |
1, | |
hp.gen.upsample_initial_channel // (2 ** (i + 1)), | |
kernel_size=stride_f0 * 2, | |
stride=stride_f0, | |
padding=stride_f0 // 2, | |
) | |
) | |
else: | |
self.noise_convs.append( | |
Conv1d(1, hp.gen.upsample_initial_channel // | |
(2 ** (i + 1)), kernel_size=1) | |
) | |
# residual blocks using anti-aliased multi-periodicity composition modules (AMP) | |
self.resblocks = nn.ModuleList() | |
for i in range(len(self.ups)): | |
ch = hp.gen.upsample_initial_channel // (2 ** (i + 1)) | |
for k, d in zip(hp.gen.resblock_kernel_sizes, hp.gen.resblock_dilation_sizes): | |
self.resblocks.append(AMPBlock(hp, ch, k, d)) | |
# post conv | |
activation_post = SnakeBeta(ch, alpha_logscale=True) | |
self.activation_post = Activation1d(activation=activation_post) | |
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) | |
# weight initialization | |
for i in range(len(self.ups)): | |
self.ups[i].apply(init_weights) | |
self.conv_post.apply(init_weights) | |
def forward(self, spk, x, f0): | |
# adapter | |
x = self.adapter(x, spk) | |
# nsf | |
f0 = f0[:, None] | |
f0 = self.f0_upsamp(f0).transpose(1, 2) | |
har_source = self.m_source(f0) | |
har_source = har_source.transpose(1, 2) | |
x = self.conv_pre(x) | |
for i in range(self.num_upsamples): | |
# upsampling | |
for i_up in range(len(self.ups[i])): | |
x = self.ups[i][i_up](x) | |
# nsf | |
x_source = self.noise_convs[i](har_source) | |
x = x + x_source | |
# AMP blocks | |
xs = None | |
for j in range(self.num_kernels): | |
if xs is None: | |
xs = self.resblocks[i * self.num_kernels + j](x) | |
else: | |
xs += self.resblocks[i * self.num_kernels + j](x) | |
x = xs / self.num_kernels | |
# post conv | |
x = self.activation_post(x) | |
x = self.conv_post(x) | |
x = torch.tanh(x) | |
return x | |
def remove_weight_norm(self): | |
for l in self.ups: | |
for l_i in l: | |
remove_weight_norm(l_i) | |
for l in self.resblocks: | |
l.remove_weight_norm() | |
remove_weight_norm(self.conv_pre) | |
remove_weight_norm(self.conv_post) | |
def eval(self, inference=False): | |
super(Generator, self).eval() | |
# don't remove weight norm while validation in training loop | |
if inference: | |
self.remove_weight_norm() | |
def pitch2source(self, f0): | |
f0 = f0[:, None] | |
f0 = self.f0_upsamp(f0).transpose(1, 2) # [1,len,1] | |
har_source = self.m_source(f0) | |
har_source = har_source.transpose(1, 2) # [1,1,len] | |
return har_source | |
def source2wav(self, audio): | |
MAX_WAV_VALUE = 32768.0 | |
audio = audio.squeeze() | |
audio = MAX_WAV_VALUE * audio | |
audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1) | |
audio = audio.short() | |
return audio.cpu().detach().numpy() | |
def inference(self, spk, x, har_source): | |
# adapter | |
x = self.adapter(x, spk) | |
x = self.conv_pre(x) | |
for i in range(self.num_upsamples): | |
# upsampling | |
for i_up in range(len(self.ups[i])): | |
x = self.ups[i][i_up](x) | |
# nsf | |
x_source = self.noise_convs[i](har_source) | |
x = x + x_source | |
# AMP blocks | |
xs = None | |
for j in range(self.num_kernels): | |
if xs is None: | |
xs = self.resblocks[i * self.num_kernels + j](x) | |
else: | |
xs += self.resblocks[i * self.num_kernels + j](x) | |
x = xs / self.num_kernels | |
# post conv | |
x = self.activation_post(x) | |
x = self.conv_post(x) | |
x = torch.tanh(x) | |
return x | |