File size: 4,874 Bytes
f54eb92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from itertools import product

import pytest
import torch

from audiocraft.modules.seanet import SEANetEncoder, SEANetDecoder, SEANetResnetBlock
from audiocraft.modules import StreamableConv1d, StreamableConvTranspose1d


class TestSEANetModel:

    def test_base(self):
        encoder = SEANetEncoder()
        decoder = SEANetDecoder()

        x = torch.randn(1, 1, 24000)
        z = encoder(x)
        assert list(z.shape) == [1, 128, 75], z.shape
        y = decoder(z)
        assert y.shape == x.shape, (x.shape, y.shape)

    def test_causal(self):
        encoder = SEANetEncoder(causal=True)
        decoder = SEANetDecoder(causal=True)
        x = torch.randn(1, 1, 24000)

        z = encoder(x)
        assert list(z.shape) == [1, 128, 75], z.shape
        y = decoder(z)
        assert y.shape == x.shape, (x.shape, y.shape)

    def test_conv_skip_connection(self):
        encoder = SEANetEncoder(true_skip=False)
        decoder = SEANetDecoder(true_skip=False)

        x = torch.randn(1, 1, 24000)
        z = encoder(x)
        assert list(z.shape) == [1, 128, 75], z.shape
        y = decoder(z)
        assert y.shape == x.shape, (x.shape, y.shape)

    def test_seanet_encoder_decoder_final_act(self):
        encoder = SEANetEncoder(true_skip=False)
        decoder = SEANetDecoder(true_skip=False, final_activation='Tanh')

        x = torch.randn(1, 1, 24000)
        z = encoder(x)
        assert list(z.shape) == [1, 128, 75], z.shape
        y = decoder(z)
        assert y.shape == x.shape, (x.shape, y.shape)

    def _check_encoder_blocks_norm(self, encoder: SEANetEncoder, n_disable_blocks: int, norm: str):
        n_blocks = 0
        for layer in encoder.model:
            if isinstance(layer, StreamableConv1d):
                n_blocks += 1
                assert layer.conv.norm_type == 'none' if n_blocks <= n_disable_blocks else norm
            elif isinstance(layer, SEANetResnetBlock):
                for resnet_layer in layer.block:
                    if isinstance(resnet_layer, StreamableConv1d):
                        # here we add + 1 to n_blocks as we increment n_blocks just after the block
                        assert resnet_layer.conv.norm_type == 'none' if (n_blocks + 1) <= n_disable_blocks else norm

    def test_encoder_disable_norm(self):
        n_residuals = [0, 1, 3]
        disable_blocks = [0, 1, 2, 3, 4, 5, 6]
        norms = ['weight_norm', 'none']
        for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
            encoder = SEANetEncoder(n_residual_layers=n_res, norm=norm,
                                    disable_norm_outer_blocks=disable_blocks)
            self._check_encoder_blocks_norm(encoder, disable_blocks, norm)

    def _check_decoder_blocks_norm(self, decoder: SEANetDecoder, n_disable_blocks: int, norm: str):
        n_blocks = 0
        for layer in decoder.model:
            if isinstance(layer, StreamableConv1d):
                n_blocks += 1
                assert layer.conv.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
            elif isinstance(layer, StreamableConvTranspose1d):
                n_blocks += 1
                assert layer.convtr.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
            elif isinstance(layer, SEANetResnetBlock):
                for resnet_layer in layer.block:
                    if isinstance(resnet_layer, StreamableConv1d):
                        assert resnet_layer.conv.norm_type == 'none' \
                            if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm

    def test_decoder_disable_norm(self):
        n_residuals = [0, 1, 3]
        disable_blocks = [0, 1, 2, 3, 4, 5, 6]
        norms = ['weight_norm', 'none']
        for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
            decoder = SEANetDecoder(n_residual_layers=n_res, norm=norm,
                                    disable_norm_outer_blocks=disable_blocks)
            self._check_decoder_blocks_norm(decoder, disable_blocks, norm)

    def test_disable_norm_raises_exception(self):
        # Invalid disable_norm_outer_blocks values raise exceptions
        with pytest.raises(AssertionError):
            SEANetEncoder(disable_norm_outer_blocks=-1)

        with pytest.raises(AssertionError):
            SEANetEncoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)

        with pytest.raises(AssertionError):
            SEANetDecoder(disable_norm_outer_blocks=-1)

        with pytest.raises(AssertionError):
            SEANetDecoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)