File size: 2,295 Bytes
5325fcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import random

import torch

from audiocraft.adversarial.discriminators import (
    MultiPeriodDiscriminator,
    MultiScaleDiscriminator,
    MultiScaleSTFTDiscriminator
)


class TestMultiPeriodDiscriminator:

    def test_mpd_discriminator(self):
        N, C, T = 2, 2, random.randrange(1, 100_000)
        t0 = torch.randn(N, C, T)
        periods = [1, 2, 3]
        mpd = MultiPeriodDiscriminator(periods=periods, in_channels=C)
        logits, fmaps = mpd(t0)

        assert len(logits) == len(periods)
        assert len(fmaps) == len(periods)
        assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits])
        assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap])


class TestMultiScaleDiscriminator:

    def test_msd_discriminator(self):
        N, C, T = 2, 2, random.randrange(1, 100_000)
        t0 = torch.randn(N, C, T)

        scale_norms = ['weight_norm', 'weight_norm']
        msd = MultiScaleDiscriminator(scale_norms=scale_norms, in_channels=C)
        logits, fmaps = msd(t0)

        assert len(logits) == len(scale_norms)
        assert len(fmaps) == len(scale_norms)
        assert all([logit.shape[0] == N and len(logit.shape) == 3 for logit in logits])
        assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap])


class TestMultiScaleStftDiscriminator:

    def test_msstftd_discriminator(self):
        N, C, T = 2, 2, random.randrange(1, 100_000)
        t0 = torch.randn(N, C, T)

        n_filters = 4
        n_ffts = [128, 256, 64]
        hop_lengths = [32, 64, 16]
        win_lengths = [128, 256, 64]

        msstftd = MultiScaleSTFTDiscriminator(filters=n_filters, n_ffts=n_ffts, hop_lengths=hop_lengths,
                                              win_lengths=win_lengths, in_channels=C)
        logits, fmaps = msstftd(t0)

        assert len(logits) == len(n_ffts)
        assert len(fmaps) == len(n_ffts)
        assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits])
        assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap])