File size: 4,176 Bytes
5325fcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...modules import NormConv2d
from .base import MultiDiscriminator, MultiDiscriminatorOutputType


def get_padding(kernel_size: int, dilation: int = 1) -> int:
    return int((kernel_size * dilation - dilation) / 2)


class PeriodDiscriminator(nn.Module):
    """Period sub-discriminator.

    Args:
        period (int): Period between samples of audio.
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        n_layers (int): Number of convolutional layers.
        kernel_sizes (list of int): Kernel sizes for convolutions.
        stride (int): Stride for convolutions.
        filters (int): Initial number of filters in convolutions.
        filters_scale (int): Multiplier of number of filters as we increase depth.
        max_filters (int): Maximum number of filters.
        norm (str): Normalization method.
        activation (str): Activation function.
        activation_params (dict): Parameters to provide to the activation function.
    """
    def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
                 n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
                 filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
                 norm: str = 'weight_norm', activation: str = 'LeakyReLU',
                 activation_params: dict = {'negative_slope': 0.2}):
        super().__init__()
        self.period = period
        self.n_layers = n_layers
        self.activation = getattr(torch.nn, activation)(**activation_params)
        self.convs = nn.ModuleList()
        in_chs = in_channels
        for i in range(self.n_layers):
            out_chs = min(filters * (filters_scale ** (i + 1)), max_filters)
            eff_stride = 1 if i == self.n_layers - 1 else stride
            self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1),
                                         padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm))
            in_chs = out_chs
        self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1,
                                    padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm)

    def forward(self, x: torch.Tensor):
        fmap = []
        # 1d to 2d
        b, c, t = x.shape
        if t % self.period != 0:  # pad first
            n_pad = self.period - (t % self.period)
            x = F.pad(x, (0, n_pad), 'reflect')
            t = t + n_pad
        x = x.view(b, c, t // self.period, self.period)

        for conv in self.convs:
            x = conv(x)
            x = self.activation(x)
            fmap.append(x)
        x = self.conv_post(x)
        fmap.append(x)
        # x = torch.flatten(x, 1, -1)

        return x, fmap


class MultiPeriodDiscriminator(MultiDiscriminator):
    """Multi-Period (MPD) Discriminator.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        periods (Sequence[int]): Periods between samples of audio for the sub-discriminators.
        **kwargs: Additional args for `PeriodDiscriminator`
    """
    def __init__(self, in_channels: int = 1, out_channels: int = 1,
                 periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs):
        super().__init__()
        self.discriminators = nn.ModuleList([
            PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods
        ])

    @property
    def num_discriminators(self):
        return len(self.discriminators)

    def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
        logits = []
        fmaps = []
        for disc in self.discriminators:
            logit, fmap = disc(x)
            logits.append(logit)
            fmaps.append(fmap)
        return logits, fmaps