Spaces:
Running
Running
File size: 2,529 Bytes
3dd84f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch
import torch.nn as nn
from typing import List
from dataclasses import asdict
from utils.audio import LogMelSpectrogram
from config import MelConfig
# Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license.
class MultiScaleMelSpectrogramLoss(nn.Module):
def __init__(self, n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320], window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048]):
super().__init__()
assert len(n_mels) == len(window_lengths), "n_mels and window_lengths must have the same length"
self.mel_transforms = nn.ModuleList(self._get_transforms(n_mels, window_lengths))
self.loss_fn = nn.L1Loss()
def _get_transforms(self, n_mels, window_lengths):
transforms = []
for n_mel, win_length in zip(n_mels, window_lengths):
transform = LogMelSpectrogram(**asdict(MelConfig(n_mels=n_mel, n_fft=win_length, win_length=win_length, hop_length=win_length//4)))
transforms.append(transform)
return transforms
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return sum(self.loss_fn(mel_transform(x), mel_transform(y)) for mel_transform in self.mel_transforms)
class SingleScaleMelSpectrogramLoss(nn.Module):
def __init__(self):
super().__init__()
self.mel_transform = LogMelSpectrogram(**asdict(MelConfig()))
self.loss_fn = nn.L1Loss()
print('using single mel loss')
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return self.loss_fn(self.mel_transform(x), self.mel_transform(y))
def feature_loss(fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss*2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1-dr)**2)
g_loss = torch.mean(dg**2)
loss += (r_loss + g_loss)
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_loss(disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1-dg)**2)
gen_losses.append(l)
loss += l
return loss, gen_losses |