File size: 2,956 Bytes
0d80816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

import torch.nn as nn

from modules.general.utils import Conv1d, zero_module
from .residual_block import ResidualBlock


class BiDilConv(nn.Module):
    r"""Dilated CNN architecture with residual connections, default diffusion decoder.

    Args:
        input_channel: The number of input channels.
        base_channel: The number of base channels.
        n_res_block: The number of residual blocks.
        conv_kernel_size: The kernel size of convolutional layers.
        dilation_cycle_length: The cycle length of dilation.
        conditioner_size: The size of conditioner.
    """

    def __init__(
        self,
        input_channel,
        base_channel,
        n_res_block,
        conv_kernel_size,
        dilation_cycle_length,
        conditioner_size,
        output_channel: int = -1,
    ):
        super().__init__()

        self.input_channel = input_channel
        self.base_channel = base_channel
        self.n_res_block = n_res_block
        self.conv_kernel_size = conv_kernel_size
        self.dilation_cycle_length = dilation_cycle_length
        self.conditioner_size = conditioner_size
        self.output_channel = output_channel if output_channel > 0 else input_channel

        self.input = nn.Sequential(
            Conv1d(
                input_channel,
                base_channel,
                1,
            ),
            nn.ReLU(),
        )

        self.residual_blocks = nn.ModuleList(
            [
                ResidualBlock(
                    channels=base_channel,
                    kernel_size=conv_kernel_size,
                    dilation=2 ** (i % dilation_cycle_length),
                    d_context=conditioner_size,
                )
                for i in range(n_res_block)
            ]
        )

        self.out_proj = nn.Sequential(
            Conv1d(
                base_channel,
                base_channel,
                1,
            ),
            nn.ReLU(),
            zero_module(
                Conv1d(
                    base_channel,
                    self.output_channel,
                    1,
                ),
            ),
        )

    def forward(self, x, y, context=None):
        """
        Args:
            x: Noisy mel-spectrogram [B x ``n_mel`` x L]
            y: FILM embeddings with the shape of (B, ``base_channel``)
            context: Context with the shape of [B x ``d_context`` x L], default to None.
        """

        h = self.input(x)

        skip = None
        for i in range(self.n_res_block):
            h, skip_connection = self.residual_blocks[i](h, y, context)
            skip = skip_connection if skip is None else skip_connection + skip

        out = skip / math.sqrt(self.n_res_block)

        out = self.out_proj(out)

        return out