Spaces:
Running
Running
# Copyright (c) 2022 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/jik876/hifi-gan under the MIT license. | |
# LICENSE is in incl_licenses directory. | |
import torch | |
from alias_free_torch import Activation1d | |
from torch.nn import Conv1d | |
from torch.nn import ConvTranspose1d | |
from torch.nn import ModuleList | |
from torch.nn.utils import remove_weight_norm | |
from torch.nn.utils import weight_norm | |
from TrainingInterfaces.Spectrogram_to_Wave.BigVGAN.AMP import AMPBlock1 | |
from TrainingInterfaces.Spectrogram_to_Wave.BigVGAN.Snake import SnakeBeta | |
class BigVGAN(torch.nn.Module): | |
# this is the main BigVGAN model. Applies anti-aliased periodic activation for resblocks. | |
def __init__(self, | |
path_to_weights, | |
num_mels=80, | |
upsample_initial_channel=512, | |
upsample_rates=(8, 6, 4, 2), # CAREFUL: Avocodo discriminator assumes that there are always 4 upsample scales, because it takes intermediate results. | |
upsample_kernel_sizes=(16, 12, 8, 4), | |
resblock_kernel_sizes=(3, 7, 11), | |
resblock_dilation_sizes=((1, 3, 5), (1, 3, 5), (1, 3, 5)), | |
): | |
super(BigVGAN, self).__init__() | |
self.num_kernels = len(resblock_kernel_sizes) | |
self.num_upsamples = len(upsample_rates) | |
# pre conv | |
self.conv_pre = weight_norm(Conv1d(num_mels, upsample_initial_channel, 7, 1, padding=3)) | |
# transposed conv-based upsamplers. does not apply anti-aliasing | |
self.ups = ModuleList() | |
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): | |
self.ups.append(ModuleList([ | |
weight_norm(ConvTranspose1d(upsample_initial_channel // (2 ** i), | |
upsample_initial_channel // (2 ** (i + 1)), | |
k, u, padding=(k - u) // 2)) | |
])) | |
# residual blocks using anti-aliased multi-periodicity composition modules (AMP) | |
self.resblocks = ModuleList() | |
for i in range(len(self.ups)): | |
ch = upsample_initial_channel // (2 ** (i + 1)) | |
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): | |
self.resblocks.append(AMPBlock1(ch, k, d)) | |
# post conv | |
activation_post = SnakeBeta(ch, alpha_logscale=True) | |
self.activation_post = Activation1d(activation=activation_post) | |
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) | |
# weight initialization | |
for i in range(len(self.ups)): | |
self.ups[i].apply(init_weights) | |
self.conv_post.apply(init_weights) | |
# for Avocodo discriminator | |
self.out_proj_x1 = torch.nn.Conv1d(512 // 4, 1, 7, 1, padding=3) | |
self.out_proj_x2 = torch.nn.Conv1d(512 // 8, 1, 7, 1, padding=3) | |
self.load_state_dict(torch.load(path_to_weights, map_location='cpu')["generator"]) | |
def forward(self, x): | |
x = x.unsqueeze(0) | |
# pre conv | |
x = self.conv_pre(x) | |
for i in range(self.num_upsamples): | |
# upsampling | |
for i_up in range(len(self.ups[i])): | |
x = self.ups[i][i_up](x) | |
# AMP blocks | |
xs = None | |
for j in range(self.num_kernels): | |
if xs is None: | |
xs = self.resblocks[i * self.num_kernels + j](x) | |
else: | |
xs += self.resblocks[i * self.num_kernels + j](x) | |
x = xs / self.num_kernels | |
# post conv | |
x = self.activation_post(x) | |
x = self.conv_post(x) | |
x = torch.tanh(x) | |
return x.squeeze() | |
def remove_weight_norm(self): | |
# print('Removing weight norm...') | |
for l in self.ups: | |
for l_i in l: | |
remove_weight_norm(l_i) | |
for l in self.resblocks: | |
l.remove_weight_norm() | |
remove_weight_norm(self.conv_pre) | |
remove_weight_norm(self.conv_post) | |
def init_weights(m, mean=0.0, std=0.01): | |
classname = m.__class__.__name__ | |
if classname.find("Conv") != -1: | |
m.weight.data.normal_(mean, std) | |
def apply_weight_norm(m): | |
classname = m.__class__.__name__ | |
if classname.find("Conv") != -1: | |
weight_norm(m) | |
def get_padding(kernel_size, dilation=1): | |
return int((kernel_size * dilation - dilation) / 2) | |