unpairedelectron07
commited on
Upload 4 files
Browse files- audiocraft/losses/balancer.py +136 -0
- audiocraft/losses/sisnr.py +97 -0
- audiocraft/losses/specloss.py +149 -0
- audiocraft/losses/stftloss.py +207 -0
audiocraft/losses/balancer.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import flashy
|
10 |
+
import torch
|
11 |
+
from torch import autograd
|
12 |
+
|
13 |
+
|
14 |
+
class Balancer:
|
15 |
+
"""Loss balancer.
|
16 |
+
|
17 |
+
The loss balancer combines losses together to compute gradients for the backward.
|
18 |
+
Given `y = f(...)`, and a number of losses `l1(y, ...)`, `l2(y, ...)`, with `...`
|
19 |
+
not having any dependence on `f`, the balancer can efficiently normalize the partial gradients
|
20 |
+
`d l1 / d y`, `d l2 / dy` before summing them in order to achieve a desired ratio between
|
21 |
+
the losses. For instance if `weights = {'l1': 2, 'l2': 1}`, 66% of the gradient
|
22 |
+
going into `f(...)` will come from `l1` on average, and 33% from `l2`. This allows for an easy
|
23 |
+
interpration of the weights even if the intrisic scale of `l1`, `l2` ... is unknown.
|
24 |
+
|
25 |
+
Noting `g1 = d l1 / dy`, etc., the balanced gradient `G` will be
|
26 |
+
(with `avg` an exponential moving average over the updates),
|
27 |
+
|
28 |
+
G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i)
|
29 |
+
|
30 |
+
If `balance_grads` is False, this is deactivated, and instead the gradient will just be the
|
31 |
+
standard sum of the partial gradients with the given weights.
|
32 |
+
|
33 |
+
A call to the backward method of the balancer will compute the the partial gradients,
|
34 |
+
combining all the losses and potentially rescaling the gradients,
|
35 |
+
which can help stabilize the training and reason about multiple losses with varying scales.
|
36 |
+
The obtained gradient with respect to `y` is then back-propagated to `f(...)`.
|
37 |
+
|
38 |
+
Expected usage:
|
39 |
+
|
40 |
+
weights = {'loss_a': 1, 'loss_b': 4}
|
41 |
+
balancer = Balancer(weights, ...)
|
42 |
+
losses: dict = {}
|
43 |
+
losses['loss_a'] = compute_loss_a(x, y)
|
44 |
+
losses['loss_b'] = compute_loss_b(x, y)
|
45 |
+
if model.training():
|
46 |
+
effective_loss = balancer.backward(losses, x)
|
47 |
+
|
48 |
+
Args:
|
49 |
+
weights (dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys
|
50 |
+
from the backward method to match the weights keys to assign weight to each of the provided loss.
|
51 |
+
balance_grads (bool): Whether to rescale gradients so that weights reflect the fraction of the
|
52 |
+
overall gradient, rather than a constant multiplier.
|
53 |
+
total_norm (float): Reference norm when rescaling gradients, ignored otherwise.
|
54 |
+
emay_decay (float): EMA decay for averaging the norms.
|
55 |
+
per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds
|
56 |
+
when rescaling the gradients.
|
57 |
+
epsilon (float): Epsilon value for numerical stability.
|
58 |
+
monitor (bool): If True, stores in `self.metrics` the relative ratio between the norm of the gradients
|
59 |
+
coming from each loss, when calling `backward()`.
|
60 |
+
"""
|
61 |
+
def __init__(self, weights: tp.Dict[str, float], balance_grads: bool = True, total_norm: float = 1.,
|
62 |
+
ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12,
|
63 |
+
monitor: bool = False):
|
64 |
+
self.weights = weights
|
65 |
+
self.per_batch_item = per_batch_item
|
66 |
+
self.total_norm = total_norm or 1.
|
67 |
+
self.averager = flashy.averager(ema_decay or 1.)
|
68 |
+
self.epsilon = epsilon
|
69 |
+
self.monitor = monitor
|
70 |
+
self.balance_grads = balance_grads
|
71 |
+
self._metrics: tp.Dict[str, tp.Any] = {}
|
72 |
+
|
73 |
+
@property
|
74 |
+
def metrics(self):
|
75 |
+
return self._metrics
|
76 |
+
|
77 |
+
def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor) -> torch.Tensor:
|
78 |
+
"""Compute the backward and return the effective train loss, e.g. the loss obtained from
|
79 |
+
computing the effective weights. If `balance_grads` is True, the effective weights
|
80 |
+
are the one that needs to be applied to each gradient to respect the desired relative
|
81 |
+
scale of gradients coming from each loss.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`.
|
85 |
+
input (torch.Tensor): the input of the losses, typically the output of the model.
|
86 |
+
This should be the single point of dependence between the losses
|
87 |
+
and the model being trained.
|
88 |
+
"""
|
89 |
+
norms = {}
|
90 |
+
grads = {}
|
91 |
+
for name, loss in losses.items():
|
92 |
+
# Compute partial derivative of the less with respect to the input.
|
93 |
+
grad, = autograd.grad(loss, [input], retain_graph=True)
|
94 |
+
if self.per_batch_item:
|
95 |
+
# We do not average the gradient over the batch dimension.
|
96 |
+
dims = tuple(range(1, grad.dim()))
|
97 |
+
norm = grad.norm(dim=dims, p=2).mean()
|
98 |
+
else:
|
99 |
+
norm = grad.norm(p=2)
|
100 |
+
norms[name] = norm
|
101 |
+
grads[name] = grad
|
102 |
+
|
103 |
+
count = 1
|
104 |
+
if self.per_batch_item:
|
105 |
+
count = len(grad)
|
106 |
+
# Average norms across workers. Theoretically we should average the
|
107 |
+
# squared norm, then take the sqrt, but it worked fine like that.
|
108 |
+
avg_norms = flashy.distrib.average_metrics(self.averager(norms), count)
|
109 |
+
# We approximate the total norm of the gradient as the sums of the norms.
|
110 |
+
# Obviously this can be very incorrect if all gradients are aligned, but it works fine.
|
111 |
+
total = sum(avg_norms.values())
|
112 |
+
|
113 |
+
self._metrics = {}
|
114 |
+
if self.monitor:
|
115 |
+
# Store the ratio of the total gradient represented by each loss.
|
116 |
+
for k, v in avg_norms.items():
|
117 |
+
self._metrics[f'ratio_{k}'] = v / total
|
118 |
+
|
119 |
+
total_weights = sum([self.weights[k] for k in avg_norms])
|
120 |
+
assert total_weights > 0.
|
121 |
+
desired_ratios = {k: w / total_weights for k, w in self.weights.items()}
|
122 |
+
|
123 |
+
out_grad = torch.zeros_like(input)
|
124 |
+
effective_loss = torch.tensor(0., device=input.device, dtype=input.dtype)
|
125 |
+
for name, avg_norm in avg_norms.items():
|
126 |
+
if self.balance_grads:
|
127 |
+
# g_balanced = g / avg(||g||) * total_norm * desired_ratio
|
128 |
+
scale = desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm)
|
129 |
+
else:
|
130 |
+
# We just do regular weighted sum of the gradients.
|
131 |
+
scale = self.weights[name]
|
132 |
+
out_grad.add_(grads[name], alpha=scale)
|
133 |
+
effective_loss += scale * losses[name].detach()
|
134 |
+
# Send the computed partial derivative with respect to the output of the model to the model.
|
135 |
+
input.backward(out_grad)
|
136 |
+
return effective_loss
|
audiocraft/losses/sisnr.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
import typing as tp
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
from torch.nn import functional as F
|
13 |
+
|
14 |
+
|
15 |
+
def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
|
16 |
+
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
|
17 |
+
with K the kernel size, by extracting frames with the given stride.
|
18 |
+
This will pad the input so that `F = ceil(T / K)`.
|
19 |
+
see https://github.com/pytorch/pytorch/issues/60466
|
20 |
+
"""
|
21 |
+
*shape, length = a.shape
|
22 |
+
n_frames = math.ceil(length / stride)
|
23 |
+
tgt_length = (n_frames - 1) * stride + kernel_size
|
24 |
+
a = F.pad(a, (0, tgt_length - length))
|
25 |
+
strides = list(a.stride())
|
26 |
+
assert strides[-1] == 1, "data should be contiguous"
|
27 |
+
strides = strides[:-1] + [stride, 1]
|
28 |
+
return a.as_strided([*shape, n_frames, kernel_size], strides)
|
29 |
+
|
30 |
+
|
31 |
+
def _center(x: torch.Tensor) -> torch.Tensor:
|
32 |
+
return x - x.mean(-1, True)
|
33 |
+
|
34 |
+
|
35 |
+
def _norm2(x: torch.Tensor) -> torch.Tensor:
|
36 |
+
return x.pow(2).sum(-1, True)
|
37 |
+
|
38 |
+
|
39 |
+
class SISNR(nn.Module):
|
40 |
+
"""SISNR loss.
|
41 |
+
|
42 |
+
Input should be [B, C, T], output is scalar.
|
43 |
+
|
44 |
+
..Warning:: This function returns the opposite of the SI-SNR (e.g. `-1 * regular_SI_SNR`).
|
45 |
+
Consequently, lower scores are better in terms of reconstruction quality,
|
46 |
+
in particular, it should be negative if training goes well. This done this way so
|
47 |
+
that this module can also be used as a loss function for training model.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
sample_rate (int): Sample rate.
|
51 |
+
segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on
|
52 |
+
entire audio only.
|
53 |
+
overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap.
|
54 |
+
epsilon (float): Epsilon value for numerical stability.
|
55 |
+
"""
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
sample_rate: int = 16000,
|
59 |
+
segment: tp.Optional[float] = 20,
|
60 |
+
overlap: float = 0.5,
|
61 |
+
epsilon: float = torch.finfo(torch.float32).eps,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
self.sample_rate = sample_rate
|
65 |
+
self.segment = segment
|
66 |
+
self.overlap = overlap
|
67 |
+
self.epsilon = epsilon
|
68 |
+
|
69 |
+
def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
|
70 |
+
B, C, T = ref_sig.shape
|
71 |
+
assert ref_sig.shape == out_sig.shape
|
72 |
+
|
73 |
+
if self.segment is None:
|
74 |
+
frame = T
|
75 |
+
stride = T
|
76 |
+
else:
|
77 |
+
frame = int(self.segment * self.sample_rate)
|
78 |
+
stride = int(frame * (1 - self.overlap))
|
79 |
+
|
80 |
+
epsilon = self.epsilon * frame # make epsilon prop to frame size.
|
81 |
+
|
82 |
+
gt = _unfold(ref_sig, frame, stride)
|
83 |
+
est = _unfold(out_sig, frame, stride)
|
84 |
+
if self.segment is None:
|
85 |
+
assert gt.shape[-1] == 1
|
86 |
+
|
87 |
+
gt = _center(gt)
|
88 |
+
est = _center(est)
|
89 |
+
dot = torch.einsum("bcft,bcft->bcf", gt, est)
|
90 |
+
|
91 |
+
proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt))
|
92 |
+
noise = est - proj
|
93 |
+
|
94 |
+
sisnr = 10 * (
|
95 |
+
torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise))
|
96 |
+
)
|
97 |
+
return -1 * sisnr[..., 0].mean()
|
audiocraft/losses/specloss.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
from torchaudio.transforms import MelSpectrogram
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
from ..modules import pad_for_conv1d
|
16 |
+
|
17 |
+
|
18 |
+
class MelSpectrogramWrapper(nn.Module):
|
19 |
+
"""Wrapper around MelSpectrogram torchaudio transform providing proper padding
|
20 |
+
and additional post-processing including log scaling.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
n_mels (int): Number of mel bins.
|
24 |
+
n_fft (int): Number of fft.
|
25 |
+
hop_length (int): Hop size.
|
26 |
+
win_length (int): Window length.
|
27 |
+
n_mels (int): Number of mel bins.
|
28 |
+
sample_rate (int): Sample rate.
|
29 |
+
f_min (float or None): Minimum frequency.
|
30 |
+
f_max (float or None): Maximum frequency.
|
31 |
+
log (bool): Whether to scale with log.
|
32 |
+
normalized (bool): Whether to normalize the melspectrogram.
|
33 |
+
floor_level (float): Floor level based on human perception (default=1e-5).
|
34 |
+
"""
|
35 |
+
def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_length: tp.Optional[int] = None,
|
36 |
+
n_mels: int = 80, sample_rate: float = 22050, f_min: float = 0.0, f_max: tp.Optional[float] = None,
|
37 |
+
log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
|
38 |
+
super().__init__()
|
39 |
+
self.n_fft = n_fft
|
40 |
+
hop_length = int(hop_length)
|
41 |
+
self.hop_length = hop_length
|
42 |
+
self.mel_transform = MelSpectrogram(n_mels=n_mels, sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
|
43 |
+
win_length=win_length, f_min=f_min, f_max=f_max, normalized=normalized,
|
44 |
+
window_fn=torch.hann_window, center=False)
|
45 |
+
self.floor_level = floor_level
|
46 |
+
self.log = log
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
p = int((self.n_fft - self.hop_length) // 2)
|
50 |
+
if len(x.shape) == 2:
|
51 |
+
x = x.unsqueeze(1)
|
52 |
+
x = F.pad(x, (p, p), "reflect")
|
53 |
+
# Make sure that all the frames are full.
|
54 |
+
# The combination of `pad_for_conv1d` and the above padding
|
55 |
+
# will make the output of size ceil(T / hop).
|
56 |
+
x = pad_for_conv1d(x, self.n_fft, self.hop_length)
|
57 |
+
self.mel_transform.to(x.device)
|
58 |
+
mel_spec = self.mel_transform(x)
|
59 |
+
B, C, freqs, frame = mel_spec.shape
|
60 |
+
if self.log:
|
61 |
+
mel_spec = torch.log10(self.floor_level + mel_spec)
|
62 |
+
return mel_spec.reshape(B, C * freqs, frame)
|
63 |
+
|
64 |
+
|
65 |
+
class MelSpectrogramL1Loss(torch.nn.Module):
|
66 |
+
"""L1 Loss on MelSpectrogram.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
sample_rate (int): Sample rate.
|
70 |
+
n_fft (int): Number of fft.
|
71 |
+
hop_length (int): Hop size.
|
72 |
+
win_length (int): Window length.
|
73 |
+
n_mels (int): Number of mel bins.
|
74 |
+
f_min (float or None): Minimum frequency.
|
75 |
+
f_max (float or None): Maximum frequency.
|
76 |
+
log (bool): Whether to scale with log.
|
77 |
+
normalized (bool): Whether to normalize the melspectrogram.
|
78 |
+
floor_level (float): Floor level value based on human perception (default=1e-5).
|
79 |
+
"""
|
80 |
+
def __init__(self, sample_rate: int, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024,
|
81 |
+
n_mels: int = 80, f_min: float = 0.0, f_max: tp.Optional[float] = None,
|
82 |
+
log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
|
83 |
+
super().__init__()
|
84 |
+
self.l1 = torch.nn.L1Loss()
|
85 |
+
self.melspec = MelSpectrogramWrapper(n_fft=n_fft, hop_length=hop_length, win_length=win_length,
|
86 |
+
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
|
87 |
+
log=log, normalized=normalized, floor_level=floor_level)
|
88 |
+
|
89 |
+
def forward(self, x, y):
|
90 |
+
self.melspec.to(x.device)
|
91 |
+
s_x = self.melspec(x)
|
92 |
+
s_y = self.melspec(y)
|
93 |
+
return self.l1(s_x, s_y)
|
94 |
+
|
95 |
+
|
96 |
+
class MultiScaleMelSpectrogramLoss(nn.Module):
|
97 |
+
"""Multi-Scale spectrogram loss (msspec).
|
98 |
+
|
99 |
+
Args:
|
100 |
+
sample_rate (int): Sample rate.
|
101 |
+
range_start (int): Power of 2 to use for the first scale.
|
102 |
+
range_stop (int): Power of 2 to use for the last scale.
|
103 |
+
n_mels (int): Number of mel bins.
|
104 |
+
f_min (float): Minimum frequency.
|
105 |
+
f_max (float or None): Maximum frequency.
|
106 |
+
normalized (bool): Whether to normalize the melspectrogram.
|
107 |
+
alphas (bool): Whether to use alphas as coefficients or not.
|
108 |
+
floor_level (float): Floor level value based on human perception (default=1e-5).
|
109 |
+
"""
|
110 |
+
def __init__(self, sample_rate: int, range_start: int = 6, range_end: int = 11,
|
111 |
+
n_mels: int = 64, f_min: float = 0.0, f_max: tp.Optional[float] = None,
|
112 |
+
normalized: bool = False, alphas: bool = True, floor_level: float = 1e-5):
|
113 |
+
super().__init__()
|
114 |
+
l1s = list()
|
115 |
+
l2s = list()
|
116 |
+
self.alphas = list()
|
117 |
+
self.total = 0
|
118 |
+
self.normalized = normalized
|
119 |
+
for i in range(range_start, range_end):
|
120 |
+
l1s.append(
|
121 |
+
MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
|
122 |
+
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
|
123 |
+
log=False, normalized=normalized, floor_level=floor_level))
|
124 |
+
l2s.append(
|
125 |
+
MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
|
126 |
+
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
|
127 |
+
log=True, normalized=normalized, floor_level=floor_level))
|
128 |
+
if alphas:
|
129 |
+
self.alphas.append(np.sqrt(2 ** i - 1))
|
130 |
+
else:
|
131 |
+
self.alphas.append(1)
|
132 |
+
self.total += self.alphas[-1] + 1
|
133 |
+
|
134 |
+
self.l1s = nn.ModuleList(l1s)
|
135 |
+
self.l2s = nn.ModuleList(l2s)
|
136 |
+
|
137 |
+
def forward(self, x, y):
|
138 |
+
loss = 0.0
|
139 |
+
self.l1s.to(x.device)
|
140 |
+
self.l2s.to(x.device)
|
141 |
+
for i in range(len(self.alphas)):
|
142 |
+
s_x_1 = self.l1s[i](x)
|
143 |
+
s_y_1 = self.l1s[i](y)
|
144 |
+
s_x_2 = self.l2s[i](x)
|
145 |
+
s_y_2 = self.l2s[i](y)
|
146 |
+
loss += F.l1_loss(s_x_1, s_y_1) + self.alphas[i] * F.mse_loss(s_x_2, s_y_2)
|
147 |
+
if self.normalized:
|
148 |
+
loss = loss / self.total
|
149 |
+
return loss
|
audiocraft/losses/stftloss.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Adapted from MIT code under the original license
|
7 |
+
# Copyright 2019 Tomoki Hayashi
|
8 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
9 |
+
import typing as tp
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
|
16 |
+
# TODO: Replace with torchaudio.STFT?
|
17 |
+
def _stft(x: torch.Tensor, fft_size: int, hop_length: int, win_length: int,
|
18 |
+
window: tp.Optional[torch.Tensor], normalized: bool) -> torch.Tensor:
|
19 |
+
"""Perform STFT and convert to magnitude spectrogram.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
x: Input signal tensor (B, C, T).
|
23 |
+
fft_size (int): FFT size.
|
24 |
+
hop_length (int): Hop size.
|
25 |
+
win_length (int): Window length.
|
26 |
+
window (torch.Tensor or None): Window function type.
|
27 |
+
normalized (bool): Whether to normalize the STFT or not.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
torch.Tensor: Magnitude spectrogram (B, C, #frames, fft_size // 2 + 1).
|
31 |
+
"""
|
32 |
+
B, C, T = x.shape
|
33 |
+
x_stft = torch.stft(
|
34 |
+
x.view(-1, T), fft_size, hop_length, win_length, window,
|
35 |
+
normalized=normalized, return_complex=True,
|
36 |
+
)
|
37 |
+
x_stft = x_stft.view(B, C, *x_stft.shape[1:])
|
38 |
+
real = x_stft.real
|
39 |
+
imag = x_stft.imag
|
40 |
+
|
41 |
+
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
42 |
+
return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
|
43 |
+
|
44 |
+
|
45 |
+
class SpectralConvergenceLoss(nn.Module):
|
46 |
+
"""Spectral convergence loss.
|
47 |
+
"""
|
48 |
+
def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
|
49 |
+
super().__init__()
|
50 |
+
self.epsilon = epsilon
|
51 |
+
|
52 |
+
def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
|
53 |
+
"""Calculate forward propagation.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
x_mag: Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
57 |
+
y_mag: Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
58 |
+
Returns:
|
59 |
+
torch.Tensor: Spectral convergence loss value.
|
60 |
+
"""
|
61 |
+
return torch.norm(y_mag - x_mag, p="fro") / (torch.norm(y_mag, p="fro") + self.epsilon)
|
62 |
+
|
63 |
+
|
64 |
+
class LogSTFTMagnitudeLoss(nn.Module):
|
65 |
+
"""Log STFT magnitude loss.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
epsilon (float): Epsilon value for numerical stability.
|
69 |
+
"""
|
70 |
+
def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
|
71 |
+
super().__init__()
|
72 |
+
self.epsilon = epsilon
|
73 |
+
|
74 |
+
def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
|
75 |
+
"""Calculate forward propagation.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
x_mag (torch.Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
79 |
+
y_mag (torch.Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
80 |
+
Returns:
|
81 |
+
torch.Tensor: Log STFT magnitude loss value.
|
82 |
+
"""
|
83 |
+
return F.l1_loss(torch.log(self.epsilon + y_mag), torch.log(self.epsilon + x_mag))
|
84 |
+
|
85 |
+
|
86 |
+
class STFTLosses(nn.Module):
|
87 |
+
"""STFT losses.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
n_fft (int): Size of FFT.
|
91 |
+
hop_length (int): Hop length.
|
92 |
+
win_length (int): Window length.
|
93 |
+
window (str): Window function type.
|
94 |
+
normalized (bool): Whether to use normalized STFT or not.
|
95 |
+
epsilon (float): Epsilon for numerical stability.
|
96 |
+
"""
|
97 |
+
def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
|
98 |
+
window: str = "hann_window", normalized: bool = False,
|
99 |
+
epsilon: float = torch.finfo(torch.float32).eps):
|
100 |
+
super().__init__()
|
101 |
+
self.n_fft = n_fft
|
102 |
+
self.hop_length = hop_length
|
103 |
+
self.win_length = win_length
|
104 |
+
self.normalized = normalized
|
105 |
+
self.register_buffer("window", getattr(torch, window)(win_length))
|
106 |
+
self.spectral_convergenge_loss = SpectralConvergenceLoss(epsilon)
|
107 |
+
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(epsilon)
|
108 |
+
|
109 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
110 |
+
"""Calculate forward propagation.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
x (torch.Tensor): Predicted signal (B, T).
|
114 |
+
y (torch.Tensor): Groundtruth signal (B, T).
|
115 |
+
Returns:
|
116 |
+
torch.Tensor: Spectral convergence loss value.
|
117 |
+
torch.Tensor: Log STFT magnitude loss value.
|
118 |
+
"""
|
119 |
+
x_mag = _stft(x, self.n_fft, self.hop_length,
|
120 |
+
self.win_length, self.window, self.normalized) # type: ignore
|
121 |
+
y_mag = _stft(y, self.n_fft, self.hop_length,
|
122 |
+
self.win_length, self.window, self.normalized) # type: ignore
|
123 |
+
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
124 |
+
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
125 |
+
|
126 |
+
return sc_loss, mag_loss
|
127 |
+
|
128 |
+
|
129 |
+
class STFTLoss(nn.Module):
|
130 |
+
"""Single Resolution STFT loss.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
n_fft (int): Nb of FFT.
|
134 |
+
hop_length (int): Hop length.
|
135 |
+
win_length (int): Window length.
|
136 |
+
window (str): Window function type.
|
137 |
+
normalized (bool): Whether to use normalized STFT or not.
|
138 |
+
epsilon (float): Epsilon for numerical stability.
|
139 |
+
factor_sc (float): Coefficient for the spectral loss.
|
140 |
+
factor_mag (float): Coefficient for the magnitude loss.
|
141 |
+
"""
|
142 |
+
def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
|
143 |
+
window: str = "hann_window", normalized: bool = False,
|
144 |
+
factor_sc: float = 0.1, factor_mag: float = 0.1,
|
145 |
+
epsilon: float = torch.finfo(torch.float32).eps):
|
146 |
+
super().__init__()
|
147 |
+
self.loss = STFTLosses(n_fft, hop_length, win_length, window, normalized, epsilon)
|
148 |
+
self.factor_sc = factor_sc
|
149 |
+
self.factor_mag = factor_mag
|
150 |
+
|
151 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
152 |
+
"""Calculate forward propagation.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
x (torch.Tensor): Predicted signal (B, T).
|
156 |
+
y (torch.Tensor): Groundtruth signal (B, T).
|
157 |
+
Returns:
|
158 |
+
torch.Tensor: Single resolution STFT loss.
|
159 |
+
"""
|
160 |
+
sc_loss, mag_loss = self.loss(x, y)
|
161 |
+
return self.factor_sc * sc_loss + self.factor_mag * mag_loss
|
162 |
+
|
163 |
+
|
164 |
+
class MRSTFTLoss(nn.Module):
|
165 |
+
"""Multi resolution STFT loss.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
n_ffts (Sequence[int]): Sequence of FFT sizes.
|
169 |
+
hop_lengths (Sequence[int]): Sequence of hop sizes.
|
170 |
+
win_lengths (Sequence[int]): Sequence of window lengths.
|
171 |
+
window (str): Window function type.
|
172 |
+
factor_sc (float): Coefficient for the spectral loss.
|
173 |
+
factor_mag (float): Coefficient for the magnitude loss.
|
174 |
+
normalized (bool): Whether to use normalized STFT or not.
|
175 |
+
epsilon (float): Epsilon for numerical stability.
|
176 |
+
"""
|
177 |
+
def __init__(self, n_ffts: tp.Sequence[int] = [1024, 2048, 512], hop_lengths: tp.Sequence[int] = [120, 240, 50],
|
178 |
+
win_lengths: tp.Sequence[int] = [600, 1200, 240], window: str = "hann_window",
|
179 |
+
factor_sc: float = 0.1, factor_mag: float = 0.1,
|
180 |
+
normalized: bool = False, epsilon: float = torch.finfo(torch.float32).eps):
|
181 |
+
super().__init__()
|
182 |
+
assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
|
183 |
+
self.stft_losses = torch.nn.ModuleList()
|
184 |
+
for fs, ss, wl in zip(n_ffts, hop_lengths, win_lengths):
|
185 |
+
self.stft_losses += [STFTLosses(fs, ss, wl, window, normalized, epsilon)]
|
186 |
+
self.factor_sc = factor_sc
|
187 |
+
self.factor_mag = factor_mag
|
188 |
+
|
189 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
190 |
+
"""Calculate forward propagation.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
x (torch.Tensor): Predicted signal (B, T).
|
194 |
+
y (torch.Tensor): Groundtruth signal (B, T).
|
195 |
+
Returns:
|
196 |
+
torch.Tensor: Multi resolution STFT loss.
|
197 |
+
"""
|
198 |
+
sc_loss = torch.Tensor([0.0])
|
199 |
+
mag_loss = torch.Tensor([0.0])
|
200 |
+
for f in self.stft_losses:
|
201 |
+
sc_l, mag_l = f(x, y)
|
202 |
+
sc_loss += sc_l
|
203 |
+
mag_loss += mag_l
|
204 |
+
sc_loss /= len(self.stft_losses)
|
205 |
+
mag_loss /= len(self.stft_losses)
|
206 |
+
|
207 |
+
return self.factor_sc * sc_loss + self.factor_mag * mag_loss
|