|
|
|
|
|
|
|
|
|
|
|
|
|
import random |
|
|
|
import torch |
|
|
|
from audiocraft.losses import ( |
|
MelSpectrogramL1Loss, |
|
MultiScaleMelSpectrogramLoss, |
|
MRSTFTLoss, |
|
SISNR, |
|
STFTLoss, |
|
) |
|
|
|
|
|
def test_mel_l1_loss(): |
|
N, C, T = 2, 2, random.randrange(1000, 100_000) |
|
t1 = torch.randn(N, C, T) |
|
t2 = torch.randn(N, C, T) |
|
|
|
mel_l1 = MelSpectrogramL1Loss(sample_rate=22_050) |
|
loss = mel_l1(t1, t2) |
|
loss_same = mel_l1(t1, t1) |
|
|
|
assert isinstance(loss, torch.Tensor) |
|
assert isinstance(loss_same, torch.Tensor) |
|
assert loss_same.item() == 0.0 |
|
|
|
|
|
def test_msspec_loss(): |
|
N, C, T = 2, 2, random.randrange(1000, 100_000) |
|
t1 = torch.randn(N, C, T) |
|
t2 = torch.randn(N, C, T) |
|
|
|
msspec = MultiScaleMelSpectrogramLoss(sample_rate=22_050) |
|
loss = msspec(t1, t2) |
|
loss_same = msspec(t1, t1) |
|
|
|
assert isinstance(loss, torch.Tensor) |
|
assert isinstance(loss_same, torch.Tensor) |
|
assert loss_same.item() == 0.0 |
|
|
|
|
|
def test_mrstft_loss(): |
|
N, C, T = 2, 2, random.randrange(1000, 100_000) |
|
t1 = torch.randn(N, C, T) |
|
t2 = torch.randn(N, C, T) |
|
|
|
mrstft = MRSTFTLoss() |
|
loss = mrstft(t1, t2) |
|
|
|
assert isinstance(loss, torch.Tensor) |
|
|
|
|
|
def test_sisnr_loss(): |
|
N, C, T = 2, 2, random.randrange(1000, 100_000) |
|
t1 = torch.randn(N, C, T) |
|
t2 = torch.randn(N, C, T) |
|
|
|
sisnr = SISNR() |
|
loss = sisnr(t1, t2) |
|
|
|
assert isinstance(loss, torch.Tensor) |
|
|
|
|
|
def test_stft_loss(): |
|
N, C, T = 2, 2, random.randrange(1000, 100_000) |
|
t1 = torch.randn(N, C, T) |
|
t2 = torch.randn(N, C, T) |
|
|
|
mrstft = STFTLoss() |
|
loss = mrstft(t1, t2) |
|
|
|
assert isinstance(loss, torch.Tensor) |
|
|