akhaliq3
spaces demo
5019931
raw
history blame contribute delete
No virus
2.34 kB
import math
from typing import Callable
import torch
import torch.nn as nn
from torchlibrosa.stft import STFT
from bytesep.models.pytorch_modules import Base
def l1(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor:
r"""L1 loss.
Args:
output: torch.Tensor
target: torch.Tensor
Returns:
loss: torch.float
"""
return torch.mean(torch.abs(output - target))
def l1_wav(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor:
r"""L1 loss in the time-domain.
Args:
output: torch.Tensor
target: torch.Tensor
Returns:
loss: torch.float
"""
return l1(output, target)
class L1_Wav_L1_Sp(nn.Module, Base):
def __init__(self):
r"""L1 loss in the time-domain and L1 loss on the spectrogram."""
super(L1_Wav_L1_Sp, self).__init__()
self.window_size = 2048
hop_size = 441
center = True
pad_mode = "reflect"
window = "hann"
self.stft = STFT(
n_fft=self.window_size,
hop_length=hop_size,
win_length=self.window_size,
window=window,
center=center,
pad_mode=pad_mode,
freeze_parameters=True,
)
def __call__(
self, output: torch.Tensor, target: torch.Tensor, **kwargs
) -> torch.Tensor:
r"""L1 loss in the time-domain and on the spectrogram.
Args:
output: torch.Tensor
target: torch.Tensor
Returns:
loss: torch.float
"""
# L1 loss in the time-domain.
wav_loss = l1_wav(output, target)
# L1 loss on the spectrogram.
sp_loss = l1(
self.wav_to_spectrogram(output, eps=1e-8),
self.wav_to_spectrogram(target, eps=1e-8),
)
# sp_loss /= math.sqrt(self.window_size)
# sp_loss *= 1.
# Total loss.
return wav_loss + sp_loss
return sp_loss
def get_loss_function(loss_type: str) -> Callable:
r"""Get loss function.
Args:
loss_type: str
Returns:
loss function: Callable
"""
if loss_type == "l1_wav":
return l1_wav
elif loss_type == "l1_wav_l1_sp":
return L1_Wav_L1_Sp()
else:
raise NotImplementedError