Florian Lux
initial commit
cea6632
raw
history blame
4.28 kB
import torch
from Layers.ResidualBlock import HiFiGANResidualBlock as ResidualBlock
class HiFiGANGenerator(torch.nn.Module):
def __init__(self,
path_to_weights,
in_channels=80,
out_channels=1,
channels=512,
kernel_size=7,
upsample_scales=(8, 6, 4, 4),
upsample_kernel_sizes=(16, 12, 8, 8),
resblock_kernel_sizes=(3, 7, 11),
resblock_dilations=[(1, 3, 5), (1, 3, 5), (1, 3, 5)],
use_additional_convs=True,
bias=True,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.1},
use_weight_norm=True, ):
super().__init__()
assert kernel_size % 2 == 1, "Kernal size must be odd number."
assert len(upsample_scales) == len(upsample_kernel_sizes)
assert len(resblock_dilations) == len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_kernel_sizes)
self.num_blocks = len(resblock_kernel_sizes)
self.input_conv = torch.nn.Conv1d(in_channels,
channels,
kernel_size,
1,
padding=(kernel_size - 1) // 2, )
self.upsamples = torch.nn.ModuleList()
self.blocks = torch.nn.ModuleList()
for i in range(len(upsample_kernel_sizes)):
self.upsamples += [
torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
torch.nn.ConvTranspose1d(channels // (2 ** i),
channels // (2 ** (i + 1)),
upsample_kernel_sizes[i],
upsample_scales[i],
padding=(upsample_kernel_sizes[i] - upsample_scales[i]) // 2, ), )]
for j in range(len(resblock_kernel_sizes)):
self.blocks += [ResidualBlock(kernel_size=resblock_kernel_sizes[j],
channels=channels // (2 ** (i + 1)),
dilations=resblock_dilations[j],
bias=bias,
use_additional_convs=use_additional_convs,
nonlinear_activation=nonlinear_activation,
nonlinear_activation_params=nonlinear_activation_params, )]
self.output_conv = torch.nn.Sequential(
torch.nn.LeakyReLU(),
torch.nn.Conv1d(channels // (2 ** (i + 1)),
out_channels,
kernel_size,
1,
padding=(kernel_size - 1) // 2, ),
torch.nn.Tanh(), )
if use_weight_norm:
self.apply_weight_norm()
self.load_state_dict(torch.load(path_to_weights, map_location='cpu')["generator"])
def forward(self, c, normalize_before=False):
if normalize_before:
c = (c - self.mean) / self.scale
c = self.input_conv(c.unsqueeze(0))
for i in range(self.num_upsamples):
c = self.upsamples[i](c)
cs = 0.0 # initialize
for j in range(self.num_blocks):
cs = cs + self.blocks[i * self.num_blocks + j](c)
c = cs / self.num_blocks
c = self.output_conv(c)
return c.squeeze(0).squeeze(0)
def remove_weight_norm(self):
def _remove_weight_norm(m):
try:
torch.nn.utils.remove_weight_norm(m)
except ValueError:
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
def _apply_weight_norm(m):
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
torch.nn.utils.weight_norm(m)
self.apply(_apply_weight_norm)