akhaliq3
spaces demo
2b7bf83
raw
history blame
No virus
18.3 kB
# -*- coding: utf-8 -*-
# Copyright 2019 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
"""Parallel WaveGAN Modules."""
import logging
import math
import numpy as np
import torch
from parallel_wavegan.layers import Conv1d
from parallel_wavegan.layers import Conv1d1x1
from parallel_wavegan.layers import upsample
from parallel_wavegan.layers import WaveNetResidualBlock as ResidualBlock
from parallel_wavegan import models
from parallel_wavegan.utils import read_hdf5
class ParallelWaveGANGenerator(torch.nn.Module):
"""Parallel WaveGAN Generator module."""
def __init__(
self,
in_channels=1,
out_channels=1,
kernel_size=3,
layers=30,
stacks=3,
residual_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=80,
aux_context_window=2,
dropout=0.0,
bias=True,
use_weight_norm=True,
use_causal_conv=False,
upsample_conditional_features=True,
upsample_net="ConvInUpsampleNetwork",
upsample_params={"upsample_scales": [4, 4, 4, 4]},
):
"""Initialize Parallel WaveGAN Generator module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Kernel size of dilated convolution.
layers (int): Number of residual block layers.
stacks (int): Number of stacks i.e., dilation cycles.
residual_channels (int): Number of channels in residual conv.
gate_channels (int): Number of channels in gated conv.
skip_channels (int): Number of channels in skip conv.
aux_channels (int): Number of channels for auxiliary feature conv.
aux_context_window (int): Context window size for auxiliary feature.
dropout (float): Dropout rate. 0.0 means no dropout applied.
bias (bool): Whether to use bias parameter in conv layer.
use_weight_norm (bool): Whether to use weight norm.
If set to true, it will be applied to all of the conv layers.
use_causal_conv (bool): Whether to use causal structure.
upsample_conditional_features (bool): Whether to use upsampling network.
upsample_net (str): Upsampling network architecture.
upsample_params (dict): Upsampling network parameters.
"""
super(ParallelWaveGANGenerator, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.aux_channels = aux_channels
self.aux_context_window = aux_context_window
self.layers = layers
self.stacks = stacks
self.kernel_size = kernel_size
# check the number of layers and stacks
assert layers % stacks == 0
layers_per_stack = layers // stacks
# define first convolution
self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
# define conv + upsampling network
if upsample_conditional_features:
upsample_params.update(
{
"use_causal_conv": use_causal_conv,
}
)
if upsample_net == "MelGANGenerator":
assert aux_context_window == 0
upsample_params.update(
{
"use_weight_norm": False, # not to apply twice
"use_final_nonlinear_activation": False,
}
)
self.upsample_net = getattr(models, upsample_net)(**upsample_params)
else:
if upsample_net == "ConvInUpsampleNetwork":
upsample_params.update(
{
"aux_channels": aux_channels,
"aux_context_window": aux_context_window,
}
)
self.upsample_net = getattr(upsample, upsample_net)(**upsample_params)
self.upsample_factor = np.prod(upsample_params["upsample_scales"])
else:
self.upsample_net = None
self.upsample_factor = 1
# define residual blocks
self.conv_layers = torch.nn.ModuleList()
for layer in range(layers):
dilation = 2 ** (layer % layers_per_stack)
conv = ResidualBlock(
kernel_size=kernel_size,
residual_channels=residual_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=aux_channels,
dilation=dilation,
dropout=dropout,
bias=bias,
use_causal_conv=use_causal_conv,
)
self.conv_layers += [conv]
# define output layers
self.last_conv_layers = torch.nn.ModuleList(
[
torch.nn.ReLU(inplace=True),
Conv1d1x1(skip_channels, skip_channels, bias=True),
torch.nn.ReLU(inplace=True),
Conv1d1x1(skip_channels, out_channels, bias=True),
]
)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
def forward(self, x, c):
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
c (Tensor): Local conditioning auxiliary features (B, C ,T').
Returns:
Tensor: Output tensor (B, out_channels, T)
"""
# perform upsampling
if c is not None and self.upsample_net is not None:
c = self.upsample_net(c)
assert c.size(-1) == x.size(-1)
# encode to hidden representation
x = self.first_conv(x)
skips = 0
for f in self.conv_layers:
x, h = f(x, c)
skips += h
skips *= math.sqrt(1.0 / len(self.conv_layers))
# apply final layers
x = skips
for f in self.last_conv_layers:
x = f(x)
return x
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m):
try:
logging.debug(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m):
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
@staticmethod
def _get_receptive_field_size(
layers, stacks, kernel_size, dilation=lambda x: 2 ** x
):
assert layers % stacks == 0
layers_per_cycle = layers // stacks
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
return (kernel_size - 1) * sum(dilations) + 1
@property
def receptive_field_size(self):
"""Return receptive field size."""
return self._get_receptive_field_size(
self.layers, self.stacks, self.kernel_size
)
def register_stats(self, stats):
"""Register stats for de-normalization as buffer.
Args:
stats (str): Path of statistics file (".npy" or ".h5").
"""
assert stats.endswith(".h5") or stats.endswith(".npy")
if stats.endswith(".h5"):
mean = read_hdf5(stats, "mean").reshape(-1)
scale = read_hdf5(stats, "scale").reshape(-1)
else:
mean = np.load(stats)[0].reshape(-1)
scale = np.load(stats)[1].reshape(-1)
self.register_buffer("mean", torch.from_numpy(mean).float())
self.register_buffer("scale", torch.from_numpy(scale).float())
logging.info("Successfully registered stats as buffer.")
def inference(self, c=None, x=None, normalize_before=False):
"""Perform inference.
Args:
c (Union[Tensor, ndarray]): Local conditioning auxiliary features (T' ,C).
x (Union[Tensor, ndarray]): Input noise signal (T, 1).
normalize_before (bool): Whether to perform normalization.
Returns:
Tensor: Output tensor (T, out_channels)
"""
if x is not None:
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, dtype=torch.float).to(
next(self.parameters()).device
)
x = x.transpose(1, 0).unsqueeze(0)
else:
assert c is not None
x = torch.randn(1, 1, len(c) * self.upsample_factor).to(
next(self.parameters()).device
)
if c is not None:
if not isinstance(c, torch.Tensor):
c = torch.tensor(c, dtype=torch.float).to(
next(self.parameters()).device
)
if normalize_before:
c = (c - self.mean) / self.scale
c = c.transpose(1, 0).unsqueeze(0)
c = torch.nn.ReplicationPad1d(self.aux_context_window)(c)
return self.forward(x, c).squeeze(0).transpose(1, 0)
class ParallelWaveGANDiscriminator(torch.nn.Module):
"""Parallel WaveGAN Discriminator module."""
def __init__(
self,
in_channels=1,
out_channels=1,
kernel_size=3,
layers=10,
conv_channels=64,
dilation_factor=1,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2},
bias=True,
use_weight_norm=True,
):
"""Initialize Parallel WaveGAN Discriminator module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Number of output channels.
layers (int): Number of conv layers.
conv_channels (int): Number of chnn layers.
dilation_factor (int): Dilation factor. For example, if dilation_factor = 2,
the dilation will be 2, 4, 8, ..., and so on.
nonlinear_activation (str): Nonlinear function after each conv.
nonlinear_activation_params (dict): Nonlinear function parameters
bias (bool): Whether to use bias parameter in conv.
use_weight_norm (bool) Whether to use weight norm.
If set to true, it will be applied to all of the conv layers.
"""
super(ParallelWaveGANDiscriminator, self).__init__()
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
assert dilation_factor > 0, "Dilation factor must be > 0."
self.conv_layers = torch.nn.ModuleList()
conv_in_channels = in_channels
for i in range(layers - 1):
if i == 0:
dilation = 1
else:
dilation = i if dilation_factor == 1 else dilation_factor ** i
conv_in_channels = conv_channels
padding = (kernel_size - 1) // 2 * dilation
conv_layer = [
Conv1d(
conv_in_channels,
conv_channels,
kernel_size=kernel_size,
padding=padding,
dilation=dilation,
bias=bias,
),
getattr(torch.nn, nonlinear_activation)(
inplace=True, **nonlinear_activation_params
),
]
self.conv_layers += conv_layer
padding = (kernel_size - 1) // 2
last_conv_layer = Conv1d(
conv_in_channels,
out_channels,
kernel_size=kernel_size,
padding=padding,
bias=bias,
)
self.conv_layers += [last_conv_layer]
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
Returns:
Tensor: Output tensor (B, 1, T)
"""
for f in self.conv_layers:
x = f(x)
return x
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m):
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m):
try:
logging.debug(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
class ResidualParallelWaveGANDiscriminator(torch.nn.Module):
"""Parallel WaveGAN Discriminator module."""
def __init__(
self,
in_channels=1,
out_channels=1,
kernel_size=3,
layers=30,
stacks=3,
residual_channels=64,
gate_channels=128,
skip_channels=64,
dropout=0.0,
bias=True,
use_weight_norm=True,
use_causal_conv=False,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2},
):
"""Initialize Parallel WaveGAN Discriminator module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Kernel size of dilated convolution.
layers (int): Number of residual block layers.
stacks (int): Number of stacks i.e., dilation cycles.
residual_channels (int): Number of channels in residual conv.
gate_channels (int): Number of channels in gated conv.
skip_channels (int): Number of channels in skip conv.
dropout (float): Dropout rate. 0.0 means no dropout applied.
bias (bool): Whether to use bias parameter in conv.
use_weight_norm (bool): Whether to use weight norm.
If set to true, it will be applied to all of the conv layers.
use_causal_conv (bool): Whether to use causal structure.
nonlinear_activation_params (dict): Nonlinear function parameters
"""
super(ResidualParallelWaveGANDiscriminator, self).__init__()
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
self.in_channels = in_channels
self.out_channels = out_channels
self.layers = layers
self.stacks = stacks
self.kernel_size = kernel_size
# check the number of layers and stacks
assert layers % stacks == 0
layers_per_stack = layers // stacks
# define first convolution
self.first_conv = torch.nn.Sequential(
Conv1d1x1(in_channels, residual_channels, bias=True),
getattr(torch.nn, nonlinear_activation)(
inplace=True, **nonlinear_activation_params
),
)
# define residual blocks
self.conv_layers = torch.nn.ModuleList()
for layer in range(layers):
dilation = 2 ** (layer % layers_per_stack)
conv = ResidualBlock(
kernel_size=kernel_size,
residual_channels=residual_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=-1,
dilation=dilation,
dropout=dropout,
bias=bias,
use_causal_conv=use_causal_conv,
)
self.conv_layers += [conv]
# define output layers
self.last_conv_layers = torch.nn.ModuleList(
[
getattr(torch.nn, nonlinear_activation)(
inplace=True, **nonlinear_activation_params
),
Conv1d1x1(skip_channels, skip_channels, bias=True),
getattr(torch.nn, nonlinear_activation)(
inplace=True, **nonlinear_activation_params
),
Conv1d1x1(skip_channels, out_channels, bias=True),
]
)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
Returns:
Tensor: Output tensor (B, 1, T)
"""
x = self.first_conv(x)
skips = 0
for f in self.conv_layers:
x, h = f(x, None)
skips += h
skips *= math.sqrt(1.0 / len(self.conv_layers))
# apply final layers
x = skips
for f in self.last_conv_layers:
x = f(x)
return x
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m):
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m):
try:
logging.debug(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)