reach-vb's picture
reach-vb HF staff
Stereo demo update (#60)
5325fcc
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...modules import NormConv2d
from .base import MultiDiscriminator, MultiDiscriminatorOutputType
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return int((kernel_size * dilation - dilation) / 2)
class PeriodDiscriminator(nn.Module):
"""Period sub-discriminator.
Args:
period (int): Period between samples of audio.
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
n_layers (int): Number of convolutional layers.
kernel_sizes (list of int): Kernel sizes for convolutions.
stride (int): Stride for convolutions.
filters (int): Initial number of filters in convolutions.
filters_scale (int): Multiplier of number of filters as we increase depth.
max_filters (int): Maximum number of filters.
norm (str): Normalization method.
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function.
"""
def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
norm: str = 'weight_norm', activation: str = 'LeakyReLU',
activation_params: dict = {'negative_slope': 0.2}):
super().__init__()
self.period = period
self.n_layers = n_layers
self.activation = getattr(torch.nn, activation)(**activation_params)
self.convs = nn.ModuleList()
in_chs = in_channels
for i in range(self.n_layers):
out_chs = min(filters * (filters_scale ** (i + 1)), max_filters)
eff_stride = 1 if i == self.n_layers - 1 else stride
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1),
padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm))
in_chs = out_chs
self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1,
padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm)
def forward(self, x: torch.Tensor):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), 'reflect')
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for conv in self.convs:
x = conv(x)
x = self.activation(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
# x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(MultiDiscriminator):
"""Multi-Period (MPD) Discriminator.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
periods (Sequence[int]): Periods between samples of audio for the sub-discriminators.
**kwargs: Additional args for `PeriodDiscriminator`
"""
def __init__(self, in_channels: int = 1, out_channels: int = 1,
periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs):
super().__init__()
self.discriminators = nn.ModuleList([
PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods
])
@property
def num_discriminators(self):
return len(self.discriminators)
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
logits = []
fmaps = []
for disc in self.discriminators:
logit, fmap = disc(x)
logits.append(logit)
fmaps.append(fmap)
return logits, fmaps