unpairedelectron07 commited on
Commit
034e769
·
verified ·
1 Parent(s): fbc2435

Upload 4 files

Browse files
audiocraft/losses/balancer.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import flashy
10
+ import torch
11
+ from torch import autograd
12
+
13
+
14
+ class Balancer:
15
+ """Loss balancer.
16
+
17
+ The loss balancer combines losses together to compute gradients for the backward.
18
+ Given `y = f(...)`, and a number of losses `l1(y, ...)`, `l2(y, ...)`, with `...`
19
+ not having any dependence on `f`, the balancer can efficiently normalize the partial gradients
20
+ `d l1 / d y`, `d l2 / dy` before summing them in order to achieve a desired ratio between
21
+ the losses. For instance if `weights = {'l1': 2, 'l2': 1}`, 66% of the gradient
22
+ going into `f(...)` will come from `l1` on average, and 33% from `l2`. This allows for an easy
23
+ interpration of the weights even if the intrisic scale of `l1`, `l2` ... is unknown.
24
+
25
+ Noting `g1 = d l1 / dy`, etc., the balanced gradient `G` will be
26
+ (with `avg` an exponential moving average over the updates),
27
+
28
+ G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i)
29
+
30
+ If `balance_grads` is False, this is deactivated, and instead the gradient will just be the
31
+ standard sum of the partial gradients with the given weights.
32
+
33
+ A call to the backward method of the balancer will compute the the partial gradients,
34
+ combining all the losses and potentially rescaling the gradients,
35
+ which can help stabilize the training and reason about multiple losses with varying scales.
36
+ The obtained gradient with respect to `y` is then back-propagated to `f(...)`.
37
+
38
+ Expected usage:
39
+
40
+ weights = {'loss_a': 1, 'loss_b': 4}
41
+ balancer = Balancer(weights, ...)
42
+ losses: dict = {}
43
+ losses['loss_a'] = compute_loss_a(x, y)
44
+ losses['loss_b'] = compute_loss_b(x, y)
45
+ if model.training():
46
+ effective_loss = balancer.backward(losses, x)
47
+
48
+ Args:
49
+ weights (dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys
50
+ from the backward method to match the weights keys to assign weight to each of the provided loss.
51
+ balance_grads (bool): Whether to rescale gradients so that weights reflect the fraction of the
52
+ overall gradient, rather than a constant multiplier.
53
+ total_norm (float): Reference norm when rescaling gradients, ignored otherwise.
54
+ emay_decay (float): EMA decay for averaging the norms.
55
+ per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds
56
+ when rescaling the gradients.
57
+ epsilon (float): Epsilon value for numerical stability.
58
+ monitor (bool): If True, stores in `self.metrics` the relative ratio between the norm of the gradients
59
+ coming from each loss, when calling `backward()`.
60
+ """
61
+ def __init__(self, weights: tp.Dict[str, float], balance_grads: bool = True, total_norm: float = 1.,
62
+ ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12,
63
+ monitor: bool = False):
64
+ self.weights = weights
65
+ self.per_batch_item = per_batch_item
66
+ self.total_norm = total_norm or 1.
67
+ self.averager = flashy.averager(ema_decay or 1.)
68
+ self.epsilon = epsilon
69
+ self.monitor = monitor
70
+ self.balance_grads = balance_grads
71
+ self._metrics: tp.Dict[str, tp.Any] = {}
72
+
73
+ @property
74
+ def metrics(self):
75
+ return self._metrics
76
+
77
+ def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor) -> torch.Tensor:
78
+ """Compute the backward and return the effective train loss, e.g. the loss obtained from
79
+ computing the effective weights. If `balance_grads` is True, the effective weights
80
+ are the one that needs to be applied to each gradient to respect the desired relative
81
+ scale of gradients coming from each loss.
82
+
83
+ Args:
84
+ losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`.
85
+ input (torch.Tensor): the input of the losses, typically the output of the model.
86
+ This should be the single point of dependence between the losses
87
+ and the model being trained.
88
+ """
89
+ norms = {}
90
+ grads = {}
91
+ for name, loss in losses.items():
92
+ # Compute partial derivative of the less with respect to the input.
93
+ grad, = autograd.grad(loss, [input], retain_graph=True)
94
+ if self.per_batch_item:
95
+ # We do not average the gradient over the batch dimension.
96
+ dims = tuple(range(1, grad.dim()))
97
+ norm = grad.norm(dim=dims, p=2).mean()
98
+ else:
99
+ norm = grad.norm(p=2)
100
+ norms[name] = norm
101
+ grads[name] = grad
102
+
103
+ count = 1
104
+ if self.per_batch_item:
105
+ count = len(grad)
106
+ # Average norms across workers. Theoretically we should average the
107
+ # squared norm, then take the sqrt, but it worked fine like that.
108
+ avg_norms = flashy.distrib.average_metrics(self.averager(norms), count)
109
+ # We approximate the total norm of the gradient as the sums of the norms.
110
+ # Obviously this can be very incorrect if all gradients are aligned, but it works fine.
111
+ total = sum(avg_norms.values())
112
+
113
+ self._metrics = {}
114
+ if self.monitor:
115
+ # Store the ratio of the total gradient represented by each loss.
116
+ for k, v in avg_norms.items():
117
+ self._metrics[f'ratio_{k}'] = v / total
118
+
119
+ total_weights = sum([self.weights[k] for k in avg_norms])
120
+ assert total_weights > 0.
121
+ desired_ratios = {k: w / total_weights for k, w in self.weights.items()}
122
+
123
+ out_grad = torch.zeros_like(input)
124
+ effective_loss = torch.tensor(0., device=input.device, dtype=input.dtype)
125
+ for name, avg_norm in avg_norms.items():
126
+ if self.balance_grads:
127
+ # g_balanced = g / avg(||g||) * total_norm * desired_ratio
128
+ scale = desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm)
129
+ else:
130
+ # We just do regular weighted sum of the gradients.
131
+ scale = self.weights[name]
132
+ out_grad.add_(grads[name], alpha=scale)
133
+ effective_loss += scale * losses[name].detach()
134
+ # Send the computed partial derivative with respect to the output of the model to the model.
135
+ input.backward(out_grad)
136
+ return effective_loss
audiocraft/losses/sisnr.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import typing as tp
9
+
10
+ import torch
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+
14
+
15
+ def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
16
+ """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
17
+ with K the kernel size, by extracting frames with the given stride.
18
+ This will pad the input so that `F = ceil(T / K)`.
19
+ see https://github.com/pytorch/pytorch/issues/60466
20
+ """
21
+ *shape, length = a.shape
22
+ n_frames = math.ceil(length / stride)
23
+ tgt_length = (n_frames - 1) * stride + kernel_size
24
+ a = F.pad(a, (0, tgt_length - length))
25
+ strides = list(a.stride())
26
+ assert strides[-1] == 1, "data should be contiguous"
27
+ strides = strides[:-1] + [stride, 1]
28
+ return a.as_strided([*shape, n_frames, kernel_size], strides)
29
+
30
+
31
+ def _center(x: torch.Tensor) -> torch.Tensor:
32
+ return x - x.mean(-1, True)
33
+
34
+
35
+ def _norm2(x: torch.Tensor) -> torch.Tensor:
36
+ return x.pow(2).sum(-1, True)
37
+
38
+
39
+ class SISNR(nn.Module):
40
+ """SISNR loss.
41
+
42
+ Input should be [B, C, T], output is scalar.
43
+
44
+ ..Warning:: This function returns the opposite of the SI-SNR (e.g. `-1 * regular_SI_SNR`).
45
+ Consequently, lower scores are better in terms of reconstruction quality,
46
+ in particular, it should be negative if training goes well. This done this way so
47
+ that this module can also be used as a loss function for training model.
48
+
49
+ Args:
50
+ sample_rate (int): Sample rate.
51
+ segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on
52
+ entire audio only.
53
+ overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap.
54
+ epsilon (float): Epsilon value for numerical stability.
55
+ """
56
+ def __init__(
57
+ self,
58
+ sample_rate: int = 16000,
59
+ segment: tp.Optional[float] = 20,
60
+ overlap: float = 0.5,
61
+ epsilon: float = torch.finfo(torch.float32).eps,
62
+ ):
63
+ super().__init__()
64
+ self.sample_rate = sample_rate
65
+ self.segment = segment
66
+ self.overlap = overlap
67
+ self.epsilon = epsilon
68
+
69
+ def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
70
+ B, C, T = ref_sig.shape
71
+ assert ref_sig.shape == out_sig.shape
72
+
73
+ if self.segment is None:
74
+ frame = T
75
+ stride = T
76
+ else:
77
+ frame = int(self.segment * self.sample_rate)
78
+ stride = int(frame * (1 - self.overlap))
79
+
80
+ epsilon = self.epsilon * frame # make epsilon prop to frame size.
81
+
82
+ gt = _unfold(ref_sig, frame, stride)
83
+ est = _unfold(out_sig, frame, stride)
84
+ if self.segment is None:
85
+ assert gt.shape[-1] == 1
86
+
87
+ gt = _center(gt)
88
+ est = _center(est)
89
+ dot = torch.einsum("bcft,bcft->bcf", gt, est)
90
+
91
+ proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt))
92
+ noise = est - proj
93
+
94
+ sisnr = 10 * (
95
+ torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise))
96
+ )
97
+ return -1 * sisnr[..., 0].mean()
audiocraft/losses/specloss.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import numpy as np
10
+ from torchaudio.transforms import MelSpectrogram
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+
15
+ from ..modules import pad_for_conv1d
16
+
17
+
18
+ class MelSpectrogramWrapper(nn.Module):
19
+ """Wrapper around MelSpectrogram torchaudio transform providing proper padding
20
+ and additional post-processing including log scaling.
21
+
22
+ Args:
23
+ n_mels (int): Number of mel bins.
24
+ n_fft (int): Number of fft.
25
+ hop_length (int): Hop size.
26
+ win_length (int): Window length.
27
+ n_mels (int): Number of mel bins.
28
+ sample_rate (int): Sample rate.
29
+ f_min (float or None): Minimum frequency.
30
+ f_max (float or None): Maximum frequency.
31
+ log (bool): Whether to scale with log.
32
+ normalized (bool): Whether to normalize the melspectrogram.
33
+ floor_level (float): Floor level based on human perception (default=1e-5).
34
+ """
35
+ def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_length: tp.Optional[int] = None,
36
+ n_mels: int = 80, sample_rate: float = 22050, f_min: float = 0.0, f_max: tp.Optional[float] = None,
37
+ log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
38
+ super().__init__()
39
+ self.n_fft = n_fft
40
+ hop_length = int(hop_length)
41
+ self.hop_length = hop_length
42
+ self.mel_transform = MelSpectrogram(n_mels=n_mels, sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
43
+ win_length=win_length, f_min=f_min, f_max=f_max, normalized=normalized,
44
+ window_fn=torch.hann_window, center=False)
45
+ self.floor_level = floor_level
46
+ self.log = log
47
+
48
+ def forward(self, x):
49
+ p = int((self.n_fft - self.hop_length) // 2)
50
+ if len(x.shape) == 2:
51
+ x = x.unsqueeze(1)
52
+ x = F.pad(x, (p, p), "reflect")
53
+ # Make sure that all the frames are full.
54
+ # The combination of `pad_for_conv1d` and the above padding
55
+ # will make the output of size ceil(T / hop).
56
+ x = pad_for_conv1d(x, self.n_fft, self.hop_length)
57
+ self.mel_transform.to(x.device)
58
+ mel_spec = self.mel_transform(x)
59
+ B, C, freqs, frame = mel_spec.shape
60
+ if self.log:
61
+ mel_spec = torch.log10(self.floor_level + mel_spec)
62
+ return mel_spec.reshape(B, C * freqs, frame)
63
+
64
+
65
+ class MelSpectrogramL1Loss(torch.nn.Module):
66
+ """L1 Loss on MelSpectrogram.
67
+
68
+ Args:
69
+ sample_rate (int): Sample rate.
70
+ n_fft (int): Number of fft.
71
+ hop_length (int): Hop size.
72
+ win_length (int): Window length.
73
+ n_mels (int): Number of mel bins.
74
+ f_min (float or None): Minimum frequency.
75
+ f_max (float or None): Maximum frequency.
76
+ log (bool): Whether to scale with log.
77
+ normalized (bool): Whether to normalize the melspectrogram.
78
+ floor_level (float): Floor level value based on human perception (default=1e-5).
79
+ """
80
+ def __init__(self, sample_rate: int, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024,
81
+ n_mels: int = 80, f_min: float = 0.0, f_max: tp.Optional[float] = None,
82
+ log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
83
+ super().__init__()
84
+ self.l1 = torch.nn.L1Loss()
85
+ self.melspec = MelSpectrogramWrapper(n_fft=n_fft, hop_length=hop_length, win_length=win_length,
86
+ n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
87
+ log=log, normalized=normalized, floor_level=floor_level)
88
+
89
+ def forward(self, x, y):
90
+ self.melspec.to(x.device)
91
+ s_x = self.melspec(x)
92
+ s_y = self.melspec(y)
93
+ return self.l1(s_x, s_y)
94
+
95
+
96
+ class MultiScaleMelSpectrogramLoss(nn.Module):
97
+ """Multi-Scale spectrogram loss (msspec).
98
+
99
+ Args:
100
+ sample_rate (int): Sample rate.
101
+ range_start (int): Power of 2 to use for the first scale.
102
+ range_stop (int): Power of 2 to use for the last scale.
103
+ n_mels (int): Number of mel bins.
104
+ f_min (float): Minimum frequency.
105
+ f_max (float or None): Maximum frequency.
106
+ normalized (bool): Whether to normalize the melspectrogram.
107
+ alphas (bool): Whether to use alphas as coefficients or not.
108
+ floor_level (float): Floor level value based on human perception (default=1e-5).
109
+ """
110
+ def __init__(self, sample_rate: int, range_start: int = 6, range_end: int = 11,
111
+ n_mels: int = 64, f_min: float = 0.0, f_max: tp.Optional[float] = None,
112
+ normalized: bool = False, alphas: bool = True, floor_level: float = 1e-5):
113
+ super().__init__()
114
+ l1s = list()
115
+ l2s = list()
116
+ self.alphas = list()
117
+ self.total = 0
118
+ self.normalized = normalized
119
+ for i in range(range_start, range_end):
120
+ l1s.append(
121
+ MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
122
+ n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
123
+ log=False, normalized=normalized, floor_level=floor_level))
124
+ l2s.append(
125
+ MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
126
+ n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
127
+ log=True, normalized=normalized, floor_level=floor_level))
128
+ if alphas:
129
+ self.alphas.append(np.sqrt(2 ** i - 1))
130
+ else:
131
+ self.alphas.append(1)
132
+ self.total += self.alphas[-1] + 1
133
+
134
+ self.l1s = nn.ModuleList(l1s)
135
+ self.l2s = nn.ModuleList(l2s)
136
+
137
+ def forward(self, x, y):
138
+ loss = 0.0
139
+ self.l1s.to(x.device)
140
+ self.l2s.to(x.device)
141
+ for i in range(len(self.alphas)):
142
+ s_x_1 = self.l1s[i](x)
143
+ s_y_1 = self.l1s[i](y)
144
+ s_x_2 = self.l2s[i](x)
145
+ s_y_2 = self.l2s[i](y)
146
+ loss += F.l1_loss(s_x_1, s_y_1) + self.alphas[i] * F.mse_loss(s_x_2, s_y_2)
147
+ if self.normalized:
148
+ loss = loss / self.total
149
+ return loss
audiocraft/losses/stftloss.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Adapted from MIT code under the original license
7
+ # Copyright 2019 Tomoki Hayashi
8
+ # MIT License (https://opensource.org/licenses/MIT)
9
+ import typing as tp
10
+
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+
15
+
16
+ # TODO: Replace with torchaudio.STFT?
17
+ def _stft(x: torch.Tensor, fft_size: int, hop_length: int, win_length: int,
18
+ window: tp.Optional[torch.Tensor], normalized: bool) -> torch.Tensor:
19
+ """Perform STFT and convert to magnitude spectrogram.
20
+
21
+ Args:
22
+ x: Input signal tensor (B, C, T).
23
+ fft_size (int): FFT size.
24
+ hop_length (int): Hop size.
25
+ win_length (int): Window length.
26
+ window (torch.Tensor or None): Window function type.
27
+ normalized (bool): Whether to normalize the STFT or not.
28
+
29
+ Returns:
30
+ torch.Tensor: Magnitude spectrogram (B, C, #frames, fft_size // 2 + 1).
31
+ """
32
+ B, C, T = x.shape
33
+ x_stft = torch.stft(
34
+ x.view(-1, T), fft_size, hop_length, win_length, window,
35
+ normalized=normalized, return_complex=True,
36
+ )
37
+ x_stft = x_stft.view(B, C, *x_stft.shape[1:])
38
+ real = x_stft.real
39
+ imag = x_stft.imag
40
+
41
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
42
+ return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
43
+
44
+
45
+ class SpectralConvergenceLoss(nn.Module):
46
+ """Spectral convergence loss.
47
+ """
48
+ def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
49
+ super().__init__()
50
+ self.epsilon = epsilon
51
+
52
+ def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
53
+ """Calculate forward propagation.
54
+
55
+ Args:
56
+ x_mag: Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
57
+ y_mag: Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
58
+ Returns:
59
+ torch.Tensor: Spectral convergence loss value.
60
+ """
61
+ return torch.norm(y_mag - x_mag, p="fro") / (torch.norm(y_mag, p="fro") + self.epsilon)
62
+
63
+
64
+ class LogSTFTMagnitudeLoss(nn.Module):
65
+ """Log STFT magnitude loss.
66
+
67
+ Args:
68
+ epsilon (float): Epsilon value for numerical stability.
69
+ """
70
+ def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
71
+ super().__init__()
72
+ self.epsilon = epsilon
73
+
74
+ def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
75
+ """Calculate forward propagation.
76
+
77
+ Args:
78
+ x_mag (torch.Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
79
+ y_mag (torch.Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
80
+ Returns:
81
+ torch.Tensor: Log STFT magnitude loss value.
82
+ """
83
+ return F.l1_loss(torch.log(self.epsilon + y_mag), torch.log(self.epsilon + x_mag))
84
+
85
+
86
+ class STFTLosses(nn.Module):
87
+ """STFT losses.
88
+
89
+ Args:
90
+ n_fft (int): Size of FFT.
91
+ hop_length (int): Hop length.
92
+ win_length (int): Window length.
93
+ window (str): Window function type.
94
+ normalized (bool): Whether to use normalized STFT or not.
95
+ epsilon (float): Epsilon for numerical stability.
96
+ """
97
+ def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
98
+ window: str = "hann_window", normalized: bool = False,
99
+ epsilon: float = torch.finfo(torch.float32).eps):
100
+ super().__init__()
101
+ self.n_fft = n_fft
102
+ self.hop_length = hop_length
103
+ self.win_length = win_length
104
+ self.normalized = normalized
105
+ self.register_buffer("window", getattr(torch, window)(win_length))
106
+ self.spectral_convergenge_loss = SpectralConvergenceLoss(epsilon)
107
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(epsilon)
108
+
109
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
110
+ """Calculate forward propagation.
111
+
112
+ Args:
113
+ x (torch.Tensor): Predicted signal (B, T).
114
+ y (torch.Tensor): Groundtruth signal (B, T).
115
+ Returns:
116
+ torch.Tensor: Spectral convergence loss value.
117
+ torch.Tensor: Log STFT magnitude loss value.
118
+ """
119
+ x_mag = _stft(x, self.n_fft, self.hop_length,
120
+ self.win_length, self.window, self.normalized) # type: ignore
121
+ y_mag = _stft(y, self.n_fft, self.hop_length,
122
+ self.win_length, self.window, self.normalized) # type: ignore
123
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
124
+ mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
125
+
126
+ return sc_loss, mag_loss
127
+
128
+
129
+ class STFTLoss(nn.Module):
130
+ """Single Resolution STFT loss.
131
+
132
+ Args:
133
+ n_fft (int): Nb of FFT.
134
+ hop_length (int): Hop length.
135
+ win_length (int): Window length.
136
+ window (str): Window function type.
137
+ normalized (bool): Whether to use normalized STFT or not.
138
+ epsilon (float): Epsilon for numerical stability.
139
+ factor_sc (float): Coefficient for the spectral loss.
140
+ factor_mag (float): Coefficient for the magnitude loss.
141
+ """
142
+ def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
143
+ window: str = "hann_window", normalized: bool = False,
144
+ factor_sc: float = 0.1, factor_mag: float = 0.1,
145
+ epsilon: float = torch.finfo(torch.float32).eps):
146
+ super().__init__()
147
+ self.loss = STFTLosses(n_fft, hop_length, win_length, window, normalized, epsilon)
148
+ self.factor_sc = factor_sc
149
+ self.factor_mag = factor_mag
150
+
151
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
152
+ """Calculate forward propagation.
153
+
154
+ Args:
155
+ x (torch.Tensor): Predicted signal (B, T).
156
+ y (torch.Tensor): Groundtruth signal (B, T).
157
+ Returns:
158
+ torch.Tensor: Single resolution STFT loss.
159
+ """
160
+ sc_loss, mag_loss = self.loss(x, y)
161
+ return self.factor_sc * sc_loss + self.factor_mag * mag_loss
162
+
163
+
164
+ class MRSTFTLoss(nn.Module):
165
+ """Multi resolution STFT loss.
166
+
167
+ Args:
168
+ n_ffts (Sequence[int]): Sequence of FFT sizes.
169
+ hop_lengths (Sequence[int]): Sequence of hop sizes.
170
+ win_lengths (Sequence[int]): Sequence of window lengths.
171
+ window (str): Window function type.
172
+ factor_sc (float): Coefficient for the spectral loss.
173
+ factor_mag (float): Coefficient for the magnitude loss.
174
+ normalized (bool): Whether to use normalized STFT or not.
175
+ epsilon (float): Epsilon for numerical stability.
176
+ """
177
+ def __init__(self, n_ffts: tp.Sequence[int] = [1024, 2048, 512], hop_lengths: tp.Sequence[int] = [120, 240, 50],
178
+ win_lengths: tp.Sequence[int] = [600, 1200, 240], window: str = "hann_window",
179
+ factor_sc: float = 0.1, factor_mag: float = 0.1,
180
+ normalized: bool = False, epsilon: float = torch.finfo(torch.float32).eps):
181
+ super().__init__()
182
+ assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
183
+ self.stft_losses = torch.nn.ModuleList()
184
+ for fs, ss, wl in zip(n_ffts, hop_lengths, win_lengths):
185
+ self.stft_losses += [STFTLosses(fs, ss, wl, window, normalized, epsilon)]
186
+ self.factor_sc = factor_sc
187
+ self.factor_mag = factor_mag
188
+
189
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
190
+ """Calculate forward propagation.
191
+
192
+ Args:
193
+ x (torch.Tensor): Predicted signal (B, T).
194
+ y (torch.Tensor): Groundtruth signal (B, T).
195
+ Returns:
196
+ torch.Tensor: Multi resolution STFT loss.
197
+ """
198
+ sc_loss = torch.Tensor([0.0])
199
+ mag_loss = torch.Tensor([0.0])
200
+ for f in self.stft_losses:
201
+ sc_l, mag_l = f(x, y)
202
+ sc_loss += sc_l
203
+ mag_loss += mag_l
204
+ sc_loss /= len(self.stft_losses)
205
+ mag_loss /= len(self.stft_losses)
206
+
207
+ return self.factor_sc * sc_loss + self.factor_mag * mag_loss