# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from torch import nn from audiocraft.modules.activations import CustomGLU class TestActivations: def test_custom_glu_calculation(self): activation = CustomGLU(nn.Identity()) initial_shape = (4, 8, 8) part_a = torch.ones(initial_shape) * 2 part_b = torch.ones(initial_shape) * -1 input = torch.cat((part_a, part_b), dim=-1) output = activation(input) # ensure all dimensions match initial shape assert output.shape == initial_shape # ensure the gating was calculated correctly a * f(b) assert torch.all(output == -2).item()