Beatpoettes1 / tests /modules /test_seanet.py
MantraDas's picture
Duplicate from facebook/MusicGen
f54eb92
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from itertools import product
import pytest
import torch
from audiocraft.modules.seanet import SEANetEncoder, SEANetDecoder, SEANetResnetBlock
from audiocraft.modules import StreamableConv1d, StreamableConvTranspose1d
class TestSEANetModel:
def test_base(self):
encoder = SEANetEncoder()
decoder = SEANetDecoder()
x = torch.randn(1, 1, 24000)
z = encoder(x)
assert list(z.shape) == [1, 128, 75], z.shape
y = decoder(z)
assert y.shape == x.shape, (x.shape, y.shape)
def test_causal(self):
encoder = SEANetEncoder(causal=True)
decoder = SEANetDecoder(causal=True)
x = torch.randn(1, 1, 24000)
z = encoder(x)
assert list(z.shape) == [1, 128, 75], z.shape
y = decoder(z)
assert y.shape == x.shape, (x.shape, y.shape)
def test_conv_skip_connection(self):
encoder = SEANetEncoder(true_skip=False)
decoder = SEANetDecoder(true_skip=False)
x = torch.randn(1, 1, 24000)
z = encoder(x)
assert list(z.shape) == [1, 128, 75], z.shape
y = decoder(z)
assert y.shape == x.shape, (x.shape, y.shape)
def test_seanet_encoder_decoder_final_act(self):
encoder = SEANetEncoder(true_skip=False)
decoder = SEANetDecoder(true_skip=False, final_activation='Tanh')
x = torch.randn(1, 1, 24000)
z = encoder(x)
assert list(z.shape) == [1, 128, 75], z.shape
y = decoder(z)
assert y.shape == x.shape, (x.shape, y.shape)
def _check_encoder_blocks_norm(self, encoder: SEANetEncoder, n_disable_blocks: int, norm: str):
n_blocks = 0
for layer in encoder.model:
if isinstance(layer, StreamableConv1d):
n_blocks += 1
assert layer.conv.norm_type == 'none' if n_blocks <= n_disable_blocks else norm
elif isinstance(layer, SEANetResnetBlock):
for resnet_layer in layer.block:
if isinstance(resnet_layer, StreamableConv1d):
# here we add + 1 to n_blocks as we increment n_blocks just after the block
assert resnet_layer.conv.norm_type == 'none' if (n_blocks + 1) <= n_disable_blocks else norm
def test_encoder_disable_norm(self):
n_residuals = [0, 1, 3]
disable_blocks = [0, 1, 2, 3, 4, 5, 6]
norms = ['weight_norm', 'none']
for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
encoder = SEANetEncoder(n_residual_layers=n_res, norm=norm,
disable_norm_outer_blocks=disable_blocks)
self._check_encoder_blocks_norm(encoder, disable_blocks, norm)
def _check_decoder_blocks_norm(self, decoder: SEANetDecoder, n_disable_blocks: int, norm: str):
n_blocks = 0
for layer in decoder.model:
if isinstance(layer, StreamableConv1d):
n_blocks += 1
assert layer.conv.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
elif isinstance(layer, StreamableConvTranspose1d):
n_blocks += 1
assert layer.convtr.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
elif isinstance(layer, SEANetResnetBlock):
for resnet_layer in layer.block:
if isinstance(resnet_layer, StreamableConv1d):
assert resnet_layer.conv.norm_type == 'none' \
if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
def test_decoder_disable_norm(self):
n_residuals = [0, 1, 3]
disable_blocks = [0, 1, 2, 3, 4, 5, 6]
norms = ['weight_norm', 'none']
for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
decoder = SEANetDecoder(n_residual_layers=n_res, norm=norm,
disable_norm_outer_blocks=disable_blocks)
self._check_decoder_blocks_norm(decoder, disable_blocks, norm)
def test_disable_norm_raises_exception(self):
# Invalid disable_norm_outer_blocks values raise exceptions
with pytest.raises(AssertionError):
SEANetEncoder(disable_norm_outer_blocks=-1)
with pytest.raises(AssertionError):
SEANetEncoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)
with pytest.raises(AssertionError):
SEANetDecoder(disable_norm_outer_blocks=-1)
with pytest.raises(AssertionError):
SEANetDecoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)