File size: 4,281 Bytes
0c54e86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch

from Layers.ResidualBlock import HiFiGANResidualBlock as ResidualBlock


class HiFiGANGenerator(torch.nn.Module):

    def __init__(self,
                 path_to_weights,
                 in_channels=80,
                 out_channels=1,
                 channels=512,
                 kernel_size=7,
                 upsample_scales=(8, 6, 4, 4),
                 upsample_kernel_sizes=(16, 12, 8, 8),
                 resblock_kernel_sizes=(3, 7, 11),
                 resblock_dilations=[(1, 3, 5), (1, 3, 5), (1, 3, 5)],
                 use_additional_convs=True,
                 bias=True,
                 nonlinear_activation="LeakyReLU",
                 nonlinear_activation_params={"negative_slope": 0.1},
                 use_weight_norm=True, ):
        super().__init__()
        assert kernel_size % 2 == 1, "Kernal size must be odd number."
        assert len(upsample_scales) == len(upsample_kernel_sizes)
        assert len(resblock_dilations) == len(resblock_kernel_sizes)
        self.num_upsamples = len(upsample_kernel_sizes)
        self.num_blocks = len(resblock_kernel_sizes)
        self.input_conv = torch.nn.Conv1d(in_channels,
                                          channels,
                                          kernel_size,
                                          1,
                                          padding=(kernel_size - 1) // 2, )
        self.upsamples = torch.nn.ModuleList()
        self.blocks = torch.nn.ModuleList()
        for i in range(len(upsample_kernel_sizes)):
            self.upsamples += [
                torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
                                    torch.nn.ConvTranspose1d(channels // (2 ** i),
                                                             channels // (2 ** (i + 1)),
                                                             upsample_kernel_sizes[i],
                                                             upsample_scales[i],
                                                             padding=(upsample_kernel_sizes[i] - upsample_scales[i]) // 2, ), )]
            for j in range(len(resblock_kernel_sizes)):
                self.blocks += [ResidualBlock(kernel_size=resblock_kernel_sizes[j],
                                              channels=channels // (2 ** (i + 1)),
                                              dilations=resblock_dilations[j],
                                              bias=bias,
                                              use_additional_convs=use_additional_convs,
                                              nonlinear_activation=nonlinear_activation,
                                              nonlinear_activation_params=nonlinear_activation_params, )]
        self.output_conv = torch.nn.Sequential(
            torch.nn.LeakyReLU(),
            torch.nn.Conv1d(channels // (2 ** (i + 1)),
                            out_channels,
                            kernel_size,
                            1,
                            padding=(kernel_size - 1) // 2, ),
            torch.nn.Tanh(), )
        if use_weight_norm:
            self.apply_weight_norm()
        self.load_state_dict(torch.load(path_to_weights, map_location='cpu')["generator"])

    def forward(self, c, normalize_before=False):
        if normalize_before:
            c = (c - self.mean) / self.scale
        c = self.input_conv(c.unsqueeze(0))
        for i in range(self.num_upsamples):
            c = self.upsamples[i](c)
            cs = 0.0  # initialize
            for j in range(self.num_blocks):
                cs = cs + self.blocks[i * self.num_blocks + j](c)
            c = cs / self.num_blocks
        c = self.output_conv(c)
        return c.squeeze(0).squeeze(0)

    def remove_weight_norm(self):
        def _remove_weight_norm(m):
            try:
                torch.nn.utils.remove_weight_norm(m)
            except ValueError:
                return

        self.apply(_remove_weight_norm)

    def apply_weight_norm(self):
        def _apply_weight_norm(m):
            if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
                torch.nn.utils.weight_norm(m)

        self.apply(_apply_weight_norm)