File size: 521 Bytes
5238467
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch

from audiocraft.quantization.vq import ResidualVectorQuantizer


class TestResidualVectorQuantizer:

    def test_rvq(self):
        x = torch.randn(1, 16, 2048)
        vq = ResidualVectorQuantizer(n_q=8, dimension=16, bins=8)
        res = vq(x, 1.)
        assert res.x.shape == torch.Size([1, 16, 2048])