Spaces:
Runtime error
Runtime error
anthonyrusso
commited on
Commit
•
f1e9197
1
Parent(s):
b3ff8a5
upload audiocraft
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- audiocraft/__init__.py +26 -0
- audiocraft/adversarial/__init__.py +22 -0
- audiocraft/adversarial/discriminators/__init__.py +10 -0
- audiocraft/adversarial/discriminators/base.py +34 -0
- audiocraft/adversarial/discriminators/mpd.py +106 -0
- audiocraft/adversarial/discriminators/msd.py +126 -0
- audiocraft/adversarial/discriminators/msstftd.py +134 -0
- audiocraft/adversarial/losses.py +228 -0
- audiocraft/data/__init__.py +10 -0
- audiocraft/data/audio.py +216 -0
- audiocraft/data/audio_dataset.py +587 -0
- audiocraft/data/audio_utils.py +176 -0
- audiocraft/data/info_audio_dataset.py +110 -0
- audiocraft/data/music_dataset.py +270 -0
- audiocraft/data/sound_dataset.py +330 -0
- audiocraft/data/zip.py +76 -0
- audiocraft/environment.py +176 -0
- audiocraft/grids/__init__.py +6 -0
- audiocraft/grids/_base_explorers.py +80 -0
- audiocraft/grids/audiogen/__init__.py +6 -0
- audiocraft/grids/audiogen/audiogen_base_16khz.py +23 -0
- audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py +68 -0
- audiocraft/grids/compression/__init__.py +6 -0
- audiocraft/grids/compression/_explorers.py +55 -0
- audiocraft/grids/compression/debug.py +31 -0
- audiocraft/grids/compression/encodec_audiogen_16khz.py +29 -0
- audiocraft/grids/compression/encodec_base_24khz.py +28 -0
- audiocraft/grids/compression/encodec_musicgen_32khz.py +34 -0
- audiocraft/grids/diffusion/4_bands_base_32khz.py +27 -0
- audiocraft/grids/diffusion/__init__.py +6 -0
- audiocraft/grids/diffusion/_explorers.py +66 -0
- audiocraft/grids/musicgen/__init__.py +6 -0
- audiocraft/grids/musicgen/_explorers.py +93 -0
- audiocraft/grids/musicgen/musicgen_base_32khz.py +43 -0
- audiocraft/grids/musicgen/musicgen_base_cached_32khz.py +67 -0
- audiocraft/grids/musicgen/musicgen_clapemb_32khz.py +32 -0
- audiocraft/grids/musicgen/musicgen_melody_32khz.py +65 -0
- audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py +99 -0
- audiocraft/losses/__init__.py +21 -0
- audiocraft/losses/balancer.py +136 -0
- audiocraft/losses/sisnr.py +92 -0
- audiocraft/losses/specloss.py +149 -0
- audiocraft/losses/stftloss.py +207 -0
- audiocraft/metrics/__init__.py +14 -0
- audiocraft/metrics/chroma_cosinesim.py +72 -0
- audiocraft/metrics/clap_consistency.py +84 -0
- audiocraft/metrics/fad.py +329 -0
- audiocraft/metrics/kld.py +220 -0
- audiocraft/metrics/rvm.py +110 -0
- audiocraft/metrics/visqol.py +216 -0
audiocraft/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
AudioCraft is a general framework for training audio generative models.
|
8 |
+
At the moment we provide the training code for:
|
9 |
+
|
10 |
+
- [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art
|
11 |
+
text-to-music and melody+text autoregressive generative model.
|
12 |
+
For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model,
|
13 |
+
`audiocraft.models.musicgen.MusicGen`.
|
14 |
+
- [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art
|
15 |
+
text-to-general-audio generative model.
|
16 |
+
- [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity
|
17 |
+
neural audio codec which provides an excellent tokenizer for autoregressive language models.
|
18 |
+
See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`.
|
19 |
+
- [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that
|
20 |
+
improves the perceived quality and reduces the artifacts coming from adversarial decoders.
|
21 |
+
"""
|
22 |
+
|
23 |
+
# flake8: noqa
|
24 |
+
from . import data, modules, models
|
25 |
+
|
26 |
+
__version__ = '1.0.0'
|
audiocraft/adversarial/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Adversarial losses and discriminator architectures."""
|
7 |
+
|
8 |
+
# flake8: noqa
|
9 |
+
from .discriminators import (
|
10 |
+
MultiPeriodDiscriminator,
|
11 |
+
MultiScaleDiscriminator,
|
12 |
+
MultiScaleSTFTDiscriminator
|
13 |
+
)
|
14 |
+
from .losses import (
|
15 |
+
AdversarialLoss,
|
16 |
+
AdvLossType,
|
17 |
+
get_adv_criterion,
|
18 |
+
get_fake_criterion,
|
19 |
+
get_real_criterion,
|
20 |
+
FeatLossType,
|
21 |
+
FeatureMatchingLoss
|
22 |
+
)
|
audiocraft/adversarial/discriminators/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# flake8: noqa
|
8 |
+
from .mpd import MultiPeriodDiscriminator
|
9 |
+
from .msd import MultiScaleDiscriminator
|
10 |
+
from .msstftd import MultiScaleSTFTDiscriminator
|
audiocraft/adversarial/discriminators/base.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
import typing as tp
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
|
14 |
+
FeatureMapType = tp.List[torch.Tensor]
|
15 |
+
LogitsType = torch.Tensor
|
16 |
+
MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
|
17 |
+
|
18 |
+
|
19 |
+
class MultiDiscriminator(ABC, nn.Module):
|
20 |
+
"""Base implementation for discriminators composed of sub-discriminators acting at different scales.
|
21 |
+
"""
|
22 |
+
def __init__(self):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
|
27 |
+
...
|
28 |
+
|
29 |
+
@property
|
30 |
+
@abstractmethod
|
31 |
+
def num_discriminators(self) -> int:
|
32 |
+
"""Number of discriminators.
|
33 |
+
"""
|
34 |
+
...
|
audiocraft/adversarial/discriminators/mpd.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from ...modules import NormConv2d
|
14 |
+
from .base import MultiDiscriminator, MultiDiscriminatorOutputType
|
15 |
+
|
16 |
+
|
17 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
18 |
+
return int((kernel_size * dilation - dilation) / 2)
|
19 |
+
|
20 |
+
|
21 |
+
class PeriodDiscriminator(nn.Module):
|
22 |
+
"""Period sub-discriminator.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
period (int): Period between samples of audio.
|
26 |
+
in_channels (int): Number of input channels.
|
27 |
+
out_channels (int): Number of output channels.
|
28 |
+
n_layers (int): Number of convolutional layers.
|
29 |
+
kernel_sizes (list of int): Kernel sizes for convolutions.
|
30 |
+
stride (int): Stride for convolutions.
|
31 |
+
filters (int): Initial number of filters in convolutions.
|
32 |
+
filters_scale (int): Multiplier of number of filters as we increase depth.
|
33 |
+
max_filters (int): Maximum number of filters.
|
34 |
+
norm (str): Normalization method.
|
35 |
+
activation (str): Activation function.
|
36 |
+
activation_params (dict): Parameters to provide to the activation function.
|
37 |
+
"""
|
38 |
+
def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
|
39 |
+
n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
|
40 |
+
filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
|
41 |
+
norm: str = 'weight_norm', activation: str = 'LeakyReLU',
|
42 |
+
activation_params: dict = {'negative_slope': 0.2}):
|
43 |
+
super().__init__()
|
44 |
+
self.period = period
|
45 |
+
self.n_layers = n_layers
|
46 |
+
self.activation = getattr(torch.nn, activation)(**activation_params)
|
47 |
+
self.convs = nn.ModuleList()
|
48 |
+
in_chs = in_channels
|
49 |
+
for i in range(self.n_layers):
|
50 |
+
out_chs = min(filters * (filters_scale ** (i + 1)), max_filters)
|
51 |
+
eff_stride = 1 if i == self.n_layers - 1 else stride
|
52 |
+
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1),
|
53 |
+
padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm))
|
54 |
+
in_chs = out_chs
|
55 |
+
self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1,
|
56 |
+
padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm)
|
57 |
+
|
58 |
+
def forward(self, x: torch.Tensor):
|
59 |
+
fmap = []
|
60 |
+
# 1d to 2d
|
61 |
+
b, c, t = x.shape
|
62 |
+
if t % self.period != 0: # pad first
|
63 |
+
n_pad = self.period - (t % self.period)
|
64 |
+
x = F.pad(x, (0, n_pad), 'reflect')
|
65 |
+
t = t + n_pad
|
66 |
+
x = x.view(b, c, t // self.period, self.period)
|
67 |
+
|
68 |
+
for conv in self.convs:
|
69 |
+
x = conv(x)
|
70 |
+
x = self.activation(x)
|
71 |
+
fmap.append(x)
|
72 |
+
x = self.conv_post(x)
|
73 |
+
fmap.append(x)
|
74 |
+
# x = torch.flatten(x, 1, -1)
|
75 |
+
|
76 |
+
return x, fmap
|
77 |
+
|
78 |
+
|
79 |
+
class MultiPeriodDiscriminator(MultiDiscriminator):
|
80 |
+
"""Multi-Period (MPD) Discriminator.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
in_channels (int): Number of input channels.
|
84 |
+
out_channels (int): Number of output channels.
|
85 |
+
periods (Sequence[int]): Periods between samples of audio for the sub-discriminators.
|
86 |
+
**kwargs: Additional args for `PeriodDiscriminator`
|
87 |
+
"""
|
88 |
+
def __init__(self, in_channels: int = 1, out_channels: int = 1,
|
89 |
+
periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs):
|
90 |
+
super().__init__()
|
91 |
+
self.discriminators = nn.ModuleList([
|
92 |
+
PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods
|
93 |
+
])
|
94 |
+
|
95 |
+
@property
|
96 |
+
def num_discriminators(self):
|
97 |
+
return len(self.discriminators)
|
98 |
+
|
99 |
+
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
|
100 |
+
logits = []
|
101 |
+
fmaps = []
|
102 |
+
for disc in self.discriminators:
|
103 |
+
logit, fmap = disc(x)
|
104 |
+
logits.append(logit)
|
105 |
+
fmaps.append(fmap)
|
106 |
+
return logits, fmaps
|
audiocraft/adversarial/discriminators/msd.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
from ...modules import NormConv1d
|
14 |
+
from .base import MultiDiscriminator, MultiDiscriminatorOutputType
|
15 |
+
|
16 |
+
|
17 |
+
class ScaleDiscriminator(nn.Module):
|
18 |
+
"""Waveform sub-discriminator.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
in_channels (int): Number of input channels.
|
22 |
+
out_channels (int): Number of output channels.
|
23 |
+
kernel_sizes (Sequence[int]): Kernel sizes for first and last convolutions.
|
24 |
+
filters (int): Number of initial filters for convolutions.
|
25 |
+
max_filters (int): Maximum number of filters.
|
26 |
+
downsample_scales (Sequence[int]): Scale for downsampling implemented as strided convolutions.
|
27 |
+
inner_kernel_sizes (Sequence[int] or None): Kernel sizes for inner convolutions.
|
28 |
+
groups (Sequence[int] or None): Groups for inner convolutions.
|
29 |
+
strides (Sequence[int] or None): Strides for inner convolutions.
|
30 |
+
paddings (Sequence[int] or None): Paddings for inner convolutions.
|
31 |
+
norm (str): Normalization method.
|
32 |
+
activation (str): Activation function.
|
33 |
+
activation_params (dict): Parameters to provide to the activation function.
|
34 |
+
pad (str): Padding for initial convolution.
|
35 |
+
pad_params (dict): Parameters to provide to the padding module.
|
36 |
+
"""
|
37 |
+
def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Sequence[int] = [5, 3],
|
38 |
+
filters: int = 16, max_filters: int = 1024, downsample_scales: tp.Sequence[int] = [4, 4, 4, 4],
|
39 |
+
inner_kernel_sizes: tp.Optional[tp.Sequence[int]] = None, groups: tp.Optional[tp.Sequence[int]] = None,
|
40 |
+
strides: tp.Optional[tp.Sequence[int]] = None, paddings: tp.Optional[tp.Sequence[int]] = None,
|
41 |
+
norm: str = 'weight_norm', activation: str = 'LeakyReLU',
|
42 |
+
activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d',
|
43 |
+
pad_params: dict = {}):
|
44 |
+
super().__init__()
|
45 |
+
assert len(kernel_sizes) == 2
|
46 |
+
assert kernel_sizes[0] % 2 == 1
|
47 |
+
assert kernel_sizes[1] % 2 == 1
|
48 |
+
assert (inner_kernel_sizes is None or len(inner_kernel_sizes) == len(downsample_scales))
|
49 |
+
assert (groups is None or len(groups) == len(downsample_scales))
|
50 |
+
assert (strides is None or len(strides) == len(downsample_scales))
|
51 |
+
assert (paddings is None or len(paddings) == len(downsample_scales))
|
52 |
+
self.activation = getattr(torch.nn, activation)(**activation_params)
|
53 |
+
self.convs = nn.ModuleList()
|
54 |
+
self.convs.append(
|
55 |
+
nn.Sequential(
|
56 |
+
getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
|
57 |
+
NormConv1d(in_channels, filters, kernel_size=np.prod(kernel_sizes), stride=1, norm=norm)
|
58 |
+
)
|
59 |
+
)
|
60 |
+
|
61 |
+
in_chs = filters
|
62 |
+
for i, downsample_scale in enumerate(downsample_scales):
|
63 |
+
out_chs = min(in_chs * downsample_scale, max_filters)
|
64 |
+
default_kernel_size = downsample_scale * 10 + 1
|
65 |
+
default_stride = downsample_scale
|
66 |
+
default_padding = (default_kernel_size - 1) // 2
|
67 |
+
default_groups = in_chs // 4
|
68 |
+
self.convs.append(
|
69 |
+
NormConv1d(in_chs, out_chs,
|
70 |
+
kernel_size=inner_kernel_sizes[i] if inner_kernel_sizes else default_kernel_size,
|
71 |
+
stride=strides[i] if strides else default_stride,
|
72 |
+
groups=groups[i] if groups else default_groups,
|
73 |
+
padding=paddings[i] if paddings else default_padding,
|
74 |
+
norm=norm))
|
75 |
+
in_chs = out_chs
|
76 |
+
|
77 |
+
out_chs = min(in_chs * 2, max_filters)
|
78 |
+
self.convs.append(NormConv1d(in_chs, out_chs, kernel_size=kernel_sizes[0], stride=1,
|
79 |
+
padding=(kernel_sizes[0] - 1) // 2, norm=norm))
|
80 |
+
self.conv_post = NormConv1d(out_chs, out_channels, kernel_size=kernel_sizes[1], stride=1,
|
81 |
+
padding=(kernel_sizes[1] - 1) // 2, norm=norm)
|
82 |
+
|
83 |
+
def forward(self, x: torch.Tensor):
|
84 |
+
fmap = []
|
85 |
+
for layer in self.convs:
|
86 |
+
x = layer(x)
|
87 |
+
x = self.activation(x)
|
88 |
+
fmap.append(x)
|
89 |
+
x = self.conv_post(x)
|
90 |
+
fmap.append(x)
|
91 |
+
# x = torch.flatten(x, 1, -1)
|
92 |
+
return x, fmap
|
93 |
+
|
94 |
+
|
95 |
+
class MultiScaleDiscriminator(MultiDiscriminator):
|
96 |
+
"""Multi-Scale (MSD) Discriminator,
|
97 |
+
|
98 |
+
Args:
|
99 |
+
in_channels (int): Number of input channels.
|
100 |
+
out_channels (int): Number of output channels.
|
101 |
+
downsample_factor (int): Downsampling factor between the different scales.
|
102 |
+
scale_norms (Sequence[str]): Normalization for each sub-discriminator.
|
103 |
+
**kwargs: Additional args for ScaleDiscriminator.
|
104 |
+
"""
|
105 |
+
def __init__(self, in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2,
|
106 |
+
scale_norms: tp.Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs):
|
107 |
+
super().__init__()
|
108 |
+
self.discriminators = nn.ModuleList([
|
109 |
+
ScaleDiscriminator(in_channels, out_channels, norm=norm, **kwargs) for norm in scale_norms
|
110 |
+
])
|
111 |
+
self.downsample = nn.AvgPool1d(downsample_factor * 2, downsample_factor, padding=downsample_factor)
|
112 |
+
|
113 |
+
@property
|
114 |
+
def num_discriminators(self):
|
115 |
+
return len(self.discriminators)
|
116 |
+
|
117 |
+
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
|
118 |
+
logits = []
|
119 |
+
fmaps = []
|
120 |
+
for i, disc in enumerate(self.discriminators):
|
121 |
+
if i != 0:
|
122 |
+
self.downsample(x)
|
123 |
+
logit, fmap = disc(x)
|
124 |
+
logits.append(logit)
|
125 |
+
fmaps.append(fmap)
|
126 |
+
return logits, fmaps
|
audiocraft/adversarial/discriminators/msstftd.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import torchaudio
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
from ...modules import NormConv2d
|
15 |
+
from .base import MultiDiscriminator, MultiDiscriminatorOutputType
|
16 |
+
|
17 |
+
|
18 |
+
def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
|
19 |
+
return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
|
20 |
+
|
21 |
+
|
22 |
+
class DiscriminatorSTFT(nn.Module):
|
23 |
+
"""STFT sub-discriminator.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
filters (int): Number of filters in convolutions.
|
27 |
+
in_channels (int): Number of input channels.
|
28 |
+
out_channels (int): Number of output channels.
|
29 |
+
n_fft (int): Size of FFT for each scale.
|
30 |
+
hop_length (int): Length of hop between STFT windows for each scale.
|
31 |
+
kernel_size (tuple of int): Inner Conv2d kernel sizes.
|
32 |
+
stride (tuple of int): Inner Conv2d strides.
|
33 |
+
dilations (list of int): Inner Conv2d dilation on the time dimension.
|
34 |
+
win_length (int): Window size for each scale.
|
35 |
+
normalized (bool): Whether to normalize by magnitude after stft.
|
36 |
+
norm (str): Normalization method.
|
37 |
+
activation (str): Activation function.
|
38 |
+
activation_params (dict): Parameters to provide to the activation function.
|
39 |
+
growth (int): Growth factor for the filters.
|
40 |
+
"""
|
41 |
+
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
|
42 |
+
n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
|
43 |
+
filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
|
44 |
+
stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
|
45 |
+
activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
|
46 |
+
super().__init__()
|
47 |
+
assert len(kernel_size) == 2
|
48 |
+
assert len(stride) == 2
|
49 |
+
self.filters = filters
|
50 |
+
self.in_channels = in_channels
|
51 |
+
self.out_channels = out_channels
|
52 |
+
self.n_fft = n_fft
|
53 |
+
self.hop_length = hop_length
|
54 |
+
self.win_length = win_length
|
55 |
+
self.normalized = normalized
|
56 |
+
self.activation = getattr(torch.nn, activation)(**activation_params)
|
57 |
+
self.spec_transform = torchaudio.transforms.Spectrogram(
|
58 |
+
n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
|
59 |
+
normalized=self.normalized, center=False, pad_mode=None, power=None)
|
60 |
+
spec_channels = 2 * self.in_channels
|
61 |
+
self.convs = nn.ModuleList()
|
62 |
+
self.convs.append(
|
63 |
+
NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
|
64 |
+
)
|
65 |
+
in_chs = min(filters_scale * self.filters, max_filters)
|
66 |
+
for i, dilation in enumerate(dilations):
|
67 |
+
out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
|
68 |
+
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
|
69 |
+
dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
|
70 |
+
norm=norm))
|
71 |
+
in_chs = out_chs
|
72 |
+
out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
|
73 |
+
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
|
74 |
+
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
|
75 |
+
norm=norm))
|
76 |
+
self.conv_post = NormConv2d(out_chs, self.out_channels,
|
77 |
+
kernel_size=(kernel_size[0], kernel_size[0]),
|
78 |
+
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
|
79 |
+
norm=norm)
|
80 |
+
|
81 |
+
def forward(self, x: torch.Tensor):
|
82 |
+
fmap = []
|
83 |
+
z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
|
84 |
+
z = torch.cat([z.real, z.imag], dim=1)
|
85 |
+
z = rearrange(z, 'b c w t -> b c t w')
|
86 |
+
for i, layer in enumerate(self.convs):
|
87 |
+
z = layer(z)
|
88 |
+
z = self.activation(z)
|
89 |
+
fmap.append(z)
|
90 |
+
z = self.conv_post(z)
|
91 |
+
return z, fmap
|
92 |
+
|
93 |
+
|
94 |
+
class MultiScaleSTFTDiscriminator(MultiDiscriminator):
|
95 |
+
"""Multi-Scale STFT (MS-STFT) discriminator.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
filters (int): Number of filters in convolutions.
|
99 |
+
in_channels (int): Number of input channels.
|
100 |
+
out_channels (int): Number of output channels.
|
101 |
+
sep_channels (bool): Separate channels to distinct samples for stereo support.
|
102 |
+
n_ffts (Sequence[int]): Size of FFT for each scale.
|
103 |
+
hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale.
|
104 |
+
win_lengths (Sequence[int]): Window size for each scale.
|
105 |
+
**kwargs: Additional args for STFTDiscriminator.
|
106 |
+
"""
|
107 |
+
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False,
|
108 |
+
n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
|
109 |
+
win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
|
110 |
+
super().__init__()
|
111 |
+
assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
|
112 |
+
self.sep_channels = sep_channels
|
113 |
+
self.discriminators = nn.ModuleList([
|
114 |
+
DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
|
115 |
+
n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
|
116 |
+
for i in range(len(n_ffts))
|
117 |
+
])
|
118 |
+
|
119 |
+
@property
|
120 |
+
def num_discriminators(self):
|
121 |
+
return len(self.discriminators)
|
122 |
+
|
123 |
+
def _separate_channels(self, x: torch.Tensor) -> torch.Tensor:
|
124 |
+
B, C, T = x.shape
|
125 |
+
return x.view(-1, 1, T)
|
126 |
+
|
127 |
+
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
|
128 |
+
logits = []
|
129 |
+
fmaps = []
|
130 |
+
for disc in self.discriminators:
|
131 |
+
logit, fmap = disc(x)
|
132 |
+
logits.append(logit)
|
133 |
+
fmaps.append(fmap)
|
134 |
+
return logits, fmaps
|
audiocraft/adversarial/losses.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Utility module to handle adversarial losses without requiring to mess up the main training loop.
|
9 |
+
"""
|
10 |
+
|
11 |
+
import typing as tp
|
12 |
+
|
13 |
+
import flashy
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
|
19 |
+
ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2']
|
20 |
+
|
21 |
+
|
22 |
+
AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]]
|
23 |
+
FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
|
24 |
+
|
25 |
+
|
26 |
+
class AdversarialLoss(nn.Module):
|
27 |
+
"""Adversary training wrapper.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
|
31 |
+
We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
|
32 |
+
where the first item is a list of logits and the second item is a list of feature maps.
|
33 |
+
optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
|
34 |
+
loss (AdvLossType): Loss function for generator training.
|
35 |
+
loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
|
36 |
+
loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
|
37 |
+
loss_feat (FeatLossType): Feature matching loss function for generator training.
|
38 |
+
normalize (bool): Whether to normalize by number of sub-discriminators.
|
39 |
+
|
40 |
+
Example of usage:
|
41 |
+
adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
|
42 |
+
for real in loader:
|
43 |
+
noise = torch.randn(...)
|
44 |
+
fake = model(noise)
|
45 |
+
adv_loss.train_adv(fake, real)
|
46 |
+
loss, _ = adv_loss(fake, real)
|
47 |
+
loss.backward()
|
48 |
+
"""
|
49 |
+
def __init__(self,
|
50 |
+
adversary: nn.Module,
|
51 |
+
optimizer: torch.optim.Optimizer,
|
52 |
+
loss: AdvLossType,
|
53 |
+
loss_real: AdvLossType,
|
54 |
+
loss_fake: AdvLossType,
|
55 |
+
loss_feat: tp.Optional[FeatLossType] = None,
|
56 |
+
normalize: bool = True):
|
57 |
+
super().__init__()
|
58 |
+
self.adversary: nn.Module = adversary
|
59 |
+
flashy.distrib.broadcast_model(self.adversary)
|
60 |
+
self.optimizer = optimizer
|
61 |
+
self.loss = loss
|
62 |
+
self.loss_real = loss_real
|
63 |
+
self.loss_fake = loss_fake
|
64 |
+
self.loss_feat = loss_feat
|
65 |
+
self.normalize = normalize
|
66 |
+
|
67 |
+
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
68 |
+
# Add the optimizer state dict inside our own.
|
69 |
+
super()._save_to_state_dict(destination, prefix, keep_vars)
|
70 |
+
destination[prefix + 'optimizer'] = self.optimizer.state_dict()
|
71 |
+
return destination
|
72 |
+
|
73 |
+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
74 |
+
# Load optimizer state.
|
75 |
+
self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
|
76 |
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
77 |
+
|
78 |
+
def get_adversary_pred(self, x):
|
79 |
+
"""Run adversary model, validating expected output format."""
|
80 |
+
logits, fmaps = self.adversary(x)
|
81 |
+
assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
|
82 |
+
f'Expecting a list of tensors as logits but {type(logits)} found.'
|
83 |
+
assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
|
84 |
+
for fmap in fmaps:
|
85 |
+
assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
|
86 |
+
f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
|
87 |
+
return logits, fmaps
|
88 |
+
|
89 |
+
def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
|
90 |
+
"""Train the adversary with the given fake and real example.
|
91 |
+
|
92 |
+
We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
|
93 |
+
The first item being the logits and second item being a list of feature maps for each sub-discriminator.
|
94 |
+
|
95 |
+
This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
|
96 |
+
and call the optimizer.
|
97 |
+
"""
|
98 |
+
loss = torch.tensor(0., device=fake.device)
|
99 |
+
all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
|
100 |
+
all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
|
101 |
+
n_sub_adversaries = len(all_logits_fake_is_fake)
|
102 |
+
for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
|
103 |
+
loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
|
104 |
+
|
105 |
+
if self.normalize:
|
106 |
+
loss /= n_sub_adversaries
|
107 |
+
|
108 |
+
self.optimizer.zero_grad()
|
109 |
+
with flashy.distrib.eager_sync_model(self.adversary):
|
110 |
+
loss.backward()
|
111 |
+
self.optimizer.step()
|
112 |
+
|
113 |
+
return loss
|
114 |
+
|
115 |
+
def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
116 |
+
"""Return the loss for the generator, i.e. trying to fool the adversary,
|
117 |
+
and feature matching loss if provided.
|
118 |
+
"""
|
119 |
+
adv = torch.tensor(0., device=fake.device)
|
120 |
+
feat = torch.tensor(0., device=fake.device)
|
121 |
+
with flashy.utils.readonly(self.adversary):
|
122 |
+
all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
|
123 |
+
all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
|
124 |
+
n_sub_adversaries = len(all_logits_fake_is_fake)
|
125 |
+
for logit_fake_is_fake in all_logits_fake_is_fake:
|
126 |
+
adv += self.loss(logit_fake_is_fake)
|
127 |
+
if self.loss_feat:
|
128 |
+
for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
|
129 |
+
feat += self.loss_feat(fmap_fake, fmap_real)
|
130 |
+
|
131 |
+
if self.normalize:
|
132 |
+
adv /= n_sub_adversaries
|
133 |
+
feat /= n_sub_adversaries
|
134 |
+
|
135 |
+
return adv, feat
|
136 |
+
|
137 |
+
|
138 |
+
def get_adv_criterion(loss_type: str) -> tp.Callable:
|
139 |
+
assert loss_type in ADVERSARIAL_LOSSES
|
140 |
+
if loss_type == 'mse':
|
141 |
+
return mse_loss
|
142 |
+
elif loss_type == 'hinge':
|
143 |
+
return hinge_loss
|
144 |
+
elif loss_type == 'hinge2':
|
145 |
+
return hinge2_loss
|
146 |
+
raise ValueError('Unsupported loss')
|
147 |
+
|
148 |
+
|
149 |
+
def get_fake_criterion(loss_type: str) -> tp.Callable:
|
150 |
+
assert loss_type in ADVERSARIAL_LOSSES
|
151 |
+
if loss_type == 'mse':
|
152 |
+
return mse_fake_loss
|
153 |
+
elif loss_type in ['hinge', 'hinge2']:
|
154 |
+
return hinge_fake_loss
|
155 |
+
raise ValueError('Unsupported loss')
|
156 |
+
|
157 |
+
|
158 |
+
def get_real_criterion(loss_type: str) -> tp.Callable:
|
159 |
+
assert loss_type in ADVERSARIAL_LOSSES
|
160 |
+
if loss_type == 'mse':
|
161 |
+
return mse_real_loss
|
162 |
+
elif loss_type in ['hinge', 'hinge2']:
|
163 |
+
return hinge_real_loss
|
164 |
+
raise ValueError('Unsupported loss')
|
165 |
+
|
166 |
+
|
167 |
+
def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
|
168 |
+
return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
|
169 |
+
|
170 |
+
|
171 |
+
def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
|
172 |
+
return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))
|
173 |
+
|
174 |
+
|
175 |
+
def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
|
176 |
+
return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
|
177 |
+
|
178 |
+
|
179 |
+
def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
|
180 |
+
return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))
|
181 |
+
|
182 |
+
|
183 |
+
def mse_loss(x: torch.Tensor) -> torch.Tensor:
|
184 |
+
if x.numel() == 0:
|
185 |
+
return torch.tensor([0.0], device=x.device)
|
186 |
+
return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
|
187 |
+
|
188 |
+
|
189 |
+
def hinge_loss(x: torch.Tensor) -> torch.Tensor:
|
190 |
+
if x.numel() == 0:
|
191 |
+
return torch.tensor([0.0], device=x.device)
|
192 |
+
return -x.mean()
|
193 |
+
|
194 |
+
|
195 |
+
def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
|
196 |
+
if x.numel() == 0:
|
197 |
+
return torch.tensor([0.0])
|
198 |
+
return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
|
199 |
+
|
200 |
+
|
201 |
+
class FeatureMatchingLoss(nn.Module):
|
202 |
+
"""Feature matching loss for adversarial training.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
|
206 |
+
normalize (bool): Whether to normalize the loss.
|
207 |
+
by number of feature maps.
|
208 |
+
"""
|
209 |
+
def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
|
210 |
+
super().__init__()
|
211 |
+
self.loss = loss
|
212 |
+
self.normalize = normalize
|
213 |
+
|
214 |
+
def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
|
215 |
+
assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
|
216 |
+
feat_loss = torch.tensor(0., device=fmap_fake[0].device)
|
217 |
+
feat_scale = torch.tensor(0., device=fmap_fake[0].device)
|
218 |
+
n_fmaps = 0
|
219 |
+
for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
|
220 |
+
assert feat_fake.shape == feat_real.shape
|
221 |
+
n_fmaps += 1
|
222 |
+
feat_loss += self.loss(feat_fake, feat_real)
|
223 |
+
feat_scale += torch.mean(torch.abs(feat_real))
|
224 |
+
|
225 |
+
if self.normalize:
|
226 |
+
feat_loss /= n_fmaps
|
227 |
+
|
228 |
+
return feat_loss
|
audiocraft/data/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Audio loading and writing support. Datasets for raw audio
|
7 |
+
or also including some metadata."""
|
8 |
+
|
9 |
+
# flake8: noqa
|
10 |
+
from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset
|
audiocraft/data/audio.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Audio IO methods are defined in this module (info, read, write),
|
9 |
+
We rely on av library for faster read when possible, otherwise on torchaudio.
|
10 |
+
"""
|
11 |
+
|
12 |
+
from dataclasses import dataclass
|
13 |
+
from pathlib import Path
|
14 |
+
import logging
|
15 |
+
import typing as tp
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import soundfile
|
19 |
+
import torch
|
20 |
+
from torch.nn import functional as F
|
21 |
+
import torchaudio as ta
|
22 |
+
|
23 |
+
import av
|
24 |
+
|
25 |
+
from .audio_utils import f32_pcm, i16_pcm, normalize_audio
|
26 |
+
|
27 |
+
|
28 |
+
_av_initialized = False
|
29 |
+
|
30 |
+
|
31 |
+
def _init_av():
|
32 |
+
global _av_initialized
|
33 |
+
if _av_initialized:
|
34 |
+
return
|
35 |
+
logger = logging.getLogger('libav.mp3')
|
36 |
+
logger.setLevel(logging.ERROR)
|
37 |
+
_av_initialized = True
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass(frozen=True)
|
41 |
+
class AudioFileInfo:
|
42 |
+
sample_rate: int
|
43 |
+
duration: float
|
44 |
+
channels: int
|
45 |
+
|
46 |
+
|
47 |
+
def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
|
48 |
+
_init_av()
|
49 |
+
with av.open(str(filepath)) as af:
|
50 |
+
stream = af.streams.audio[0]
|
51 |
+
sample_rate = stream.codec_context.sample_rate
|
52 |
+
duration = float(stream.duration * stream.time_base)
|
53 |
+
channels = stream.channels
|
54 |
+
return AudioFileInfo(sample_rate, duration, channels)
|
55 |
+
|
56 |
+
|
57 |
+
def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
|
58 |
+
info = soundfile.info(filepath)
|
59 |
+
return AudioFileInfo(info.samplerate, info.duration, info.channels)
|
60 |
+
|
61 |
+
|
62 |
+
def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
|
63 |
+
# torchaudio no longer returns useful duration informations for some formats like mp3s.
|
64 |
+
filepath = Path(filepath)
|
65 |
+
if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info
|
66 |
+
# ffmpeg has some weird issue with flac.
|
67 |
+
return _soundfile_info(filepath)
|
68 |
+
else:
|
69 |
+
return _av_info(filepath)
|
70 |
+
|
71 |
+
|
72 |
+
def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
|
73 |
+
"""FFMPEG-based audio file reading using PyAV bindings.
|
74 |
+
Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
filepath (str or Path): Path to audio file to read.
|
78 |
+
seek_time (float): Time at which to start reading in the file.
|
79 |
+
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
80 |
+
Returns:
|
81 |
+
tuple of torch.Tensor, int: Tuple containing audio data and sample rate
|
82 |
+
"""
|
83 |
+
_init_av()
|
84 |
+
with av.open(str(filepath)) as af:
|
85 |
+
stream = af.streams.audio[0]
|
86 |
+
sr = stream.codec_context.sample_rate
|
87 |
+
num_frames = int(sr * duration) if duration >= 0 else -1
|
88 |
+
frame_offset = int(sr * seek_time)
|
89 |
+
# we need a small negative offset otherwise we get some edge artifact
|
90 |
+
# from the mp3 decoder.
|
91 |
+
af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
|
92 |
+
frames = []
|
93 |
+
length = 0
|
94 |
+
for frame in af.decode(streams=stream.index):
|
95 |
+
current_offset = int(frame.rate * frame.pts * frame.time_base)
|
96 |
+
strip = max(0, frame_offset - current_offset)
|
97 |
+
buf = torch.from_numpy(frame.to_ndarray())
|
98 |
+
if buf.shape[0] != stream.channels:
|
99 |
+
buf = buf.view(-1, stream.channels).t()
|
100 |
+
buf = buf[:, strip:]
|
101 |
+
frames.append(buf)
|
102 |
+
length += buf.shape[1]
|
103 |
+
if num_frames > 0 and length >= num_frames:
|
104 |
+
break
|
105 |
+
assert frames
|
106 |
+
# If the above assert fails, it is likely because we seeked past the end of file point,
|
107 |
+
# in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
|
108 |
+
# This will need proper debugging, in due time.
|
109 |
+
wav = torch.cat(frames, dim=1)
|
110 |
+
assert wav.shape[0] == stream.channels
|
111 |
+
if num_frames > 0:
|
112 |
+
wav = wav[:, :num_frames]
|
113 |
+
return f32_pcm(wav), sr
|
114 |
+
|
115 |
+
|
116 |
+
def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
|
117 |
+
duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
|
118 |
+
"""Read audio by picking the most appropriate backend tool based on the audio format.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
filepath (str or Path): Path to audio file to read.
|
122 |
+
seek_time (float): Time at which to start reading in the file.
|
123 |
+
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
124 |
+
pad (bool): Pad output audio if not reaching expected duration.
|
125 |
+
Returns:
|
126 |
+
tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
|
127 |
+
"""
|
128 |
+
fp = Path(filepath)
|
129 |
+
if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
|
130 |
+
# There is some bug with ffmpeg and reading flac
|
131 |
+
info = _soundfile_info(filepath)
|
132 |
+
frames = -1 if duration <= 0 else int(duration * info.sample_rate)
|
133 |
+
frame_offset = int(seek_time * info.sample_rate)
|
134 |
+
wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
|
135 |
+
assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
|
136 |
+
wav = torch.from_numpy(wav).t().contiguous()
|
137 |
+
if len(wav.shape) == 1:
|
138 |
+
wav = torch.unsqueeze(wav, 0)
|
139 |
+
elif (
|
140 |
+
fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
|
141 |
+
and duration <= 0 and seek_time == 0
|
142 |
+
):
|
143 |
+
# Torchaudio is faster if we load an entire file at once.
|
144 |
+
wav, sr = ta.load(fp)
|
145 |
+
else:
|
146 |
+
wav, sr = _av_read(filepath, seek_time, duration)
|
147 |
+
if pad and duration > 0:
|
148 |
+
expected_frames = int(duration * sr)
|
149 |
+
wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
|
150 |
+
return wav, sr
|
151 |
+
|
152 |
+
|
153 |
+
def audio_write(stem_name: tp.Union[str, Path],
|
154 |
+
wav: torch.Tensor, sample_rate: int,
|
155 |
+
format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
|
156 |
+
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
|
157 |
+
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
|
158 |
+
loudness_compressor: bool = False,
|
159 |
+
log_clipping: bool = True, make_parent_dir: bool = True,
|
160 |
+
add_suffix: bool = True) -> Path:
|
161 |
+
"""Convenience function for saving audio to disk. Returns the filename the audio was written to.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
stem_name (str or Path): Filename without extension which will be added automatically.
|
165 |
+
format (str): Either "wav" or "mp3".
|
166 |
+
mp3_rate (int): kbps when using mp3s.
|
167 |
+
normalize (bool): if `True` (default), normalizes according to the prescribed
|
168 |
+
strategy (see after). If `False`, the strategy is only used in case clipping
|
169 |
+
would happen.
|
170 |
+
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
|
171 |
+
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
|
172 |
+
with extra headroom to avoid clipping. 'clip' just clips.
|
173 |
+
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
|
174 |
+
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
|
175 |
+
than the `peak_clip` one to avoid further clipping.
|
176 |
+
loudness_headroom_db (float): Target loudness for loudness normalization.
|
177 |
+
loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
|
178 |
+
when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
|
179 |
+
occurs despite strategy (only for 'rms').
|
180 |
+
make_parent_dir (bool): Make parent directory if it doesn't exist.
|
181 |
+
Returns:
|
182 |
+
Path: Path of the saved audio.
|
183 |
+
"""
|
184 |
+
assert wav.dtype.is_floating_point, "wav is not floating point"
|
185 |
+
if wav.dim() == 1:
|
186 |
+
wav = wav[None]
|
187 |
+
elif wav.dim() > 2:
|
188 |
+
raise ValueError("Input wav should be at most 2 dimension.")
|
189 |
+
assert wav.isfinite().all()
|
190 |
+
wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
|
191 |
+
rms_headroom_db, loudness_headroom_db, loudness_compressor,
|
192 |
+
log_clipping=log_clipping, sample_rate=sample_rate,
|
193 |
+
stem_name=str(stem_name))
|
194 |
+
kwargs: dict = {}
|
195 |
+
if format == 'mp3':
|
196 |
+
suffix = '.mp3'
|
197 |
+
kwargs.update({"compression": mp3_rate})
|
198 |
+
elif format == 'wav':
|
199 |
+
wav = i16_pcm(wav)
|
200 |
+
suffix = '.wav'
|
201 |
+
kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
|
202 |
+
else:
|
203 |
+
raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
|
204 |
+
if not add_suffix:
|
205 |
+
suffix = ''
|
206 |
+
path = Path(str(stem_name) + suffix)
|
207 |
+
if make_parent_dir:
|
208 |
+
path.parent.mkdir(exist_ok=True, parents=True)
|
209 |
+
try:
|
210 |
+
ta.save(path, wav, sample_rate, **kwargs)
|
211 |
+
except Exception:
|
212 |
+
if path.exists():
|
213 |
+
# we do not want to leave half written files around.
|
214 |
+
path.unlink()
|
215 |
+
raise
|
216 |
+
return path
|
audiocraft/data/audio_dataset.py
ADDED
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""AudioDataset support. In order to handle a larger number of files
|
7 |
+
without having to scan again the folders, we precompute some metadata
|
8 |
+
(filename, sample rate, duration), and use that to efficiently sample audio segments.
|
9 |
+
"""
|
10 |
+
import argparse
|
11 |
+
import copy
|
12 |
+
from concurrent.futures import ThreadPoolExecutor, Future
|
13 |
+
from dataclasses import dataclass, fields
|
14 |
+
from contextlib import ExitStack
|
15 |
+
from functools import lru_cache
|
16 |
+
import gzip
|
17 |
+
import json
|
18 |
+
import logging
|
19 |
+
import os
|
20 |
+
from pathlib import Path
|
21 |
+
import random
|
22 |
+
import sys
|
23 |
+
import typing as tp
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import torch.nn.functional as F
|
27 |
+
|
28 |
+
from .audio import audio_read, audio_info
|
29 |
+
from .audio_utils import convert_audio
|
30 |
+
from .zip import PathInZip
|
31 |
+
|
32 |
+
try:
|
33 |
+
import dora
|
34 |
+
except ImportError:
|
35 |
+
dora = None # type: ignore
|
36 |
+
|
37 |
+
|
38 |
+
@dataclass(order=True)
|
39 |
+
class BaseInfo:
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def _dict2fields(cls, dictionary: dict):
|
43 |
+
return {
|
44 |
+
field.name: dictionary[field.name]
|
45 |
+
for field in fields(cls) if field.name in dictionary
|
46 |
+
}
|
47 |
+
|
48 |
+
@classmethod
|
49 |
+
def from_dict(cls, dictionary: dict):
|
50 |
+
_dictionary = cls._dict2fields(dictionary)
|
51 |
+
return cls(**_dictionary)
|
52 |
+
|
53 |
+
def to_dict(self):
|
54 |
+
return {
|
55 |
+
field.name: self.__getattribute__(field.name)
|
56 |
+
for field in fields(self)
|
57 |
+
}
|
58 |
+
|
59 |
+
|
60 |
+
@dataclass(order=True)
|
61 |
+
class AudioMeta(BaseInfo):
|
62 |
+
path: str
|
63 |
+
duration: float
|
64 |
+
sample_rate: int
|
65 |
+
amplitude: tp.Optional[float] = None
|
66 |
+
weight: tp.Optional[float] = None
|
67 |
+
# info_path is used to load additional information about the audio file that is stored in zip files.
|
68 |
+
info_path: tp.Optional[PathInZip] = None
|
69 |
+
|
70 |
+
@classmethod
|
71 |
+
def from_dict(cls, dictionary: dict):
|
72 |
+
base = cls._dict2fields(dictionary)
|
73 |
+
if 'info_path' in base and base['info_path'] is not None:
|
74 |
+
base['info_path'] = PathInZip(base['info_path'])
|
75 |
+
return cls(**base)
|
76 |
+
|
77 |
+
def to_dict(self):
|
78 |
+
d = super().to_dict()
|
79 |
+
if d['info_path'] is not None:
|
80 |
+
d['info_path'] = str(d['info_path'])
|
81 |
+
return d
|
82 |
+
|
83 |
+
|
84 |
+
@dataclass(order=True)
|
85 |
+
class SegmentInfo(BaseInfo):
|
86 |
+
meta: AudioMeta
|
87 |
+
seek_time: float
|
88 |
+
# The following values are given once the audio is processed, e.g.
|
89 |
+
# at the target sample rate and target number of channels.
|
90 |
+
n_frames: int # actual number of frames without padding
|
91 |
+
total_frames: int # total number of frames, padding included
|
92 |
+
sample_rate: int # actual sample rate
|
93 |
+
channels: int # number of audio channels.
|
94 |
+
|
95 |
+
|
96 |
+
DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
|
97 |
+
|
98 |
+
logger = logging.getLogger(__name__)
|
99 |
+
|
100 |
+
|
101 |
+
def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
|
102 |
+
"""AudioMeta from a path to an audio file.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
file_path (str): Resolved path of valid audio file.
|
106 |
+
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
|
107 |
+
Returns:
|
108 |
+
AudioMeta: Audio file path and its metadata.
|
109 |
+
"""
|
110 |
+
info = audio_info(file_path)
|
111 |
+
amplitude: tp.Optional[float] = None
|
112 |
+
if not minimal:
|
113 |
+
wav, sr = audio_read(file_path)
|
114 |
+
amplitude = wav.abs().max().item()
|
115 |
+
return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
|
116 |
+
|
117 |
+
|
118 |
+
def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
|
119 |
+
"""If Dora is available as a dependency, try to resolve potential relative paths
|
120 |
+
in list of AudioMeta. This method is expected to be used when loading meta from file.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
m (AudioMeta): Audio meta to resolve.
|
124 |
+
fast (bool): If True, uses a really fast check for determining if a file
|
125 |
+
is already absolute or not. Only valid on Linux/Mac.
|
126 |
+
Returns:
|
127 |
+
AudioMeta: Audio meta with resolved path.
|
128 |
+
"""
|
129 |
+
def is_abs(m):
|
130 |
+
if fast:
|
131 |
+
return str(m)[0] == '/'
|
132 |
+
else:
|
133 |
+
os.path.isabs(str(m))
|
134 |
+
|
135 |
+
if not dora:
|
136 |
+
return m
|
137 |
+
|
138 |
+
if not is_abs(m.path):
|
139 |
+
m.path = dora.git_save.to_absolute_path(m.path)
|
140 |
+
if m.info_path is not None and not is_abs(m.info_path.zip_path):
|
141 |
+
m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
|
142 |
+
return m
|
143 |
+
|
144 |
+
|
145 |
+
def find_audio_files(path: tp.Union[Path, str],
|
146 |
+
exts: tp.List[str] = DEFAULT_EXTS,
|
147 |
+
resolve: bool = True,
|
148 |
+
minimal: bool = True,
|
149 |
+
progress: bool = False,
|
150 |
+
workers: int = 0) -> tp.List[AudioMeta]:
|
151 |
+
"""Build a list of AudioMeta from a given path,
|
152 |
+
collecting relevant audio files and fetching meta info.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
path (str or Path): Path to folder containing audio files.
|
156 |
+
exts (list of str): List of file extensions to consider for audio files.
|
157 |
+
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
|
158 |
+
progress (bool): Whether to log progress on audio files collection.
|
159 |
+
workers (int): number of parallel workers, if 0, use only the current thread.
|
160 |
+
Returns:
|
161 |
+
list of AudioMeta: List of audio file path and its metadata.
|
162 |
+
"""
|
163 |
+
audio_files = []
|
164 |
+
futures: tp.List[Future] = []
|
165 |
+
pool: tp.Optional[ThreadPoolExecutor] = None
|
166 |
+
with ExitStack() as stack:
|
167 |
+
if workers > 0:
|
168 |
+
pool = ThreadPoolExecutor(workers)
|
169 |
+
stack.enter_context(pool)
|
170 |
+
|
171 |
+
if progress:
|
172 |
+
print("Finding audio files...")
|
173 |
+
for root, folders, files in os.walk(path, followlinks=True):
|
174 |
+
for file in files:
|
175 |
+
full_path = Path(root) / file
|
176 |
+
if full_path.suffix.lower() in exts:
|
177 |
+
audio_files.append(full_path)
|
178 |
+
if pool is not None:
|
179 |
+
futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
|
180 |
+
if progress:
|
181 |
+
print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
|
182 |
+
|
183 |
+
if progress:
|
184 |
+
print("Getting audio metadata...")
|
185 |
+
meta: tp.List[AudioMeta] = []
|
186 |
+
for idx, file_path in enumerate(audio_files):
|
187 |
+
try:
|
188 |
+
if pool is None:
|
189 |
+
m = _get_audio_meta(str(file_path), minimal)
|
190 |
+
else:
|
191 |
+
m = futures[idx].result()
|
192 |
+
if resolve:
|
193 |
+
m = _resolve_audio_meta(m)
|
194 |
+
except Exception as err:
|
195 |
+
print("Error with", str(file_path), err, file=sys.stderr)
|
196 |
+
continue
|
197 |
+
meta.append(m)
|
198 |
+
if progress:
|
199 |
+
print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
|
200 |
+
meta.sort()
|
201 |
+
return meta
|
202 |
+
|
203 |
+
|
204 |
+
def load_audio_meta(path: tp.Union[str, Path],
|
205 |
+
resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
|
206 |
+
"""Load list of AudioMeta from an optionally compressed json file.
|
207 |
+
|
208 |
+
Args:
|
209 |
+
path (str or Path): Path to JSON file.
|
210 |
+
resolve (bool): Whether to resolve the path from AudioMeta (default=True).
|
211 |
+
fast (bool): activates some tricks to make things faster.
|
212 |
+
Returns:
|
213 |
+
list of AudioMeta: List of audio file path and its total duration.
|
214 |
+
"""
|
215 |
+
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
|
216 |
+
with open_fn(path, 'rb') as fp: # type: ignore
|
217 |
+
lines = fp.readlines()
|
218 |
+
meta = []
|
219 |
+
for line in lines:
|
220 |
+
d = json.loads(line)
|
221 |
+
m = AudioMeta.from_dict(d)
|
222 |
+
if resolve:
|
223 |
+
m = _resolve_audio_meta(m, fast=fast)
|
224 |
+
meta.append(m)
|
225 |
+
return meta
|
226 |
+
|
227 |
+
|
228 |
+
def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
|
229 |
+
"""Save the audio metadata to the file pointer as json.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
path (str or Path): Path to JSON file.
|
233 |
+
metadata (list of BaseAudioMeta): List of audio meta to save.
|
234 |
+
"""
|
235 |
+
Path(path).parent.mkdir(exist_ok=True, parents=True)
|
236 |
+
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
|
237 |
+
with open_fn(path, 'wb') as fp: # type: ignore
|
238 |
+
for m in meta:
|
239 |
+
json_str = json.dumps(m.to_dict()) + '\n'
|
240 |
+
json_bytes = json_str.encode('utf-8')
|
241 |
+
fp.write(json_bytes)
|
242 |
+
|
243 |
+
|
244 |
+
class AudioDataset:
|
245 |
+
"""Base audio dataset.
|
246 |
+
|
247 |
+
The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
|
248 |
+
and potentially additional information, by creating random segments from the list of audio
|
249 |
+
files referenced in the metadata and applying minimal data pre-processing such as resampling,
|
250 |
+
mixing of channels, padding, etc.
|
251 |
+
|
252 |
+
If no segment_duration value is provided, the AudioDataset will return the full wav for each
|
253 |
+
audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
|
254 |
+
duration, applying padding if required.
|
255 |
+
|
256 |
+
By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
|
257 |
+
allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
|
258 |
+
original audio meta.
|
259 |
+
|
260 |
+
Note that you can call `start_epoch(epoch)` in order to get
|
261 |
+
a deterministic "randomization" for `shuffle=True`.
|
262 |
+
For a given epoch and dataset index, this will always return the same extract.
|
263 |
+
You can get back some diversity by setting the `shuffle_seed` param.
|
264 |
+
|
265 |
+
Args:
|
266 |
+
meta (list of AudioMeta): List of audio files metadata.
|
267 |
+
segment_duration (float, optional): Optional segment duration of audio to load.
|
268 |
+
If not specified, the dataset will load the full audio segment from the file.
|
269 |
+
shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
|
270 |
+
sample_rate (int): Target sample rate of the loaded audio samples.
|
271 |
+
channels (int): Target number of channels of the loaded audio samples.
|
272 |
+
sample_on_duration (bool): Set to `True` to sample segments with probability
|
273 |
+
dependent on audio file duration. This is only used if `segment_duration` is provided.
|
274 |
+
sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
|
275 |
+
`AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
|
276 |
+
of the file duration and file weight. This is only used if `segment_duration` is provided.
|
277 |
+
min_segment_ratio (float): Minimum segment ratio to use when the audio file
|
278 |
+
is shorter than the desired segment.
|
279 |
+
max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
|
280 |
+
return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
|
281 |
+
min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
|
282 |
+
audio shorter than this will be filtered out.
|
283 |
+
max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
|
284 |
+
audio longer than this will be filtered out.
|
285 |
+
shuffle_seed (int): can be used to further randomize
|
286 |
+
load_wav (bool): if False, skip loading the wav but returns a tensor of 0
|
287 |
+
with the expected segment_duration (which must be provided if load_wav is False).
|
288 |
+
permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
|
289 |
+
are False. Will ensure a permutation on files when going through the dataset.
|
290 |
+
In that case the epoch number must be provided in order for the model
|
291 |
+
to continue the permutation across epochs. In that case, it is assumed
|
292 |
+
that `num_samples = total_batch_size * num_updates_per_epoch`, with
|
293 |
+
`total_batch_size` the overall batch size accounting for all gpus.
|
294 |
+
"""
|
295 |
+
def __init__(self,
|
296 |
+
meta: tp.List[AudioMeta],
|
297 |
+
segment_duration: tp.Optional[float] = None,
|
298 |
+
shuffle: bool = True,
|
299 |
+
num_samples: int = 10_000,
|
300 |
+
sample_rate: int = 48_000,
|
301 |
+
channels: int = 2,
|
302 |
+
pad: bool = True,
|
303 |
+
sample_on_duration: bool = True,
|
304 |
+
sample_on_weight: bool = True,
|
305 |
+
min_segment_ratio: float = 0.5,
|
306 |
+
max_read_retry: int = 10,
|
307 |
+
return_info: bool = False,
|
308 |
+
min_audio_duration: tp.Optional[float] = None,
|
309 |
+
max_audio_duration: tp.Optional[float] = None,
|
310 |
+
shuffle_seed: int = 0,
|
311 |
+
load_wav: bool = True,
|
312 |
+
permutation_on_files: bool = False,
|
313 |
+
):
|
314 |
+
assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
|
315 |
+
assert segment_duration is None or segment_duration > 0
|
316 |
+
assert segment_duration is None or min_segment_ratio >= 0
|
317 |
+
self.segment_duration = segment_duration
|
318 |
+
self.min_segment_ratio = min_segment_ratio
|
319 |
+
self.max_audio_duration = max_audio_duration
|
320 |
+
self.min_audio_duration = min_audio_duration
|
321 |
+
if self.min_audio_duration is not None and self.max_audio_duration is not None:
|
322 |
+
assert self.min_audio_duration <= self.max_audio_duration
|
323 |
+
self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
|
324 |
+
assert len(self.meta) # Fail fast if all data has been filtered.
|
325 |
+
self.total_duration = sum(d.duration for d in self.meta)
|
326 |
+
|
327 |
+
if segment_duration is None:
|
328 |
+
num_samples = len(self.meta)
|
329 |
+
self.num_samples = num_samples
|
330 |
+
self.shuffle = shuffle
|
331 |
+
self.sample_rate = sample_rate
|
332 |
+
self.channels = channels
|
333 |
+
self.pad = pad
|
334 |
+
self.sample_on_weight = sample_on_weight
|
335 |
+
self.sample_on_duration = sample_on_duration
|
336 |
+
self.sampling_probabilities = self._get_sampling_probabilities()
|
337 |
+
self.max_read_retry = max_read_retry
|
338 |
+
self.return_info = return_info
|
339 |
+
self.shuffle_seed = shuffle_seed
|
340 |
+
self.current_epoch: tp.Optional[int] = None
|
341 |
+
self.load_wav = load_wav
|
342 |
+
if not load_wav:
|
343 |
+
assert segment_duration is not None
|
344 |
+
self.permutation_on_files = permutation_on_files
|
345 |
+
if permutation_on_files:
|
346 |
+
assert not self.sample_on_duration
|
347 |
+
assert not self.sample_on_weight
|
348 |
+
assert self.shuffle
|
349 |
+
|
350 |
+
def start_epoch(self, epoch: int):
|
351 |
+
self.current_epoch = epoch
|
352 |
+
|
353 |
+
def __len__(self):
|
354 |
+
return self.num_samples
|
355 |
+
|
356 |
+
def _get_sampling_probabilities(self, normalized: bool = True):
|
357 |
+
"""Return the sampling probabilities for each file inside `self.meta`."""
|
358 |
+
scores: tp.List[float] = []
|
359 |
+
for file_meta in self.meta:
|
360 |
+
score = 1.
|
361 |
+
if self.sample_on_weight and file_meta.weight is not None:
|
362 |
+
score *= file_meta.weight
|
363 |
+
if self.sample_on_duration:
|
364 |
+
score *= file_meta.duration
|
365 |
+
scores.append(score)
|
366 |
+
probabilities = torch.tensor(scores)
|
367 |
+
if normalized:
|
368 |
+
probabilities /= probabilities.sum()
|
369 |
+
return probabilities
|
370 |
+
|
371 |
+
@staticmethod
|
372 |
+
@lru_cache(16)
|
373 |
+
def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
|
374 |
+
# Used to keep the most recent files permutation in memory implicitely.
|
375 |
+
# will work unless someone is using a lot of Datasets in parallel.
|
376 |
+
rng = torch.Generator()
|
377 |
+
rng.manual_seed(base_seed + permutation_index)
|
378 |
+
return torch.randperm(num_files, generator=rng)
|
379 |
+
|
380 |
+
def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
|
381 |
+
"""Sample a given file from `self.meta`. Can be overridden in subclasses.
|
382 |
+
This is only called if `segment_duration` is not None.
|
383 |
+
|
384 |
+
You must use the provided random number generator `rng` for reproducibility.
|
385 |
+
You can further make use of the index accessed.
|
386 |
+
"""
|
387 |
+
if self.permutation_on_files:
|
388 |
+
assert self.current_epoch is not None
|
389 |
+
total_index = self.current_epoch * len(self) + index
|
390 |
+
permutation_index = total_index // len(self.meta)
|
391 |
+
relative_index = total_index % len(self.meta)
|
392 |
+
permutation = AudioDataset._get_file_permutation(
|
393 |
+
len(self.meta), permutation_index, self.shuffle_seed)
|
394 |
+
file_index = permutation[relative_index]
|
395 |
+
return self.meta[file_index]
|
396 |
+
|
397 |
+
if not self.sample_on_weight and not self.sample_on_duration:
|
398 |
+
file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
|
399 |
+
else:
|
400 |
+
file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
|
401 |
+
|
402 |
+
return self.meta[file_index]
|
403 |
+
|
404 |
+
def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
|
405 |
+
# Override this method in subclass if needed.
|
406 |
+
if self.load_wav:
|
407 |
+
return audio_read(path, seek_time, duration, pad=False)
|
408 |
+
else:
|
409 |
+
assert self.segment_duration is not None
|
410 |
+
n_frames = int(self.sample_rate * self.segment_duration)
|
411 |
+
return torch.zeros(self.channels, n_frames), self.sample_rate
|
412 |
+
|
413 |
+
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
|
414 |
+
if self.segment_duration is None:
|
415 |
+
file_meta = self.meta[index]
|
416 |
+
out, sr = audio_read(file_meta.path)
|
417 |
+
out = convert_audio(out, sr, self.sample_rate, self.channels)
|
418 |
+
n_frames = out.shape[-1]
|
419 |
+
segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
|
420 |
+
sample_rate=self.sample_rate, channels=out.shape[0])
|
421 |
+
else:
|
422 |
+
rng = torch.Generator()
|
423 |
+
if self.shuffle:
|
424 |
+
# We use index, plus extra randomness, either totally random if we don't know the epoch.
|
425 |
+
# otherwise we make use of the epoch number and optional shuffle_seed.
|
426 |
+
if self.current_epoch is None:
|
427 |
+
rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
|
428 |
+
else:
|
429 |
+
rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
|
430 |
+
else:
|
431 |
+
# We only use index
|
432 |
+
rng.manual_seed(index)
|
433 |
+
|
434 |
+
for retry in range(self.max_read_retry):
|
435 |
+
file_meta = self.sample_file(index, rng)
|
436 |
+
# We add some variance in the file position even if audio file is smaller than segment
|
437 |
+
# without ending up with empty segments
|
438 |
+
max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
|
439 |
+
seek_time = torch.rand(1, generator=rng).item() * max_seek
|
440 |
+
try:
|
441 |
+
out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
|
442 |
+
out = convert_audio(out, sr, self.sample_rate, self.channels)
|
443 |
+
n_frames = out.shape[-1]
|
444 |
+
target_frames = int(self.segment_duration * self.sample_rate)
|
445 |
+
if self.pad:
|
446 |
+
out = F.pad(out, (0, target_frames - n_frames))
|
447 |
+
segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
|
448 |
+
sample_rate=self.sample_rate, channels=out.shape[0])
|
449 |
+
except Exception as exc:
|
450 |
+
logger.warning("Error opening file %s: %r", file_meta.path, exc)
|
451 |
+
if retry == self.max_read_retry - 1:
|
452 |
+
raise
|
453 |
+
else:
|
454 |
+
break
|
455 |
+
|
456 |
+
if self.return_info:
|
457 |
+
# Returns the wav and additional information on the wave segment
|
458 |
+
return out, segment_info
|
459 |
+
else:
|
460 |
+
return out
|
461 |
+
|
462 |
+
def collater(self, samples):
|
463 |
+
"""The collater function has to be provided to the dataloader
|
464 |
+
if AudioDataset has return_info=True in order to properly collate
|
465 |
+
the samples of a batch.
|
466 |
+
"""
|
467 |
+
if self.segment_duration is None and len(samples) > 1:
|
468 |
+
assert self.pad, "Must allow padding when batching examples of different durations."
|
469 |
+
|
470 |
+
# In this case the audio reaching the collater is of variable length as segment_duration=None.
|
471 |
+
to_pad = self.segment_duration is None and self.pad
|
472 |
+
if to_pad:
|
473 |
+
max_len = max([wav.shape[-1] for wav, _ in samples])
|
474 |
+
|
475 |
+
def _pad_wav(wav):
|
476 |
+
return F.pad(wav, (0, max_len - wav.shape[-1]))
|
477 |
+
|
478 |
+
if self.return_info:
|
479 |
+
if len(samples) > 0:
|
480 |
+
assert len(samples[0]) == 2
|
481 |
+
assert isinstance(samples[0][0], torch.Tensor)
|
482 |
+
assert isinstance(samples[0][1], SegmentInfo)
|
483 |
+
|
484 |
+
wavs = [wav for wav, _ in samples]
|
485 |
+
segment_infos = [copy.deepcopy(info) for _, info in samples]
|
486 |
+
|
487 |
+
if to_pad:
|
488 |
+
# Each wav could be of a different duration as they are not segmented.
|
489 |
+
for i in range(len(samples)):
|
490 |
+
# Determines the total length of the signal with padding, so we update here as we pad.
|
491 |
+
segment_infos[i].total_frames = max_len
|
492 |
+
wavs[i] = _pad_wav(wavs[i])
|
493 |
+
|
494 |
+
wav = torch.stack(wavs)
|
495 |
+
return wav, segment_infos
|
496 |
+
else:
|
497 |
+
assert isinstance(samples[0], torch.Tensor)
|
498 |
+
if to_pad:
|
499 |
+
samples = [_pad_wav(s) for s in samples]
|
500 |
+
return torch.stack(samples)
|
501 |
+
|
502 |
+
def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
|
503 |
+
"""Filters out audio files with audio durations that will not allow to sample examples from them."""
|
504 |
+
orig_len = len(meta)
|
505 |
+
|
506 |
+
# Filter data that is too short.
|
507 |
+
if self.min_audio_duration is not None:
|
508 |
+
meta = [m for m in meta if m.duration >= self.min_audio_duration]
|
509 |
+
|
510 |
+
# Filter data that is too long.
|
511 |
+
if self.max_audio_duration is not None:
|
512 |
+
meta = [m for m in meta if m.duration <= self.max_audio_duration]
|
513 |
+
|
514 |
+
filtered_len = len(meta)
|
515 |
+
removed_percentage = 100*(1-float(filtered_len)/orig_len)
|
516 |
+
msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
|
517 |
+
if removed_percentage < 10:
|
518 |
+
logging.debug(msg)
|
519 |
+
else:
|
520 |
+
logging.warning(msg)
|
521 |
+
return meta
|
522 |
+
|
523 |
+
@classmethod
|
524 |
+
def from_meta(cls, root: tp.Union[str, Path], **kwargs):
|
525 |
+
"""Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
|
526 |
+
|
527 |
+
Args:
|
528 |
+
root (str or Path): Path to root folder containing audio files.
|
529 |
+
kwargs: Additional keyword arguments for the AudioDataset.
|
530 |
+
"""
|
531 |
+
root = Path(root)
|
532 |
+
if root.is_dir():
|
533 |
+
if (root / 'data.jsonl').exists():
|
534 |
+
root = root / 'data.jsonl'
|
535 |
+
elif (root / 'data.jsonl.gz').exists():
|
536 |
+
root = root / 'data.jsonl.gz'
|
537 |
+
else:
|
538 |
+
raise ValueError("Don't know where to read metadata from in the dir. "
|
539 |
+
"Expecting either a data.jsonl or data.jsonl.gz file but none found.")
|
540 |
+
meta = load_audio_meta(root)
|
541 |
+
return cls(meta, **kwargs)
|
542 |
+
|
543 |
+
@classmethod
|
544 |
+
def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
|
545 |
+
exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
|
546 |
+
"""Instantiate AudioDataset from a path containing (possibly nested) audio files.
|
547 |
+
|
548 |
+
Args:
|
549 |
+
root (str or Path): Path to root folder containing audio files.
|
550 |
+
minimal_meta (bool): Whether to only load minimal metadata or not.
|
551 |
+
exts (list of str): Extensions for audio files.
|
552 |
+
kwargs: Additional keyword arguments for the AudioDataset.
|
553 |
+
"""
|
554 |
+
root = Path(root)
|
555 |
+
if root.is_file():
|
556 |
+
meta = load_audio_meta(root, resolve=True)
|
557 |
+
else:
|
558 |
+
meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
|
559 |
+
return cls(meta, **kwargs)
|
560 |
+
|
561 |
+
|
562 |
+
def main():
|
563 |
+
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
|
564 |
+
parser = argparse.ArgumentParser(
|
565 |
+
prog='audio_dataset',
|
566 |
+
description='Generate .jsonl files by scanning a folder.')
|
567 |
+
parser.add_argument('root', help='Root folder with all the audio files')
|
568 |
+
parser.add_argument('output_meta_file',
|
569 |
+
help='Output file to store the metadata, ')
|
570 |
+
parser.add_argument('--complete',
|
571 |
+
action='store_false', dest='minimal', default=True,
|
572 |
+
help='Retrieve all metadata, even the one that are expansive '
|
573 |
+
'to compute (e.g. normalization).')
|
574 |
+
parser.add_argument('--resolve',
|
575 |
+
action='store_true', default=False,
|
576 |
+
help='Resolve the paths to be absolute and with no symlinks.')
|
577 |
+
parser.add_argument('--workers',
|
578 |
+
default=10, type=int,
|
579 |
+
help='Number of workers.')
|
580 |
+
args = parser.parse_args()
|
581 |
+
meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
|
582 |
+
resolve=args.resolve, minimal=args.minimal, workers=args.workers)
|
583 |
+
save_audio_meta(args.output_meta_file, meta)
|
584 |
+
|
585 |
+
|
586 |
+
if __name__ == '__main__':
|
587 |
+
main()
|
audiocraft/data/audio_utils.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Various utilities for audio convertion (pcm format, sample rate and channels),
|
7 |
+
and volume normalization."""
|
8 |
+
import sys
|
9 |
+
import typing as tp
|
10 |
+
|
11 |
+
import julius
|
12 |
+
import torch
|
13 |
+
import torchaudio
|
14 |
+
|
15 |
+
|
16 |
+
def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
|
17 |
+
"""Convert audio to the given number of channels.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
wav (torch.Tensor): Audio wave of shape [B, C, T].
|
21 |
+
channels (int): Expected number of channels as output.
|
22 |
+
Returns:
|
23 |
+
torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
|
24 |
+
"""
|
25 |
+
*shape, src_channels, length = wav.shape
|
26 |
+
if src_channels == channels:
|
27 |
+
pass
|
28 |
+
elif channels == 1:
|
29 |
+
# Case 1:
|
30 |
+
# The caller asked 1-channel audio, and the stream has multiple
|
31 |
+
# channels, downmix all channels.
|
32 |
+
wav = wav.mean(dim=-2, keepdim=True)
|
33 |
+
elif src_channels == 1:
|
34 |
+
# Case 2:
|
35 |
+
# The caller asked for multiple channels, but the input file has
|
36 |
+
# a single channel, replicate the audio over all channels.
|
37 |
+
wav = wav.expand(*shape, channels, length)
|
38 |
+
elif src_channels >= channels:
|
39 |
+
# Case 3:
|
40 |
+
# The caller asked for multiple channels, and the input file has
|
41 |
+
# more channels than requested. In that case return the first channels.
|
42 |
+
wav = wav[..., :channels, :]
|
43 |
+
else:
|
44 |
+
# Case 4: What is a reasonable choice here?
|
45 |
+
raise ValueError('The audio file has less channels than requested but is not mono.')
|
46 |
+
return wav
|
47 |
+
|
48 |
+
|
49 |
+
def convert_audio(wav: torch.Tensor, from_rate: float,
|
50 |
+
to_rate: float, to_channels: int) -> torch.Tensor:
|
51 |
+
"""Convert audio to new sample rate and number of audio channels."""
|
52 |
+
wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
|
53 |
+
wav = convert_audio_channels(wav, to_channels)
|
54 |
+
return wav
|
55 |
+
|
56 |
+
|
57 |
+
def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
|
58 |
+
loudness_compressor: bool = False, energy_floor: float = 2e-3):
|
59 |
+
"""Normalize an input signal to a user loudness in dB LKFS.
|
60 |
+
Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
wav (torch.Tensor): Input multichannel audio data.
|
64 |
+
sample_rate (int): Sample rate.
|
65 |
+
loudness_headroom_db (float): Target loudness of the output in dB LUFS.
|
66 |
+
loudness_compressor (bool): Uses tanh for soft clipping.
|
67 |
+
energy_floor (float): anything below that RMS level will not be rescaled.
|
68 |
+
Returns:
|
69 |
+
torch.Tensor: Loudness normalized output data.
|
70 |
+
"""
|
71 |
+
energy = wav.pow(2).mean().sqrt().item()
|
72 |
+
if energy < energy_floor:
|
73 |
+
return wav
|
74 |
+
transform = torchaudio.transforms.Loudness(sample_rate)
|
75 |
+
input_loudness_db = transform(wav).item()
|
76 |
+
# calculate the gain needed to scale to the desired loudness level
|
77 |
+
delta_loudness = -loudness_headroom_db - input_loudness_db
|
78 |
+
gain = 10.0 ** (delta_loudness / 20.0)
|
79 |
+
output = gain * wav
|
80 |
+
if loudness_compressor:
|
81 |
+
output = torch.tanh(output)
|
82 |
+
assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
|
83 |
+
return output
|
84 |
+
|
85 |
+
|
86 |
+
def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
|
87 |
+
"""Utility function to clip the audio with logging if specified."""
|
88 |
+
max_scale = wav.abs().max()
|
89 |
+
if log_clipping and max_scale > 1:
|
90 |
+
clamp_prob = (wav.abs() > 1).float().mean().item()
|
91 |
+
print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
|
92 |
+
clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
|
93 |
+
wav.clamp_(-1, 1)
|
94 |
+
|
95 |
+
|
96 |
+
def normalize_audio(wav: torch.Tensor, normalize: bool = True,
|
97 |
+
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
|
98 |
+
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
|
99 |
+
loudness_compressor: bool = False, log_clipping: bool = False,
|
100 |
+
sample_rate: tp.Optional[int] = None,
|
101 |
+
stem_name: tp.Optional[str] = None) -> torch.Tensor:
|
102 |
+
"""Normalize the audio according to the prescribed strategy (see after).
|
103 |
+
|
104 |
+
Args:
|
105 |
+
wav (torch.Tensor): Audio data.
|
106 |
+
normalize (bool): if `True` (default), normalizes according to the prescribed
|
107 |
+
strategy (see after). If `False`, the strategy is only used in case clipping
|
108 |
+
would happen.
|
109 |
+
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
|
110 |
+
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
|
111 |
+
with extra headroom to avoid clipping. 'clip' just clips.
|
112 |
+
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
|
113 |
+
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
|
114 |
+
than the `peak_clip` one to avoid further clipping.
|
115 |
+
loudness_headroom_db (float): Target loudness for loudness normalization.
|
116 |
+
loudness_compressor (bool): If True, uses tanh based soft clipping.
|
117 |
+
log_clipping (bool): If True, basic logging on stderr when clipping still
|
118 |
+
occurs despite strategy (only for 'rms').
|
119 |
+
sample_rate (int): Sample rate for the audio data (required for loudness).
|
120 |
+
stem_name (str, optional): Stem name for clipping logging.
|
121 |
+
Returns:
|
122 |
+
torch.Tensor: Normalized audio.
|
123 |
+
"""
|
124 |
+
scale_peak = 10 ** (-peak_clip_headroom_db / 20)
|
125 |
+
scale_rms = 10 ** (-rms_headroom_db / 20)
|
126 |
+
if strategy == 'peak':
|
127 |
+
rescaling = (scale_peak / wav.abs().max())
|
128 |
+
if normalize or rescaling < 1:
|
129 |
+
wav = wav * rescaling
|
130 |
+
elif strategy == 'clip':
|
131 |
+
wav = wav.clamp(-scale_peak, scale_peak)
|
132 |
+
elif strategy == 'rms':
|
133 |
+
mono = wav.mean(dim=0)
|
134 |
+
rescaling = scale_rms / mono.pow(2).mean().sqrt()
|
135 |
+
if normalize or rescaling < 1:
|
136 |
+
wav = wav * rescaling
|
137 |
+
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
|
138 |
+
elif strategy == 'loudness':
|
139 |
+
assert sample_rate is not None, "Loudness normalization requires sample rate."
|
140 |
+
wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
|
141 |
+
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
|
142 |
+
else:
|
143 |
+
assert wav.abs().max() < 1
|
144 |
+
assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
|
145 |
+
return wav
|
146 |
+
|
147 |
+
|
148 |
+
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
149 |
+
"""Convert audio to float 32 bits PCM format.
|
150 |
+
"""
|
151 |
+
if wav.dtype.is_floating_point:
|
152 |
+
return wav
|
153 |
+
elif wav.dtype == torch.int16:
|
154 |
+
return wav.float() / 2**15
|
155 |
+
elif wav.dtype == torch.int32:
|
156 |
+
return wav.float() / 2**31
|
157 |
+
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
158 |
+
|
159 |
+
|
160 |
+
def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
|
161 |
+
"""Convert audio to int 16 bits PCM format.
|
162 |
+
|
163 |
+
..Warning:: There exist many formula for doing this conversion. None are perfect
|
164 |
+
due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
|
165 |
+
or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
|
166 |
+
it is possible that `i16_pcm(f32_pcm)) != Identity`.
|
167 |
+
"""
|
168 |
+
if wav.dtype.is_floating_point:
|
169 |
+
assert wav.abs().max() <= 1
|
170 |
+
candidate = (wav * 2 ** 15).round()
|
171 |
+
if candidate.max() >= 2 ** 15: # clipping would occur
|
172 |
+
candidate = (wav * (2 ** 15 - 1)).round()
|
173 |
+
return candidate.short()
|
174 |
+
else:
|
175 |
+
assert wav.dtype == torch.int16
|
176 |
+
return wav
|
audiocraft/data/info_audio_dataset.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Base classes for the datasets that also provide non-audio metadata,
|
7 |
+
e.g. description, text transcription etc.
|
8 |
+
"""
|
9 |
+
from dataclasses import dataclass
|
10 |
+
import logging
|
11 |
+
import math
|
12 |
+
import re
|
13 |
+
import typing as tp
|
14 |
+
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from .audio_dataset import AudioDataset, AudioMeta
|
18 |
+
from ..environment import AudioCraftEnvironment
|
19 |
+
from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
|
26 |
+
"""Monkey-patch meta to match cluster specificities."""
|
27 |
+
meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
|
28 |
+
if meta.info_path is not None:
|
29 |
+
meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
|
30 |
+
return meta
|
31 |
+
|
32 |
+
|
33 |
+
def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
|
34 |
+
"""Monkey-patch all meta to match cluster specificities."""
|
35 |
+
return [_clusterify_meta(m) for m in meta]
|
36 |
+
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class AudioInfo(SegmentWithAttributes):
|
40 |
+
"""Dummy SegmentInfo with empty attributes.
|
41 |
+
|
42 |
+
The InfoAudioDataset is expected to return metadata that inherits
|
43 |
+
from SegmentWithAttributes class and can return conditioning attributes.
|
44 |
+
|
45 |
+
This basically guarantees all datasets will be compatible with current
|
46 |
+
solver that contain conditioners requiring this.
|
47 |
+
"""
|
48 |
+
audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM.
|
49 |
+
|
50 |
+
def to_condition_attributes(self) -> ConditioningAttributes:
|
51 |
+
return ConditioningAttributes()
|
52 |
+
|
53 |
+
|
54 |
+
class InfoAudioDataset(AudioDataset):
|
55 |
+
"""AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
|
56 |
+
|
57 |
+
See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
|
58 |
+
"""
|
59 |
+
def __init__(self, meta: tp.List[AudioMeta], **kwargs):
|
60 |
+
super().__init__(clusterify_all_meta(meta), **kwargs)
|
61 |
+
|
62 |
+
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
|
63 |
+
if not self.return_info:
|
64 |
+
wav = super().__getitem__(index)
|
65 |
+
assert isinstance(wav, torch.Tensor)
|
66 |
+
return wav
|
67 |
+
wav, meta = super().__getitem__(index)
|
68 |
+
return wav, AudioInfo(**meta.to_dict())
|
69 |
+
|
70 |
+
|
71 |
+
def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
|
72 |
+
"""Preprocess a single keyword or possible a list of keywords."""
|
73 |
+
if isinstance(value, list):
|
74 |
+
return get_keyword_list(value)
|
75 |
+
else:
|
76 |
+
return get_keyword(value)
|
77 |
+
|
78 |
+
|
79 |
+
def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
|
80 |
+
"""Preprocess a single keyword."""
|
81 |
+
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
|
82 |
+
return None
|
83 |
+
else:
|
84 |
+
return value.strip()
|
85 |
+
|
86 |
+
|
87 |
+
def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
|
88 |
+
"""Preprocess a single keyword."""
|
89 |
+
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
|
90 |
+
return None
|
91 |
+
else:
|
92 |
+
return value.strip().lower()
|
93 |
+
|
94 |
+
|
95 |
+
def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
|
96 |
+
"""Preprocess a list of keywords."""
|
97 |
+
if isinstance(values, str):
|
98 |
+
values = [v.strip() for v in re.split(r'[,\s]', values)]
|
99 |
+
elif isinstance(values, float) and math.isnan(values):
|
100 |
+
values = []
|
101 |
+
if not isinstance(values, list):
|
102 |
+
logger.debug(f"Unexpected keyword list {values}")
|
103 |
+
values = [str(values)]
|
104 |
+
|
105 |
+
kws = [get_keyword(v) for v in values]
|
106 |
+
kw_list = [k for k in kws if k is not None]
|
107 |
+
if len(kw_list) == 0:
|
108 |
+
return None
|
109 |
+
else:
|
110 |
+
return kw_list
|
audiocraft/data/music_dataset.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Dataset of music tracks with rich metadata.
|
7 |
+
"""
|
8 |
+
from dataclasses import dataclass, field, fields, replace
|
9 |
+
import gzip
|
10 |
+
import json
|
11 |
+
import logging
|
12 |
+
from pathlib import Path
|
13 |
+
import random
|
14 |
+
import typing as tp
|
15 |
+
|
16 |
+
import torch
|
17 |
+
|
18 |
+
from .info_audio_dataset import (
|
19 |
+
InfoAudioDataset,
|
20 |
+
AudioInfo,
|
21 |
+
get_keyword_list,
|
22 |
+
get_keyword,
|
23 |
+
get_string
|
24 |
+
)
|
25 |
+
from ..modules.conditioners import (
|
26 |
+
ConditioningAttributes,
|
27 |
+
JointEmbedCondition,
|
28 |
+
WavCondition,
|
29 |
+
)
|
30 |
+
from ..utils.utils import warn_once
|
31 |
+
|
32 |
+
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class MusicInfo(AudioInfo):
|
38 |
+
"""Segment info augmented with music metadata.
|
39 |
+
"""
|
40 |
+
# music-specific metadata
|
41 |
+
title: tp.Optional[str] = None
|
42 |
+
artist: tp.Optional[str] = None # anonymized artist id, used to ensure no overlap between splits
|
43 |
+
key: tp.Optional[str] = None
|
44 |
+
bpm: tp.Optional[float] = None
|
45 |
+
genre: tp.Optional[str] = None
|
46 |
+
moods: tp.Optional[list] = None
|
47 |
+
keywords: tp.Optional[list] = None
|
48 |
+
description: tp.Optional[str] = None
|
49 |
+
name: tp.Optional[str] = None
|
50 |
+
instrument: tp.Optional[str] = None
|
51 |
+
# original wav accompanying the metadata
|
52 |
+
self_wav: tp.Optional[WavCondition] = None
|
53 |
+
# dict mapping attributes names to tuple of wav, text and metadata
|
54 |
+
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
|
55 |
+
|
56 |
+
@property
|
57 |
+
def has_music_meta(self) -> bool:
|
58 |
+
return self.name is not None
|
59 |
+
|
60 |
+
def to_condition_attributes(self) -> ConditioningAttributes:
|
61 |
+
out = ConditioningAttributes()
|
62 |
+
for _field in fields(self):
|
63 |
+
key, value = _field.name, getattr(self, _field.name)
|
64 |
+
if key == 'self_wav':
|
65 |
+
out.wav[key] = value
|
66 |
+
elif key == 'joint_embed':
|
67 |
+
for embed_attribute, embed_cond in value.items():
|
68 |
+
out.joint_embed[embed_attribute] = embed_cond
|
69 |
+
else:
|
70 |
+
if isinstance(value, list):
|
71 |
+
value = ' '.join(value)
|
72 |
+
out.text[key] = value
|
73 |
+
return out
|
74 |
+
|
75 |
+
@staticmethod
|
76 |
+
def attribute_getter(attribute):
|
77 |
+
if attribute == 'bpm':
|
78 |
+
preprocess_func = get_bpm
|
79 |
+
elif attribute == 'key':
|
80 |
+
preprocess_func = get_musical_key
|
81 |
+
elif attribute in ['moods', 'keywords']:
|
82 |
+
preprocess_func = get_keyword_list
|
83 |
+
elif attribute in ['genre', 'name', 'instrument']:
|
84 |
+
preprocess_func = get_keyword
|
85 |
+
elif attribute in ['title', 'artist', 'description']:
|
86 |
+
preprocess_func = get_string
|
87 |
+
else:
|
88 |
+
preprocess_func = None
|
89 |
+
return preprocess_func
|
90 |
+
|
91 |
+
@classmethod
|
92 |
+
def from_dict(cls, dictionary: dict, fields_required: bool = False):
|
93 |
+
_dictionary: tp.Dict[str, tp.Any] = {}
|
94 |
+
|
95 |
+
# allow a subset of attributes to not be loaded from the dictionary
|
96 |
+
# these attributes may be populated later
|
97 |
+
post_init_attributes = ['self_wav', 'joint_embed']
|
98 |
+
optional_fields = ['keywords']
|
99 |
+
|
100 |
+
for _field in fields(cls):
|
101 |
+
if _field.name in post_init_attributes:
|
102 |
+
continue
|
103 |
+
elif _field.name not in dictionary:
|
104 |
+
if fields_required and _field.name not in optional_fields:
|
105 |
+
raise KeyError(f"Unexpected missing key: {_field.name}")
|
106 |
+
else:
|
107 |
+
preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
|
108 |
+
value = dictionary[_field.name]
|
109 |
+
if preprocess_func:
|
110 |
+
value = preprocess_func(value)
|
111 |
+
_dictionary[_field.name] = value
|
112 |
+
return cls(**_dictionary)
|
113 |
+
|
114 |
+
|
115 |
+
def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0.,
|
116 |
+
drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo:
|
117 |
+
"""Augment MusicInfo description with additional metadata fields and potential dropout.
|
118 |
+
Additional textual attributes are added given probability 'merge_text_conditions_p' and
|
119 |
+
the original textual description is dropped from the augmented description given probability drop_desc_p.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
music_info (MusicInfo): The music metadata to augment.
|
123 |
+
merge_text_p (float): Probability of merging additional metadata to the description.
|
124 |
+
If provided value is 0, then no merging is performed.
|
125 |
+
drop_desc_p (float): Probability of dropping the original description on text merge.
|
126 |
+
if provided value is 0, then no drop out is performed.
|
127 |
+
drop_other_p (float): Probability of dropping the other fields used for text augmentation.
|
128 |
+
Returns:
|
129 |
+
MusicInfo: The MusicInfo with augmented textual description.
|
130 |
+
"""
|
131 |
+
def is_valid_field(field_name: str, field_value: tp.Any) -> bool:
|
132 |
+
valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords']
|
133 |
+
valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list))
|
134 |
+
keep_field = random.uniform(0, 1) < drop_other_p
|
135 |
+
return valid_field_name and valid_field_value and keep_field
|
136 |
+
|
137 |
+
def process_value(v: tp.Any) -> str:
|
138 |
+
if isinstance(v, (int, float, str)):
|
139 |
+
return str(v)
|
140 |
+
if isinstance(v, list):
|
141 |
+
return ", ".join(v)
|
142 |
+
else:
|
143 |
+
raise ValueError(f"Unknown type for text value! ({type(v), v})")
|
144 |
+
|
145 |
+
description = music_info.description
|
146 |
+
|
147 |
+
metadata_text = ""
|
148 |
+
if random.uniform(0, 1) < merge_text_p:
|
149 |
+
meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}'
|
150 |
+
for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))]
|
151 |
+
random.shuffle(meta_pairs)
|
152 |
+
metadata_text = ". ".join(meta_pairs)
|
153 |
+
description = description if not random.uniform(0, 1) < drop_desc_p else None
|
154 |
+
logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}")
|
155 |
+
|
156 |
+
if description is None:
|
157 |
+
description = metadata_text if len(metadata_text) > 1 else None
|
158 |
+
else:
|
159 |
+
description = ". ".join([description.rstrip('.'), metadata_text])
|
160 |
+
description = description.strip() if description else None
|
161 |
+
|
162 |
+
music_info = replace(music_info)
|
163 |
+
music_info.description = description
|
164 |
+
return music_info
|
165 |
+
|
166 |
+
|
167 |
+
class Paraphraser:
|
168 |
+
def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.):
|
169 |
+
self.paraphrase_p = paraphrase_p
|
170 |
+
open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open
|
171 |
+
with open_fn(paraphrase_source, 'rb') as f: # type: ignore
|
172 |
+
self.paraphrase_source = json.loads(f.read())
|
173 |
+
logger.info(f"loaded paraphrasing source from: {paraphrase_source}")
|
174 |
+
|
175 |
+
def sample_paraphrase(self, audio_path: str, description: str):
|
176 |
+
if random.random() >= self.paraphrase_p:
|
177 |
+
return description
|
178 |
+
info_path = Path(audio_path).with_suffix('.json')
|
179 |
+
if info_path not in self.paraphrase_source:
|
180 |
+
warn_once(logger, f"{info_path} not in paraphrase source!")
|
181 |
+
return description
|
182 |
+
new_desc = random.choice(self.paraphrase_source[info_path])
|
183 |
+
logger.debug(f"{description} -> {new_desc}")
|
184 |
+
return new_desc
|
185 |
+
|
186 |
+
|
187 |
+
class MusicDataset(InfoAudioDataset):
|
188 |
+
"""Music dataset is an AudioDataset with music-related metadata.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
info_fields_required (bool): Whether to enforce having required fields.
|
192 |
+
merge_text_p (float): Probability of merging additional metadata to the description.
|
193 |
+
drop_desc_p (float): Probability of dropping the original description on text merge.
|
194 |
+
drop_other_p (float): Probability of dropping the other fields used for text augmentation.
|
195 |
+
joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned.
|
196 |
+
paraphrase_source (str, optional): Path to the .json or .json.gz file containing the
|
197 |
+
paraphrases for the description. The json should be a dict with keys are the
|
198 |
+
original info path (e.g. track_path.json) and each value is a list of possible
|
199 |
+
paraphrased.
|
200 |
+
paraphrase_p (float): probability of taking a paraphrase.
|
201 |
+
|
202 |
+
See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
|
203 |
+
"""
|
204 |
+
def __init__(self, *args, info_fields_required: bool = True,
|
205 |
+
merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0.,
|
206 |
+
joint_embed_attributes: tp.List[str] = [],
|
207 |
+
paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0,
|
208 |
+
**kwargs):
|
209 |
+
kwargs['return_info'] = True # We require the info for each song of the dataset.
|
210 |
+
super().__init__(*args, **kwargs)
|
211 |
+
self.info_fields_required = info_fields_required
|
212 |
+
self.merge_text_p = merge_text_p
|
213 |
+
self.drop_desc_p = drop_desc_p
|
214 |
+
self.drop_other_p = drop_other_p
|
215 |
+
self.joint_embed_attributes = joint_embed_attributes
|
216 |
+
self.paraphraser = None
|
217 |
+
if paraphrase_source is not None:
|
218 |
+
self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p)
|
219 |
+
|
220 |
+
def __getitem__(self, index):
|
221 |
+
wav, info = super().__getitem__(index)
|
222 |
+
info_data = info.to_dict()
|
223 |
+
music_info_path = Path(info.meta.path).with_suffix('.json')
|
224 |
+
|
225 |
+
if Path(music_info_path).exists():
|
226 |
+
with open(music_info_path, 'r') as json_file:
|
227 |
+
music_data = json.load(json_file)
|
228 |
+
music_data.update(info_data)
|
229 |
+
music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required)
|
230 |
+
if self.paraphraser is not None:
|
231 |
+
music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description)
|
232 |
+
if self.merge_text_p:
|
233 |
+
music_info = augment_music_info_description(
|
234 |
+
music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p)
|
235 |
+
else:
|
236 |
+
music_info = MusicInfo.from_dict(info_data, fields_required=False)
|
237 |
+
|
238 |
+
music_info.self_wav = WavCondition(
|
239 |
+
wav=wav[None], length=torch.tensor([info.n_frames]),
|
240 |
+
sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
|
241 |
+
|
242 |
+
for att in self.joint_embed_attributes:
|
243 |
+
att_value = getattr(music_info, att)
|
244 |
+
joint_embed_cond = JointEmbedCondition(
|
245 |
+
wav[None], [att_value], torch.tensor([info.n_frames]),
|
246 |
+
sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
|
247 |
+
music_info.joint_embed[att] = joint_embed_cond
|
248 |
+
|
249 |
+
return wav, music_info
|
250 |
+
|
251 |
+
|
252 |
+
def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]:
|
253 |
+
"""Preprocess key keywords, discarding them if there are multiple key defined."""
|
254 |
+
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
|
255 |
+
return None
|
256 |
+
elif ',' in value:
|
257 |
+
# For now, we discard when multiple keys are defined separated with comas
|
258 |
+
return None
|
259 |
+
else:
|
260 |
+
return value.strip().lower()
|
261 |
+
|
262 |
+
|
263 |
+
def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]:
|
264 |
+
"""Preprocess to a float."""
|
265 |
+
if value is None:
|
266 |
+
return None
|
267 |
+
try:
|
268 |
+
return float(value)
|
269 |
+
except ValueError:
|
270 |
+
return None
|
audiocraft/data/sound_dataset.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Dataset of audio with a simple description.
|
7 |
+
"""
|
8 |
+
|
9 |
+
from dataclasses import dataclass, fields, replace
|
10 |
+
import json
|
11 |
+
from pathlib import Path
|
12 |
+
import random
|
13 |
+
import typing as tp
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
|
18 |
+
from .info_audio_dataset import (
|
19 |
+
InfoAudioDataset,
|
20 |
+
get_keyword_or_keyword_list
|
21 |
+
)
|
22 |
+
from ..modules.conditioners import (
|
23 |
+
ConditioningAttributes,
|
24 |
+
SegmentWithAttributes,
|
25 |
+
WavCondition,
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
EPS = torch.finfo(torch.float32).eps
|
30 |
+
TARGET_LEVEL_LOWER = -35
|
31 |
+
TARGET_LEVEL_UPPER = -15
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class SoundInfo(SegmentWithAttributes):
|
36 |
+
"""Segment info augmented with Sound metadata.
|
37 |
+
"""
|
38 |
+
description: tp.Optional[str] = None
|
39 |
+
self_wav: tp.Optional[torch.Tensor] = None
|
40 |
+
|
41 |
+
@property
|
42 |
+
def has_sound_meta(self) -> bool:
|
43 |
+
return self.description is not None
|
44 |
+
|
45 |
+
def to_condition_attributes(self) -> ConditioningAttributes:
|
46 |
+
out = ConditioningAttributes()
|
47 |
+
|
48 |
+
for _field in fields(self):
|
49 |
+
key, value = _field.name, getattr(self, _field.name)
|
50 |
+
if key == 'self_wav':
|
51 |
+
out.wav[key] = value
|
52 |
+
else:
|
53 |
+
out.text[key] = value
|
54 |
+
return out
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def attribute_getter(attribute):
|
58 |
+
if attribute == 'description':
|
59 |
+
preprocess_func = get_keyword_or_keyword_list
|
60 |
+
else:
|
61 |
+
preprocess_func = None
|
62 |
+
return preprocess_func
|
63 |
+
|
64 |
+
@classmethod
|
65 |
+
def from_dict(cls, dictionary: dict, fields_required: bool = False):
|
66 |
+
_dictionary: tp.Dict[str, tp.Any] = {}
|
67 |
+
|
68 |
+
# allow a subset of attributes to not be loaded from the dictionary
|
69 |
+
# these attributes may be populated later
|
70 |
+
post_init_attributes = ['self_wav']
|
71 |
+
|
72 |
+
for _field in fields(cls):
|
73 |
+
if _field.name in post_init_attributes:
|
74 |
+
continue
|
75 |
+
elif _field.name not in dictionary:
|
76 |
+
if fields_required:
|
77 |
+
raise KeyError(f"Unexpected missing key: {_field.name}")
|
78 |
+
else:
|
79 |
+
preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
|
80 |
+
value = dictionary[_field.name]
|
81 |
+
if preprocess_func:
|
82 |
+
value = preprocess_func(value)
|
83 |
+
_dictionary[_field.name] = value
|
84 |
+
return cls(**_dictionary)
|
85 |
+
|
86 |
+
|
87 |
+
class SoundDataset(InfoAudioDataset):
|
88 |
+
"""Sound audio dataset: Audio dataset with environmental sound-specific metadata.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata.
|
92 |
+
external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset.
|
93 |
+
The metadata files contained in this folder are expected to match the stem of the audio file with
|
94 |
+
a json extension.
|
95 |
+
aug_p (float): Probability of performing audio mixing augmentation on the batch.
|
96 |
+
mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation.
|
97 |
+
mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation.
|
98 |
+
mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation.
|
99 |
+
mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation.
|
100 |
+
kwargs: Additional arguments for AudioDataset.
|
101 |
+
|
102 |
+
See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
|
103 |
+
"""
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
*args,
|
107 |
+
info_fields_required: bool = True,
|
108 |
+
external_metadata_source: tp.Optional[str] = None,
|
109 |
+
aug_p: float = 0.,
|
110 |
+
mix_p: float = 0.,
|
111 |
+
mix_snr_low: int = -5,
|
112 |
+
mix_snr_high: int = 5,
|
113 |
+
mix_min_overlap: float = 0.5,
|
114 |
+
**kwargs
|
115 |
+
):
|
116 |
+
kwargs['return_info'] = True # We require the info for each song of the dataset.
|
117 |
+
super().__init__(*args, **kwargs)
|
118 |
+
self.info_fields_required = info_fields_required
|
119 |
+
self.external_metadata_source = external_metadata_source
|
120 |
+
self.aug_p = aug_p
|
121 |
+
self.mix_p = mix_p
|
122 |
+
if self.aug_p > 0:
|
123 |
+
assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0"
|
124 |
+
assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio"
|
125 |
+
self.mix_snr_low = mix_snr_low
|
126 |
+
self.mix_snr_high = mix_snr_high
|
127 |
+
self.mix_min_overlap = mix_min_overlap
|
128 |
+
|
129 |
+
def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
|
130 |
+
"""Get path of JSON with metadata (description, etc.).
|
131 |
+
If there exists a JSON with the same name as 'path.name', then it will be used.
|
132 |
+
Else, such JSON will be searched for in an external json source folder if it exists.
|
133 |
+
"""
|
134 |
+
info_path = Path(path).with_suffix('.json')
|
135 |
+
if Path(info_path).exists():
|
136 |
+
return info_path
|
137 |
+
elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists():
|
138 |
+
return Path(self.external_metadata_source) / info_path.name
|
139 |
+
else:
|
140 |
+
raise Exception(f"Unable to find a metadata JSON for path: {path}")
|
141 |
+
|
142 |
+
def __getitem__(self, index):
|
143 |
+
wav, info = super().__getitem__(index)
|
144 |
+
info_data = info.to_dict()
|
145 |
+
info_path = self._get_info_path(info.meta.path)
|
146 |
+
if Path(info_path).exists():
|
147 |
+
with open(info_path, 'r') as json_file:
|
148 |
+
sound_data = json.load(json_file)
|
149 |
+
sound_data.update(info_data)
|
150 |
+
sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required)
|
151 |
+
# if there are multiple descriptions, sample one randomly
|
152 |
+
if isinstance(sound_info.description, list):
|
153 |
+
sound_info.description = random.choice(sound_info.description)
|
154 |
+
else:
|
155 |
+
sound_info = SoundInfo.from_dict(info_data, fields_required=False)
|
156 |
+
|
157 |
+
sound_info.self_wav = WavCondition(
|
158 |
+
wav=wav[None], length=torch.tensor([info.n_frames]),
|
159 |
+
sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
|
160 |
+
|
161 |
+
return wav, sound_info
|
162 |
+
|
163 |
+
def collater(self, samples):
|
164 |
+
# when training, audio mixing is performed in the collate function
|
165 |
+
wav, sound_info = super().collater(samples) # SoundDataset always returns infos
|
166 |
+
if self.aug_p > 0:
|
167 |
+
wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p,
|
168 |
+
snr_low=self.mix_snr_low, snr_high=self.mix_snr_high,
|
169 |
+
min_overlap=self.mix_min_overlap)
|
170 |
+
return wav, sound_info
|
171 |
+
|
172 |
+
|
173 |
+
def rms_f(x: torch.Tensor) -> torch.Tensor:
|
174 |
+
return (x ** 2).mean(1).pow(0.5)
|
175 |
+
|
176 |
+
|
177 |
+
def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor:
|
178 |
+
"""Normalize the signal to the target level."""
|
179 |
+
rms = rms_f(audio)
|
180 |
+
scalar = 10 ** (target_level / 20) / (rms + EPS)
|
181 |
+
audio = audio * scalar.unsqueeze(1)
|
182 |
+
return audio
|
183 |
+
|
184 |
+
|
185 |
+
def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor:
|
186 |
+
return (abs(audio) > clipping_threshold).any(1)
|
187 |
+
|
188 |
+
|
189 |
+
def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor:
|
190 |
+
start = random.randint(0, int(src.shape[1] * (1 - min_overlap)))
|
191 |
+
remainder = src.shape[1] - start
|
192 |
+
if dst.shape[1] > remainder:
|
193 |
+
src[:, start:] = src[:, start:] + dst[:, :remainder]
|
194 |
+
else:
|
195 |
+
src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst
|
196 |
+
return src
|
197 |
+
|
198 |
+
|
199 |
+
def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float,
|
200 |
+
target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor:
|
201 |
+
"""Function to mix clean speech and noise at various SNR levels.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
clean (torch.Tensor): Clean audio source to mix, of shape [B, T].
|
205 |
+
noise (torch.Tensor): Noise audio source to mix, of shape [B, T].
|
206 |
+
snr (int): SNR level when mixing.
|
207 |
+
min_overlap (float): Minimum overlap between the two mixed sources.
|
208 |
+
target_level (int): Gain level in dB.
|
209 |
+
clipping_threshold (float): Threshold for clipping the audio.
|
210 |
+
Returns:
|
211 |
+
torch.Tensor: The mixed audio, of shape [B, T].
|
212 |
+
"""
|
213 |
+
if clean.shape[1] > noise.shape[1]:
|
214 |
+
noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1]))
|
215 |
+
else:
|
216 |
+
noise = noise[:, :clean.shape[1]]
|
217 |
+
|
218 |
+
# normalizing to -25 dB FS
|
219 |
+
clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS)
|
220 |
+
clean = normalize(clean, target_level)
|
221 |
+
rmsclean = rms_f(clean)
|
222 |
+
|
223 |
+
noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS)
|
224 |
+
noise = normalize(noise, target_level)
|
225 |
+
rmsnoise = rms_f(noise)
|
226 |
+
|
227 |
+
# set the noise level for a given SNR
|
228 |
+
noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1)
|
229 |
+
noisenewlevel = noise * noisescalar
|
230 |
+
|
231 |
+
# mix noise and clean speech
|
232 |
+
noisyspeech = mix_pair(clean, noisenewlevel, min_overlap)
|
233 |
+
|
234 |
+
# randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
|
235 |
+
# there is a chance of clipping that might happen with very less probability, which is not a major issue.
|
236 |
+
noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER)
|
237 |
+
rmsnoisy = rms_f(noisyspeech)
|
238 |
+
scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1)
|
239 |
+
noisyspeech = noisyspeech * scalarnoisy
|
240 |
+
clean = clean * scalarnoisy
|
241 |
+
noisenewlevel = noisenewlevel * scalarnoisy
|
242 |
+
|
243 |
+
# final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
|
244 |
+
clipped = is_clipped(noisyspeech)
|
245 |
+
if clipped.any():
|
246 |
+
noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS)
|
247 |
+
noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel
|
248 |
+
|
249 |
+
return noisyspeech
|
250 |
+
|
251 |
+
|
252 |
+
def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float):
|
253 |
+
if snr_low == snr_high:
|
254 |
+
snr = snr_low
|
255 |
+
else:
|
256 |
+
snr = np.random.randint(snr_low, snr_high)
|
257 |
+
mix = snr_mixer(src, dst, snr, min_overlap)
|
258 |
+
return mix
|
259 |
+
|
260 |
+
|
261 |
+
def mix_text(src_text: str, dst_text: str):
|
262 |
+
"""Mix text from different sources by concatenating them."""
|
263 |
+
if src_text == dst_text:
|
264 |
+
return src_text
|
265 |
+
return src_text + " " + dst_text
|
266 |
+
|
267 |
+
|
268 |
+
def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float,
|
269 |
+
snr_low: int, snr_high: int, min_overlap: float):
|
270 |
+
"""Mix samples within a batch, summing the waveforms and concatenating the text infos.
|
271 |
+
|
272 |
+
Args:
|
273 |
+
wavs (torch.Tensor): Audio tensors of shape [B, C, T].
|
274 |
+
infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio.
|
275 |
+
aug_p (float): Augmentation probability.
|
276 |
+
mix_p (float): Proportion of items in the batch to mix (and merge) together.
|
277 |
+
snr_low (int): Lowerbound for sampling SNR.
|
278 |
+
snr_high (int): Upperbound for sampling SNR.
|
279 |
+
min_overlap (float): Minimum overlap between mixed samples.
|
280 |
+
Returns:
|
281 |
+
tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs
|
282 |
+
and mixed SoundInfo for the given batch.
|
283 |
+
"""
|
284 |
+
# no mixing to perform within the batch
|
285 |
+
if mix_p == 0:
|
286 |
+
return wavs, infos
|
287 |
+
|
288 |
+
if random.uniform(0, 1) < aug_p:
|
289 |
+
# perform all augmentations on waveforms as [B, T]
|
290 |
+
# randomly picking pairs of audio to mix
|
291 |
+
assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}"
|
292 |
+
wavs = wavs.mean(dim=1, keepdim=False)
|
293 |
+
B, T = wavs.shape
|
294 |
+
k = int(mix_p * B)
|
295 |
+
mixed_sources_idx = torch.randperm(B)[:k]
|
296 |
+
mixed_targets_idx = torch.randperm(B)[:k]
|
297 |
+
aug_wavs = snr_mix(
|
298 |
+
wavs[mixed_sources_idx],
|
299 |
+
wavs[mixed_targets_idx],
|
300 |
+
snr_low,
|
301 |
+
snr_high,
|
302 |
+
min_overlap,
|
303 |
+
)
|
304 |
+
# mixing textual descriptions in metadata
|
305 |
+
descriptions = [info.description for info in infos]
|
306 |
+
aug_infos = []
|
307 |
+
for i, j in zip(mixed_sources_idx, mixed_targets_idx):
|
308 |
+
text = mix_text(descriptions[i], descriptions[j])
|
309 |
+
m = replace(infos[i])
|
310 |
+
m.description = text
|
311 |
+
aug_infos.append(m)
|
312 |
+
|
313 |
+
# back to [B, C, T]
|
314 |
+
aug_wavs = aug_wavs.unsqueeze(1)
|
315 |
+
assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch."
|
316 |
+
assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}"
|
317 |
+
assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch"
|
318 |
+
|
319 |
+
return aug_wavs, aug_infos # [B, C, T]
|
320 |
+
else:
|
321 |
+
# randomly pick samples in the batch to match
|
322 |
+
# the batch size when performing audio mixing
|
323 |
+
B, C, T = wavs.shape
|
324 |
+
k = int(mix_p * B)
|
325 |
+
wav_idx = torch.randperm(B)[:k]
|
326 |
+
wavs = wavs[wav_idx]
|
327 |
+
infos = [infos[i] for i in wav_idx]
|
328 |
+
assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch"
|
329 |
+
|
330 |
+
return wavs, infos # [B, C, T]
|
audiocraft/data/zip.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Utility for reading some info from inside a zip file.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import typing
|
10 |
+
import zipfile
|
11 |
+
|
12 |
+
from dataclasses import dataclass
|
13 |
+
from functools import lru_cache
|
14 |
+
from typing_extensions import Literal
|
15 |
+
|
16 |
+
|
17 |
+
DEFAULT_SIZE = 32
|
18 |
+
MODE = Literal['r', 'w', 'x', 'a']
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass(order=True)
|
22 |
+
class PathInZip:
|
23 |
+
"""Hold a path of file within a zip file.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
path (str): The convention is <path_to_zip>:<relative_path_inside_zip>.
|
27 |
+
Let's assume there is a zip file /some/location/foo.zip
|
28 |
+
and inside of it is a json file located at /data/file1.json,
|
29 |
+
Then we expect path = "/some/location/foo.zip:/data/file1.json".
|
30 |
+
"""
|
31 |
+
|
32 |
+
INFO_PATH_SEP = ':'
|
33 |
+
zip_path: str
|
34 |
+
file_path: str
|
35 |
+
|
36 |
+
def __init__(self, path: str) -> None:
|
37 |
+
split_path = path.split(self.INFO_PATH_SEP)
|
38 |
+
assert len(split_path) == 2
|
39 |
+
self.zip_path, self.file_path = split_path
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def from_paths(cls, zip_path: str, file_path: str):
|
43 |
+
return cls(zip_path + cls.INFO_PATH_SEP + file_path)
|
44 |
+
|
45 |
+
def __str__(self) -> str:
|
46 |
+
return self.zip_path + self.INFO_PATH_SEP + self.file_path
|
47 |
+
|
48 |
+
|
49 |
+
def _open_zip(path: str, mode: MODE = 'r'):
|
50 |
+
return zipfile.ZipFile(path, mode)
|
51 |
+
|
52 |
+
|
53 |
+
_cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
|
54 |
+
|
55 |
+
|
56 |
+
def set_zip_cache_size(max_size: int):
|
57 |
+
"""Sets the maximal LRU caching for zip file opening.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
max_size (int): the maximal LRU cache.
|
61 |
+
"""
|
62 |
+
global _cached_open_zip
|
63 |
+
_cached_open_zip = lru_cache(max_size)(_open_zip)
|
64 |
+
|
65 |
+
|
66 |
+
def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
|
67 |
+
"""Opens a file stored inside a zip and returns a file-like object.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
|
71 |
+
mode (str): The mode in which to open the file with.
|
72 |
+
Returns:
|
73 |
+
A file-like object for PathInZip.
|
74 |
+
"""
|
75 |
+
zf = _cached_open_zip(path_in_zip.zip_path)
|
76 |
+
return zf.open(path_in_zip.file_path)
|
audiocraft/environment.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Provides cluster and tools configuration across clusters (slurm, dora, utilities).
|
9 |
+
"""
|
10 |
+
|
11 |
+
import logging
|
12 |
+
import os
|
13 |
+
from pathlib import Path
|
14 |
+
import re
|
15 |
+
import typing as tp
|
16 |
+
|
17 |
+
import omegaconf
|
18 |
+
|
19 |
+
from .utils.cluster import _guess_cluster_type
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class AudioCraftEnvironment:
|
26 |
+
"""Environment configuration for teams and clusters.
|
27 |
+
|
28 |
+
AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
|
29 |
+
or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
|
30 |
+
provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
|
31 |
+
allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
|
32 |
+
map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
|
33 |
+
|
34 |
+
The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
|
35 |
+
Use the following environment variables to specify the cluster, team or configuration:
|
36 |
+
|
37 |
+
AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
|
38 |
+
cannot be inferred automatically.
|
39 |
+
AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
|
40 |
+
If not set, configuration is read from config/teams.yaml.
|
41 |
+
AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
|
42 |
+
Cluster configuration are shared across teams to match compute allocation,
|
43 |
+
specify your cluster configuration in the configuration file under a key mapping
|
44 |
+
your team name.
|
45 |
+
"""
|
46 |
+
_instance = None
|
47 |
+
DEFAULT_TEAM = "default"
|
48 |
+
|
49 |
+
def __init__(self) -> None:
|
50 |
+
"""Loads configuration."""
|
51 |
+
self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
|
52 |
+
cluster_type = _guess_cluster_type()
|
53 |
+
cluster = os.getenv(
|
54 |
+
"AUDIOCRAFT_CLUSTER", cluster_type.value
|
55 |
+
)
|
56 |
+
logger.info("Detecting cluster type %s", cluster_type)
|
57 |
+
|
58 |
+
self.cluster: str = cluster
|
59 |
+
|
60 |
+
config_path = os.getenv(
|
61 |
+
"AUDIOCRAFT_CONFIG",
|
62 |
+
Path(__file__)
|
63 |
+
.parent.parent.joinpath("config/teams", self.team)
|
64 |
+
.with_suffix(".yaml"),
|
65 |
+
)
|
66 |
+
self.config = omegaconf.OmegaConf.load(config_path)
|
67 |
+
self._dataset_mappers = []
|
68 |
+
cluster_config = self._get_cluster_config()
|
69 |
+
if "dataset_mappers" in cluster_config:
|
70 |
+
for pattern, repl in cluster_config["dataset_mappers"].items():
|
71 |
+
regex = re.compile(pattern)
|
72 |
+
self._dataset_mappers.append((regex, repl))
|
73 |
+
|
74 |
+
def _get_cluster_config(self) -> omegaconf.DictConfig:
|
75 |
+
assert isinstance(self.config, omegaconf.DictConfig)
|
76 |
+
return self.config[self.cluster]
|
77 |
+
|
78 |
+
@classmethod
|
79 |
+
def instance(cls):
|
80 |
+
if cls._instance is None:
|
81 |
+
cls._instance = cls()
|
82 |
+
return cls._instance
|
83 |
+
|
84 |
+
@classmethod
|
85 |
+
def reset(cls):
|
86 |
+
"""Clears the environment and forces a reload on next invocation."""
|
87 |
+
cls._instance = None
|
88 |
+
|
89 |
+
@classmethod
|
90 |
+
def get_team(cls) -> str:
|
91 |
+
"""Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
|
92 |
+
If not defined, defaults to "labs".
|
93 |
+
"""
|
94 |
+
return cls.instance().team
|
95 |
+
|
96 |
+
@classmethod
|
97 |
+
def get_cluster(cls) -> str:
|
98 |
+
"""Gets the detected cluster.
|
99 |
+
This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
|
100 |
+
"""
|
101 |
+
return cls.instance().cluster
|
102 |
+
|
103 |
+
@classmethod
|
104 |
+
def get_dora_dir(cls) -> Path:
|
105 |
+
"""Gets the path to the dora directory for the current team and cluster.
|
106 |
+
Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
|
107 |
+
"""
|
108 |
+
cluster_config = cls.instance()._get_cluster_config()
|
109 |
+
dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
|
110 |
+
logger.warning(f"Dora directory: {dora_dir}")
|
111 |
+
return Path(dora_dir)
|
112 |
+
|
113 |
+
@classmethod
|
114 |
+
def get_reference_dir(cls) -> Path:
|
115 |
+
"""Gets the path to the reference directory for the current team and cluster.
|
116 |
+
Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
|
117 |
+
"""
|
118 |
+
cluster_config = cls.instance()._get_cluster_config()
|
119 |
+
return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
|
120 |
+
|
121 |
+
@classmethod
|
122 |
+
def get_slurm_exclude(cls) -> tp.Optional[str]:
|
123 |
+
"""Get the list of nodes to exclude for that cluster."""
|
124 |
+
cluster_config = cls.instance()._get_cluster_config()
|
125 |
+
return cluster_config.get("slurm_exclude")
|
126 |
+
|
127 |
+
@classmethod
|
128 |
+
def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
|
129 |
+
"""Gets the requested partitions for the current team and cluster as a comma-separated string.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
partition_types (list[str], optional): partition types to retrieve. Values must be
|
133 |
+
from ['global', 'team']. If not provided, the global partition is returned.
|
134 |
+
"""
|
135 |
+
if not partition_types:
|
136 |
+
partition_types = ["global"]
|
137 |
+
|
138 |
+
cluster_config = cls.instance()._get_cluster_config()
|
139 |
+
partitions = [
|
140 |
+
cluster_config["partitions"][partition_type]
|
141 |
+
for partition_type in partition_types
|
142 |
+
]
|
143 |
+
return ",".join(partitions)
|
144 |
+
|
145 |
+
@classmethod
|
146 |
+
def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
|
147 |
+
"""Converts reference placeholder in path with configured reference dir to resolve paths.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
path (str or Path): Path to resolve.
|
151 |
+
Returns:
|
152 |
+
Path: Resolved path.
|
153 |
+
"""
|
154 |
+
path = str(path)
|
155 |
+
|
156 |
+
if path.startswith("//reference"):
|
157 |
+
reference_dir = cls.get_reference_dir()
|
158 |
+
logger.warn(f"Reference directory: {reference_dir}")
|
159 |
+
assert (
|
160 |
+
reference_dir.exists() and reference_dir.is_dir()
|
161 |
+
), f"Reference directory does not exist: {reference_dir}."
|
162 |
+
path = re.sub("^//reference", str(reference_dir), path)
|
163 |
+
|
164 |
+
return Path(path)
|
165 |
+
|
166 |
+
@classmethod
|
167 |
+
def apply_dataset_mappers(cls, path: str) -> str:
|
168 |
+
"""Applies dataset mapping regex rules as defined in the configuration.
|
169 |
+
If no rules are defined, the path is returned as-is.
|
170 |
+
"""
|
171 |
+
instance = cls.instance()
|
172 |
+
|
173 |
+
for pattern, repl in instance._dataset_mappers:
|
174 |
+
path = pattern.sub(repl, path)
|
175 |
+
|
176 |
+
return path
|
audiocraft/grids/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Dora Grids."""
|
audiocraft/grids/_base_explorers.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
import time
|
9 |
+
import typing as tp
|
10 |
+
from dora import Explorer
|
11 |
+
import treetable as tt
|
12 |
+
|
13 |
+
|
14 |
+
def get_sheep_ping(sheep) -> tp.Optional[str]:
|
15 |
+
"""Return the amount of time since the Sheep made some update
|
16 |
+
to its log. Returns a str using the relevant time unit."""
|
17 |
+
ping = None
|
18 |
+
if sheep.log is not None and sheep.log.exists():
|
19 |
+
delta = time.time() - sheep.log.stat().st_mtime
|
20 |
+
if delta > 3600 * 24:
|
21 |
+
ping = f'{delta / (3600 * 24):.1f}d'
|
22 |
+
elif delta > 3600:
|
23 |
+
ping = f'{delta / (3600):.1f}h'
|
24 |
+
elif delta > 60:
|
25 |
+
ping = f'{delta / 60:.1f}m'
|
26 |
+
else:
|
27 |
+
ping = f'{delta:.1f}s'
|
28 |
+
return ping
|
29 |
+
|
30 |
+
|
31 |
+
class BaseExplorer(ABC, Explorer):
|
32 |
+
"""Base explorer for AudioCraft grids.
|
33 |
+
|
34 |
+
All task specific solvers are expected to implement the `get_grid_metrics`
|
35 |
+
method to specify logic about metrics to display for a given task.
|
36 |
+
|
37 |
+
If additional stages are used, the child explorer must define how to handle
|
38 |
+
these new stages in the `process_history` and `process_sheep` methods.
|
39 |
+
"""
|
40 |
+
def stages(self):
|
41 |
+
return ["train", "valid", "evaluate"]
|
42 |
+
|
43 |
+
def get_grid_meta(self):
|
44 |
+
"""Returns the list of Meta information to display for each XP/job.
|
45 |
+
"""
|
46 |
+
return [
|
47 |
+
tt.leaf("index", align=">"),
|
48 |
+
tt.leaf("name", wrap=140),
|
49 |
+
tt.leaf("state"),
|
50 |
+
tt.leaf("sig", align=">"),
|
51 |
+
tt.leaf("sid", align="<"),
|
52 |
+
]
|
53 |
+
|
54 |
+
@abstractmethod
|
55 |
+
def get_grid_metrics(self):
|
56 |
+
"""Return the metrics that should be displayed in the tracking table.
|
57 |
+
"""
|
58 |
+
...
|
59 |
+
|
60 |
+
def process_sheep(self, sheep, history):
|
61 |
+
train = {
|
62 |
+
"epoch": len(history),
|
63 |
+
}
|
64 |
+
parts = {"train": train}
|
65 |
+
for metrics in history:
|
66 |
+
for key, sub in metrics.items():
|
67 |
+
part = parts.get(key, {})
|
68 |
+
if 'duration' in sub:
|
69 |
+
# Convert to minutes for readability.
|
70 |
+
sub['duration'] = sub['duration'] / 60.
|
71 |
+
part.update(sub)
|
72 |
+
parts[key] = part
|
73 |
+
ping = get_sheep_ping(sheep)
|
74 |
+
if ping is not None:
|
75 |
+
for name in self.stages():
|
76 |
+
if name not in parts:
|
77 |
+
parts[name] = {}
|
78 |
+
# Add the ping to each part for convenience.
|
79 |
+
parts[name]['ping'] = ping
|
80 |
+
return parts
|
audiocraft/grids/audiogen/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""AudioGen grids."""
|
audiocraft/grids/audiogen/audiogen_base_16khz.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from ..musicgen._explorers import LMExplorer
|
8 |
+
from ...environment import AudioCraftEnvironment
|
9 |
+
|
10 |
+
|
11 |
+
@LMExplorer
|
12 |
+
def explorer(launcher):
|
13 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
14 |
+
launcher.slurm_(gpus=64, partition=partitions)
|
15 |
+
launcher.bind_(solver='audiogen/audiogen_base_16khz')
|
16 |
+
# replace this by the desired environmental sound dataset
|
17 |
+
launcher.bind_(dset='internal/sounds_16khz')
|
18 |
+
|
19 |
+
fsdp = {'autocast': False, 'fsdp.use': True}
|
20 |
+
medium = {'model/lm/model_scale': 'medium'}
|
21 |
+
|
22 |
+
launcher.bind_(fsdp)
|
23 |
+
launcher(medium)
|
audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Evaluation with objective metrics for the pretrained AudioGen models.
|
9 |
+
This grid takes signature from the training grid and runs evaluation-only stage.
|
10 |
+
|
11 |
+
When running the grid for the first time, please use:
|
12 |
+
REGEN=1 dora grid audiogen.audiogen_pretrained_16khz_eval
|
13 |
+
and re-use the REGEN=1 option when the grid is changed to force regenerating it.
|
14 |
+
|
15 |
+
Note that you need the proper metrics external libraries setup to use all
|
16 |
+
the objective metrics activated in this grid. Refer to the README for more information.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import os
|
20 |
+
|
21 |
+
from ..musicgen._explorers import GenerationEvalExplorer
|
22 |
+
from ...environment import AudioCraftEnvironment
|
23 |
+
from ... import train
|
24 |
+
|
25 |
+
|
26 |
+
def eval(launcher, batch_size: int = 32):
|
27 |
+
opts = {
|
28 |
+
'dset': 'audio/audiocaps_16khz',
|
29 |
+
'solver/audiogen/evaluation': 'objective_eval',
|
30 |
+
'execute_only': 'evaluate',
|
31 |
+
'+dataset.evaluate.batch_size': batch_size,
|
32 |
+
'+metrics.fad.tf.batch_size': 32,
|
33 |
+
}
|
34 |
+
# binary for FAD computation: replace this path with your own path
|
35 |
+
metrics_opts = {
|
36 |
+
'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
|
37 |
+
}
|
38 |
+
opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
|
39 |
+
opt2 = {'transformer_lm.two_step_cfg': True}
|
40 |
+
|
41 |
+
sub = launcher.bind(opts)
|
42 |
+
sub.bind_(metrics_opts)
|
43 |
+
|
44 |
+
# base objective metrics
|
45 |
+
sub(opt1, opt2)
|
46 |
+
|
47 |
+
|
48 |
+
@GenerationEvalExplorer
|
49 |
+
def explorer(launcher):
|
50 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
51 |
+
launcher.slurm_(gpus=4, partition=partitions)
|
52 |
+
|
53 |
+
if 'REGEN' not in os.environ:
|
54 |
+
folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
|
55 |
+
with launcher.job_array():
|
56 |
+
for sig in folder.iterdir():
|
57 |
+
if not sig.is_symlink():
|
58 |
+
continue
|
59 |
+
xp = train.main.get_xp_from_sig(sig.name)
|
60 |
+
launcher(xp.argv)
|
61 |
+
return
|
62 |
+
|
63 |
+
audiogen_base = launcher.bind(solver="audiogen/audiogen_base_16khz")
|
64 |
+
audiogen_base.bind_({'autocast': False, 'fsdp.use': True})
|
65 |
+
|
66 |
+
audiogen_base_medium = audiogen_base.bind({'continue_from': '//pretrained/facebook/audiogen-medium'})
|
67 |
+
audiogen_base_medium.bind_({'model/lm/model_scale': 'medium'})
|
68 |
+
eval(audiogen_base_medium, batch_size=128)
|
audiocraft/grids/compression/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""EnCodec grids."""
|
audiocraft/grids/compression/_explorers.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import treetable as tt
|
8 |
+
|
9 |
+
from .._base_explorers import BaseExplorer
|
10 |
+
|
11 |
+
|
12 |
+
class CompressionExplorer(BaseExplorer):
|
13 |
+
eval_metrics = ["sisnr", "visqol"]
|
14 |
+
|
15 |
+
def stages(self):
|
16 |
+
return ["train", "valid", "evaluate"]
|
17 |
+
|
18 |
+
def get_grid_meta(self):
|
19 |
+
"""Returns the list of Meta information to display for each XP/job.
|
20 |
+
"""
|
21 |
+
return [
|
22 |
+
tt.leaf("index", align=">"),
|
23 |
+
tt.leaf("name", wrap=140),
|
24 |
+
tt.leaf("state"),
|
25 |
+
tt.leaf("sig", align=">"),
|
26 |
+
]
|
27 |
+
|
28 |
+
def get_grid_metrics(self):
|
29 |
+
"""Return the metrics that should be displayed in the tracking table.
|
30 |
+
"""
|
31 |
+
return [
|
32 |
+
tt.group(
|
33 |
+
"train",
|
34 |
+
[
|
35 |
+
tt.leaf("epoch"),
|
36 |
+
tt.leaf("bandwidth", ".2f"),
|
37 |
+
tt.leaf("adv", ".4f"),
|
38 |
+
tt.leaf("d_loss", ".4f"),
|
39 |
+
],
|
40 |
+
align=">",
|
41 |
+
),
|
42 |
+
tt.group(
|
43 |
+
"valid",
|
44 |
+
[
|
45 |
+
tt.leaf("bandwidth", ".2f"),
|
46 |
+
tt.leaf("adv", ".4f"),
|
47 |
+
tt.leaf("msspec", ".4f"),
|
48 |
+
tt.leaf("sisnr", ".2f"),
|
49 |
+
],
|
50 |
+
align=">",
|
51 |
+
),
|
52 |
+
tt.group(
|
53 |
+
"evaluate", [tt.leaf(name, ".3f") for name in self.eval_metrics], align=">"
|
54 |
+
),
|
55 |
+
]
|
audiocraft/grids/compression/debug.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Grid search file, simply list all the exp you want in `explorer`.
|
9 |
+
Any new exp added there will be scheduled.
|
10 |
+
You can cancel and experiment by commenting its line.
|
11 |
+
|
12 |
+
This grid is a minimal example for debugging compression task
|
13 |
+
and how to override parameters directly in a grid.
|
14 |
+
Learn more about dora grids: https://github.com/facebookresearch/dora
|
15 |
+
"""
|
16 |
+
|
17 |
+
from ._explorers import CompressionExplorer
|
18 |
+
from ...environment import AudioCraftEnvironment
|
19 |
+
|
20 |
+
|
21 |
+
@CompressionExplorer
|
22 |
+
def explorer(launcher):
|
23 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
24 |
+
launcher.slurm_(gpus=2, partition=partitions)
|
25 |
+
launcher.bind_(solver='compression/debug')
|
26 |
+
|
27 |
+
with launcher.job_array():
|
28 |
+
# base debug task using config from solver=compression/debug
|
29 |
+
launcher()
|
30 |
+
# we can override parameters in the grid to launch additional xps
|
31 |
+
launcher({'rvq.bins': 2048, 'rvq.n_q': 4})
|
audiocraft/grids/compression/encodec_audiogen_16khz.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Grid search file, simply list all the exp you want in `explorer`.
|
9 |
+
Any new exp added there will be scheduled.
|
10 |
+
You can cancel and experiment by commenting its line.
|
11 |
+
|
12 |
+
This grid shows how to train the new AudioGen EnCodec model at 16 kHz.
|
13 |
+
"""
|
14 |
+
|
15 |
+
from ._explorers import CompressionExplorer
|
16 |
+
from ...environment import AudioCraftEnvironment
|
17 |
+
|
18 |
+
|
19 |
+
@CompressionExplorer
|
20 |
+
def explorer(launcher):
|
21 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
22 |
+
launcher.slurm_(gpus=8, partition=partitions)
|
23 |
+
# use configuration for AudioGen's EnCodec model trained on monophonic audio sampled at 16 kHz
|
24 |
+
# AudioGen's EnCodec is trained with a total stride of 320 leading to a frame rate of 50 hz
|
25 |
+
launcher.bind_(solver='compression/encodec_audiogen_16khz')
|
26 |
+
# replace this by the desired sound dataset
|
27 |
+
launcher.bind_(dset='internal/sounds_16khz')
|
28 |
+
# launch xp
|
29 |
+
launcher()
|
audiocraft/grids/compression/encodec_base_24khz.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Grid search file, simply list all the exp you want in `explorer`.
|
9 |
+
Any new exp added there will be scheduled.
|
10 |
+
You can cancel and experiment by commenting its line.
|
11 |
+
|
12 |
+
This grid shows how to train a base causal EnCodec model at 24 kHz.
|
13 |
+
"""
|
14 |
+
|
15 |
+
from ._explorers import CompressionExplorer
|
16 |
+
from ...environment import AudioCraftEnvironment
|
17 |
+
|
18 |
+
|
19 |
+
@CompressionExplorer
|
20 |
+
def explorer(launcher):
|
21 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
22 |
+
launcher.slurm_(gpus=8, partition=partitions)
|
23 |
+
# base causal EnCodec trained on monophonic audio sampled at 24 kHz
|
24 |
+
launcher.bind_(solver='compression/encodec_base_24khz')
|
25 |
+
# replace this by the desired dataset
|
26 |
+
launcher.bind_(dset='audio/example')
|
27 |
+
# launch xp
|
28 |
+
launcher()
|
audiocraft/grids/compression/encodec_musicgen_32khz.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Grid search file, simply list all the exp you want in `explorer`.
|
9 |
+
Any new exp added there will be scheduled.
|
10 |
+
You can cancel and experiment by commenting its line.
|
11 |
+
|
12 |
+
This grid shows how to train a MusicGen EnCodec model at 32 kHz.
|
13 |
+
"""
|
14 |
+
|
15 |
+
from ._explorers import CompressionExplorer
|
16 |
+
from ...environment import AudioCraftEnvironment
|
17 |
+
|
18 |
+
|
19 |
+
@CompressionExplorer
|
20 |
+
def explorer(launcher):
|
21 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
22 |
+
launcher.slurm_(gpus=8, partition=partitions)
|
23 |
+
# use configuration for MusicGen's EnCodec model trained on monophonic audio sampled at 32 kHz
|
24 |
+
# MusicGen's EnCodec is trained with a total stride of 640 leading to a frame rate of 50 hz
|
25 |
+
launcher.bind_(solver='compression/encodec_musicgen_32khz')
|
26 |
+
# replace this by the desired music dataset
|
27 |
+
launcher.bind_(dset='internal/music_400k_32khz')
|
28 |
+
# launch xp
|
29 |
+
launcher()
|
30 |
+
launcher({
|
31 |
+
'metrics.visqol.bin': '/data/home/jadecopet/local/usr/opt/visqol',
|
32 |
+
'label': 'visqol',
|
33 |
+
'evaluate.metrics.visqol': True
|
34 |
+
})
|
audiocraft/grids/diffusion/4_bands_base_32khz.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Training of the 4 diffusion models described in
|
9 |
+
"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
|
10 |
+
(paper link).
|
11 |
+
"""
|
12 |
+
|
13 |
+
from ._explorers import DiffusionExplorer
|
14 |
+
|
15 |
+
|
16 |
+
@DiffusionExplorer
|
17 |
+
def explorer(launcher):
|
18 |
+
launcher.slurm_(gpus=4, partition='learnfair')
|
19 |
+
|
20 |
+
launcher.bind_({'solver': 'diffusion/default',
|
21 |
+
'dset': 'internal/music_10k_32khz'})
|
22 |
+
|
23 |
+
with launcher.job_array():
|
24 |
+
launcher({'filter.use': True, 'filter.idx_band': 0, "processor.use": False, 'processor.power_std': 0.4})
|
25 |
+
launcher({'filter.use': True, 'filter.idx_band': 1, "processor.use": False, 'processor.power_std': 0.4})
|
26 |
+
launcher({'filter.use': True, 'filter.idx_band': 2, "processor.use": True, 'processor.power_std': 0.4})
|
27 |
+
launcher({'filter.use': True, 'filter.idx_band': 3, "processor.use": True, 'processor.power_std': 0.75})
|
audiocraft/grids/diffusion/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Diffusion grids."""
|
audiocraft/grids/diffusion/_explorers.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import treetable as tt
|
8 |
+
|
9 |
+
from .._base_explorers import BaseExplorer
|
10 |
+
|
11 |
+
|
12 |
+
class DiffusionExplorer(BaseExplorer):
|
13 |
+
eval_metrics = ["sisnr", "visqol"]
|
14 |
+
|
15 |
+
def stages(self):
|
16 |
+
return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"]
|
17 |
+
|
18 |
+
def get_grid_meta(self):
|
19 |
+
"""Returns the list of Meta information to display for each XP/job.
|
20 |
+
"""
|
21 |
+
return [
|
22 |
+
tt.leaf("index", align=">"),
|
23 |
+
tt.leaf("name", wrap=140),
|
24 |
+
tt.leaf("state"),
|
25 |
+
tt.leaf("sig", align=">"),
|
26 |
+
]
|
27 |
+
|
28 |
+
def get_grid_metrics(self):
|
29 |
+
"""Return the metrics that should be displayed in the tracking table.
|
30 |
+
"""
|
31 |
+
return [
|
32 |
+
tt.group(
|
33 |
+
"train",
|
34 |
+
[
|
35 |
+
tt.leaf("epoch"),
|
36 |
+
tt.leaf("loss", ".3%"),
|
37 |
+
],
|
38 |
+
align=">",
|
39 |
+
),
|
40 |
+
tt.group(
|
41 |
+
"valid",
|
42 |
+
[
|
43 |
+
tt.leaf("loss", ".3%"),
|
44 |
+
# tt.leaf("loss_0", ".3%"),
|
45 |
+
],
|
46 |
+
align=">",
|
47 |
+
),
|
48 |
+
tt.group(
|
49 |
+
"valid_ema",
|
50 |
+
[
|
51 |
+
tt.leaf("loss", ".3%"),
|
52 |
+
# tt.leaf("loss_0", ".3%"),
|
53 |
+
],
|
54 |
+
align=">",
|
55 |
+
),
|
56 |
+
tt.group(
|
57 |
+
"evaluate", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
|
58 |
+
tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
|
59 |
+
tt.leaf("rvm_3", ".4f"), ], align=">"
|
60 |
+
),
|
61 |
+
tt.group(
|
62 |
+
"evaluate_ema", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
|
63 |
+
tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
|
64 |
+
tt.leaf("rvm_3", ".4f")], align=">"
|
65 |
+
),
|
66 |
+
]
|
audiocraft/grids/musicgen/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""MusicGen grids."""
|
audiocraft/grids/musicgen/_explorers.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import treetable as tt
|
10 |
+
|
11 |
+
from .._base_explorers import BaseExplorer
|
12 |
+
|
13 |
+
|
14 |
+
class LMExplorer(BaseExplorer):
|
15 |
+
eval_metrics: tp.List[str] = []
|
16 |
+
|
17 |
+
def stages(self) -> tp.List[str]:
|
18 |
+
return ['train', 'valid']
|
19 |
+
|
20 |
+
def get_grid_metrics(self):
|
21 |
+
"""Return the metrics that should be displayed in the tracking table."""
|
22 |
+
return [
|
23 |
+
tt.group(
|
24 |
+
'train',
|
25 |
+
[
|
26 |
+
tt.leaf('epoch'),
|
27 |
+
tt.leaf('duration', '.1f'), # duration in minutes
|
28 |
+
tt.leaf('ping'),
|
29 |
+
tt.leaf('ce', '.4f'), # cross entropy
|
30 |
+
tt.leaf("ppl", '.3f'), # perplexity
|
31 |
+
],
|
32 |
+
align='>',
|
33 |
+
),
|
34 |
+
tt.group(
|
35 |
+
'valid',
|
36 |
+
[
|
37 |
+
tt.leaf('ce', '.4f'),
|
38 |
+
tt.leaf('ppl', '.3f'),
|
39 |
+
tt.leaf('best_ppl', '.3f'),
|
40 |
+
],
|
41 |
+
align='>',
|
42 |
+
),
|
43 |
+
]
|
44 |
+
|
45 |
+
def process_sheep(self, sheep, history):
|
46 |
+
parts = super().process_sheep(sheep, history)
|
47 |
+
|
48 |
+
track_by = {'ppl': 'lower'} # values should be in ['lower', 'higher']
|
49 |
+
best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()}
|
50 |
+
|
51 |
+
def comparator(mode, a, b):
|
52 |
+
return a < b if mode == 'lower' else a > b
|
53 |
+
|
54 |
+
for metrics in history:
|
55 |
+
for key, sub in metrics.items():
|
56 |
+
for metric in track_by:
|
57 |
+
# for the validation set, keep track of best metrics (ppl in this example)
|
58 |
+
# this is so we can conveniently compare metrics between runs in the grid
|
59 |
+
if key == 'valid' and metric in sub and comparator(
|
60 |
+
track_by[metric], sub[metric], best_metrics[metric]
|
61 |
+
):
|
62 |
+
best_metrics[metric] = sub[metric]
|
63 |
+
|
64 |
+
if 'valid' in parts:
|
65 |
+
parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()})
|
66 |
+
return parts
|
67 |
+
|
68 |
+
|
69 |
+
class GenerationEvalExplorer(BaseExplorer):
|
70 |
+
eval_metrics: tp.List[str] = []
|
71 |
+
|
72 |
+
def stages(self) -> tp.List[str]:
|
73 |
+
return ['evaluate']
|
74 |
+
|
75 |
+
def get_grid_metrics(self):
|
76 |
+
"""Return the metrics that should be displayed in the tracking table."""
|
77 |
+
return [
|
78 |
+
tt.group(
|
79 |
+
'evaluate',
|
80 |
+
[
|
81 |
+
tt.leaf('epoch', '.3f'),
|
82 |
+
tt.leaf('duration', '.1f'),
|
83 |
+
tt.leaf('ping'),
|
84 |
+
tt.leaf('ce', '.4f'),
|
85 |
+
tt.leaf('ppl', '.3f'),
|
86 |
+
tt.leaf('fad', '.3f'),
|
87 |
+
tt.leaf('kld', '.3f'),
|
88 |
+
tt.leaf('text_consistency', '.3f'),
|
89 |
+
tt.leaf('chroma_cosine', '.3f'),
|
90 |
+
],
|
91 |
+
align='>',
|
92 |
+
),
|
93 |
+
]
|
audiocraft/grids/musicgen/musicgen_base_32khz.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from ._explorers import LMExplorer
|
8 |
+
from ...environment import AudioCraftEnvironment
|
9 |
+
|
10 |
+
|
11 |
+
@LMExplorer
|
12 |
+
def explorer(launcher):
|
13 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
14 |
+
launcher.slurm_(gpus=32, partition=partitions)
|
15 |
+
launcher.bind_(solver='musicgen/musicgen_base_32khz')
|
16 |
+
# replace this by the desired music dataset
|
17 |
+
launcher.bind_(dset='internal/music_400k_32khz')
|
18 |
+
|
19 |
+
fsdp = {'autocast': False, 'fsdp.use': True}
|
20 |
+
medium = {'model/lm/model_scale': 'medium'}
|
21 |
+
large = {'model/lm/model_scale': 'large'}
|
22 |
+
|
23 |
+
cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
|
24 |
+
wd_low = {'conditioners.description.t5.word_dropout': 0.2}
|
25 |
+
|
26 |
+
adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
|
27 |
+
|
28 |
+
launcher.bind_(fsdp)
|
29 |
+
|
30 |
+
launcher.slurm_(gpus=32).bind_(label='32gpus')
|
31 |
+
with launcher.job_array():
|
32 |
+
sub = launcher.bind()
|
33 |
+
sub()
|
34 |
+
|
35 |
+
launcher.slurm_(gpus=64).bind_(label='64gpus')
|
36 |
+
with launcher.job_array():
|
37 |
+
sub = launcher.bind()
|
38 |
+
sub(medium, adam)
|
39 |
+
|
40 |
+
launcher.slurm_(gpus=96).bind_(label='96gpus')
|
41 |
+
with launcher.job_array():
|
42 |
+
sub = launcher.bind()
|
43 |
+
sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
|
audiocraft/grids/musicgen/musicgen_base_cached_32khz.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from ._explorers import LMExplorer
|
8 |
+
from ...environment import AudioCraftEnvironment
|
9 |
+
|
10 |
+
|
11 |
+
@LMExplorer
|
12 |
+
def explorer(launcher):
|
13 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
14 |
+
launcher.slurm_(gpus=32, partition=partitions)
|
15 |
+
launcher.bind_(solver='musicgen/musicgen_base_32khz')
|
16 |
+
# replace this by the desired music dataset
|
17 |
+
launcher.bind_(dset='internal/music_400k_32khz')
|
18 |
+
|
19 |
+
fsdp = {'autocast': False, 'fsdp.use': True}
|
20 |
+
medium = {'model/lm/model_scale': 'medium'}
|
21 |
+
large = {'model/lm/model_scale': 'large'}
|
22 |
+
|
23 |
+
cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
|
24 |
+
wd_low = {'conditioners.description.t5.word_dropout': 0.2}
|
25 |
+
|
26 |
+
adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
|
27 |
+
|
28 |
+
# BEGINNING OF CACHE WRITING JOBS.
|
29 |
+
cache_write = {
|
30 |
+
'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
|
31 |
+
'cache.write': True,
|
32 |
+
'generate.every': 500,
|
33 |
+
'evaluate.every': 500,
|
34 |
+
'logging.log_updates': 50,
|
35 |
+
}
|
36 |
+
|
37 |
+
cache_sub = launcher.bind({'model/lm/model_scale': 'xsmall', 'conditioner': 'none'})
|
38 |
+
cache_sub.bind_({'deadlock.use': True})
|
39 |
+
cache_sub.slurm_(gpus=8)
|
40 |
+
with launcher.job_array():
|
41 |
+
num_shards = 10 # total number of jobs running in parallel.
|
42 |
+
for shard in range(0, num_shards):
|
43 |
+
launcher(cache_write, {'cache.write_num_shards': num_shards, 'cache.write_shard': shard})
|
44 |
+
|
45 |
+
# REMOVE THE FOLLOWING RETURN STATEMENT ONCE THE ABOVE JOBS ARE DONE,
|
46 |
+
# OR SUFFICIENTLY AHEAD.
|
47 |
+
return
|
48 |
+
|
49 |
+
cache = {
|
50 |
+
'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
|
51 |
+
}
|
52 |
+
launcher.bind_(fsdp, cache)
|
53 |
+
|
54 |
+
launcher.slurm_(gpus=32).bind_(label='32gpus')
|
55 |
+
with launcher.job_array():
|
56 |
+
sub = launcher.bind()
|
57 |
+
sub()
|
58 |
+
|
59 |
+
launcher.slurm_(gpus=64).bind_(label='64gpus')
|
60 |
+
with launcher.job_array():
|
61 |
+
sub = launcher.bind()
|
62 |
+
sub(medium, adam)
|
63 |
+
|
64 |
+
launcher.slurm_(gpus=96).bind_(label='96gpus')
|
65 |
+
with launcher.job_array():
|
66 |
+
sub = launcher.bind()
|
67 |
+
sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
|
audiocraft/grids/musicgen/musicgen_clapemb_32khz.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from ._explorers import LMExplorer
|
8 |
+
from ...environment import AudioCraftEnvironment
|
9 |
+
|
10 |
+
|
11 |
+
@LMExplorer
|
12 |
+
def explorer(launcher):
|
13 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
14 |
+
launcher.slurm_(gpus=32, partition=partitions)
|
15 |
+
launcher.bind_(solver='musicgen/musicgen_base_32khz')
|
16 |
+
# replace this by the desired music dataset
|
17 |
+
launcher.bind_(dset='internal/music_400k_32khz')
|
18 |
+
launcher.bind_(conditioner='clapemb2music')
|
19 |
+
|
20 |
+
fsdp = {'autocast': False, 'fsdp.use': True}
|
21 |
+
cache_path = {'conditioners.description.clap.cache_path':
|
22 |
+
'/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/clap_embed_music'}
|
23 |
+
text_wav_training_opt = {'conditioners.description.clap.text_p': 0.5}
|
24 |
+
|
25 |
+
launcher.bind_(fsdp)
|
26 |
+
|
27 |
+
launcher.slurm_(gpus=32).bind_(label='32gpus')
|
28 |
+
with launcher.job_array():
|
29 |
+
launcher()
|
30 |
+
launcher(text_wav_training_opt)
|
31 |
+
launcher(cache_path)
|
32 |
+
launcher(cache_path, text_wav_training_opt)
|
audiocraft/grids/musicgen/musicgen_melody_32khz.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from ._explorers import LMExplorer
|
8 |
+
from ...environment import AudioCraftEnvironment
|
9 |
+
|
10 |
+
|
11 |
+
@LMExplorer
|
12 |
+
def explorer(launcher):
|
13 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
14 |
+
launcher.slurm_(gpus=32, partition=partitions)
|
15 |
+
launcher.bind_(solver='musicgen/musicgen_melody_32khz')
|
16 |
+
# replace this by the desired music dataset
|
17 |
+
launcher.bind_(dset='internal/music_400k_32khz')
|
18 |
+
|
19 |
+
fsdp = {'autocast': False, 'fsdp.use': True}
|
20 |
+
medium = {'model/lm/model_scale': 'medium'}
|
21 |
+
large = {'model/lm/model_scale': 'large'}
|
22 |
+
|
23 |
+
cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
|
24 |
+
wd_low = {'conditioners.description.t5.word_dropout': 0.2}
|
25 |
+
|
26 |
+
adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
|
27 |
+
|
28 |
+
cache_path = {'conditioners.self_wav.chroma_stem.cache_path':
|
29 |
+
'/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/chroma_stem'}
|
30 |
+
|
31 |
+
# CACHE GENERATION JOBS
|
32 |
+
n_cache_gen_jobs = 4
|
33 |
+
gen_sub = launcher.slurm(gpus=1)
|
34 |
+
gen_sub.bind_(
|
35 |
+
cache_path, {
|
36 |
+
# the cache is always computed over the whole file, so duration doesn't matter here.
|
37 |
+
'dataset.segment_duration': 2.,
|
38 |
+
'dataset.batch_size': 8,
|
39 |
+
'dataset.train.permutation_on_files': True, # try to not repeat files.
|
40 |
+
'optim.epochs': 10,
|
41 |
+
'model/lm/model_scale': 'xsmall',
|
42 |
+
|
43 |
+
})
|
44 |
+
with gen_sub.job_array():
|
45 |
+
for gen_job in range(n_cache_gen_jobs):
|
46 |
+
gen_sub({'dataset.train.shuffle_seed': gen_job})
|
47 |
+
|
48 |
+
# ACTUAL TRAINING JOBS.
|
49 |
+
launcher.bind_(fsdp)
|
50 |
+
|
51 |
+
launcher.slurm_(gpus=32).bind_(label='32gpus')
|
52 |
+
with launcher.job_array():
|
53 |
+
sub = launcher.bind()
|
54 |
+
sub()
|
55 |
+
sub(cache_path)
|
56 |
+
|
57 |
+
launcher.slurm_(gpus=64).bind_(label='64gpus')
|
58 |
+
with launcher.job_array():
|
59 |
+
sub = launcher.bind()
|
60 |
+
sub(medium, adam)
|
61 |
+
|
62 |
+
launcher.slurm_(gpus=96).bind_(label='96gpus')
|
63 |
+
with launcher.job_array():
|
64 |
+
sub = launcher.bind()
|
65 |
+
sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
|
audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Evaluation with objective metrics for the pretrained MusicGen models.
|
9 |
+
This grid takes signature from the training grid and runs evaluation-only stage.
|
10 |
+
|
11 |
+
When running the grid for the first time, please use:
|
12 |
+
REGEN=1 dora grid musicgen.musicgen_pretrained_32khz_eval
|
13 |
+
and re-use the REGEN=1 option when the grid is changed to force regenerating it.
|
14 |
+
|
15 |
+
Note that you need the proper metrics external libraries setup to use all
|
16 |
+
the objective metrics activated in this grid. Refer to the README for more information.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import os
|
20 |
+
|
21 |
+
from ._explorers import GenerationEvalExplorer
|
22 |
+
from ...environment import AudioCraftEnvironment
|
23 |
+
from ... import train
|
24 |
+
|
25 |
+
|
26 |
+
def eval(launcher, batch_size: int = 32, eval_melody: bool = False):
|
27 |
+
opts = {
|
28 |
+
'dset': 'audio/musiccaps_32khz',
|
29 |
+
'solver/musicgen/evaluation': 'objective_eval',
|
30 |
+
'execute_only': 'evaluate',
|
31 |
+
'+dataset.evaluate.batch_size': batch_size,
|
32 |
+
'+metrics.fad.tf.batch_size': 16,
|
33 |
+
}
|
34 |
+
# chroma-specific evaluation
|
35 |
+
chroma_opts = {
|
36 |
+
'dset': 'internal/music_400k_32khz',
|
37 |
+
'dataset.evaluate.segment_duration': 30,
|
38 |
+
'dataset.evaluate.num_samples': 1000,
|
39 |
+
'evaluate.metrics.chroma_cosine': True,
|
40 |
+
'evaluate.metrics.fad': False,
|
41 |
+
'evaluate.metrics.kld': False,
|
42 |
+
'evaluate.metrics.text_consistency': False,
|
43 |
+
}
|
44 |
+
# binary for FAD computation: replace this path with your own path
|
45 |
+
metrics_opts = {
|
46 |
+
'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
|
47 |
+
}
|
48 |
+
opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
|
49 |
+
opt2 = {'transformer_lm.two_step_cfg': True}
|
50 |
+
|
51 |
+
sub = launcher.bind(opts)
|
52 |
+
sub.bind_(metrics_opts)
|
53 |
+
|
54 |
+
# base objective metrics
|
55 |
+
sub(opt1, opt2)
|
56 |
+
|
57 |
+
if eval_melody:
|
58 |
+
# chroma-specific metrics
|
59 |
+
sub(opt1, opt2, chroma_opts)
|
60 |
+
|
61 |
+
|
62 |
+
@GenerationEvalExplorer
|
63 |
+
def explorer(launcher):
|
64 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
65 |
+
launcher.slurm_(gpus=4, partition=partitions)
|
66 |
+
|
67 |
+
if 'REGEN' not in os.environ:
|
68 |
+
folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
|
69 |
+
with launcher.job_array():
|
70 |
+
for sig in folder.iterdir():
|
71 |
+
if not sig.is_symlink():
|
72 |
+
continue
|
73 |
+
xp = train.main.get_xp_from_sig(sig.name)
|
74 |
+
launcher(xp.argv)
|
75 |
+
return
|
76 |
+
|
77 |
+
with launcher.job_array():
|
78 |
+
musicgen_base = launcher.bind(solver="musicgen/musicgen_base_32khz")
|
79 |
+
musicgen_base.bind_({'autocast': False, 'fsdp.use': True})
|
80 |
+
|
81 |
+
# base musicgen models
|
82 |
+
musicgen_base_small = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-small'})
|
83 |
+
eval(musicgen_base_small, batch_size=128)
|
84 |
+
|
85 |
+
musicgen_base_medium = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-medium'})
|
86 |
+
musicgen_base_medium.bind_({'model/lm/model_scale': 'medium'})
|
87 |
+
eval(musicgen_base_medium, batch_size=128)
|
88 |
+
|
89 |
+
musicgen_base_large = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-large'})
|
90 |
+
musicgen_base_large.bind_({'model/lm/model_scale': 'large'})
|
91 |
+
eval(musicgen_base_large, batch_size=128)
|
92 |
+
|
93 |
+
# melody musicgen model
|
94 |
+
musicgen_melody = launcher.bind(solver="musicgen/musicgen_melody_32khz")
|
95 |
+
musicgen_melody.bind_({'autocast': False, 'fsdp.use': True})
|
96 |
+
|
97 |
+
musicgen_melody_medium = musicgen_melody.bind({'continue_from': '//pretrained/facebook/musicgen-melody'})
|
98 |
+
musicgen_melody_medium.bind_({'model/lm/model_scale': 'medium'})
|
99 |
+
eval(musicgen_melody_medium, batch_size=128, eval_melody=True)
|
audiocraft/losses/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Loss related classes and functions. In particular the loss balancer from
|
7 |
+
EnCodec, and the usual spectral losses."""
|
8 |
+
|
9 |
+
# flake8: noqa
|
10 |
+
from .balancer import Balancer
|
11 |
+
from .sisnr import SISNR
|
12 |
+
from .stftloss import (
|
13 |
+
LogSTFTMagnitudeLoss,
|
14 |
+
MRSTFTLoss,
|
15 |
+
SpectralConvergenceLoss,
|
16 |
+
STFTLoss
|
17 |
+
)
|
18 |
+
from .specloss import (
|
19 |
+
MelSpectrogramL1Loss,
|
20 |
+
MultiScaleMelSpectrogramLoss,
|
21 |
+
)
|
audiocraft/losses/balancer.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import flashy
|
10 |
+
import torch
|
11 |
+
from torch import autograd
|
12 |
+
|
13 |
+
|
14 |
+
class Balancer:
|
15 |
+
"""Loss balancer.
|
16 |
+
|
17 |
+
The loss balancer combines losses together to compute gradients for the backward.
|
18 |
+
Given `y = f(...)`, and a number of losses `l1(y, ...)`, `l2(y, ...)`, with `...`
|
19 |
+
not having any dependence on `f`, the balancer can efficiently normalize the partial gradients
|
20 |
+
`d l1 / d y`, `d l2 / dy` before summing them in order to achieve a desired ratio between
|
21 |
+
the losses. For instance if `weights = {'l1': 2, 'l2': 1}`, 66% of the gradient
|
22 |
+
going into `f(...)` will come from `l1` on average, and 33% from `l2`. This allows for an easy
|
23 |
+
interpration of the weights even if the intrisic scale of `l1`, `l2` ... is unknown.
|
24 |
+
|
25 |
+
Noting `g1 = d l1 / dy`, etc., the balanced gradient `G` will be
|
26 |
+
(with `avg` an exponential moving average over the updates),
|
27 |
+
|
28 |
+
G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i)
|
29 |
+
|
30 |
+
If `balance_grads` is False, this is deactivated, and instead the gradient will just be the
|
31 |
+
standard sum of the partial gradients with the given weights.
|
32 |
+
|
33 |
+
A call to the backward method of the balancer will compute the the partial gradients,
|
34 |
+
combining all the losses and potentially rescaling the gradients,
|
35 |
+
which can help stabilize the training and reason about multiple losses with varying scales.
|
36 |
+
The obtained gradient with respect to `y` is then back-propagated to `f(...)`.
|
37 |
+
|
38 |
+
Expected usage:
|
39 |
+
|
40 |
+
weights = {'loss_a': 1, 'loss_b': 4}
|
41 |
+
balancer = Balancer(weights, ...)
|
42 |
+
losses: dict = {}
|
43 |
+
losses['loss_a'] = compute_loss_a(x, y)
|
44 |
+
losses['loss_b'] = compute_loss_b(x, y)
|
45 |
+
if model.training():
|
46 |
+
effective_loss = balancer.backward(losses, x)
|
47 |
+
|
48 |
+
Args:
|
49 |
+
weights (dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys
|
50 |
+
from the backward method to match the weights keys to assign weight to each of the provided loss.
|
51 |
+
balance_grads (bool): Whether to rescale gradients so that weights reflect the fraction of the
|
52 |
+
overall gradient, rather than a constant multiplier.
|
53 |
+
total_norm (float): Reference norm when rescaling gradients, ignored otherwise.
|
54 |
+
emay_decay (float): EMA decay for averaging the norms.
|
55 |
+
per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds
|
56 |
+
when rescaling the gradients.
|
57 |
+
epsilon (float): Epsilon value for numerical stability.
|
58 |
+
monitor (bool): If True, stores in `self.metrics` the relative ratio between the norm of the gradients
|
59 |
+
coming from each loss, when calling `backward()`.
|
60 |
+
"""
|
61 |
+
def __init__(self, weights: tp.Dict[str, float], balance_grads: bool = True, total_norm: float = 1.,
|
62 |
+
ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12,
|
63 |
+
monitor: bool = False):
|
64 |
+
self.weights = weights
|
65 |
+
self.per_batch_item = per_batch_item
|
66 |
+
self.total_norm = total_norm or 1.
|
67 |
+
self.averager = flashy.averager(ema_decay or 1.)
|
68 |
+
self.epsilon = epsilon
|
69 |
+
self.monitor = monitor
|
70 |
+
self.balance_grads = balance_grads
|
71 |
+
self._metrics: tp.Dict[str, tp.Any] = {}
|
72 |
+
|
73 |
+
@property
|
74 |
+
def metrics(self):
|
75 |
+
return self._metrics
|
76 |
+
|
77 |
+
def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor) -> torch.Tensor:
|
78 |
+
"""Compute the backward and return the effective train loss, e.g. the loss obtained from
|
79 |
+
computing the effective weights. If `balance_grads` is True, the effective weights
|
80 |
+
are the one that needs to be applied to each gradient to respect the desired relative
|
81 |
+
scale of gradients coming from each loss.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`.
|
85 |
+
input (torch.Tensor): the input of the losses, typically the output of the model.
|
86 |
+
This should be the single point of dependence between the losses
|
87 |
+
and the model being trained.
|
88 |
+
"""
|
89 |
+
norms = {}
|
90 |
+
grads = {}
|
91 |
+
for name, loss in losses.items():
|
92 |
+
# Compute partial derivative of the less with respect to the input.
|
93 |
+
grad, = autograd.grad(loss, [input], retain_graph=True)
|
94 |
+
if self.per_batch_item:
|
95 |
+
# We do not average the gradient over the batch dimension.
|
96 |
+
dims = tuple(range(1, grad.dim()))
|
97 |
+
norm = grad.norm(dim=dims, p=2).mean()
|
98 |
+
else:
|
99 |
+
norm = grad.norm(p=2)
|
100 |
+
norms[name] = norm
|
101 |
+
grads[name] = grad
|
102 |
+
|
103 |
+
count = 1
|
104 |
+
if self.per_batch_item:
|
105 |
+
count = len(grad)
|
106 |
+
# Average norms across workers. Theoretically we should average the
|
107 |
+
# squared norm, then take the sqrt, but it worked fine like that.
|
108 |
+
avg_norms = flashy.distrib.average_metrics(self.averager(norms), count)
|
109 |
+
# We approximate the total norm of the gradient as the sums of the norms.
|
110 |
+
# Obviously this can be very incorrect if all gradients are aligned, but it works fine.
|
111 |
+
total = sum(avg_norms.values())
|
112 |
+
|
113 |
+
self._metrics = {}
|
114 |
+
if self.monitor:
|
115 |
+
# Store the ratio of the total gradient represented by each loss.
|
116 |
+
for k, v in avg_norms.items():
|
117 |
+
self._metrics[f'ratio_{k}'] = v / total
|
118 |
+
|
119 |
+
total_weights = sum([self.weights[k] for k in avg_norms])
|
120 |
+
assert total_weights > 0.
|
121 |
+
desired_ratios = {k: w / total_weights for k, w in self.weights.items()}
|
122 |
+
|
123 |
+
out_grad = torch.zeros_like(input)
|
124 |
+
effective_loss = torch.tensor(0., device=input.device, dtype=input.dtype)
|
125 |
+
for name, avg_norm in avg_norms.items():
|
126 |
+
if self.balance_grads:
|
127 |
+
# g_balanced = g / avg(||g||) * total_norm * desired_ratio
|
128 |
+
scale = desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm)
|
129 |
+
else:
|
130 |
+
# We just do regular weighted sum of the gradients.
|
131 |
+
scale = self.weights[name]
|
132 |
+
out_grad.add_(grads[name], alpha=scale)
|
133 |
+
effective_loss += scale * losses[name].detach()
|
134 |
+
# Send the computed partial derivative with respect to the output of the model to the model.
|
135 |
+
input.backward(out_grad)
|
136 |
+
return effective_loss
|
audiocraft/losses/sisnr.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
import typing as tp
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
from torch.nn import functional as F
|
13 |
+
|
14 |
+
|
15 |
+
def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
|
16 |
+
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
|
17 |
+
with K the kernel size, by extracting frames with the given stride.
|
18 |
+
This will pad the input so that `F = ceil(T / K)`.
|
19 |
+
see https://github.com/pytorch/pytorch/issues/60466
|
20 |
+
"""
|
21 |
+
*shape, length = a.shape
|
22 |
+
n_frames = math.ceil(length / stride)
|
23 |
+
tgt_length = (n_frames - 1) * stride + kernel_size
|
24 |
+
a = F.pad(a, (0, tgt_length - length))
|
25 |
+
strides = list(a.stride())
|
26 |
+
assert strides[-1] == 1, "data should be contiguous"
|
27 |
+
strides = strides[:-1] + [stride, 1]
|
28 |
+
return a.as_strided([*shape, n_frames, kernel_size], strides)
|
29 |
+
|
30 |
+
|
31 |
+
def _center(x: torch.Tensor) -> torch.Tensor:
|
32 |
+
return x - x.mean(-1, True)
|
33 |
+
|
34 |
+
|
35 |
+
def _norm2(x: torch.Tensor) -> torch.Tensor:
|
36 |
+
return x.pow(2).sum(-1, True)
|
37 |
+
|
38 |
+
|
39 |
+
class SISNR(nn.Module):
|
40 |
+
"""SISNR loss.
|
41 |
+
|
42 |
+
Input should be [B, C, T], output is scalar.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
sample_rate (int): Sample rate.
|
46 |
+
segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on
|
47 |
+
entire audio only.
|
48 |
+
overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap.
|
49 |
+
epsilon (float): Epsilon value for numerical stability.
|
50 |
+
"""
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
sample_rate: int = 16000,
|
54 |
+
segment: tp.Optional[float] = 20,
|
55 |
+
overlap: float = 0.5,
|
56 |
+
epsilon: float = torch.finfo(torch.float32).eps,
|
57 |
+
):
|
58 |
+
super().__init__()
|
59 |
+
self.sample_rate = sample_rate
|
60 |
+
self.segment = segment
|
61 |
+
self.overlap = overlap
|
62 |
+
self.epsilon = epsilon
|
63 |
+
|
64 |
+
def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
|
65 |
+
B, C, T = ref_sig.shape
|
66 |
+
assert ref_sig.shape == out_sig.shape
|
67 |
+
|
68 |
+
if self.segment is None:
|
69 |
+
frame = T
|
70 |
+
stride = T
|
71 |
+
else:
|
72 |
+
frame = int(self.segment * self.sample_rate)
|
73 |
+
stride = int(frame * (1 - self.overlap))
|
74 |
+
|
75 |
+
epsilon = self.epsilon * frame # make epsilon prop to frame size.
|
76 |
+
|
77 |
+
gt = _unfold(ref_sig, frame, stride)
|
78 |
+
est = _unfold(out_sig, frame, stride)
|
79 |
+
if self.segment is None:
|
80 |
+
assert gt.shape[-1] == 1
|
81 |
+
|
82 |
+
gt = _center(gt)
|
83 |
+
est = _center(est)
|
84 |
+
dot = torch.einsum("bcft,bcft->bcf", gt, est)
|
85 |
+
|
86 |
+
proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt))
|
87 |
+
noise = est - proj
|
88 |
+
|
89 |
+
sisnr = 10 * (
|
90 |
+
torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise))
|
91 |
+
)
|
92 |
+
return -1 * sisnr[..., 0].mean()
|
audiocraft/losses/specloss.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
from torchaudio.transforms import MelSpectrogram
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
from ..modules import pad_for_conv1d
|
16 |
+
|
17 |
+
|
18 |
+
class MelSpectrogramWrapper(nn.Module):
|
19 |
+
"""Wrapper around MelSpectrogram torchaudio transform providing proper padding
|
20 |
+
and additional post-processing including log scaling.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
n_mels (int): Number of mel bins.
|
24 |
+
n_fft (int): Number of fft.
|
25 |
+
hop_length (int): Hop size.
|
26 |
+
win_length (int): Window length.
|
27 |
+
n_mels (int): Number of mel bins.
|
28 |
+
sample_rate (int): Sample rate.
|
29 |
+
f_min (float or None): Minimum frequency.
|
30 |
+
f_max (float or None): Maximum frequency.
|
31 |
+
log (bool): Whether to scale with log.
|
32 |
+
normalized (bool): Whether to normalize the melspectrogram.
|
33 |
+
floor_level (float): Floor level based on human perception (default=1e-5).
|
34 |
+
"""
|
35 |
+
def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_length: tp.Optional[int] = None,
|
36 |
+
n_mels: int = 80, sample_rate: float = 22050, f_min: float = 0.0, f_max: tp.Optional[float] = None,
|
37 |
+
log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
|
38 |
+
super().__init__()
|
39 |
+
self.n_fft = n_fft
|
40 |
+
hop_length = int(hop_length)
|
41 |
+
self.hop_length = hop_length
|
42 |
+
self.mel_transform = MelSpectrogram(n_mels=n_mels, sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
|
43 |
+
win_length=win_length, f_min=f_min, f_max=f_max, normalized=normalized,
|
44 |
+
window_fn=torch.hann_window, center=False)
|
45 |
+
self.floor_level = floor_level
|
46 |
+
self.log = log
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
p = int((self.n_fft - self.hop_length) // 2)
|
50 |
+
if len(x.shape) == 2:
|
51 |
+
x = x.unsqueeze(1)
|
52 |
+
x = F.pad(x, (p, p), "reflect")
|
53 |
+
# Make sure that all the frames are full.
|
54 |
+
# The combination of `pad_for_conv1d` and the above padding
|
55 |
+
# will make the output of size ceil(T / hop).
|
56 |
+
x = pad_for_conv1d(x, self.n_fft, self.hop_length)
|
57 |
+
self.mel_transform.to(x.device)
|
58 |
+
mel_spec = self.mel_transform(x)
|
59 |
+
B, C, freqs, frame = mel_spec.shape
|
60 |
+
if self.log:
|
61 |
+
mel_spec = torch.log10(self.floor_level + mel_spec)
|
62 |
+
return mel_spec.reshape(B, C * freqs, frame)
|
63 |
+
|
64 |
+
|
65 |
+
class MelSpectrogramL1Loss(torch.nn.Module):
|
66 |
+
"""L1 Loss on MelSpectrogram.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
sample_rate (int): Sample rate.
|
70 |
+
n_fft (int): Number of fft.
|
71 |
+
hop_length (int): Hop size.
|
72 |
+
win_length (int): Window length.
|
73 |
+
n_mels (int): Number of mel bins.
|
74 |
+
f_min (float or None): Minimum frequency.
|
75 |
+
f_max (float or None): Maximum frequency.
|
76 |
+
log (bool): Whether to scale with log.
|
77 |
+
normalized (bool): Whether to normalize the melspectrogram.
|
78 |
+
floor_level (float): Floor level value based on human perception (default=1e-5).
|
79 |
+
"""
|
80 |
+
def __init__(self, sample_rate: int, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024,
|
81 |
+
n_mels: int = 80, f_min: float = 0.0, f_max: tp.Optional[float] = None,
|
82 |
+
log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
|
83 |
+
super().__init__()
|
84 |
+
self.l1 = torch.nn.L1Loss()
|
85 |
+
self.melspec = MelSpectrogramWrapper(n_fft=n_fft, hop_length=hop_length, win_length=win_length,
|
86 |
+
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
|
87 |
+
log=log, normalized=normalized, floor_level=floor_level)
|
88 |
+
|
89 |
+
def forward(self, x, y):
|
90 |
+
self.melspec.to(x.device)
|
91 |
+
s_x = self.melspec(x)
|
92 |
+
s_y = self.melspec(y)
|
93 |
+
return self.l1(s_x, s_y)
|
94 |
+
|
95 |
+
|
96 |
+
class MultiScaleMelSpectrogramLoss(nn.Module):
|
97 |
+
"""Multi-Scale spectrogram loss (msspec).
|
98 |
+
|
99 |
+
Args:
|
100 |
+
sample_rate (int): Sample rate.
|
101 |
+
range_start (int): Power of 2 to use for the first scale.
|
102 |
+
range_stop (int): Power of 2 to use for the last scale.
|
103 |
+
n_mels (int): Number of mel bins.
|
104 |
+
f_min (float): Minimum frequency.
|
105 |
+
f_max (float or None): Maximum frequency.
|
106 |
+
normalized (bool): Whether to normalize the melspectrogram.
|
107 |
+
alphas (bool): Whether to use alphas as coefficients or not.
|
108 |
+
floor_level (float): Floor level value based on human perception (default=1e-5).
|
109 |
+
"""
|
110 |
+
def __init__(self, sample_rate: int, range_start: int = 6, range_end: int = 11,
|
111 |
+
n_mels: int = 64, f_min: float = 0.0, f_max: tp.Optional[float] = None,
|
112 |
+
normalized: bool = False, alphas: bool = True, floor_level: float = 1e-5):
|
113 |
+
super().__init__()
|
114 |
+
l1s = list()
|
115 |
+
l2s = list()
|
116 |
+
self.alphas = list()
|
117 |
+
self.total = 0
|
118 |
+
self.normalized = normalized
|
119 |
+
for i in range(range_start, range_end):
|
120 |
+
l1s.append(
|
121 |
+
MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
|
122 |
+
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
|
123 |
+
log=False, normalized=normalized, floor_level=floor_level))
|
124 |
+
l2s.append(
|
125 |
+
MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
|
126 |
+
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
|
127 |
+
log=True, normalized=normalized, floor_level=floor_level))
|
128 |
+
if alphas:
|
129 |
+
self.alphas.append(np.sqrt(2 ** i - 1))
|
130 |
+
else:
|
131 |
+
self.alphas.append(1)
|
132 |
+
self.total += self.alphas[-1] + 1
|
133 |
+
|
134 |
+
self.l1s = nn.ModuleList(l1s)
|
135 |
+
self.l2s = nn.ModuleList(l2s)
|
136 |
+
|
137 |
+
def forward(self, x, y):
|
138 |
+
loss = 0.0
|
139 |
+
self.l1s.to(x.device)
|
140 |
+
self.l2s.to(x.device)
|
141 |
+
for i in range(len(self.alphas)):
|
142 |
+
s_x_1 = self.l1s[i](x)
|
143 |
+
s_y_1 = self.l1s[i](y)
|
144 |
+
s_x_2 = self.l2s[i](x)
|
145 |
+
s_y_2 = self.l2s[i](y)
|
146 |
+
loss += F.l1_loss(s_x_1, s_y_1) + self.alphas[i] * F.mse_loss(s_x_2, s_y_2)
|
147 |
+
if self.normalized:
|
148 |
+
loss = loss / self.total
|
149 |
+
return loss
|
audiocraft/losses/stftloss.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Adapted from MIT code under the original license
|
7 |
+
# Copyright 2019 Tomoki Hayashi
|
8 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
9 |
+
import typing as tp
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
|
16 |
+
# TODO: Replace with torchaudio.STFT?
|
17 |
+
def _stft(x: torch.Tensor, fft_size: int, hop_length: int, win_length: int,
|
18 |
+
window: tp.Optional[torch.Tensor], normalized: bool) -> torch.Tensor:
|
19 |
+
"""Perform STFT and convert to magnitude spectrogram.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
x: Input signal tensor (B, C, T).
|
23 |
+
fft_size (int): FFT size.
|
24 |
+
hop_length (int): Hop size.
|
25 |
+
win_length (int): Window length.
|
26 |
+
window (torch.Tensor or None): Window function type.
|
27 |
+
normalized (bool): Whether to normalize the STFT or not.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
torch.Tensor: Magnitude spectrogram (B, C, #frames, fft_size // 2 + 1).
|
31 |
+
"""
|
32 |
+
B, C, T = x.shape
|
33 |
+
x_stft = torch.stft(
|
34 |
+
x.view(-1, T), fft_size, hop_length, win_length, window,
|
35 |
+
normalized=normalized, return_complex=True,
|
36 |
+
)
|
37 |
+
x_stft = x_stft.view(B, C, *x_stft.shape[1:])
|
38 |
+
real = x_stft.real
|
39 |
+
imag = x_stft.imag
|
40 |
+
|
41 |
+
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
42 |
+
return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
|
43 |
+
|
44 |
+
|
45 |
+
class SpectralConvergenceLoss(nn.Module):
|
46 |
+
"""Spectral convergence loss.
|
47 |
+
"""
|
48 |
+
def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
|
49 |
+
super().__init__()
|
50 |
+
self.epsilon = epsilon
|
51 |
+
|
52 |
+
def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
|
53 |
+
"""Calculate forward propagation.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
x_mag: Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
57 |
+
y_mag: Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
58 |
+
Returns:
|
59 |
+
torch.Tensor: Spectral convergence loss value.
|
60 |
+
"""
|
61 |
+
return torch.norm(y_mag - x_mag, p="fro") / (torch.norm(y_mag, p="fro") + self.epsilon)
|
62 |
+
|
63 |
+
|
64 |
+
class LogSTFTMagnitudeLoss(nn.Module):
|
65 |
+
"""Log STFT magnitude loss.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
epsilon (float): Epsilon value for numerical stability.
|
69 |
+
"""
|
70 |
+
def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
|
71 |
+
super().__init__()
|
72 |
+
self.epsilon = epsilon
|
73 |
+
|
74 |
+
def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
|
75 |
+
"""Calculate forward propagation.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
x_mag (torch.Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
79 |
+
y_mag (torch.Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
80 |
+
Returns:
|
81 |
+
torch.Tensor: Log STFT magnitude loss value.
|
82 |
+
"""
|
83 |
+
return F.l1_loss(torch.log(self.epsilon + y_mag), torch.log(self.epsilon + x_mag))
|
84 |
+
|
85 |
+
|
86 |
+
class STFTLosses(nn.Module):
|
87 |
+
"""STFT losses.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
n_fft (int): Size of FFT.
|
91 |
+
hop_length (int): Hop length.
|
92 |
+
win_length (int): Window length.
|
93 |
+
window (str): Window function type.
|
94 |
+
normalized (bool): Whether to use normalized STFT or not.
|
95 |
+
epsilon (float): Epsilon for numerical stability.
|
96 |
+
"""
|
97 |
+
def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
|
98 |
+
window: str = "hann_window", normalized: bool = False,
|
99 |
+
epsilon: float = torch.finfo(torch.float32).eps):
|
100 |
+
super().__init__()
|
101 |
+
self.n_fft = n_fft
|
102 |
+
self.hop_length = hop_length
|
103 |
+
self.win_length = win_length
|
104 |
+
self.normalized = normalized
|
105 |
+
self.register_buffer("window", getattr(torch, window)(win_length))
|
106 |
+
self.spectral_convergenge_loss = SpectralConvergenceLoss(epsilon)
|
107 |
+
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(epsilon)
|
108 |
+
|
109 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
110 |
+
"""Calculate forward propagation.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
x (torch.Tensor): Predicted signal (B, T).
|
114 |
+
y (torch.Tensor): Groundtruth signal (B, T).
|
115 |
+
Returns:
|
116 |
+
torch.Tensor: Spectral convergence loss value.
|
117 |
+
torch.Tensor: Log STFT magnitude loss value.
|
118 |
+
"""
|
119 |
+
x_mag = _stft(x, self.n_fft, self.hop_length,
|
120 |
+
self.win_length, self.window, self.normalized) # type: ignore
|
121 |
+
y_mag = _stft(y, self.n_fft, self.hop_length,
|
122 |
+
self.win_length, self.window, self.normalized) # type: ignore
|
123 |
+
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
124 |
+
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
125 |
+
|
126 |
+
return sc_loss, mag_loss
|
127 |
+
|
128 |
+
|
129 |
+
class STFTLoss(nn.Module):
|
130 |
+
"""Single Resolution STFT loss.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
n_fft (int): Nb of FFT.
|
134 |
+
hop_length (int): Hop length.
|
135 |
+
win_length (int): Window length.
|
136 |
+
window (str): Window function type.
|
137 |
+
normalized (bool): Whether to use normalized STFT or not.
|
138 |
+
epsilon (float): Epsilon for numerical stability.
|
139 |
+
factor_sc (float): Coefficient for the spectral loss.
|
140 |
+
factor_mag (float): Coefficient for the magnitude loss.
|
141 |
+
"""
|
142 |
+
def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
|
143 |
+
window: str = "hann_window", normalized: bool = False,
|
144 |
+
factor_sc: float = 0.1, factor_mag: float = 0.1,
|
145 |
+
epsilon: float = torch.finfo(torch.float32).eps):
|
146 |
+
super().__init__()
|
147 |
+
self.loss = STFTLosses(n_fft, hop_length, win_length, window, normalized, epsilon)
|
148 |
+
self.factor_sc = factor_sc
|
149 |
+
self.factor_mag = factor_mag
|
150 |
+
|
151 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
152 |
+
"""Calculate forward propagation.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
x (torch.Tensor): Predicted signal (B, T).
|
156 |
+
y (torch.Tensor): Groundtruth signal (B, T).
|
157 |
+
Returns:
|
158 |
+
torch.Tensor: Single resolution STFT loss.
|
159 |
+
"""
|
160 |
+
sc_loss, mag_loss = self.loss(x, y)
|
161 |
+
return self.factor_sc * sc_loss + self.factor_mag * mag_loss
|
162 |
+
|
163 |
+
|
164 |
+
class MRSTFTLoss(nn.Module):
|
165 |
+
"""Multi resolution STFT loss.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
n_ffts (Sequence[int]): Sequence of FFT sizes.
|
169 |
+
hop_lengths (Sequence[int]): Sequence of hop sizes.
|
170 |
+
win_lengths (Sequence[int]): Sequence of window lengths.
|
171 |
+
window (str): Window function type.
|
172 |
+
factor_sc (float): Coefficient for the spectral loss.
|
173 |
+
factor_mag (float): Coefficient for the magnitude loss.
|
174 |
+
normalized (bool): Whether to use normalized STFT or not.
|
175 |
+
epsilon (float): Epsilon for numerical stability.
|
176 |
+
"""
|
177 |
+
def __init__(self, n_ffts: tp.Sequence[int] = [1024, 2048, 512], hop_lengths: tp.Sequence[int] = [120, 240, 50],
|
178 |
+
win_lengths: tp.Sequence[int] = [600, 1200, 240], window: str = "hann_window",
|
179 |
+
factor_sc: float = 0.1, factor_mag: float = 0.1,
|
180 |
+
normalized: bool = False, epsilon: float = torch.finfo(torch.float32).eps):
|
181 |
+
super().__init__()
|
182 |
+
assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
|
183 |
+
self.stft_losses = torch.nn.ModuleList()
|
184 |
+
for fs, ss, wl in zip(n_ffts, hop_lengths, win_lengths):
|
185 |
+
self.stft_losses += [STFTLosses(fs, ss, wl, window, normalized, epsilon)]
|
186 |
+
self.factor_sc = factor_sc
|
187 |
+
self.factor_mag = factor_mag
|
188 |
+
|
189 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
190 |
+
"""Calculate forward propagation.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
x (torch.Tensor): Predicted signal (B, T).
|
194 |
+
y (torch.Tensor): Groundtruth signal (B, T).
|
195 |
+
Returns:
|
196 |
+
torch.Tensor: Multi resolution STFT loss.
|
197 |
+
"""
|
198 |
+
sc_loss = torch.Tensor([0.0])
|
199 |
+
mag_loss = torch.Tensor([0.0])
|
200 |
+
for f in self.stft_losses:
|
201 |
+
sc_l, mag_l = f(x, y)
|
202 |
+
sc_loss += sc_l
|
203 |
+
mag_loss += mag_l
|
204 |
+
sc_loss /= len(self.stft_losses)
|
205 |
+
mag_loss /= len(self.stft_losses)
|
206 |
+
|
207 |
+
return self.factor_sc * sc_loss + self.factor_mag * mag_loss
|
audiocraft/metrics/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Metrics like CLAP score, FAD, KLD, Visqol, Chroma similarity, etc.
|
7 |
+
"""
|
8 |
+
# flake8: noqa
|
9 |
+
from .clap_consistency import CLAPTextConsistencyMetric, TextConsistencyMetric
|
10 |
+
from .chroma_cosinesim import ChromaCosineSimilarityMetric
|
11 |
+
from .fad import FrechetAudioDistanceMetric
|
12 |
+
from .kld import KLDivergenceMetric, PasstKLDivergenceMetric
|
13 |
+
from .rvm import RelativeVolumeMel
|
14 |
+
from .visqol import ViSQOL
|
audiocraft/metrics/chroma_cosinesim.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torchmetrics
|
9 |
+
|
10 |
+
from ..data.audio_utils import convert_audio
|
11 |
+
from ..modules.chroma import ChromaExtractor
|
12 |
+
|
13 |
+
|
14 |
+
class ChromaCosineSimilarityMetric(torchmetrics.Metric):
|
15 |
+
"""Chroma cosine similarity metric.
|
16 |
+
|
17 |
+
This metric extracts a chromagram for a reference waveform and
|
18 |
+
a generated waveform and compares each frame using the cosine similarity
|
19 |
+
function. The output is the mean cosine similarity.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
sample_rate (int): Sample rate used by the chroma extractor.
|
23 |
+
n_chroma (int): Number of chroma used by the chroma extractor.
|
24 |
+
radix2_exp (int): Exponent for the chroma extractor.
|
25 |
+
argmax (bool): Whether the chroma extractor uses argmax.
|
26 |
+
eps (float): Epsilon for cosine similarity computation.
|
27 |
+
"""
|
28 |
+
def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, argmax: bool, eps: float = 1e-8):
|
29 |
+
super().__init__()
|
30 |
+
self.chroma_sample_rate = sample_rate
|
31 |
+
self.n_chroma = n_chroma
|
32 |
+
self.eps = eps
|
33 |
+
self.chroma_extractor = ChromaExtractor(sample_rate=self.chroma_sample_rate, n_chroma=self.n_chroma,
|
34 |
+
radix2_exp=radix2_exp, argmax=argmax)
|
35 |
+
self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
36 |
+
self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
|
37 |
+
|
38 |
+
def update(self, preds: torch.Tensor, targets: torch.Tensor,
|
39 |
+
sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
|
40 |
+
"""Compute cosine similarity between chromagrams and accumulate scores over the dataset."""
|
41 |
+
if preds.size(0) == 0:
|
42 |
+
return
|
43 |
+
|
44 |
+
assert preds.shape == targets.shape, (
|
45 |
+
f"Preds and target shapes mismatch: preds={preds.shape}, targets={targets.shape}")
|
46 |
+
assert preds.size(0) == sizes.size(0), (
|
47 |
+
f"Number of items in preds ({preds.shape}) mismatch ",
|
48 |
+
f"with sizes ({sizes.shape})")
|
49 |
+
assert preds.size(0) == sample_rates.size(0), (
|
50 |
+
f"Number of items in preds ({preds.shape}) mismatch ",
|
51 |
+
f"with sample_rates ({sample_rates.shape})")
|
52 |
+
assert torch.all(sample_rates == sample_rates[0].item()), "All sample rates are not the same in the batch"
|
53 |
+
|
54 |
+
device = self.weight.device
|
55 |
+
preds, targets = preds.to(device), targets.to(device) # type: ignore
|
56 |
+
sample_rate = sample_rates[0].item()
|
57 |
+
preds = convert_audio(preds, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
|
58 |
+
targets = convert_audio(targets, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
|
59 |
+
gt_chroma = self.chroma_extractor(targets)
|
60 |
+
gen_chroma = self.chroma_extractor(preds)
|
61 |
+
chroma_lens = (sizes / self.chroma_extractor.winhop).ceil().int()
|
62 |
+
for i in range(len(gt_chroma)):
|
63 |
+
t = int(chroma_lens[i].item())
|
64 |
+
cosine_sim = torch.nn.functional.cosine_similarity(
|
65 |
+
gt_chroma[i, :t], gen_chroma[i, :t], dim=1, eps=self.eps)
|
66 |
+
self.cosine_sum += cosine_sim.sum(dim=0) # type: ignore
|
67 |
+
self.weight += torch.tensor(t) # type: ignore
|
68 |
+
|
69 |
+
def compute(self) -> float:
|
70 |
+
"""Computes the average cosine similarty across all generated/target chromagrams pairs."""
|
71 |
+
assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
|
72 |
+
return (self.cosine_sum / self.weight).item() # type: ignore
|
audiocraft/metrics/clap_consistency.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from pathlib import Path
|
8 |
+
import typing as tp
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torchmetrics
|
12 |
+
from transformers import RobertaTokenizer # type: ignore
|
13 |
+
|
14 |
+
from ..data.audio_utils import convert_audio
|
15 |
+
from ..environment import AudioCraftEnvironment
|
16 |
+
from ..utils.utils import load_clap_state_dict
|
17 |
+
|
18 |
+
try:
|
19 |
+
import laion_clap # type: ignore
|
20 |
+
except ImportError:
|
21 |
+
laion_clap = None
|
22 |
+
|
23 |
+
|
24 |
+
class TextConsistencyMetric(torchmetrics.Metric):
|
25 |
+
"""Text consistency metric measuring consistency between audio and text pairs."""
|
26 |
+
|
27 |
+
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
|
28 |
+
raise NotImplementedError("implement how to update the metric from the audio and text pairs.")
|
29 |
+
|
30 |
+
def compute(self):
|
31 |
+
raise NotImplementedError("implement how to compute the final metric score.")
|
32 |
+
|
33 |
+
|
34 |
+
class CLAPTextConsistencyMetric(TextConsistencyMetric):
|
35 |
+
"""Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP).
|
36 |
+
|
37 |
+
This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf)
|
38 |
+
or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf).
|
39 |
+
|
40 |
+
As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the
|
41 |
+
similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as
|
42 |
+
well as the generated audio based on them, and define the MCC metric as the average cosine similarity
|
43 |
+
between these embeddings.
|
44 |
+
|
45 |
+
Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP
|
46 |
+
"""
|
47 |
+
def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False):
|
48 |
+
super().__init__()
|
49 |
+
if laion_clap is None:
|
50 |
+
raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'")
|
51 |
+
self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
52 |
+
self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
|
53 |
+
self._initialize_model(model_path, model_arch, enable_fusion)
|
54 |
+
|
55 |
+
def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool):
|
56 |
+
model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
|
57 |
+
self.tokenize = RobertaTokenizer.from_pretrained('roberta-base')
|
58 |
+
self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
|
59 |
+
self.model_sample_rate = 48_000
|
60 |
+
load_clap_state_dict(self.model, model_path)
|
61 |
+
self.model.eval()
|
62 |
+
|
63 |
+
def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
|
64 |
+
# we use the default params from CLAP module here as well
|
65 |
+
return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
|
66 |
+
|
67 |
+
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
|
68 |
+
"""Compute cosine similarity between audio and text pairs and accumulate scores over the dataset."""
|
69 |
+
assert audio.size(0) == len(text), "Number of audio and text samples should match"
|
70 |
+
assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate"
|
71 |
+
sample_rate = int(sample_rates[0].item())
|
72 |
+
# convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T]
|
73 |
+
audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1)
|
74 |
+
audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True)
|
75 |
+
text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
|
76 |
+
# cosine similarity between the text and the audio embedding
|
77 |
+
cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8)
|
78 |
+
self.cosine_sum += cosine_sim.sum(dim=0)
|
79 |
+
self.weight += torch.tensor(cosine_sim.size(0))
|
80 |
+
|
81 |
+
def compute(self):
|
82 |
+
"""Computes the average cosine similarty across all audio/text pairs."""
|
83 |
+
assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
|
84 |
+
return (self.cosine_sum / self.weight).item() # type: ignore
|
audiocraft/metrics/fad.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from pathlib import Path
|
9 |
+
import os
|
10 |
+
import subprocess
|
11 |
+
import tempfile
|
12 |
+
import typing as tp
|
13 |
+
|
14 |
+
from audiocraft.data.audio import audio_write
|
15 |
+
from audiocraft.data.audio_utils import convert_audio
|
16 |
+
import flashy
|
17 |
+
import torch
|
18 |
+
import torchmetrics
|
19 |
+
|
20 |
+
from ..environment import AudioCraftEnvironment
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
VGGISH_SAMPLE_RATE = 16_000
|
26 |
+
VGGISH_CHANNELS = 1
|
27 |
+
|
28 |
+
|
29 |
+
class FrechetAudioDistanceMetric(torchmetrics.Metric):
|
30 |
+
"""Fréchet Audio Distance computation based on official TensorFlow implementation from Google Research.
|
31 |
+
|
32 |
+
From: D.C. Dowson & B.V. Landau The Fréchet distance between
|
33 |
+
multivariate normal distributions
|
34 |
+
https://doi.org/10.1016/0047-259X(82)90077-X
|
35 |
+
The Fréchet distance between two multivariate gaussians,
|
36 |
+
`X ~ N(mu_x, sigma_x)` and `Y ~ N(mu_y, sigma_y)`, is `d^2`.
|
37 |
+
d^2 = (mu_x - mu_y)^2 + Tr(sigma_x + sigma_y - 2 * sqrt(sigma_x*sigma_y))
|
38 |
+
= (mu_x - mu_y)^2 + Tr(sigma_x) + Tr(sigma_y)
|
39 |
+
- 2 * Tr(sqrt(sigma_x*sigma_y)))
|
40 |
+
|
41 |
+
To use this FAD computation metric, you need to have the proper Frechet Audio Distance tool setup
|
42 |
+
from: https://github.com/google-research/google-research/tree/master/frechet_audio_distance
|
43 |
+
We provide the below instructions as reference but we do not guarantee for further support
|
44 |
+
in frechet_audio_distance installation. This was tested with python 3.10, cuda 11.8, tensorflow 2.12.0.
|
45 |
+
|
46 |
+
We recommend installing the frechet_audio_distance library in a dedicated env (e.g. conda).
|
47 |
+
|
48 |
+
1. Get the code and models following the repository instructions. We used the steps below:
|
49 |
+
git clone git@github.com:google-research/google-research.git
|
50 |
+
git clone git@github.com:tensorflow/models.git
|
51 |
+
mkdir google-research/tensorflow_models
|
52 |
+
touch google-research/tensorflow_models/__init__.py
|
53 |
+
cp -r models/research/audioset google-research/tensorflow_models/
|
54 |
+
touch google-research/tensorflow_models/audioset/__init__.py
|
55 |
+
echo "from .vggish import mel_features, vggish_params, vggish_slim" > \
|
56 |
+
google-research/tensorflow_models/audioset/__init__.py
|
57 |
+
# we can now remove the tensorflow models repository
|
58 |
+
# rm -r models
|
59 |
+
cd google-research
|
60 |
+
Follow the instructions to download the vggish checkpoint. AudioCraft base configuration
|
61 |
+
assumes it is placed in the AudioCraft reference dir.
|
62 |
+
|
63 |
+
Note that we operate the following changes for the code to work with TensorFlow 2.X and python 3:
|
64 |
+
- Update xrange for range in:
|
65 |
+
https://github.com/google-research/google-research/blob/master/frechet_audio_distance/audioset_model.py
|
66 |
+
- Update `tf_record = tf.python_io.tf_record_iterator(filename).next()` to
|
67 |
+
`tf_record = tf.python_io.tf_record_iterator(filename).__next__()` in
|
68 |
+
https://github.com/google-research/google-research/blob/master/frechet_audio_distance/fad_utils.py
|
69 |
+
- Update `import vggish_params as params` to `from . import vggish_params as params` in:
|
70 |
+
https://github.com/tensorflow/models/blob/master/research/audioset/vggish/vggish_slim.py
|
71 |
+
- Add flag to provide a given batch size for running the AudioSet model in:
|
72 |
+
https://github.com/google-research/google-research/blob/master/frechet_audio_distance/create_embeddings_main.py
|
73 |
+
```
|
74 |
+
flags.DEFINE_integer('batch_size', 64,
|
75 |
+
'Number of samples in the batch for AudioSet model.')
|
76 |
+
```
|
77 |
+
Ensure you pass the flag to the create_embeddings_beam.create_pipeline function, adding:
|
78 |
+
`batch_size=FLAGS.batch_size` to the provided parameters.
|
79 |
+
|
80 |
+
2. Follow instructions for the library installation and a valid TensorFlow installation
|
81 |
+
```
|
82 |
+
# e.g. instructions from: https://www.tensorflow.org/install/pip
|
83 |
+
conda install -c conda-forge cudatoolkit=11.8.0
|
84 |
+
python3 -m pip install nvidia-cudnn-cu11==8.6.0.163 tensorflow==2.12.*
|
85 |
+
mkdir -p $CONDA_PREFIX/etc/conda/activate.d
|
86 |
+
echo 'CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))' \
|
87 |
+
>> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
88 |
+
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/:$CUDNN_PATH/lib' \
|
89 |
+
>> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
90 |
+
source $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
91 |
+
# Verify install: on a machine with GPU device
|
92 |
+
python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"
|
93 |
+
```
|
94 |
+
|
95 |
+
Now install frechet_audio_distance required dependencies:
|
96 |
+
```
|
97 |
+
# We assume we already have TensorFlow installed from the above steps
|
98 |
+
pip install apache-beam numpy scipy tf_slim
|
99 |
+
```
|
100 |
+
|
101 |
+
Finally, follow remaining library instructions to ensure you have a working frechet_audio_distance setup
|
102 |
+
(you may want to specify --model_ckpt flag pointing to the model's path).
|
103 |
+
|
104 |
+
3. AudioCraft's FrechetAudioDistanceMetric requires 2 environment variables pointing to the python executable
|
105 |
+
and Tensorflow library path from the above installation steps:
|
106 |
+
export TF_PYTHON_EXE="<PATH_TO_THE_ENV_PYTHON_BINARY>"
|
107 |
+
export TF_LIBRARY_PATH="<PATH_TO_THE_ENV_CUDNN_LIBRARY>"
|
108 |
+
|
109 |
+
e.g. assuming we have installed everything in a dedicated conda env
|
110 |
+
with python 3.10 that is currently active:
|
111 |
+
export TF_PYTHON_EXE="$CONDA_PREFIX/bin/python"
|
112 |
+
export TF_LIBRARY_PATH="$CONDA_PREFIX/lib/python3.10/site-packages/nvidia/cudnn/lib"
|
113 |
+
|
114 |
+
Finally you may want to export the following variable:
|
115 |
+
export TF_FORCE_GPU_ALLOW_GROWTH=true
|
116 |
+
See: https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth
|
117 |
+
|
118 |
+
You can save those environment variables in your training conda env, when currently active:
|
119 |
+
`$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh`
|
120 |
+
e.g. assuming the env with TensorFlow and frechet_audio_distance install is named ac_eval,
|
121 |
+
and the training conda env is named audiocraft:
|
122 |
+
```
|
123 |
+
# activate training env
|
124 |
+
conda activate audiocraft
|
125 |
+
# get path to all envs
|
126 |
+
CONDA_ENV_DIR=$(dirname $CONDA_PREFIX)
|
127 |
+
# export pointers to evaluation env for using TensorFlow in FrechetAudioDistanceMetric
|
128 |
+
touch $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
129 |
+
echo 'export TF_PYTHON_EXE="$CONDA_ENV_DIR/ac_eval/bin/python"' >> \
|
130 |
+
$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
131 |
+
echo 'export TF_LIBRARY_PATH="$CONDA_ENV_DIR/ac_eval/lib/python3.10/site-packages/nvidia/cudnn/lib"' >> \
|
132 |
+
$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
133 |
+
# optionally:
|
134 |
+
echo 'export TF_FORCE_GPU_ALLOW_GROWTH=true' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
135 |
+
# you may need to reactivate the audiocraft env for this to take effect
|
136 |
+
```
|
137 |
+
|
138 |
+
Args:
|
139 |
+
bin (Path or str): Path to installed frechet audio distance code.
|
140 |
+
model_path (Path or str): Path to Tensorflow checkpoint for the model
|
141 |
+
used to compute statistics over the embedding beams.
|
142 |
+
format (str): Audio format used to save files.
|
143 |
+
log_folder (Path or str, optional): Path where to write process logs.
|
144 |
+
"""
|
145 |
+
def __init__(self, bin: tp.Union[Path, str], model_path: tp.Union[Path, str],
|
146 |
+
format: str = "wav", batch_size: tp.Optional[int] = None,
|
147 |
+
log_folder: tp.Optional[tp.Union[Path, str]] = None):
|
148 |
+
super().__init__()
|
149 |
+
self.model_sample_rate = VGGISH_SAMPLE_RATE
|
150 |
+
self.model_channels = VGGISH_CHANNELS
|
151 |
+
self.model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
|
152 |
+
assert Path(self.model_path).exists(), f"Could not find provided model checkpoint path at: {self.model_path}"
|
153 |
+
self.format = format
|
154 |
+
self.batch_size = batch_size
|
155 |
+
self.bin = bin
|
156 |
+
self.tf_env = {"PYTHONPATH": str(self.bin)}
|
157 |
+
self.python_path = os.environ.get('TF_PYTHON_EXE') or 'python'
|
158 |
+
logger.info("Python exe for TF is %s", self.python_path)
|
159 |
+
if 'TF_LIBRARY_PATH' in os.environ:
|
160 |
+
self.tf_env['LD_LIBRARY_PATH'] = os.environ['TF_LIBRARY_PATH']
|
161 |
+
if 'TF_FORCE_GPU_ALLOW_GROWTH' in os.environ:
|
162 |
+
self.tf_env['TF_FORCE_GPU_ALLOW_GROWTH'] = os.environ['TF_FORCE_GPU_ALLOW_GROWTH']
|
163 |
+
logger.info("Env for TF is %r", self.tf_env)
|
164 |
+
self.reset(log_folder)
|
165 |
+
self.add_state("total_files", default=torch.tensor(0.), dist_reduce_fx="sum")
|
166 |
+
|
167 |
+
def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None):
|
168 |
+
"""Reset torchmetrics.Metrics state."""
|
169 |
+
log_folder = Path(log_folder or tempfile.mkdtemp())
|
170 |
+
self.tmp_dir = log_folder / 'fad'
|
171 |
+
self.tmp_dir.mkdir(exist_ok=True)
|
172 |
+
self.samples_tests_dir = self.tmp_dir / 'tests'
|
173 |
+
self.samples_tests_dir.mkdir(exist_ok=True)
|
174 |
+
self.samples_background_dir = self.tmp_dir / 'background'
|
175 |
+
self.samples_background_dir.mkdir(exist_ok=True)
|
176 |
+
self.manifest_tests = self.tmp_dir / 'files_tests.cvs'
|
177 |
+
self.manifest_background = self.tmp_dir / 'files_background.cvs'
|
178 |
+
self.stats_tests_dir = self.tmp_dir / 'stats_tests'
|
179 |
+
self.stats_background_dir = self.tmp_dir / 'stats_background'
|
180 |
+
self.counter = 0
|
181 |
+
|
182 |
+
def update(self, preds: torch.Tensor, targets: torch.Tensor,
|
183 |
+
sizes: torch.Tensor, sample_rates: torch.Tensor,
|
184 |
+
stems: tp.Optional[tp.List[str]] = None):
|
185 |
+
"""Update torchmetrics.Metrics by saving the audio and updating the manifest file."""
|
186 |
+
assert preds.shape == targets.shape, f"preds={preds.shape} != targets={targets.shape}"
|
187 |
+
num_samples = preds.shape[0]
|
188 |
+
assert num_samples == sizes.size(0) and num_samples == sample_rates.size(0)
|
189 |
+
assert stems is None or num_samples == len(set(stems))
|
190 |
+
for i in range(num_samples):
|
191 |
+
self.total_files += 1 # type: ignore
|
192 |
+
self.counter += 1
|
193 |
+
wav_len = int(sizes[i].item())
|
194 |
+
sample_rate = int(sample_rates[i].item())
|
195 |
+
pred_wav = preds[i]
|
196 |
+
target_wav = targets[i]
|
197 |
+
pred_wav = pred_wav[..., :wav_len]
|
198 |
+
target_wav = target_wav[..., :wav_len]
|
199 |
+
stem_name = stems[i] if stems is not None else f'sample_{self.counter}_{flashy.distrib.rank()}'
|
200 |
+
# dump audio files
|
201 |
+
try:
|
202 |
+
pred_wav = convert_audio(
|
203 |
+
pred_wav.unsqueeze(0), from_rate=sample_rate,
|
204 |
+
to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
|
205 |
+
audio_write(
|
206 |
+
self.samples_tests_dir / stem_name, pred_wav, sample_rate=self.model_sample_rate,
|
207 |
+
format=self.format, strategy="peak")
|
208 |
+
except Exception as e:
|
209 |
+
logger.error(f"Exception occured when saving tests files for FAD computation: {repr(e)} - {e}")
|
210 |
+
try:
|
211 |
+
# for the ground truth audio, we enforce the 'peak' strategy to avoid modifying
|
212 |
+
# the original audio when writing it
|
213 |
+
target_wav = convert_audio(
|
214 |
+
target_wav.unsqueeze(0), from_rate=sample_rate,
|
215 |
+
to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
|
216 |
+
audio_write(
|
217 |
+
self.samples_background_dir / stem_name, target_wav, sample_rate=self.model_sample_rate,
|
218 |
+
format=self.format, strategy="peak")
|
219 |
+
except Exception as e:
|
220 |
+
logger.error(f"Exception occured when saving background files for FAD computation: {repr(e)} - {e}")
|
221 |
+
|
222 |
+
def _get_samples_name(self, is_background: bool):
|
223 |
+
return 'background' if is_background else 'tests'
|
224 |
+
|
225 |
+
def _create_embedding_beams(self, is_background: bool, gpu_index: tp.Optional[int] = None):
|
226 |
+
if is_background:
|
227 |
+
input_samples_dir = self.samples_background_dir
|
228 |
+
input_filename = self.manifest_background
|
229 |
+
stats_name = self.stats_background_dir
|
230 |
+
else:
|
231 |
+
input_samples_dir = self.samples_tests_dir
|
232 |
+
input_filename = self.manifest_tests
|
233 |
+
stats_name = self.stats_tests_dir
|
234 |
+
beams_name = self._get_samples_name(is_background)
|
235 |
+
log_file = self.tmp_dir / f'fad_logs_create_beams_{beams_name}.log'
|
236 |
+
|
237 |
+
logger.info(f"Scanning samples folder to fetch list of files: {input_samples_dir}")
|
238 |
+
with open(input_filename, "w") as fout:
|
239 |
+
for path in Path(input_samples_dir).glob(f"*.{self.format}"):
|
240 |
+
fout.write(f"{str(path)}\n")
|
241 |
+
|
242 |
+
cmd = [
|
243 |
+
self.python_path, "-m",
|
244 |
+
"frechet_audio_distance.create_embeddings_main",
|
245 |
+
"--model_ckpt", f"{self.model_path}",
|
246 |
+
"--input_files", f"{str(input_filename)}",
|
247 |
+
"--stats", f"{str(stats_name)}",
|
248 |
+
]
|
249 |
+
if self.batch_size is not None:
|
250 |
+
cmd += ["--batch_size", str(self.batch_size)]
|
251 |
+
logger.info(f"Launching frechet_audio_distance embeddings main method: {' '.join(cmd)} on {beams_name}")
|
252 |
+
env = os.environ
|
253 |
+
if gpu_index is not None:
|
254 |
+
env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
|
255 |
+
process = subprocess.Popen(
|
256 |
+
cmd, stdout=open(log_file, "w"), env={**env, **self.tf_env}, stderr=subprocess.STDOUT)
|
257 |
+
return process, log_file
|
258 |
+
|
259 |
+
def _compute_fad_score(self, gpu_index: tp.Optional[int] = None):
|
260 |
+
cmd = [
|
261 |
+
self.python_path, "-m", "frechet_audio_distance.compute_fad",
|
262 |
+
"--test_stats", f"{str(self.stats_tests_dir)}",
|
263 |
+
"--background_stats", f"{str(self.stats_background_dir)}",
|
264 |
+
]
|
265 |
+
logger.info(f"Launching frechet_audio_distance compute fad method: {' '.join(cmd)}")
|
266 |
+
env = os.environ
|
267 |
+
if gpu_index is not None:
|
268 |
+
env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
|
269 |
+
result = subprocess.run(cmd, env={**env, **self.tf_env}, capture_output=True)
|
270 |
+
if result.returncode:
|
271 |
+
logger.error(
|
272 |
+
"Error with FAD computation from stats: \n %s \n %s",
|
273 |
+
result.stdout.decode(), result.stderr.decode()
|
274 |
+
)
|
275 |
+
raise RuntimeError("Error while executing FAD computation from stats")
|
276 |
+
try:
|
277 |
+
# result is "FAD: (d+).(d+)" hence we remove the prefix with (d+) being one digit or more
|
278 |
+
fad_score = float(result.stdout[4:])
|
279 |
+
return fad_score
|
280 |
+
except Exception as e:
|
281 |
+
raise RuntimeError(f"Error parsing FAD score from command stdout: {e}")
|
282 |
+
|
283 |
+
def _log_process_result(self, returncode: int, log_file: tp.Union[Path, str], is_background: bool) -> None:
|
284 |
+
beams_name = self._get_samples_name(is_background)
|
285 |
+
if returncode:
|
286 |
+
with open(log_file, "r") as f:
|
287 |
+
error_log = f.read()
|
288 |
+
logger.error(error_log)
|
289 |
+
os._exit(1)
|
290 |
+
else:
|
291 |
+
logger.info(f"Successfully computed embedding beams on {beams_name} samples.")
|
292 |
+
|
293 |
+
def _parallel_create_embedding_beams(self, num_of_gpus: int):
|
294 |
+
assert num_of_gpus > 0
|
295 |
+
logger.info("Creating embeddings beams in a parallel manner on different GPUs")
|
296 |
+
tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False, gpu_index=0)
|
297 |
+
bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True, gpu_index=1)
|
298 |
+
tests_beams_code = tests_beams_process.wait()
|
299 |
+
bg_beams_code = bg_beams_process.wait()
|
300 |
+
self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
|
301 |
+
self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
|
302 |
+
|
303 |
+
def _sequential_create_embedding_beams(self):
|
304 |
+
logger.info("Creating embeddings beams in a sequential manner")
|
305 |
+
tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False)
|
306 |
+
tests_beams_code = tests_beams_process.wait()
|
307 |
+
self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
|
308 |
+
bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True)
|
309 |
+
bg_beams_code = bg_beams_process.wait()
|
310 |
+
self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
|
311 |
+
|
312 |
+
@flashy.distrib.rank_zero_only
|
313 |
+
def _local_compute_frechet_audio_distance(self):
|
314 |
+
"""Compute Frechet Audio Distance score calling TensorFlow API."""
|
315 |
+
num_of_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
|
316 |
+
if num_of_gpus > 1:
|
317 |
+
self._parallel_create_embedding_beams(num_of_gpus)
|
318 |
+
else:
|
319 |
+
self._sequential_create_embedding_beams()
|
320 |
+
fad_score = self._compute_fad_score(gpu_index=0)
|
321 |
+
return fad_score
|
322 |
+
|
323 |
+
def compute(self) -> float:
|
324 |
+
"""Compute metrics."""
|
325 |
+
assert self.total_files.item() > 0, "No files dumped for FAD computation!" # type: ignore
|
326 |
+
fad_score = self._local_compute_frechet_audio_distance()
|
327 |
+
logger.warning(f"FAD score = {fad_score}")
|
328 |
+
fad_score = flashy.distrib.broadcast_object(fad_score, src=0)
|
329 |
+
return fad_score
|
audiocraft/metrics/kld.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import contextlib
|
8 |
+
from functools import partial
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import typing as tp
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torchmetrics
|
15 |
+
|
16 |
+
from ..data.audio_utils import convert_audio
|
17 |
+
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class _patch_passt_stft:
|
23 |
+
"""Decorator to patch torch.stft in PaSST."""
|
24 |
+
def __init__(self):
|
25 |
+
self.old_stft = torch.stft
|
26 |
+
|
27 |
+
def __enter__(self):
|
28 |
+
# return_complex is a mandatory parameter in latest torch versions
|
29 |
+
# torch is throwing RuntimeErrors when not set
|
30 |
+
torch.stft = partial(torch.stft, return_complex=False)
|
31 |
+
|
32 |
+
def __exit__(self, *exc):
|
33 |
+
torch.stft = self.old_stft
|
34 |
+
|
35 |
+
|
36 |
+
def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor:
|
37 |
+
"""Computes the elementwise KL-Divergence loss between probability distributions
|
38 |
+
from generated samples and target samples.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
pred_probs (torch.Tensor): Probabilities for each label obtained
|
42 |
+
from a classifier on generated audio. Expected shape is [B, num_classes].
|
43 |
+
target_probs (torch.Tensor): Probabilities for each label obtained
|
44 |
+
from a classifier on target audio. Expected shape is [B, num_classes].
|
45 |
+
epsilon (float): Epsilon value.
|
46 |
+
Returns:
|
47 |
+
kld (torch.Tensor): KLD loss between each generated sample and target pair.
|
48 |
+
"""
|
49 |
+
kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none")
|
50 |
+
return kl_div.sum(-1)
|
51 |
+
|
52 |
+
|
53 |
+
class KLDivergenceMetric(torchmetrics.Metric):
|
54 |
+
"""Base implementation for KL Divergence metric.
|
55 |
+
|
56 |
+
The KL divergence is measured between probability distributions
|
57 |
+
of class predictions returned by a pre-trained audio classification model.
|
58 |
+
When the KL-divergence is low, the generated audio is expected to
|
59 |
+
have similar acoustic characteristics as the reference audio,
|
60 |
+
according to the classifier.
|
61 |
+
"""
|
62 |
+
def __init__(self):
|
63 |
+
super().__init__()
|
64 |
+
self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
65 |
+
self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
66 |
+
self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
67 |
+
self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum")
|
68 |
+
|
69 |
+
def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
|
70 |
+
sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
|
71 |
+
"""Get model output given provided input tensor.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
x (torch.Tensor): Input audio tensor of shape [B, C, T].
|
75 |
+
sizes (torch.Tensor): Actual audio sample length, of shape [B].
|
76 |
+
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
|
77 |
+
Returns:
|
78 |
+
probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes].
|
79 |
+
"""
|
80 |
+
raise NotImplementedError("implement method to extract label distributions from the model.")
|
81 |
+
|
82 |
+
def update(self, preds: torch.Tensor, targets: torch.Tensor,
|
83 |
+
sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
|
84 |
+
"""Calculates running KL-Divergence loss between batches of audio
|
85 |
+
preds (generated) and target (ground-truth)
|
86 |
+
Args:
|
87 |
+
preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T].
|
88 |
+
targets (torch.Tensor): Target samples to compare against, of shape [B, C, T].
|
89 |
+
sizes (torch.Tensor): Actual audio sample length, of shape [B].
|
90 |
+
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
|
91 |
+
"""
|
92 |
+
assert preds.shape == targets.shape
|
93 |
+
assert preds.size(0) > 0, "Cannot update the loss with empty tensors"
|
94 |
+
preds_probs = self._get_label_distribution(preds, sizes, sample_rates)
|
95 |
+
targets_probs = self._get_label_distribution(targets, sizes, sample_rates)
|
96 |
+
if preds_probs is not None and targets_probs is not None:
|
97 |
+
assert preds_probs.shape == targets_probs.shape
|
98 |
+
kld_scores = kl_divergence(preds_probs, targets_probs)
|
99 |
+
assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!"
|
100 |
+
self.kld_pq_sum += torch.sum(kld_scores)
|
101 |
+
kld_qp_scores = kl_divergence(targets_probs, preds_probs)
|
102 |
+
self.kld_qp_sum += torch.sum(kld_qp_scores)
|
103 |
+
self.weight += torch.tensor(kld_scores.size(0))
|
104 |
+
|
105 |
+
def compute(self) -> dict:
|
106 |
+
"""Computes KL-Divergence across all evaluated pred/target pairs."""
|
107 |
+
weight: float = float(self.weight.item()) # type: ignore
|
108 |
+
assert weight > 0, "Unable to compute with total number of comparisons <= 0"
|
109 |
+
logger.info(f"Computing KL divergence on a total of {weight} samples")
|
110 |
+
kld_pq = self.kld_pq_sum.item() / weight # type: ignore
|
111 |
+
kld_qp = self.kld_qp_sum.item() / weight # type: ignore
|
112 |
+
kld_both = kld_pq + kld_qp
|
113 |
+
return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both}
|
114 |
+
|
115 |
+
|
116 |
+
class PasstKLDivergenceMetric(KLDivergenceMetric):
|
117 |
+
"""KL-Divergence metric based on pre-trained PASST classifier on AudioSet.
|
118 |
+
|
119 |
+
From: PaSST: Efficient Training of Audio Transformers with Patchout
|
120 |
+
Paper: https://arxiv.org/abs/2110.05069
|
121 |
+
Implementation: https://github.com/kkoutini/PaSST
|
122 |
+
|
123 |
+
Follow instructions from the github repo:
|
124 |
+
```
|
125 |
+
pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'
|
126 |
+
```
|
127 |
+
|
128 |
+
Args:
|
129 |
+
pretrained_length (float, optional): Audio duration used for the pretrained model.
|
130 |
+
"""
|
131 |
+
def __init__(self, pretrained_length: tp.Optional[float] = None):
|
132 |
+
super().__init__()
|
133 |
+
self._initialize_model(pretrained_length)
|
134 |
+
|
135 |
+
def _initialize_model(self, pretrained_length: tp.Optional[float] = None):
|
136 |
+
"""Initialize underlying PaSST audio classifier."""
|
137 |
+
model, sr, max_frames, min_frames = self._load_base_model(pretrained_length)
|
138 |
+
self.min_input_frames = min_frames
|
139 |
+
self.max_input_frames = max_frames
|
140 |
+
self.model_sample_rate = sr
|
141 |
+
self.model = model
|
142 |
+
self.model.eval()
|
143 |
+
self.model.to(self.device)
|
144 |
+
|
145 |
+
def _load_base_model(self, pretrained_length: tp.Optional[float]):
|
146 |
+
"""Load pretrained model from PaSST."""
|
147 |
+
try:
|
148 |
+
if pretrained_length == 30:
|
149 |
+
from hear21passt.base30sec import get_basic_model # type: ignore
|
150 |
+
max_duration = 30
|
151 |
+
elif pretrained_length == 20:
|
152 |
+
from hear21passt.base20sec import get_basic_model # type: ignore
|
153 |
+
max_duration = 20
|
154 |
+
else:
|
155 |
+
from hear21passt.base import get_basic_model # type: ignore
|
156 |
+
# Original PASST was trained on AudioSet with 10s-long audio samples
|
157 |
+
max_duration = 10
|
158 |
+
min_duration = 0.15
|
159 |
+
min_duration = 0.15
|
160 |
+
except ModuleNotFoundError:
|
161 |
+
raise ModuleNotFoundError(
|
162 |
+
"Please install hear21passt to compute KL divergence: ",
|
163 |
+
"pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'"
|
164 |
+
)
|
165 |
+
model_sample_rate = 32_000
|
166 |
+
max_input_frames = int(max_duration * model_sample_rate)
|
167 |
+
min_input_frames = int(min_duration * model_sample_rate)
|
168 |
+
with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f):
|
169 |
+
model = get_basic_model(mode='logits')
|
170 |
+
return model, model_sample_rate, max_input_frames, min_input_frames
|
171 |
+
|
172 |
+
def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.List[torch.Tensor]:
|
173 |
+
"""Process audio to feed to the pretrained model."""
|
174 |
+
wav = wav.unsqueeze(0)
|
175 |
+
wav = wav[..., :wav_len]
|
176 |
+
wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1)
|
177 |
+
wav = wav.squeeze(0)
|
178 |
+
# we don't pad but return a list of audio segments as this otherwise affects the KLD computation
|
179 |
+
segments = torch.split(wav, self.max_input_frames, dim=-1)
|
180 |
+
valid_segments = []
|
181 |
+
for s in segments:
|
182 |
+
# ignoring too small segments that are breaking the model inference
|
183 |
+
if s.size(-1) > self.min_input_frames:
|
184 |
+
valid_segments.append(s)
|
185 |
+
return [s[None] for s in valid_segments]
|
186 |
+
|
187 |
+
def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor:
|
188 |
+
"""Run the pretrained model and get the predictions."""
|
189 |
+
assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}"
|
190 |
+
wav = wav.mean(dim=1)
|
191 |
+
# PaSST is printing a lot of garbage that we are not interested in
|
192 |
+
with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
|
193 |
+
with torch.no_grad(), _patch_passt_stft():
|
194 |
+
logits = self.model(wav.to(self.device))
|
195 |
+
probs = torch.softmax(logits, dim=-1)
|
196 |
+
return probs
|
197 |
+
|
198 |
+
def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
|
199 |
+
sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
|
200 |
+
"""Get model output given provided input tensor.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
x (torch.Tensor): Input audio tensor of shape [B, C, T].
|
204 |
+
sizes (torch.Tensor): Actual audio sample length, of shape [B].
|
205 |
+
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
|
206 |
+
Returns:
|
207 |
+
probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes].
|
208 |
+
"""
|
209 |
+
all_probs: tp.List[torch.Tensor] = []
|
210 |
+
for i, wav in enumerate(x):
|
211 |
+
sample_rate = int(sample_rates[i].item())
|
212 |
+
wav_len = int(sizes[i].item())
|
213 |
+
wav_segments = self._process_audio(wav, sample_rate, wav_len)
|
214 |
+
for segment in wav_segments:
|
215 |
+
probs = self._get_model_preds(segment).mean(dim=0)
|
216 |
+
all_probs.append(probs)
|
217 |
+
if len(all_probs) > 0:
|
218 |
+
return torch.stack(all_probs, dim=0)
|
219 |
+
else:
|
220 |
+
return None
|
audiocraft/metrics/rvm.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
import torchaudio
|
11 |
+
|
12 |
+
|
13 |
+
def db_to_scale(volume: tp.Union[float, torch.Tensor]):
|
14 |
+
return 10 ** (volume / 20)
|
15 |
+
|
16 |
+
|
17 |
+
def scale_to_db(scale: torch.Tensor, min_volume: float = -120):
|
18 |
+
min_scale = db_to_scale(min_volume)
|
19 |
+
return 20 * torch.log10(scale.clamp(min=min_scale))
|
20 |
+
|
21 |
+
|
22 |
+
class RelativeVolumeMel(nn.Module):
|
23 |
+
"""Relative volume melspectrogram measure.
|
24 |
+
|
25 |
+
Computes a measure of distance over two mel spectrogram that is interpretable in terms
|
26 |
+
of decibels. Given `x_ref` and `x_est` two waveforms of shape `[*, T]`, it will
|
27 |
+
first renormalize both by the ground truth of `x_ref`.
|
28 |
+
|
29 |
+
..Warning:: This class returns the volume of the distortion at the spectrogram level,
|
30 |
+
e.g. low negative values reflects lower distortion levels. For a SNR (like reported
|
31 |
+
in the MultiBandDiffusion paper), just take `-rvm`.
|
32 |
+
|
33 |
+
Then it computes the mel spectrogram `z_ref` and `z_est` and compute volume of the difference
|
34 |
+
relative to the volume of `z_ref` for each time-frequency bin. It further adds some limits, e.g.
|
35 |
+
clamping the values between -25 and 25 dB (controlled by `min_relative_volume` and `max_relative_volume`)
|
36 |
+
with the goal of avoiding the loss being dominated by parts where the reference is almost silent.
|
37 |
+
Indeed, volumes in dB can take unbounded values both towards -oo and +oo, which can make the final
|
38 |
+
average metric harder to interpret. Besides, anything below -30 dB of attenuation would sound extremely
|
39 |
+
good (for a neural network output, although sound engineers typically aim for much lower attenuations).
|
40 |
+
Similarly, anything above +30 dB would just be completely missing the target, and there is no point
|
41 |
+
in measuring by exactly how much it missed it. -25, 25 is a more conservative range, but also more
|
42 |
+
in line with what neural nets currently can achieve.
|
43 |
+
|
44 |
+
For instance, a Relative Volume Mel (RVM) score of -10 dB means that on average, the delta between
|
45 |
+
the target and reference mel-spec is 10 dB lower than the reference mel-spec value.
|
46 |
+
|
47 |
+
The metric can be aggregated over a given frequency band in order have different insights for
|
48 |
+
different region of the spectrum. `num_aggregated_bands` controls the number of bands.
|
49 |
+
|
50 |
+
..Warning:: While this function is optimized for interpretability, nothing was done to ensure it
|
51 |
+
is numerically stable when computing its gradient. We thus advise against using it as a training loss.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
sample_rate (int): Sample rate of the input audio.
|
55 |
+
n_mels (int): Number of mel bands to use.
|
56 |
+
n_fft (int): Number of frequency bins for the STFT.
|
57 |
+
hop_length (int): Hop length of the STFT and the mel-spectrogram.
|
58 |
+
min_relative_volume (float): The error `z_ref - z_est` volume is given relative to
|
59 |
+
the volume of `z_ref`. If error is smaller than -25 dB of `z_ref`, then it is clamped.
|
60 |
+
max_relative_volume (float): Same as `min_relative_volume` but clamping if the error is larger than that.
|
61 |
+
max_initial_gain (float): When rescaling the audio at the very beginning, we will limit the gain
|
62 |
+
to that amount, to avoid rescaling near silence. Given in dB.
|
63 |
+
min_activity_volume (float): When computing the reference level from `z_ref`, will clamp low volume
|
64 |
+
bins to that amount. This is effectively our "zero" level for the reference mel-spectrogram,
|
65 |
+
and anything below that will be considered equally.
|
66 |
+
num_aggregated_bands (int): Number of bands to keep when computing the average RVM value.
|
67 |
+
For instance, a value of 3 would give 3 scores, roughly for low, mid and high freqs.
|
68 |
+
"""
|
69 |
+
def __init__(self, sample_rate: int = 24000, n_mels: int = 80, n_fft: int = 512,
|
70 |
+
hop_length: int = 128, min_relative_volume: float = -25,
|
71 |
+
max_relative_volume: float = 25, max_initial_gain: float = 25,
|
72 |
+
min_activity_volume: float = -25,
|
73 |
+
num_aggregated_bands: int = 4) -> None:
|
74 |
+
super().__init__()
|
75 |
+
self.melspec = torchaudio.transforms.MelSpectrogram(
|
76 |
+
n_mels=n_mels, n_fft=n_fft, hop_length=hop_length,
|
77 |
+
normalized=True, sample_rate=sample_rate, power=2)
|
78 |
+
self.min_relative_volume = min_relative_volume
|
79 |
+
self.max_relative_volume = max_relative_volume
|
80 |
+
self.max_initial_gain = max_initial_gain
|
81 |
+
self.min_activity_volume = min_activity_volume
|
82 |
+
self.num_aggregated_bands = num_aggregated_bands
|
83 |
+
|
84 |
+
def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) -> tp.Dict[str, torch.Tensor]:
|
85 |
+
"""Compute RVM metric between estimate and reference samples.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
estimate (torch.Tensor): Estimate sample.
|
89 |
+
ground_truth (torch.Tensor): Reference sample.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
dict[str, torch.Tensor]: Metrics with keys `rvm` for the overall average, and `rvm_{k}`
|
93 |
+
for the RVM over the k-th band (k=0..num_aggregated_bands - 1).
|
94 |
+
"""
|
95 |
+
min_scale = db_to_scale(-self.max_initial_gain)
|
96 |
+
std = ground_truth.pow(2).mean().sqrt().clamp(min=min_scale)
|
97 |
+
z_gt = self.melspec(ground_truth / std).sqrt()
|
98 |
+
z_est = self.melspec(estimate / std).sqrt()
|
99 |
+
|
100 |
+
delta = z_gt - z_est
|
101 |
+
ref_db = scale_to_db(z_gt, self.min_activity_volume)
|
102 |
+
delta_db = scale_to_db(delta.abs(), min_volume=-120)
|
103 |
+
relative_db = (delta_db - ref_db).clamp(self.min_relative_volume, self.max_relative_volume)
|
104 |
+
dims = list(range(relative_db.dim()))
|
105 |
+
dims.remove(dims[-2])
|
106 |
+
losses_per_band = relative_db.mean(dim=dims)
|
107 |
+
aggregated = [chunk.mean() for chunk in losses_per_band.chunk(self.num_aggregated_bands, dim=0)]
|
108 |
+
metrics = {f'rvm_{index}': value for index, value in enumerate(aggregated)}
|
109 |
+
metrics['rvm'] = losses_per_band.mean()
|
110 |
+
return metrics
|
audiocraft/metrics/visqol.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import csv
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
from pathlib import Path
|
11 |
+
import tempfile
|
12 |
+
import typing as tp
|
13 |
+
import subprocess
|
14 |
+
import shutil
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torchaudio
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class ViSQOL:
|
23 |
+
"""ViSQOL wrapper to run ViSQOL from Python using a pre-installed binary.
|
24 |
+
|
25 |
+
To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the
|
26 |
+
instructions available in the open source repository: https://github.com/google/visqol
|
27 |
+
|
28 |
+
ViSQOL is capable of running in two modes:
|
29 |
+
|
30 |
+
Audio Mode:
|
31 |
+
When running in audio mode, input signals must have a 48kHz sample rate. Input should be resampled to 48kHz.
|
32 |
+
Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
|
33 |
+
Audio mode uses support vector regression, with the maximum range at ~4.75.
|
34 |
+
|
35 |
+
Speech Mode:
|
36 |
+
When running in speech mode, ViSQOL uses a wideband model. It therefore expects input sample rates of 16kHz.
|
37 |
+
Input should be resampled to 16kHz.
|
38 |
+
As part of the speech mode processing, a root mean square implementation for voice activity detection
|
39 |
+
is performed on the reference signal to determine what parts of the signal have voice activity and
|
40 |
+
should therefore be included in the comparison. The signal is normalized before performing the voice
|
41 |
+
activity detection.
|
42 |
+
Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
|
43 |
+
Speech mode is scaled to have a maximum MOS of 5.0 to match previous version behavior.
|
44 |
+
|
45 |
+
For more details, check the guidelines: https://github.com/google/visqol#general-guidelines-for-input
|
46 |
+
|
47 |
+
Args:
|
48 |
+
visqol_bin (str): Path to the ViSQOL binary.
|
49 |
+
mode (str): ViSQOL computation mode, expecting "audio" or "speech".
|
50 |
+
model (str): Name of the model to use for similarity to quality model.
|
51 |
+
debug (bool): Whether to also get debug metrics from ViSQOL or not.
|
52 |
+
"""
|
53 |
+
SAMPLE_RATES_MODES = {"audio": 48_000, "speech": 16_000}
|
54 |
+
ALLOWED_SAMPLE_RATES = frozenset(SAMPLE_RATES_MODES.values())
|
55 |
+
|
56 |
+
def __init__(self, bin: tp.Union[Path, str], mode: str = "audio",
|
57 |
+
model: str = "libsvm_nu_svr_model.txt", debug: bool = False):
|
58 |
+
assert bin is not None and Path(bin).exists(), f"Could not find ViSQOL binary in specified path: {bin}"
|
59 |
+
self.visqol_bin = str(bin)
|
60 |
+
self.visqol_mode = mode
|
61 |
+
self.target_sr = self._get_target_sr(self.visqol_mode)
|
62 |
+
self.model = model
|
63 |
+
self.debug = debug
|
64 |
+
assert Path(self.visqol_model).exists(), \
|
65 |
+
f"Could not find the specified model in ViSQOL install: {self.visqol_model}"
|
66 |
+
|
67 |
+
def _get_target_sr(self, mode: str) -> int:
|
68 |
+
# returns target sampling rate for the corresponding ViSQOL mode.
|
69 |
+
if mode not in ViSQOL.SAMPLE_RATES_MODES:
|
70 |
+
raise ValueError(
|
71 |
+
f"Unsupported mode! Allowed are: {', '.join(ViSQOL.SAMPLE_RATES_MODES.keys())}"
|
72 |
+
)
|
73 |
+
return ViSQOL.SAMPLE_RATES_MODES[mode]
|
74 |
+
|
75 |
+
def _prepare_files(
|
76 |
+
self, ref_sig: torch.Tensor, deg_sig: torch.Tensor, sr: int, target_sr: int, pad_with_silence: bool = False
|
77 |
+
):
|
78 |
+
# prepare files for ViSQOL evaluation.
|
79 |
+
assert target_sr in ViSQOL.ALLOWED_SAMPLE_RATES
|
80 |
+
assert len(ref_sig) == len(deg_sig), (
|
81 |
+
"Expects same number of ref and degraded inputs",
|
82 |
+
f" but ref len {len(ref_sig)} != deg len {len(deg_sig)}"
|
83 |
+
)
|
84 |
+
# resample audio if needed
|
85 |
+
if sr != target_sr:
|
86 |
+
transform = torchaudio.transforms.Resample(sr, target_sr)
|
87 |
+
pad = int(0.5 * target_sr)
|
88 |
+
rs_ref = []
|
89 |
+
rs_deg = []
|
90 |
+
for i in range(len(ref_sig)):
|
91 |
+
rs_ref_i = transform(ref_sig[i])
|
92 |
+
rs_deg_i = transform(deg_sig[i])
|
93 |
+
if pad_with_silence:
|
94 |
+
rs_ref_i = torch.nn.functional.pad(rs_ref_i, (pad, pad), mode='constant', value=0)
|
95 |
+
rs_deg_i = torch.nn.functional.pad(rs_deg_i, (pad, pad), mode='constant', value=0)
|
96 |
+
rs_ref.append(rs_ref_i)
|
97 |
+
rs_deg.append(rs_deg_i)
|
98 |
+
ref_sig = torch.stack(rs_ref)
|
99 |
+
deg_sig = torch.stack(rs_deg)
|
100 |
+
# save audio chunks to tmp dir and create csv
|
101 |
+
tmp_dir = Path(tempfile.mkdtemp())
|
102 |
+
try:
|
103 |
+
tmp_input_csv_path = tmp_dir / "input.csv"
|
104 |
+
tmp_results_csv_path = tmp_dir / "results.csv"
|
105 |
+
tmp_debug_json_path = tmp_dir / "debug.json"
|
106 |
+
with open(tmp_input_csv_path, "w") as csv_file:
|
107 |
+
csv_writer = csv.writer(csv_file)
|
108 |
+
csv_writer.writerow(["reference", "degraded"])
|
109 |
+
for i in range(len(ref_sig)):
|
110 |
+
tmp_ref_filename = tmp_dir / f"ref_{i}.wav"
|
111 |
+
tmp_deg_filename = tmp_dir / f"deg_{i}.wav"
|
112 |
+
torchaudio.save(
|
113 |
+
tmp_ref_filename,
|
114 |
+
torch.clamp(ref_sig[i], min=-0.99, max=0.99),
|
115 |
+
sample_rate=target_sr,
|
116 |
+
bits_per_sample=16,
|
117 |
+
encoding="PCM_S"
|
118 |
+
)
|
119 |
+
torchaudio.save(
|
120 |
+
tmp_deg_filename,
|
121 |
+
torch.clamp(deg_sig[i], min=-0.99, max=0.99),
|
122 |
+
sample_rate=target_sr,
|
123 |
+
bits_per_sample=16,
|
124 |
+
encoding="PCM_S"
|
125 |
+
)
|
126 |
+
csv_writer.writerow([str(tmp_ref_filename), str(tmp_deg_filename)])
|
127 |
+
return tmp_dir, tmp_input_csv_path, tmp_results_csv_path, tmp_debug_json_path
|
128 |
+
except Exception as e:
|
129 |
+
logger.error("Exception occurred when preparing files for ViSQOL: %s", e)
|
130 |
+
return tmp_dir, None, None, None
|
131 |
+
|
132 |
+
def _flush_files(self, tmp_dir: tp.Union[Path, str]):
|
133 |
+
# flush tmp files used to compute ViSQOL.
|
134 |
+
shutil.rmtree(str(tmp_dir))
|
135 |
+
|
136 |
+
def _collect_moslqo_score(self, results_csv_path: tp.Union[Path, str]) -> float:
|
137 |
+
# collect results for each evaluated pair and return averaged moslqo score.
|
138 |
+
with open(results_csv_path, "r") as csv_file:
|
139 |
+
reader = csv.DictReader(csv_file)
|
140 |
+
moslqo_scores = [float(row["moslqo"]) for row in reader]
|
141 |
+
if len(moslqo_scores) > 0:
|
142 |
+
return sum(moslqo_scores) / len(moslqo_scores)
|
143 |
+
else:
|
144 |
+
return 0.0
|
145 |
+
|
146 |
+
def _collect_debug_data(self, debug_json_path: tp.Union[Path, str]) -> dict:
|
147 |
+
# collect debug data for the visqol inference.
|
148 |
+
with open(debug_json_path, "r") as f:
|
149 |
+
data = json.load(f)
|
150 |
+
return data
|
151 |
+
|
152 |
+
@property
|
153 |
+
def visqol_model(self):
|
154 |
+
return f'{self.visqol_bin}/model/{self.model}'
|
155 |
+
|
156 |
+
def _run_visqol(
|
157 |
+
self,
|
158 |
+
input_csv_path: tp.Union[Path, str],
|
159 |
+
results_csv_path: tp.Union[Path, str],
|
160 |
+
debug_csv_path: tp.Optional[tp.Union[Path, str]],
|
161 |
+
):
|
162 |
+
input_csv_path = str(input_csv_path)
|
163 |
+
results_csv_path = str(results_csv_path)
|
164 |
+
debug_csv_path = str(debug_csv_path)
|
165 |
+
cmd = [
|
166 |
+
f'{self.visqol_bin}/bazel-bin/visqol',
|
167 |
+
'--batch_input_csv', f'{input_csv_path}',
|
168 |
+
'--results_csv', f'{results_csv_path}'
|
169 |
+
]
|
170 |
+
if debug_csv_path is not None:
|
171 |
+
cmd += ['--output_debug', f'{debug_csv_path}']
|
172 |
+
if self.visqol_mode == "speech":
|
173 |
+
cmd += ['--use_speech_mode']
|
174 |
+
cmd += ['--similarity_to_quality_model', f'{self.visqol_model}']
|
175 |
+
result = subprocess.run(cmd, capture_output=True)
|
176 |
+
if result.returncode:
|
177 |
+
logger.error("Error with visqol: \n %s \n %s", result.stdout.decode(), result.stderr.decode())
|
178 |
+
raise RuntimeError("Error while executing visqol")
|
179 |
+
result.check_returncode()
|
180 |
+
|
181 |
+
def __call__(
|
182 |
+
self,
|
183 |
+
ref_sig: torch.Tensor,
|
184 |
+
deg_sig: torch.Tensor,
|
185 |
+
sr: int,
|
186 |
+
pad_with_silence: bool = False,
|
187 |
+
):
|
188 |
+
"""Calculate the ViSQOL metric for a pair of audio signals at a given sample rate.
|
189 |
+
Args:
|
190 |
+
ref_sig (torch.Tensor): Reference signals as [B, C, T].
|
191 |
+
deg_sig (torch.Tensor): Degraded signals as [B, C, T].
|
192 |
+
sr (int): Sample rate of the two audio signals.
|
193 |
+
pad_with_silence (bool): Whether to pad the file with silences as recommended
|
194 |
+
in visqol guidelines (see: https://github.com/google/visqol#general-guidelines-for-input).
|
195 |
+
Returns:
|
196 |
+
float: The ViSQOL score or mean score for the batch.
|
197 |
+
"""
|
198 |
+
logger.debug(f"Calculating visqol with mode={self.visqol_mode} on {len(ref_sig)} samples")
|
199 |
+
tmp_dir, input_csv, results_csv, debug_json = self._prepare_files(
|
200 |
+
ref_sig, deg_sig, sr, self.target_sr, pad_with_silence
|
201 |
+
)
|
202 |
+
try:
|
203 |
+
if input_csv and results_csv:
|
204 |
+
self._run_visqol(
|
205 |
+
input_csv,
|
206 |
+
results_csv,
|
207 |
+
debug_json if self.debug else None,
|
208 |
+
)
|
209 |
+
mosqol = self._collect_moslqo_score(results_csv)
|
210 |
+
return mosqol
|
211 |
+
else:
|
212 |
+
raise RuntimeError("Something unexpected happened when running VISQOL!")
|
213 |
+
except Exception as e:
|
214 |
+
logger.error("Exception occurred when running ViSQOL: %s", e)
|
215 |
+
finally:
|
216 |
+
self._flush_files(tmp_dir)
|