akhaliq3
spaces demo
5019931
raw history blame
No virus
2.34 kB
import math
from typing import Callable
import torch
import torch.nn as nn
from torchlibrosa.stft import STFT
from bytesep.models.pytorch_modules import Base
def l1(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor:
r"""L1 loss.
Args:
output: torch.Tensor
target: torch.Tensor
Returns:
loss: torch.float
"""
return torch.mean(torch.abs(output - target))
def l1_wav(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor:
r"""L1 loss in the time-domain.
Args:
output: torch.Tensor
target: torch.Tensor
Returns:
loss: torch.float
"""
return l1(output, target)
class L1_Wav_L1_Sp(nn.Module, Base):
def __init__(self):
r"""L1 loss in the time-domain and L1 loss on the spectrogram."""
super(L1_Wav_L1_Sp, self).__init__()
self.window_size = 2048
hop_size = 441
center = True
pad_mode = "reflect"
window = "hann"
self.stft = STFT(
n_fft=self.window_size,
hop_length=hop_size,
win_length=self.window_size,
window=window,
center=center,
pad_mode=pad_mode,
freeze_parameters=True,
)
def __call__(
self, output: torch.Tensor, target: torch.Tensor, **kwargs
) -> torch.Tensor:
r"""L1 loss in the time-domain and on the spectrogram.
Args:
output: torch.Tensor
target: torch.Tensor
Returns:
loss: torch.float
"""
# L1 loss in the time-domain.
wav_loss = l1_wav(output, target)
# L1 loss on the spectrogram.
sp_loss = l1(
self.wav_to_spectrogram(output, eps=1e-8),
self.wav_to_spectrogram(target, eps=1e-8),
)
# sp_loss /= math.sqrt(self.window_size)
# sp_loss *= 1.
# Total loss.
return wav_loss + sp_loss
return sp_loss
def get_loss_function(loss_type: str) -> Callable:
r"""Get loss function.
Args:
loss_type: str
Returns:
loss function: Callable
"""
if loss_type == "l1_wav":
return l1_wav
elif loss_type == "l1_wav_l1_sp":
return L1_Wav_L1_Sp()
else:
raise NotImplementedError