File size: 2,325 Bytes
00f2826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import random

import numpy as np
import torch
from audiocraft.models.multibanddiffusion import MultiBandDiffusion, DiffusionProcess
from audiocraft.models import EncodecModel, DiffusionUnet
from audiocraft.modules import SEANetEncoder, SEANetDecoder
from audiocraft.modules.diffusion_schedule import NoiseSchedule
from audiocraft.quantization import DummyQuantizer


class TestMBD:

    def _create_mbd(self,
                    sample_rate: int,
                    channels: int,
                    n_filters: int = 3,
                    n_residual_layers: int = 1,
                    ratios: list = [5, 4, 3, 2],
                    num_steps: int = 1000,
                    codec_dim: int = 128,
                    **kwargs):
        frame_rate = np.prod(ratios)
        encoder = SEANetEncoder(channels=channels, dimension=codec_dim, n_filters=n_filters,
                                n_residual_layers=n_residual_layers, ratios=ratios)
        decoder = SEANetDecoder(channels=channels, dimension=codec_dim, n_filters=n_filters,
                                n_residual_layers=n_residual_layers, ratios=ratios)
        quantizer = DummyQuantizer()
        compression_model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate,
                                         sample_rate=sample_rate, channels=channels, **kwargs)
        diffusion_model = DiffusionUnet(chin=channels, num_steps=num_steps, codec_dim=codec_dim)
        schedule = NoiseSchedule(device='cpu', num_steps=num_steps)
        DP = DiffusionProcess(model=diffusion_model, noise_schedule=schedule)
        mbd = MultiBandDiffusion(DPs=[DP], codec_model=compression_model)
        return mbd

    def test_model(self):
        random.seed(1234)
        sample_rate = 24_000
        channels = 1
        codec_dim = 128
        mbd = self._create_mbd(sample_rate=sample_rate, channels=channels, codec_dim=codec_dim)
        for _ in range(10):
            length = random.randrange(1, 10_000)
            x = torch.randn(2, channels, length)
            res = mbd.regenerate(x, sample_rate)
            assert res.shape == x.shape